zamsync-network 1.2.2

TCP and mTLS transport for the ZamSync distributed sync engine
Documentation
use crate::protocol;
use crate::tls::TlsConfig;
use rustls::pki_types::ServerName;
use rustls::{ClientConnection, ServerConnection, StreamOwned};
use std::collections::HashMap;
use std::io::{self, BufWriter, Read, Write};
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::sync::Arc;
use std::time::Duration;
use tracing::{info, warn};
use zamsync_core::ports::Transport;
use zamsync_core::{NodeId, SyncMessage, ZamError, ZamResult};

use super::peer::TlsPeerTransport;

// ---- TLS stream wrapper -----------------------------------------------------

/// Unified read+write wrapper over a TLS stream (server-accepted or client-connected).
/// `pub(super)` so `peer.rs` can use it as a field type without exposing it publicly.
pub(super) enum TlsStream {
    Server(StreamOwned<ServerConnection, TcpStream>),
    Client(StreamOwned<ClientConnection, TcpStream>),
}

impl TlsStream {
    pub(super) fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
        match self {
            TlsStream::Server(s) => s.get_ref().set_read_timeout(dur),
            TlsStream::Client(s) => s.get_ref().set_read_timeout(dur),
        }
    }
}

impl Read for TlsStream {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        match self {
            TlsStream::Server(s) => s.read(buf),
            TlsStream::Client(s) => s.read(buf),
        }
    }
}

impl Write for TlsStream {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        match self {
            TlsStream::Server(s) => s.write(buf),
            TlsStream::Client(s) => s.write(buf),
        }
    }

    fn flush(&mut self) -> io::Result<()> {
        match self {
            TlsStream::Server(s) => s.flush(),
            TlsStream::Client(s) => s.flush(),
        }
    }
}

// ---- Internal per-connection state ------------------------------------------

struct TlsPeerConn {
    stream: TlsStream,
    frame_buf: protocol::FrameBuffer,
    pending: Option<SyncMessage>,
}

impl TlsPeerConn {
    fn new(stream: TlsStream, pending: Option<SyncMessage>) -> Self {
        Self {
            stream,
            frame_buf: protocol::FrameBuffer::new(),
            pending,
        }
    }
}

// ---- TlsTcpTransport --------------------------------------------------------

/// Fixed SNI hostname used in all ZamSync TLS connections.
/// Must match the SAN in the node certificate generated by `zamsync keygen`.
const TLS_SERVER_NAME: &str = "zamsync.local";

/// TLS-encrypted TCP transport using mutual TLS (mTLS).
///
/// Both sides present a certificate from the shared deployment CA. A node that
/// cannot present a valid cert is refused at the TLS handshake -- no application
/// protocol is reached. Use [`TlsConfig::from_files`] to load credentials.
pub struct TlsTcpTransport {
    listener: TcpListener,
    server_config: Arc<rustls::ServerConfig>,
    client_config: Arc<rustls::ClientConfig>,
    peers: HashMap<u32, TlsPeerConn>,
}

impl TlsTcpTransport {
    pub fn bind(addr: &str, config: &TlsConfig) -> ZamResult<Self> {
        crate::tls::install_crypto_provider();
        let server_config = config.server_config()?;
        let client_config = config.client_config()?;
        let listener = TcpListener::bind(addr)?;
        listener.set_nonblocking(true)?;
        info!("TLS listener on {}", addr);
        Ok(Self {
            listener,
            server_config,
            client_config,
            peers: HashMap::new(),
        })
    }

    pub fn local_addr(&self) -> ZamResult<SocketAddr> {
        Ok(self.listener.local_addr()?)
    }

    fn raw_accept(&mut self) -> ZamResult<(TcpStream, SocketAddr)> {
        self.listener.set_nonblocking(false)?;
        let result = self.listener.accept()?;
        self.listener.set_nonblocking(true)?;
        Ok(result)
    }

    fn make_server_stream(&self, tcp: TcpStream) -> ZamResult<TlsStream> {
        let conn = ServerConnection::new(Arc::clone(&self.server_config))
            .map_err(|e| ZamError::Config(format!("TLS server init: {e}")))?;
        Ok(TlsStream::Server(StreamOwned::new(conn, tcp)))
    }

    /// Blocking accept: reads the initial Handshake to discover peer NodeId.
    /// Performs TLS handshake + mTLS client-cert verification transparently.
    pub fn accept_any(&mut self) -> ZamResult<NodeId> {
        let (tcp, addr) = self.raw_accept()?;
        tcp.set_read_timeout(Some(Duration::from_millis(5_000)))?;

        let mut tls = self.make_server_stream(tcp)?;
        let msg = protocol::decode(&mut tls)?;
        let node_id = match &msg {
            SyncMessage::Handshake { node_id, .. } => *node_id,
            other => {
                warn!(?other, "expected Handshake as first TLS message");
                return Err(ZamError::Protocol(
                    "first message from TLS peer must be a Handshake".into(),
                ));
            }
        };

        tls.set_read_timeout(Some(Duration::from_millis(50)))?;
        self.peers
            .insert(node_id.0, TlsPeerConn::new(tls, Some(msg)));
        info!(peer = node_id.0, %addr, "TLS peer accepted");
        Ok(node_id)
    }

    pub fn disconnect(&mut self, peer_id: NodeId) {
        self.peers.remove(&peer_id.0);
    }

    pub fn connect(&mut self, peer_id: NodeId, addr: &str) -> ZamResult<()> {
        crate::tls::install_crypto_provider();
        let tcp = TcpStream::connect(addr)?;
        tcp.set_read_timeout(Some(Duration::from_millis(50)))?;

        let server_name = ServerName::try_from(TLS_SERVER_NAME)
            .map_err(|e| ZamError::Config(format!("invalid TLS server name: {e}")))?
            .to_owned();

        let conn = ClientConnection::new(Arc::clone(&self.client_config), server_name)
            .map_err(|e| ZamError::Config(format!("TLS client init: {e}")))?;

        let tls = TlsStream::Client(StreamOwned::new(conn, tcp));
        self.peers.insert(peer_id.0, TlsPeerConn::new(tls, None));
        info!(peer = peer_id.0, addr, "TLS connection established");
        Ok(())
    }

    pub fn peer_count(&self) -> usize {
        self.peers.len()
    }

    /// Accepts one TLS connection and returns a self-contained per-peer transport.
    ///
    /// Unlike [`accept_any`], the connection is not stored in the internal
    /// HashMap. The returned [`TlsPeerTransport`] is `Send` and can be moved
    /// into a worker thread for concurrent hub serving.
    pub fn accept_split(&mut self) -> ZamResult<TlsPeerTransport> {
        let (tcp, addr) = self.raw_accept()?;
        tcp.set_read_timeout(Some(Duration::from_millis(5_000)))?;

        let mut tls = self.make_server_stream(tcp)?;
        let msg = protocol::decode(&mut tls)?;
        let node_id = match &msg {
            SyncMessage::Handshake { node_id, .. } => *node_id,
            other => {
                warn!(?other, "expected Handshake as first TLS message");
                return Err(ZamError::Protocol(
                    "first message from TLS peer must be a Handshake".into(),
                ));
            }
        };

        tls.set_read_timeout(Some(Duration::from_millis(50)))?;
        info!(peer = node_id.0, %addr, "TLS peer accepted (split mode)");
        Ok(TlsPeerTransport::new(node_id, tls, Some(msg)))
    }
}

impl Transport for TlsTcpTransport {
    fn send(&mut self, peer_id: NodeId, message: &SyncMessage) -> ZamResult<()> {
        let peer = self.peers.get_mut(&peer_id.0).ok_or_else(|| {
            ZamError::Protocol(format!("no TLS connection to peer {}", peer_id.0))
        })?;
        let mut writer = BufWriter::new(&mut peer.stream);
        protocol::encode(message, &mut writer)
    }

    fn receive(&mut self) -> ZamResult<Option<(NodeId, SyncMessage)>> {
        let peer_ids: Vec<u32> = self.peers.keys().cloned().collect();
        for peer_id_raw in peer_ids {
            if let Some(peer) = self.peers.get_mut(&peer_id_raw) {
                if let Some(msg) = peer.pending.take() {
                    return Ok(Some((NodeId(peer_id_raw), msg)));
                }
                match peer.frame_buf.try_read_frame(&mut peer.stream) {
                    Ok(Some(bytes)) => {
                        let msg = rkyv::from_bytes::<SyncMessage>(&bytes)
                            .map_err(|e| ZamError::Serialization(format!("{}", e)))?;
                        return Ok(Some((NodeId(peer_id_raw), msg)));
                    }
                    Ok(None) => continue,
                    Err(e) => return Err(e),
                }
            }
        }
        Ok(None)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::tls::generate_credentials;
    use std::thread;
    use std::time::Duration;
    use zamsync_core::{NodeId, SyncMessage, VersionVector};

    #[test]
    fn test_tls_handshake_and_message_exchange() {
        crate::tls::install_crypto_provider();
        let creds = generate_credentials().expect("keygen");

        let server_config = TlsConfig::from_pem(
            creds.node_cert_pem.clone(),
            creds.node_key_pem.clone(),
            creds.ca_cert_pem.clone(),
        );
        let client_config = TlsConfig::from_pem(
            creds.node_cert_pem.clone(),
            creds.node_key_pem.clone(),
            creds.ca_cert_pem.clone(),
        );

        let mut server = TlsTcpTransport::bind("127.0.0.1:0", &server_config).unwrap();
        let server_addr = server.local_addr().unwrap().to_string();

        let client_handle = thread::spawn(move || {
            let mut client = TlsTcpTransport::bind("127.0.0.1:0", &client_config).unwrap();
            let server_id = NodeId(99);
            client.connect(server_id, &server_addr).unwrap();

            let hs = SyncMessage::Handshake {
                node_id: NodeId(1),
                vv: VersionVector::new(),
            };
            client.send(server_id, &hs).unwrap();

            loop {
                if let Some((_id, msg)) = client.receive().unwrap() {
                    return msg;
                }
                thread::sleep(Duration::from_millis(10));
            }
        });

        let peer_id = server.accept_any().unwrap();
        assert_eq!(peer_id.0, 1);

        let (from, msg) = loop {
            if let Some(pair) = server.receive().unwrap() {
                break pair;
            }
        };
        assert_eq!(from.0, 1);
        assert!(matches!(msg, SyncMessage::Handshake { .. }));

        server.send(peer_id, &SyncMessage::SyncComplete).unwrap();

        let reply = client_handle.join().unwrap();
        assert!(matches!(reply, SyncMessage::SyncComplete));
    }

    #[test]
    fn test_tls_rejects_untrusted_client() {
        crate::tls::install_crypto_provider();

        let creds_a = generate_credentials().expect("keygen A");
        let server_config = TlsConfig::from_pem(
            creds_a.node_cert_pem.clone(),
            creds_a.node_key_pem.clone(),
            creds_a.ca_cert_pem.clone(),
        );

        let creds_b = generate_credentials().expect("keygen B");
        let client_config = TlsConfig::from_pem(
            creds_b.node_cert_pem,
            creds_b.node_key_pem,
            creds_a.ca_cert_pem.clone(), // trusts server CA but presents wrong cert
        );

        let mut server = TlsTcpTransport::bind("127.0.0.1:0", &server_config).unwrap();
        let server_addr = server.local_addr().unwrap().to_string();

        let client_handle = thread::spawn(move || {
            let mut client = TlsTcpTransport::bind("127.0.0.1:0", &client_config).unwrap();
            let server_id = NodeId(99);
            client.connect(server_id, &server_addr).unwrap();

            let hs = SyncMessage::Handshake {
                node_id: NodeId(2),
                vv: VersionVector::new(),
            };
            client.send(server_id, &hs)
        });

        let server_result = server.accept_any();
        let client_result = client_handle.join().unwrap();

        assert!(
            server_result.is_err() || client_result.is_err(),
            "expected TLS rejection but both sides succeeded"
        );
    }
}