1use std::{
2 future::poll_fn,
3 io::{Read, Write},
4 pin::Pin,
5};
6
7use futures_core::Stream;
8use futures_sink::Sink;
9
10use crate::{
11 CoreError, FoctetFramed, Session,
12 io::SyncIo,
13 payload::{self, Tlv, tlv_type},
14};
15
16#[derive(Debug)]
23pub struct SecureChannel<T> {
24 io: SyncIo<T>,
25 session: Session,
26 app_stream_id: u32,
27 app_flags: u8,
28}
29
30#[derive(Debug)]
37pub struct AsyncSecureChannel<T> {
38 framed: FoctetFramed<T>,
39 session: Session,
40 app_stream_id: u32,
41 app_flags: u8,
42}
43
44impl<T: Read + Write> SecureChannel<T> {
45 pub fn from_active_session(io: T, session: Session) -> Result<Self, CoreError> {
49 let active_keys = session
50 .active_keys()
51 .ok_or(CoreError::InvalidSessionState)?;
52 let inbound = session.inbound_direction();
53 let outbound = session.outbound_direction();
54
55 Ok(Self {
56 io: SyncIo::new(io, active_keys, inbound, outbound),
57 session,
58 app_stream_id: 0,
59 app_flags: 0,
60 })
61 }
62
63 pub fn with_app_stream_id(mut self, stream_id: u32) -> Self {
65 self.app_stream_id = stream_id;
66 self
67 }
68
69 pub fn with_app_flags(mut self, flags: u8) -> Self {
71 self.app_flags = flags;
72 self
73 }
74
75 pub fn session(&self) -> &Session {
77 &self.session
78 }
79
80 pub fn session_mut(&mut self) -> &mut Session {
82 &mut self.session
83 }
84
85 pub fn send_data(&mut self, plaintext: &[u8]) -> Result<(), CoreError> {
87 self.io.send_data_with_session(
88 &mut self.session,
89 self.app_flags,
90 self.app_stream_id,
91 plaintext,
92 )
93 }
94
95 pub fn send_tlvs(&mut self, tlvs: &[Tlv]) -> Result<(), CoreError> {
99 let payload = payload::encode_tlvs(tlvs)?;
100 self.io.send_data_with_session(
101 &mut self.session,
102 self.app_flags,
103 self.app_stream_id,
104 &payload,
105 )
106 }
107
108 pub fn recv_application(&mut self) -> Result<Vec<u8>, CoreError> {
114 loop {
115 let Some(plaintext) = self.io.recv_application_with_session(&mut self.session)? else {
116 continue;
117 };
118
119 let tlvs = payload::decode_tlvs(&plaintext)?;
120 let app = tlvs
121 .iter()
122 .find(|t| t.typ == tlv_type::APPLICATION_DATA)
123 .ok_or(CoreError::InvalidTlv)?;
124 return Ok(app.value.clone());
125 }
126 }
127
128 pub fn recv_tlvs(&mut self) -> Result<Vec<Tlv>, CoreError> {
130 loop {
131 let Some(plaintext) = self.io.recv_application_with_session(&mut self.session)? else {
132 continue;
133 };
134 return payload::decode_tlvs(&plaintext);
135 }
136 }
137
138 pub fn into_parts(self) -> (T, Session) {
140 (self.io.into_inner(), self.session)
141 }
142}
143
144impl<T> AsyncSecureChannel<T> {
145 pub fn with_app_stream_id(mut self, stream_id: u32) -> Self {
147 self.app_stream_id = stream_id;
148 self
149 }
150
151 pub fn with_app_flags(mut self, flags: u8) -> Self {
153 self.app_flags = flags;
154 self
155 }
156
157 pub fn session(&self) -> &Session {
159 &self.session
160 }
161
162 pub fn session_mut(&mut self) -> &mut Session {
164 &mut self.session
165 }
166
167 pub fn framed_ref(&self) -> &FoctetFramed<T> {
169 &self.framed
170 }
171
172 pub fn framed_mut(&mut self) -> &mut FoctetFramed<T> {
174 &mut self.framed
175 }
176
177 pub fn into_parts(self) -> (FoctetFramed<T>, Session) {
179 (self.framed, self.session)
180 }
181}
182
183impl<T: crate::io::PollIo + Unpin> AsyncSecureChannel<T> {
184 pub fn from_active_session(io: T, session: Session) -> Result<Self, CoreError> {
188 let active_keys = session
189 .active_keys()
190 .ok_or(CoreError::InvalidSessionState)?;
191 let inbound = session.inbound_direction();
192 let outbound = session.outbound_direction();
193 let framed = FoctetFramed::new(io, active_keys, inbound, outbound);
194
195 Ok(Self {
196 framed,
197 session,
198 app_stream_id: 0,
199 app_flags: 0,
200 })
201 }
202
203 pub async fn send_data(&mut self, plaintext: &[u8]) -> Result<(), CoreError> {
205 poll_fn(|cx| {
206 let mut framed = Pin::new(&mut self.framed);
207 match framed.as_mut().poll_ready(cx) {
208 std::task::Poll::Pending => return std::task::Poll::Pending,
209 std::task::Poll::Ready(Err(e)) => return std::task::Poll::Ready(Err(e)),
210 std::task::Poll::Ready(Ok(())) => {}
211 }
212
213 framed.as_mut().start_send_data_with_session(
214 &mut self.session,
215 self.app_flags,
216 self.app_stream_id,
217 plaintext,
218 )?;
219
220 framed.poll_flush(cx)
221 })
222 .await
223 }
224
225 pub async fn send_tlvs(&mut self, tlvs: &[Tlv]) -> Result<(), CoreError> {
229 let payload = payload::encode_tlvs(tlvs)?;
230 self.send_data(&payload).await
231 }
232
233 pub async fn recv_application(&mut self) -> Result<Vec<u8>, CoreError> {
239 loop {
240 let item = poll_fn(|cx| Pin::new(&mut self.framed).poll_next(cx)).await;
241 let decoded = match item {
242 Some(Ok(frame)) => frame,
243 Some(Err(e)) => return Err(e),
244 None => return Err(CoreError::UnexpectedEof),
245 };
246
247 if let Some(plaintext) = Pin::new(&mut self.framed)
248 .handle_incoming_with_session(&mut self.session, decoded)?
249 {
250 let tlvs = payload::decode_tlvs(&plaintext)?;
251 let app = tlvs
252 .iter()
253 .find(|t| t.typ == tlv_type::APPLICATION_DATA)
254 .ok_or(CoreError::InvalidTlv)?;
255 return Ok(app.value.clone());
256 }
257 }
258 }
259
260 pub async fn recv_tlvs(&mut self) -> Result<Vec<Tlv>, CoreError> {
262 loop {
263 let item = poll_fn(|cx| Pin::new(&mut self.framed).poll_next(cx)).await;
264 let decoded = match item {
265 Some(Ok(frame)) => frame,
266 Some(Err(e)) => return Err(e),
267 None => return Err(CoreError::UnexpectedEof),
268 };
269
270 if let Some(plaintext) = Pin::new(&mut self.framed)
271 .handle_incoming_with_session(&mut self.session, decoded)?
272 {
273 return payload::decode_tlvs(&plaintext);
274 }
275 }
276 }
277}
278
279#[cfg(feature = "runtime-tokio")]
280impl<T> AsyncSecureChannel<crate::io::TokioIo<T>>
281where
282 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
283{
284 pub fn from_tokio(io: T, session: Session) -> Result<Self, CoreError> {
286 Self::from_active_session(crate::io::TokioIo::new(io), session)
287 }
288}
289
290#[cfg(feature = "runtime-futures")]
291impl<T> AsyncSecureChannel<crate::io::FuturesIo<T>>
292where
293 T: futures_io::AsyncRead + futures_io::AsyncWrite + Unpin,
294{
295 pub fn from_futures(io: T, session: Session) -> Result<Self, CoreError> {
297 Self::from_active_session(crate::io::FuturesIo::new(io), session)
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use std::{
304 collections::VecDeque,
305 io::{Read, Write},
306 sync::{Arc, Mutex},
307 time::Duration,
308 };
309
310 use crate::{ControlMessage, RekeyThresholds, Session};
311
312 use super::SecureChannel;
313
314 #[derive(Clone, Debug)]
315 struct MemPipe {
316 rx: Arc<Mutex<VecDeque<u8>>>,
317 tx: Arc<Mutex<VecDeque<u8>>>,
318 }
319
320 impl MemPipe {
321 fn pair() -> (Self, Self) {
322 let a_rx = Arc::new(Mutex::new(VecDeque::new()));
323 let b_rx = Arc::new(Mutex::new(VecDeque::new()));
324 (
325 Self {
326 rx: Arc::clone(&a_rx),
327 tx: Arc::clone(&b_rx),
328 },
329 Self { rx: b_rx, tx: a_rx },
330 )
331 }
332 }
333
334 impl Read for MemPipe {
335 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
336 let mut rx = self.rx.lock().expect("lock rx");
337 let n = buf.len().min(rx.len());
338 for slot in buf.iter_mut().take(n) {
339 *slot = rx.pop_front().expect("rx byte");
340 }
341 Ok(n)
342 }
343 }
344
345 impl Write for MemPipe {
346 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
347 let mut tx = self.tx.lock().expect("lock tx");
348 tx.extend(buf.iter().copied());
349 Ok(buf.len())
350 }
351
352 fn flush(&mut self) -> std::io::Result<()> {
353 Ok(())
354 }
355 }
356
357 fn make_session_pair() -> (Session, Session) {
358 let thresholds = RekeyThresholds {
359 max_frames: 1,
360 max_bytes: 1 << 30,
361 max_age: Duration::from_secs(3600),
362 max_previous_keys: 2,
363 };
364
365 let (mut initiator, hello) = Session::new_initiator(thresholds.clone());
366 let mut responder = Session::new_responder(thresholds);
367 let server_hello = responder
368 .handle_control(&hello)
369 .expect("responder handle client hello")
370 .expect("server hello");
371 let none = initiator
372 .handle_control(&server_hello)
373 .expect("initiator handle server hello");
374 assert!(none.is_none());
375 (initiator, responder)
376 }
377
378 #[test]
379 fn secure_channel_roundtrip_and_rekey() {
380 let (a_io, b_io) = MemPipe::pair();
381 let (a_session, b_session) = make_session_pair();
382
383 let mut client = SecureChannel::from_active_session(a_io, a_session)
384 .expect("client channel")
385 .with_app_stream_id(7);
386 let mut server = SecureChannel::from_active_session(b_io, b_session)
387 .expect("server channel")
388 .with_app_stream_id(7);
389
390 client.send_data(b"hello-1").expect("send 1");
391 let m1 = server.recv_application().expect("recv 1");
392 assert_eq!(m1, b"hello-1");
393
394 client.send_data(b"hello-2").expect("send 2");
396 let m2 = server.recv_application().expect("recv 2");
397 assert_eq!(m2, b"hello-2");
398 }
399
400 #[test]
401 fn secure_channel_rejects_non_active_session() {
402 let (io, _peer) = MemPipe::pair();
403 let thresholds = RekeyThresholds::default();
404 let responder = Session::new_responder(thresholds);
405 let err = SecureChannel::from_active_session(io, responder)
406 .expect_err("must reject non-active session");
407 assert!(matches!(err, crate::CoreError::InvalidSessionState));
408 }
409
410 #[test]
411 fn handshake_exchange_is_control_messages() {
412 let thresholds = RekeyThresholds::default();
413 let (_initiator, hello) = Session::new_initiator(thresholds.clone());
414 let mut responder = Session::new_responder(thresholds);
415 let response = responder
416 .handle_control(&hello)
417 .expect("valid client hello")
418 .expect("server hello");
419 assert!(matches!(hello, ControlMessage::ClientHello { .. }));
420 assert!(matches!(response, ControlMessage::ServerHello { .. }));
421 }
422}