Skip to main content

srx/
high_api.rs

1//! High-level SRX API: secure framed TCP session with automatic handshake.
2//!
3//! This API hides handshake/session bootstrapping and provides simple
4//! send/receive primitives for application payloads and in-band signals.
5
6use std::net::SocketAddr;
7use std::sync::Arc;
8
9use tokio::sync::Mutex;
10use tracing::{debug, info};
11
12use crate::config::SrxConfig;
13use crate::error::{Result, SessionError, SrxError};
14use crate::pipeline::Payload;
15use crate::session::{Handshake, Session};
16use crate::signaling::inband::Signal;
17use crate::transport::TcpTransport;
18use crate::{SrxNode, TransportManager};
19
20/// Secure framed TCP session built on SRX.
21pub struct SecureTcpSession {
22    transport: Arc<TcpTransport>,
23    node: Mutex<SrxNode>,
24}
25
26impl SecureTcpSession {
27    /// Connect as initiator and establish secure SRX session over TCP.
28    pub async fn connect(addr: SocketAddr, config: SrxConfig) -> Result<Self> {
29        info!("SRX: connecting to {addr}");
30        let stream = tokio::net::TcpStream::connect(addr).await?;
31        debug!("SRX: TCP connection established to {addr}");
32        let transport = Arc::new(TcpTransport::from_stream(stream));
33        Self::connect_over_transport(transport, config).await
34    }
35
36    /// Accept as responder and establish secure SRX session over accepted stream.
37    pub async fn accept(stream: tokio::net::TcpStream, config: SrxConfig) -> Result<Self> {
38        debug!("SRX: accepting incoming session");
39        let transport = Arc::new(TcpTransport::from_stream(stream));
40        Self::accept_over_transport(transport, config).await
41    }
42
43    async fn connect_over_transport(
44        transport: Arc<TcpTransport>,
45        config: SrxConfig,
46    ) -> Result<Self> {
47        debug!("SRX handshake: initiating as client");
48        let mut hs = Handshake::new_initiator();
49        let client_hello = hs.client_hello()?;
50        debug!("SRX handshake: sending ClientHello ({} bytes)", client_hello.len());
51        transport.send_framed(&client_hello).await?;
52        debug!("SRX handshake: waiting for ServerHello");
53        let server_hello = transport.recv_framed().await?;
54        debug!("SRX handshake: received ServerHello ({} bytes)", server_hello.len());
55        let client_finished = hs.finalize(&server_hello)?;
56        debug!("SRX handshake: sending ClientFinished ({} bytes)", client_finished.len());
57        transport.send_framed(&client_finished).await?;
58        let master = hs.master_secret().ok_or_else(|| {
59            SrxError::Session(SessionError::HandshakeFailed(
60                "master secret not established".to_string(),
61            ))
62        })?;
63
64        let session = session_from_client_hello(false, master, &client_hello)?;
65        let aead = Arc::new(config.build_aead_pipeline(&session.data_key)?);
66        let node = SrxNode::from_session(config, session, aead, TransportManager::new())?;
67        info!("SRX handshake: session established (initiator)");
68        Ok(Self {
69            transport,
70            node: Mutex::new(node),
71        })
72    }
73
74    async fn accept_over_transport(
75        transport: Arc<TcpTransport>,
76        config: SrxConfig,
77    ) -> Result<Self> {
78        debug!("SRX handshake: initiating as server (responder)");
79        let mut hs = Handshake::new_responder();
80        debug!("SRX handshake: waiting for ClientHello");
81        let client_hello = transport.recv_framed().await?;
82        debug!("SRX handshake: received ClientHello ({} bytes)", client_hello.len());
83        let server_hello = hs.server_hello(&client_hello)?;
84        debug!("SRX handshake: sending ServerHello ({} bytes)", server_hello.len());
85        transport.send_framed(&server_hello).await?;
86        debug!("SRX handshake: waiting for ClientFinished");
87        let client_finished = transport.recv_framed().await?;
88        debug!("SRX handshake: received ClientFinished ({} bytes)", client_finished.len());
89        hs.server_finish(&client_finished)?;
90        let master = hs.master_secret().ok_or_else(|| {
91            SrxError::Session(SessionError::HandshakeFailed(
92                "master secret not established".to_string(),
93            ))
94        })?;
95
96        let session = session_from_client_hello(true, master, &client_hello)?;
97        let aead = Arc::new(config.build_aead_pipeline(&session.data_key)?);
98        let node = SrxNode::from_session(config, session, aead, TransportManager::new())?;
99        info!("SRX handshake: session established (responder)");
100        Ok(Self {
101            transport,
102            node: Mutex::new(node),
103        })
104    }
105
106    /// Send application data over secure channel.
107    pub async fn send_data(&self, payload: &[u8]) -> Result<()> {
108        let envelope = {
109            let mut node = self.node.lock().await;
110            node.prepare_outgoing(payload)?
111        };
112        self.transport.send_framed(&envelope).await?;
113        Ok(())
114    }
115
116    /// Send in-band control signal over secure channel.
117    pub async fn send_signal(&self, signal: &Signal) -> Result<()> {
118        let envelope = {
119            let mut node = self.node.lock().await;
120            node.send_signal(signal)?
121        };
122        self.transport.send_framed(&envelope).await?;
123        Ok(())
124    }
125
126    /// Receive next encrypted frame and decode it into data or signal.
127    pub async fn recv(&self) -> Result<Payload> {
128        let envelope = self.transport.recv_framed().await?;
129        let node = self.node.lock().await;
130        node.process_incoming_dispatched(&envelope)
131    }
132}
133
134fn session_from_client_hello(
135    is_server: bool,
136    master_secret: [u8; 32],
137    client_hello: &[u8],
138) -> Result<Session> {
139    let (timestamp, nonce) = parse_client_hello(client_hello)?;
140    let session_id = if is_server { 1 } else { 2 };
141    Session::from_master_secret(session_id, &master_secret, timestamp, &nonce)
142}
143
144fn parse_client_hello(client_hello: &[u8]) -> Result<(u64, [u8; 16])> {
145    if client_hello.len() < 10 {
146        return Err(SrxError::Session(SessionError::HandshakeFailed(
147            "ClientHello too short".to_string(),
148        )));
149    }
150    if &client_hello[0..4] != b"SRXH" {
151        return Err(SrxError::Session(SessionError::HandshakeFailed(
152            "Invalid magic in ClientHello".to_string(),
153        )));
154    }
155    if client_hello[5] != 1 {
156        return Err(SrxError::Session(SessionError::HandshakeFailed(
157            "Invalid message type in ClientHello".to_string(),
158        )));
159    }
160    let payload_len = u32::from_be_bytes(client_hello[6..10].try_into().expect("len")) as usize;
161    if client_hello.len() != 10 + payload_len || payload_len != 24 {
162        return Err(SrxError::Session(SessionError::HandshakeFailed(
163            "Invalid ClientHello payload length".to_string(),
164        )));
165    }
166
167    let payload = &client_hello[10..];
168    let timestamp = u64::from_be_bytes(payload[0..8].try_into().expect("timestamp"));
169    let mut nonce = [0u8; 16];
170    nonce.copy_from_slice(&payload[8..24]);
171    Ok((timestamp, nonce))
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    #[tokio::test]
179    async fn secure_tcp_session_roundtrip() {
180        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
181        let addr = listener.local_addr().unwrap();
182        let server = tokio::spawn(async move {
183            let (stream, _) = listener.accept().await.unwrap();
184            let mut cfg = SrxConfig::default();
185            cfg.replay.persist_enabled = false;
186            let sess = SecureTcpSession::accept(stream, cfg).await.unwrap();
187            match sess.recv().await.unwrap() {
188                Payload::Data(d) => {
189                    assert_eq!(d, b"ping");
190                    sess.send_data(b"pong").await.unwrap();
191                }
192                Payload::Signal(_) => panic!("expected data"),
193            }
194        });
195
196        let mut cfg = SrxConfig::default();
197        cfg.replay.persist_enabled = false;
198        let sess = SecureTcpSession::connect(addr, cfg).await.unwrap();
199        sess.send_data(b"ping").await.unwrap();
200        match sess.recv().await.unwrap() {
201            Payload::Data(d) => assert_eq!(d, b"pong"),
202            Payload::Signal(_) => panic!("expected data"),
203        }
204
205        server.await.unwrap();
206    }
207}