use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use nodedb_types::config::tuning::ClusterTransportTuning;
use tracing::info;
use crate::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig, RetryPolicy};
use crate::error::{ClusterError, Result};
use crate::transport::auth_context::AuthContext;
use crate::transport::config;
use crate::transport::credentials::{self, TransportCredentials};
use crate::transport::server::{NoopIdentityStore, PeerIdentityStore};
pub struct NexarTransport {
pub(super) node_id: u64,
pub(super) listener: nexar::TransportListener,
pub(super) client_config: quinn::ClientConfig,
pub(super) peers: RwLock<HashMap<u64, quinn::Connection>>,
pub(super) peer_addrs: RwLock<HashMap<u64, SocketAddr>>,
pub(super) rpc_timeout: Duration,
pub(super) circuit_breaker: Arc<CircuitBreaker>,
pub(super) retry_policy: RetryPolicy,
pub(super) auth: Arc<AuthContext>,
local_spki_pin: Option<[u8; 32]>,
pub(super) agreed_versions: RwLock<HashMap<usize, crate::wire_version::WireVersion>>,
}
impl NexarTransport {
pub fn new(node_id: u64, listen_addr: SocketAddr, creds: TransportCredentials) -> Result<Self> {
Self::with_timeout(node_id, listen_addr, config::DEFAULT_RPC_TIMEOUT, creds)
}
pub fn with_timeout(
node_id: u64,
listen_addr: SocketAddr,
rpc_timeout: Duration,
creds: TransportCredentials,
) -> Result<Self> {
let defaults = ClusterTransportTuning::default();
Self::build(
node_id,
listen_addr,
rpc_timeout,
&defaults,
creds,
Arc::new(NoopIdentityStore),
)
}
pub fn with_tuning(
node_id: u64,
listen_addr: SocketAddr,
tuning: &ClusterTransportTuning,
creds: TransportCredentials,
) -> Result<Self> {
let rpc_timeout = Duration::from_secs(tuning.rpc_timeout_secs);
Self::build(
node_id,
listen_addr,
rpc_timeout,
tuning,
creds,
Arc::new(NoopIdentityStore),
)
}
pub fn with_tuning_and_identity(
node_id: u64,
listen_addr: SocketAddr,
tuning: &ClusterTransportTuning,
creds: TransportCredentials,
identity_store: Arc<dyn PeerIdentityStore>,
) -> Result<Self> {
let rpc_timeout = Duration::from_secs(tuning.rpc_timeout_secs);
Self::build(
node_id,
listen_addr,
rpc_timeout,
tuning,
creds,
identity_store,
)
}
fn build(
node_id: u64,
listen_addr: SocketAddr,
rpc_timeout: Duration,
tuning: &ClusterTransportTuning,
creds: TransportCredentials,
identity_store: Arc<dyn PeerIdentityStore>,
) -> Result<Self> {
let (server_config, client_config) = match &creds {
TransportCredentials::Mtls(tls) => (
config::make_raft_server_config_mtls(tls, tuning, Arc::clone(&identity_store))?,
config::make_raft_client_config_mtls(tls, tuning, Arc::clone(&identity_store))?,
),
TransportCredentials::Insecure => {
credentials::announce_insecure_transport(node_id);
(
config::make_raft_server_config(tuning)?,
config::make_raft_client_config(tuning)?,
)
}
};
let local_spki_pin = match &creds {
TransportCredentials::Mtls(tls) => Some(tls.spki_pin),
TransportCredentials::Insecure => None,
};
let auth = Arc::new(AuthContext::from_credentials(node_id, &creds));
let listener = nexar::TransportListener::bind_with_config(listen_addr, server_config)
.map_err(|e| ClusterError::Transport {
detail: format!("bind {listen_addr}: {e}"),
})?;
info!(
node_id,
addr = %listener.local_addr(),
rpc_timeout_ms = rpc_timeout.as_millis() as u64,
mtls = !creds.is_insecure(),
"raft transport bound"
);
Ok(Self {
node_id,
listener,
client_config,
peers: RwLock::new(HashMap::new()),
peer_addrs: RwLock::new(HashMap::new()),
rpc_timeout,
circuit_breaker: Arc::new(CircuitBreaker::new(CircuitBreakerConfig::default())),
retry_policy: RetryPolicy::default(),
auth,
local_spki_pin,
agreed_versions: RwLock::new(HashMap::new()),
})
}
pub(super) fn auth(&self) -> &Arc<AuthContext> {
&self.auth
}
pub fn circuit_breaker(&self) -> &Arc<CircuitBreaker> {
&self.circuit_breaker
}
pub fn local_addr(&self) -> SocketAddr {
self.listener.local_addr()
}
pub fn node_id(&self) -> u64 {
self.node_id
}
pub fn mac_key(&self) -> crate::rpc_codec::MacKey {
self.auth.mac_key.clone()
}
pub fn local_spki_pin(&self) -> Option<[u8; 32]> {
self.local_spki_pin
}
pub fn agreed_version_for(&self, stable_id: usize) -> Option<crate::wire_version::WireVersion> {
let versions = self
.agreed_versions
.read()
.unwrap_or_else(|p| p.into_inner());
versions.get(&stable_id).copied()
}
pub(super) fn store_agreed_version(
&self,
stable_id: usize,
version: crate::wire_version::WireVersion,
) {
let mut versions = self
.agreed_versions
.write()
.unwrap_or_else(|p| p.into_inner());
versions.insert(stable_id, version);
}
pub(super) fn evict_agreed_version(&self, stable_id: usize) {
let mut versions = self
.agreed_versions
.write()
.unwrap_or_else(|p| p.into_inner());
versions.remove(&stable_id);
}
#[doc(hidden)]
pub async fn accept_raw(&self) -> crate::error::Result<quinn::Connection> {
self.listener
.accept()
.await
.map_err(|e| crate::error::ClusterError::Transport {
detail: format!("accept_raw: {e}"),
})
}
#[doc(hidden)]
pub async fn connect_raw(
&self,
addr: std::net::SocketAddr,
) -> crate::error::Result<quinn::Connection> {
self.listener
.endpoint()
.connect_with(
self.client_config.clone(),
addr,
crate::transport::config::SNI_HOSTNAME,
)
.map_err(|e| crate::error::ClusterError::Transport {
detail: format!("connect_raw to {addr}: {e}"),
})?
.await
.map_err(|e| crate::error::ClusterError::Transport {
detail: format!("connect_raw handshake with {addr}: {e}"),
})
}
pub fn peer_snapshot(&self) -> Vec<TransportPeerSnapshot> {
let addrs = self.peer_addrs.read().unwrap_or_else(|p| p.into_inner());
let peers = self.peers.read().unwrap_or_else(|p| p.into_inner());
let mut out: Vec<TransportPeerSnapshot> = addrs
.iter()
.map(|(id, addr)| TransportPeerSnapshot {
peer_id: *id,
addr: addr.to_string(),
connected: peers.contains_key(id),
})
.collect();
out.sort_by_key(|p| p.peer_id);
out
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct TransportPeerSnapshot {
pub peer_id: u64,
pub addr: String,
pub connected: bool,
}