use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::{debug, info};
use crate::config::SrxConfig;
use crate::error::{Result, SessionError, SrxError};
use crate::pipeline::Payload;
use crate::session::{Handshake, Session};
use crate::signaling::inband::Signal;
use crate::transport::TcpTransport;
use crate::{SrxNode, TransportManager};
pub struct SecureTcpSession {
transport: Arc<TcpTransport>,
node: Mutex<SrxNode>,
}
impl SecureTcpSession {
pub async fn connect(addr: SocketAddr, config: SrxConfig) -> Result<Self> {
info!("SRX: connecting to {addr}");
let stream = tokio::net::TcpStream::connect(addr).await?;
debug!("SRX: TCP connection established to {addr}");
let transport = Arc::new(TcpTransport::from_stream(stream));
Self::connect_over_transport(transport, config).await
}
pub async fn accept(stream: tokio::net::TcpStream, config: SrxConfig) -> Result<Self> {
debug!("SRX: accepting incoming session");
let transport = Arc::new(TcpTransport::from_stream(stream));
Self::accept_over_transport(transport, config).await
}
async fn connect_over_transport(
transport: Arc<TcpTransport>,
config: SrxConfig,
) -> Result<Self> {
debug!("SRX handshake: initiating as client");
let mut hs = Handshake::new_initiator();
let client_hello = hs.client_hello()?;
debug!("SRX handshake: sending ClientHello ({} bytes)", client_hello.len());
transport.send_framed(&client_hello).await?;
debug!("SRX handshake: waiting for ServerHello");
let server_hello = transport.recv_framed().await?;
debug!("SRX handshake: received ServerHello ({} bytes)", server_hello.len());
let client_finished = hs.finalize(&server_hello)?;
debug!("SRX handshake: sending ClientFinished ({} bytes)", client_finished.len());
transport.send_framed(&client_finished).await?;
let master = hs.master_secret().ok_or_else(|| {
SrxError::Session(SessionError::HandshakeFailed(
"master secret not established".to_string(),
))
})?;
let session = session_from_client_hello(false, master, &client_hello)?;
let aead = Arc::new(config.build_aead_pipeline(&session.data_key)?);
let node = SrxNode::from_session(config, session, aead, TransportManager::new())?;
info!("SRX handshake: session established (initiator)");
Ok(Self {
transport,
node: Mutex::new(node),
})
}
async fn accept_over_transport(
transport: Arc<TcpTransport>,
config: SrxConfig,
) -> Result<Self> {
debug!("SRX handshake: initiating as server (responder)");
let mut hs = Handshake::new_responder();
debug!("SRX handshake: waiting for ClientHello");
let client_hello = transport.recv_framed().await?;
debug!("SRX handshake: received ClientHello ({} bytes)", client_hello.len());
let server_hello = hs.server_hello(&client_hello)?;
debug!("SRX handshake: sending ServerHello ({} bytes)", server_hello.len());
transport.send_framed(&server_hello).await?;
debug!("SRX handshake: waiting for ClientFinished");
let client_finished = transport.recv_framed().await?;
debug!("SRX handshake: received ClientFinished ({} bytes)", client_finished.len());
hs.server_finish(&client_finished)?;
let master = hs.master_secret().ok_or_else(|| {
SrxError::Session(SessionError::HandshakeFailed(
"master secret not established".to_string(),
))
})?;
let session = session_from_client_hello(true, master, &client_hello)?;
let aead = Arc::new(config.build_aead_pipeline(&session.data_key)?);
let node = SrxNode::from_session(config, session, aead, TransportManager::new())?;
info!("SRX handshake: session established (responder)");
Ok(Self {
transport,
node: Mutex::new(node),
})
}
pub async fn send_data(&self, payload: &[u8]) -> Result<()> {
let envelope = {
let mut node = self.node.lock().await;
node.prepare_outgoing(payload)?
};
self.transport.send_framed(&envelope).await?;
Ok(())
}
pub async fn send_signal(&self, signal: &Signal) -> Result<()> {
let envelope = {
let mut node = self.node.lock().await;
node.send_signal(signal)?
};
self.transport.send_framed(&envelope).await?;
Ok(())
}
pub async fn recv(&self) -> Result<Payload> {
let envelope = self.transport.recv_framed().await?;
let node = self.node.lock().await;
node.process_incoming_dispatched(&envelope)
}
}
fn session_from_client_hello(
is_server: bool,
master_secret: [u8; 32],
client_hello: &[u8],
) -> Result<Session> {
let (timestamp, nonce) = parse_client_hello(client_hello)?;
let session_id = if is_server { 1 } else { 2 };
Session::from_master_secret(session_id, &master_secret, timestamp, &nonce)
}
fn parse_client_hello(client_hello: &[u8]) -> Result<(u64, [u8; 16])> {
if client_hello.len() < 10 {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"ClientHello too short".to_string(),
)));
}
if &client_hello[0..4] != b"SRXH" {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"Invalid magic in ClientHello".to_string(),
)));
}
if client_hello[5] != 1 {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"Invalid message type in ClientHello".to_string(),
)));
}
let payload_len = u32::from_be_bytes(client_hello[6..10].try_into().expect("len")) as usize;
if client_hello.len() != 10 + payload_len || payload_len != 24 {
return Err(SrxError::Session(SessionError::HandshakeFailed(
"Invalid ClientHello payload length".to_string(),
)));
}
let payload = &client_hello[10..];
let timestamp = u64::from_be_bytes(payload[0..8].try_into().expect("timestamp"));
let mut nonce = [0u8; 16];
nonce.copy_from_slice(&payload[8..24]);
Ok((timestamp, nonce))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn secure_tcp_session_roundtrip() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut cfg = SrxConfig::default();
cfg.replay.persist_enabled = false;
let sess = SecureTcpSession::accept(stream, cfg).await.unwrap();
match sess.recv().await.unwrap() {
Payload::Data(d) => {
assert_eq!(d, b"ping");
sess.send_data(b"pong").await.unwrap();
}
Payload::Signal(_) => panic!("expected data"),
}
});
let mut cfg = SrxConfig::default();
cfg.replay.persist_enabled = false;
let sess = SecureTcpSession::connect(addr, cfg).await.unwrap();
sess.send_data(b"ping").await.unwrap();
match sess.recv().await.unwrap() {
Payload::Data(d) => assert_eq!(d, b"pong"),
Payload::Signal(_) => panic!("expected data"),
}
server.await.unwrap();
}
}