1use std::net::SocketAddr;
7use std::sync::Arc;
8
9use tokio::sync::Mutex;
10use tracing::{debug, info};
11
12use crate::config::SrxConfig;
13use crate::error::{Result, SessionError, SrxError};
14use crate::pipeline::Payload;
15use crate::session::{Handshake, Session};
16use crate::signaling::inband::Signal;
17use crate::transport::TcpTransport;
18use crate::{SrxNode, TransportManager};
19
20pub struct SecureTcpSession {
22 transport: Arc<TcpTransport>,
23 node: Mutex<SrxNode>,
24}
25
26impl SecureTcpSession {
27 pub async fn connect(addr: SocketAddr, config: SrxConfig) -> Result<Self> {
29 info!("SRX: connecting to {addr}");
30 let stream = tokio::net::TcpStream::connect(addr).await?;
31 debug!("SRX: TCP connection established to {addr}");
32 let transport = Arc::new(TcpTransport::from_stream(stream));
33 Self::connect_over_transport(transport, config).await
34 }
35
36 pub async fn accept(stream: tokio::net::TcpStream, config: SrxConfig) -> Result<Self> {
38 debug!("SRX: accepting incoming session");
39 let transport = Arc::new(TcpTransport::from_stream(stream));
40 Self::accept_over_transport(transport, config).await
41 }
42
43 async fn connect_over_transport(
44 transport: Arc<TcpTransport>,
45 config: SrxConfig,
46 ) -> Result<Self> {
47 debug!("SRX handshake: initiating as client");
48 let mut hs = Handshake::new_initiator();
49 let client_hello = hs.client_hello()?;
50 debug!("SRX handshake: sending ClientHello ({} bytes)", client_hello.len());
51 transport.send_framed(&client_hello).await?;
52 debug!("SRX handshake: waiting for ServerHello");
53 let server_hello = transport.recv_framed().await?;
54 debug!("SRX handshake: received ServerHello ({} bytes)", server_hello.len());
55 let client_finished = hs.finalize(&server_hello)?;
56 debug!("SRX handshake: sending ClientFinished ({} bytes)", client_finished.len());
57 transport.send_framed(&client_finished).await?;
58 let master = hs.master_secret().ok_or_else(|| {
59 SrxError::Session(SessionError::HandshakeFailed(
60 "master secret not established".to_string(),
61 ))
62 })?;
63
64 let session = session_from_client_hello(false, master, &client_hello)?;
65 let aead = Arc::new(config.build_aead_pipeline(&session.data_key)?);
66 let node = SrxNode::from_session(config, session, aead, TransportManager::new())?;
67 info!("SRX handshake: session established (initiator)");
68 Ok(Self {
69 transport,
70 node: Mutex::new(node),
71 })
72 }
73
74 async fn accept_over_transport(
75 transport: Arc<TcpTransport>,
76 config: SrxConfig,
77 ) -> Result<Self> {
78 debug!("SRX handshake: initiating as server (responder)");
79 let mut hs = Handshake::new_responder();
80 debug!("SRX handshake: waiting for ClientHello");
81 let client_hello = transport.recv_framed().await?;
82 debug!("SRX handshake: received ClientHello ({} bytes)", client_hello.len());
83 let server_hello = hs.server_hello(&client_hello)?;
84 debug!("SRX handshake: sending ServerHello ({} bytes)", server_hello.len());
85 transport.send_framed(&server_hello).await?;
86 debug!("SRX handshake: waiting for ClientFinished");
87 let client_finished = transport.recv_framed().await?;
88 debug!("SRX handshake: received ClientFinished ({} bytes)", client_finished.len());
89 hs.server_finish(&client_finished)?;
90 let master = hs.master_secret().ok_or_else(|| {
91 SrxError::Session(SessionError::HandshakeFailed(
92 "master secret not established".to_string(),
93 ))
94 })?;
95
96 let session = session_from_client_hello(true, master, &client_hello)?;
97 let aead = Arc::new(config.build_aead_pipeline(&session.data_key)?);
98 let node = SrxNode::from_session(config, session, aead, TransportManager::new())?;
99 info!("SRX handshake: session established (responder)");
100 Ok(Self {
101 transport,
102 node: Mutex::new(node),
103 })
104 }
105
106 pub async fn send_data(&self, payload: &[u8]) -> Result<()> {
108 let envelope = {
109 let mut node = self.node.lock().await;
110 node.prepare_outgoing(payload)?
111 };
112 self.transport.send_framed(&envelope).await?;
113 Ok(())
114 }
115
116 pub async fn send_signal(&self, signal: &Signal) -> Result<()> {
118 let envelope = {
119 let mut node = self.node.lock().await;
120 node.send_signal(signal)?
121 };
122 self.transport.send_framed(&envelope).await?;
123 Ok(())
124 }
125
126 pub async fn recv(&self) -> Result<Payload> {
128 let envelope = self.transport.recv_framed().await?;
129 let node = self.node.lock().await;
130 node.process_incoming_dispatched(&envelope)
131 }
132}
133
134fn session_from_client_hello(
135 is_server: bool,
136 master_secret: [u8; 32],
137 client_hello: &[u8],
138) -> Result<Session> {
139 let (timestamp, nonce) = parse_client_hello(client_hello)?;
140 let session_id = if is_server { 1 } else { 2 };
141 Session::from_master_secret(session_id, &master_secret, timestamp, &nonce)
142}
143
144fn parse_client_hello(client_hello: &[u8]) -> Result<(u64, [u8; 16])> {
145 if client_hello.len() < 10 {
146 return Err(SrxError::Session(SessionError::HandshakeFailed(
147 "ClientHello too short".to_string(),
148 )));
149 }
150 if &client_hello[0..4] != b"SRXH" {
151 return Err(SrxError::Session(SessionError::HandshakeFailed(
152 "Invalid magic in ClientHello".to_string(),
153 )));
154 }
155 if client_hello[5] != 1 {
156 return Err(SrxError::Session(SessionError::HandshakeFailed(
157 "Invalid message type in ClientHello".to_string(),
158 )));
159 }
160 let payload_len = u32::from_be_bytes(client_hello[6..10].try_into().expect("len")) as usize;
161 if client_hello.len() != 10 + payload_len || payload_len != 24 {
162 return Err(SrxError::Session(SessionError::HandshakeFailed(
163 "Invalid ClientHello payload length".to_string(),
164 )));
165 }
166
167 let payload = &client_hello[10..];
168 let timestamp = u64::from_be_bytes(payload[0..8].try_into().expect("timestamp"));
169 let mut nonce = [0u8; 16];
170 nonce.copy_from_slice(&payload[8..24]);
171 Ok((timestamp, nonce))
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177
178 #[tokio::test]
179 async fn secure_tcp_session_roundtrip() {
180 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
181 let addr = listener.local_addr().unwrap();
182 let server = tokio::spawn(async move {
183 let (stream, _) = listener.accept().await.unwrap();
184 let mut cfg = SrxConfig::default();
185 cfg.replay.persist_enabled = false;
186 let sess = SecureTcpSession::accept(stream, cfg).await.unwrap();
187 match sess.recv().await.unwrap() {
188 Payload::Data(d) => {
189 assert_eq!(d, b"ping");
190 sess.send_data(b"pong").await.unwrap();
191 }
192 Payload::Signal(_) => panic!("expected data"),
193 }
194 });
195
196 let mut cfg = SrxConfig::default();
197 cfg.replay.persist_enabled = false;
198 let sess = SecureTcpSession::connect(addr, cfg).await.unwrap();
199 sess.send_data(b"ping").await.unwrap();
200 match sess.recv().await.unwrap() {
201 Payload::Data(d) => assert_eq!(d, b"pong"),
202 Payload::Signal(_) => panic!("expected data"),
203 }
204
205 server.await.unwrap();
206 }
207}