use crate::protocol;
use crate::tls::TlsConfig;
use rustls::pki_types::ServerName;
use rustls::{ClientConnection, ServerConnection, StreamOwned};
use std::collections::HashMap;
use std::io::{self, BufWriter, Read, Write};
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::sync::Arc;
use std::time::Duration;
use tracing::{info, warn};
use zamsync_core::ports::Transport;
use zamsync_core::{NodeId, SyncMessage, ZamError, ZamResult};
enum TlsStream {
Server(StreamOwned<ServerConnection, TcpStream>),
Client(StreamOwned<ClientConnection, TcpStream>),
}
impl TlsStream {
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
match self {
TlsStream::Server(s) => s.get_ref().set_read_timeout(dur),
TlsStream::Client(s) => s.get_ref().set_read_timeout(dur),
}
}
}
impl Read for TlsStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
TlsStream::Server(s) => s.read(buf),
TlsStream::Client(s) => s.read(buf),
}
}
}
impl Write for TlsStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
TlsStream::Server(s) => s.write(buf),
TlsStream::Client(s) => s.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
TlsStream::Server(s) => s.flush(),
TlsStream::Client(s) => s.flush(),
}
}
}
struct TlsPeerConn {
stream: TlsStream,
frame_buf: crate::protocol::FrameBuffer,
pending: Option<SyncMessage>,
}
pub struct TlsTcpTransport {
listener: TcpListener,
server_config: Arc<rustls::ServerConfig>,
client_config: Arc<rustls::ClientConfig>,
peers: HashMap<u32, TlsPeerConn>,
}
const TLS_SERVER_NAME: &str = "zamsync.local";
impl TlsTcpTransport {
pub fn bind(addr: &str, config: &TlsConfig) -> ZamResult<Self> {
crate::tls::install_crypto_provider();
let server_config = config.server_config()?;
let client_config = config.client_config()?;
let listener = TcpListener::bind(addr)?;
listener.set_nonblocking(true)?;
info!("TLS listener on {}", addr);
Ok(Self {
listener,
server_config,
client_config,
peers: HashMap::new(),
})
}
pub fn local_addr(&self) -> ZamResult<SocketAddr> {
Ok(self.listener.local_addr()?)
}
pub fn accept_any(&mut self) -> ZamResult<NodeId> {
self.listener.set_nonblocking(false)?;
let (tcp, addr) = self.listener.accept()?;
self.listener.set_nonblocking(true)?;
tcp.set_read_timeout(Some(Duration::from_millis(5_000)))?;
let conn = ServerConnection::new(Arc::clone(&self.server_config))
.map_err(|e| ZamError::Config(format!("TLS server init: {e}")))?;
let mut tls = TlsStream::Server(StreamOwned::new(conn, tcp));
let msg = protocol::decode(&mut tls)?;
let node_id = match &msg {
SyncMessage::Handshake { node_id, .. } => *node_id,
other => {
warn!(?other, "expected Handshake as first TLS message");
return Err(ZamError::Protocol(
"first message from TLS peer must be a Handshake".into(),
));
}
};
tls.set_read_timeout(Some(Duration::from_millis(50)))?;
self.peers.insert(
node_id.0,
TlsPeerConn {
stream: tls,
frame_buf: crate::protocol::FrameBuffer::new(),
pending: Some(msg),
},
);
info!(peer = node_id.0, %addr, "TLS peer accepted");
Ok(node_id)
}
pub fn disconnect(&mut self, peer_id: NodeId) {
self.peers.remove(&peer_id.0);
}
pub fn connect(&mut self, peer_id: NodeId, addr: &str) -> ZamResult<()> {
crate::tls::install_crypto_provider();
let tcp = TcpStream::connect(addr)?;
tcp.set_read_timeout(Some(Duration::from_millis(50)))?;
let server_name = ServerName::try_from(TLS_SERVER_NAME)
.map_err(|e| ZamError::Config(format!("invalid TLS server name: {e}")))?
.to_owned();
let conn = ClientConnection::new(Arc::clone(&self.client_config), server_name)
.map_err(|e| ZamError::Config(format!("TLS client init: {e}")))?;
let tls = TlsStream::Client(StreamOwned::new(conn, tcp));
self.peers.insert(
peer_id.0,
TlsPeerConn {
stream: tls,
frame_buf: crate::protocol::FrameBuffer::new(),
pending: None,
},
);
info!(peer = peer_id.0, addr, "TLS connection established");
Ok(())
}
pub fn peer_count(&self) -> usize {
self.peers.len()
}
}
impl Transport for TlsTcpTransport {
fn send(&mut self, peer_id: NodeId, message: &SyncMessage) -> ZamResult<()> {
let peer = self.peers.get_mut(&peer_id.0).ok_or_else(|| {
ZamError::Protocol(format!("no TLS connection to peer {}", peer_id.0))
})?;
let mut writer = BufWriter::new(&mut peer.stream);
protocol::encode(message, &mut writer)
}
fn receive(&mut self) -> ZamResult<Option<(NodeId, SyncMessage)>> {
let peer_ids: Vec<u32> = self.peers.keys().cloned().collect();
for peer_id_raw in peer_ids {
if let Some(peer) = self.peers.get_mut(&peer_id_raw) {
if let Some(msg) = peer.pending.take() {
return Ok(Some((NodeId(peer_id_raw), msg)));
}
match peer.frame_buf.try_read_frame(&mut peer.stream) {
Ok(Some(bytes)) => {
let msg = rkyv::from_bytes::<SyncMessage>(&bytes)
.map_err(|e| ZamError::Serialization(format!("{}", e)))?;
return Ok(Some((NodeId(peer_id_raw), msg)));
}
Ok(None) => continue,
Err(e) => return Err(e),
}
}
}
Ok(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tls::generate_credentials;
use std::thread;
use std::time::Duration;
use zamsync_core::{NodeId, SyncMessage, VersionVector};
#[test]
fn test_tls_handshake_and_message_exchange() {
crate::tls::install_crypto_provider();
let creds = generate_credentials().expect("keygen");
let server_config = TlsConfig::from_pem(
creds.node_cert_pem.clone(),
creds.node_key_pem.clone(),
creds.ca_cert_pem.clone(),
);
let client_config = TlsConfig::from_pem(
creds.node_cert_pem.clone(),
creds.node_key_pem.clone(),
creds.ca_cert_pem.clone(),
);
let mut server = TlsTcpTransport::bind("127.0.0.1:0", &server_config).unwrap();
let server_addr = server.local_addr().unwrap().to_string();
let client_handle = thread::spawn(move || {
let mut client = TlsTcpTransport::bind("127.0.0.1:0", &client_config).unwrap();
let server_id = NodeId(99);
client.connect(server_id, &server_addr).unwrap();
let hs = SyncMessage::Handshake {
node_id: NodeId(1),
vv: VersionVector::new(),
};
client.send(server_id, &hs).unwrap();
loop {
if let Some((_id, msg)) = client.receive().unwrap() {
return msg;
}
thread::sleep(Duration::from_millis(10));
}
});
let peer_id = server.accept_any().unwrap();
assert_eq!(peer_id.0, 1);
let (from, msg) = loop {
if let Some(pair) = server.receive().unwrap() {
break pair;
}
};
assert_eq!(from.0, 1);
assert!(matches!(msg, SyncMessage::Handshake { .. }));
server.send(peer_id, &SyncMessage::SyncComplete).unwrap();
let reply = client_handle.join().unwrap();
assert!(matches!(reply, SyncMessage::SyncComplete));
}
#[test]
fn test_tls_rejects_untrusted_client() {
crate::tls::install_crypto_provider();
let creds_a = generate_credentials().expect("keygen A");
let server_config = TlsConfig::from_pem(
creds_a.node_cert_pem.clone(),
creds_a.node_key_pem.clone(),
creds_a.ca_cert_pem.clone(),
);
let creds_b = generate_credentials().expect("keygen B");
let client_config = TlsConfig::from_pem(
creds_b.node_cert_pem,
creds_b.node_key_pem,
creds_a.ca_cert_pem.clone(), );
let mut server = TlsTcpTransport::bind("127.0.0.1:0", &server_config).unwrap();
let server_addr = server.local_addr().unwrap().to_string();
let client_handle = thread::spawn(move || {
let mut client = TlsTcpTransport::bind("127.0.0.1:0", &client_config).unwrap();
let server_id = NodeId(99);
client.connect(server_id, &server_addr).unwrap();
let hs = SyncMessage::Handshake {
node_id: NodeId(2),
vv: VersionVector::new(),
};
client.send(server_id, &hs)
});
let server_result = server.accept_any();
let client_result = client_handle.join().unwrap();
assert!(
server_result.is_err() || client_result.is_err(),
"expected TLS rejection but both sides succeeded"
);
}
}