use anyhow::Result;
use async_trait::async_trait;
use rand;
use runar_serializer::ArcValue;
use serde::{Deserialize, Serialize};
use std::future::Future;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener};
use std::ops::Range;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use rustls::client::danger::{ServerCertVerified, ServerCertVerifier};
use rustls_pki_types::{CertificateDer, ServerName};
pub mod connection_pool;
pub mod peer_state;
pub mod quic_transport; pub mod stream_pool;
use crate::routing::TopicPath;
pub use connection_pool::ConnectionPool;
pub use peer_state::PeerState;
pub use stream_pool::StreamPool;
#[derive(Debug)]
pub struct SkipServerVerification {}
impl ServerCertVerifier for SkipServerVerification {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls_pki_types::UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA1,
rustls::SignatureScheme::ECDSA_SHA1_Legacy,
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::ED448,
]
}
}
pub use quic_transport::{QuicTransport, QuicTransportOptions};
use super::discovery::multicast_discovery::PeerInfo;
use super::discovery::NodeInfo;
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransportOptions {
pub timeout: Option<Duration>,
pub max_message_size: Option<usize>,
pub bind_address: SocketAddr,
}
impl Default for TransportOptions {
fn default() -> Self {
let port = pick_free_port(50000..51000).unwrap_or(0);
let bind_address = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port);
Self {
timeout: Some(Duration::from_secs(30)),
max_message_size: Some(1024 * 1024), bind_address,
}
}
}
pub fn pick_free_port(port_range: Range<u16>) -> Option<u16> {
use rand::Rng;
let mut rng = rand::rng();
let range_size = port_range.end - port_range.start;
let max_attempts = 50;
let mut attempts = 0;
while attempts < max_attempts {
let port = port_range.start + rng.random_range(0..range_size);
if let Ok(tcp_listener) =
TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port))
{
let bound_port = match tcp_listener.local_addr() {
Ok(addr) => addr.port(),
Err(_) => {
attempts += 1;
continue;
}
};
if std::net::UdpSocket::bind(SocketAddr::new(
IpAddr::V4(Ipv4Addr::UNSPECIFIED),
bound_port,
))
.is_ok()
{
return Some(bound_port);
}
}
attempts += 1;
}
None }
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum NetworkMessageType {
Request,
Response,
Event,
Discovery,
Heartbeat,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageContext {
pub profile_public_key: Vec<u8>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkMessagePayloadItem {
pub path: String,
pub value_bytes: Vec<u8>,
pub context: Option<MessageContext>,
pub correlation_id: String,
}
impl NetworkMessagePayloadItem {
pub fn new(
path: String,
value_bytes: Vec<u8>,
correlation_id: String,
context: MessageContext,
) -> Self {
Self {
path,
value_bytes,
correlation_id,
context: Some(context),
}
}
}
pub const MESSAGE_TYPE_DISCOVERY: u32 = 1;
pub const MESSAGE_TYPE_HEARTBEAT: u32 = 2;
pub const MESSAGE_TYPE_HANDSHAKE: u32 = 3;
pub const MESSAGE_TYPE_REQUEST: u32 = 4;
pub const MESSAGE_TYPE_RESPONSE: u32 = 5;
pub const MESSAGE_TYPE_EVENT: u32 = 6;
pub const MESSAGE_TYPE_ERROR: u32 = 7;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkMessage {
pub source_node_id: String,
pub destination_node_id: String,
pub message_type: u32,
pub payloads: Vec<NetworkMessagePayloadItem>,
}
pub type MessageHandler = Box<
dyn Fn(
NetworkMessage,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<Option<NetworkMessage>, NetworkError>>
+ Send,
>,
> + Send
+ Sync,
>;
pub type OneWayMessageHandler = Box<
dyn Fn(
NetworkMessage,
)
-> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), NetworkError>> + Send>>
+ Send
+ Sync,
>;
pub type MessageCallback =
Arc<dyn Fn(NetworkMessage) -> BoxFuture<'static, Result<()>> + Send + Sync>;
pub type ConnectionCallback =
Arc<dyn Fn(String, bool, Option<NodeInfo>) -> BoxFuture<'static, Result<()>> + Send + Sync>;
#[async_trait]
pub trait NetworkTransport: Send + Sync {
async fn start(self: Arc<Self>) -> Result<(), NetworkError>;
async fn stop(&self) -> Result<(), NetworkError>;
async fn disconnect(&self, node_id: &str) -> Result<(), NetworkError>;
async fn is_connected(&self, node_id: &str) -> bool;
async fn request(
&self,
topic_path: &TopicPath,
params: Option<ArcValue>,
peer_node_id: &str,
context: MessageContext,
) -> Result<ArcValue, NetworkError>;
async fn publish(
&self,
topic_path: &TopicPath,
params: Option<ArcValue>,
peer_node_id: &str,
) -> Result<(), NetworkError>;
async fn connect_peer(self: Arc<Self>, discovery_msg: PeerInfo) -> Result<(), NetworkError>;
fn get_local_address(&self) -> String;
async fn update_peers(&self, node_info: NodeInfo) -> Result<(), NetworkError>;
fn keystore(&self) -> Arc<dyn runar_serializer::traits::EnvelopeCrypto>;
fn label_resolver(&self) -> Arc<dyn runar_serializer::traits::LabelResolver>;
}
#[derive(Error, Debug)]
pub enum NetworkError {
#[error("Connection error: {0}")]
ConnectionError(String),
#[error("Message error: {0}")]
MessageError(String),
#[error("Discovery error: {0}")]
DiscoveryError(String),
#[error("Transport error: {0}")]
TransportError(String),
#[error("Configuration error: {0}")]
ConfigurationError(String),
}