use std::any::Any;
use std::fmt::Display;
use openraft::error::InstallSnapshotError;
use openraft::error::NetworkError;
use openraft::error::RPCError;
use openraft::error::RaftError;
use openraft::error::RemoteError;
use openraft::network::RPCOption;
use openraft::network::RaftNetwork;
use openraft::network::RaftNetworkFactory;
use openraft::raft::AppendEntriesRequest;
use openraft::raft::AppendEntriesResponse;
use openraft::raft::InstallSnapshotRequest;
use openraft::raft::InstallSnapshotResponse;
use openraft::raft::VoteRequest;
use openraft::raft::VoteResponse;
use openraft::AnyError;
use serde::de::DeserializeOwned;
use toy_rpc_ha421::Client;
use super::raft::RaftClientStub;
use crate::Node;
use crate::NodeId;
use crate::TypeConfig;
use rustls::client::danger::ServerCertVerified;
use rustls::client::danger::ServerCertVerifier;
use rustls::{pki_types::CertificateDer, ClientConfig, RootCertStore};
use crate::RSQliteNodeTlsConfig;
use std::net::{IpAddr, ToSocketAddrs};
use std::sync::Arc;
#[derive(Debug)]
struct AllowAnyCertVerifier;
impl ServerCertVerifier for AllowAnyCertVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA1,
rustls::SignatureScheme::ECDSA_SHA1_Legacy,
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::ED448,
]
}
}
pub struct Network {
pub tls_config: Option<RSQliteNodeTlsConfig>,
}
impl RaftNetworkFactory<TypeConfig> for Network {
type Network = NetworkConnection;
#[tracing::instrument(level = "debug", skip_all)]
async fn new_client(&mut self, target: NodeId, node: &Node) -> Self::Network {
if let Some(tls_config) = self.tls_config.as_ref() {
let addr = node.rpc_addr.clone();
let parts: Vec<&str> = addr.split(':').collect();
let host = parts[0];
let port: u16 = parts[1].parse().unwrap();
let (addr, domain) = match host.parse::<IpAddr>() {
Ok(_) => (host.to_string(), host.to_string()),
Err(_) => match (host, port).to_socket_addrs() {
Ok(mut addrs) => match addrs.next() {
Some(addr) => (addr.to_string(), host.to_string()),
None => {
tracing::error!("No address found for {}", host);
(host.to_string(), host.to_string())
}
},
Err(e) => {
tracing::error!("DNS resolution error for {}({})", host, e);
(host.to_string(), host.to_string())
}
},
};
let addr = format!("{}:{}", addr, port);
if tls_config.accept_invalid_certificates {
let root_certs = RootCertStore::empty();
let mut config= ClientConfig::builder()
.with_root_certificates(root_certs)
.with_no_client_auth();
config
.dangerous()
.set_certificate_verifier(Arc::new(AllowAnyCertVerifier));
let client = Client::dial_with_tls_config(&addr, &domain, config)
.await
.ok();
tracing::debug!("new_client: is_none: {}", client.is_none());
NetworkConnection {
addr,
domain: domain.to_string(),
client,
target,
tls_config: self.tls_config.clone(),
}
} else {
let root_certs = RootCertStore::empty();
let config = ClientConfig::builder()
.with_root_certificates(root_certs)
.with_no_client_auth();
let client = Client::dial_with_tls_config(&addr, &domain, config)
.await
.ok();
tracing::debug!("new_client: is_none: {}", client.is_none());
NetworkConnection {
addr,
domain: domain.to_string(),
client,
target,
tls_config: self.tls_config.clone(),
}
}
} else {
let addr = format!("ws://{}", node.rpc_addr);
let client = Client::dial_websocket(&addr).await.ok();
tracing::debug!("new_client: is_none: {}", client.is_none());
NetworkConnection {
addr,
client,
target,
domain: String::default(),
tls_config: self.tls_config.clone(),
}
}
}
}
pub struct NetworkConnection {
addr: String,
domain: String,
client: Option<Client >,
target: NodeId,
tls_config: Option<RSQliteNodeTlsConfig>,
}
impl NetworkConnection {
async fn c<E: std::error::Error + DeserializeOwned>(
&mut self,
) -> Result<&Client , RPCError<NodeId, Node, E>> {
if self.client.is_none() {
if let Some(tls_config) = self.tls_config.as_ref() {
if tls_config.accept_invalid_certificates {
let root_certs = RootCertStore::empty();
let mut config = ClientConfig::builder()
.with_root_certificates(root_certs)
.with_no_client_auth();
config
.dangerous()
.set_certificate_verifier(Arc::new(AllowAnyCertVerifier));
self.client = Client::dial_with_tls_config(&self.addr, &self.domain, config)
.await
.ok();
} else {
let root_certs = RootCertStore::empty();
let config = ClientConfig::builder()
.with_root_certificates(root_certs)
.with_no_client_auth();
self.client = Client::dial_with_tls_config(&self.addr, &self.domain, config)
.await
.ok();
}
} else {
self.client = Client::dial_websocket(&self.addr).await.ok();
}
}
self.client
.as_ref()
.ok_or_else(|| RPCError::Network(NetworkError::from(AnyError::default())))
}
}
#[derive(Debug)]
struct ErrWrap(Box<dyn std::error::Error>);
impl Display for ErrWrap {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::error::Error for ErrWrap {}
fn to_error<E: std::error::Error + 'static + Clone>(
e: toy_rpc_ha421::Error,
target: NodeId,
) -> RPCError<NodeId, Node, E> {
match e {
toy_rpc_ha421::Error::IoError(e) => RPCError::Network(NetworkError::new(&e)),
toy_rpc_ha421::Error::ParseError(e) => RPCError::Network(NetworkError::new(&ErrWrap(e))),
toy_rpc_ha421::Error::Internal(e) => {
let any: &dyn Any = &e;
let error: &E = any.downcast_ref().unwrap();
RPCError::RemoteError(RemoteError::new(target, error.clone()))
}
e @ (toy_rpc_ha421::Error::InvalidArgument
| toy_rpc_ha421::Error::ServiceNotFound
| toy_rpc_ha421::Error::MethodNotFound
| toy_rpc_ha421::Error::ExecutionError(_)
| toy_rpc_ha421::Error::Canceled(_)
| toy_rpc_ha421::Error::Timeout(_)
| toy_rpc_ha421::Error::MaxRetriesReached(_)) => RPCError::Network(NetworkError::new(&e)),
}
}
#[allow(clippy::blocks_in_conditions)]
impl RaftNetwork<TypeConfig> for NetworkConnection {
#[tracing::instrument(level = "debug", skip_all, err(Debug))]
async fn append_entries(
&mut self,
req: AppendEntriesRequest<TypeConfig>,
_option: RPCOption,
) -> Result<AppendEntriesResponse<NodeId>, RPCError<NodeId, Node, RaftError<NodeId>>> {
tracing::debug!(req = debug(&req), "append_entries");
let c = self.c().await?;
tracing::debug!("got connection");
let raft = c.raft();
tracing::debug!("got raft");
raft.append(req).await.map_err(|e| to_error(e, self.target))
}
#[tracing::instrument(level = "debug", skip_all, err(Debug))]
async fn install_snapshot(
&mut self,
req: InstallSnapshotRequest<TypeConfig>,
_option: RPCOption,
) -> Result<
InstallSnapshotResponse<NodeId>,
RPCError<NodeId, Node, RaftError<NodeId, InstallSnapshotError>>,
> {
tracing::debug!(req = debug(&req), "install_snapshot");
self.c()
.await?
.raft()
.snapshot(req)
.await
.map_err(|e| to_error(e, self.target))
}
#[tracing::instrument(level = "debug", skip_all, err(Debug))]
async fn vote(
&mut self,
req: VoteRequest<NodeId>,
_option: RPCOption,
) -> Result<VoteResponse<NodeId>, RPCError<NodeId, Node, RaftError<NodeId>>> {
tracing::debug!(req = debug(&req), "vote");
self.c()
.await?
.raft()
.vote(req)
.await
.map_err(|e| to_error(e, self.target))
}
}