Skip to main content

foctet_core/
secure_channel.rs

1use std::{
2    future::poll_fn,
3    io::{Read, Write},
4    pin::Pin,
5};
6
7use futures_core::Stream;
8use futures_sink::Sink;
9
10use crate::{
11    CoreError, FoctetFramed, Session,
12    io::SyncIo,
13    payload::{self, Tlv, tlv_type},
14};
15
16/// High-level blocking facade that combines `Session`, `SyncIo`, and TLV helpers.
17///
18/// This wrapper is intended for the common case where callers want:
19/// - automatic session-aware control/rekey handling
20/// - application-data TLV framing by default
21/// - a simple send/receive application-data API
22#[derive(Debug)]
23pub struct SecureChannel<T> {
24    io: SyncIo<T>,
25    session: Session,
26    app_stream_id: u32,
27    app_flags: u8,
28}
29
30/// High-level async facade that combines `Session`, `FoctetFramed`, and TLV helpers.
31///
32/// This wrapper is intended for async runtimes where callers want:
33/// - automatic session-aware control/rekey handling
34/// - application-data TLV framing by default
35/// - an async send/receive application-data API
36#[derive(Debug)]
37pub struct AsyncSecureChannel<T> {
38    framed: FoctetFramed<T>,
39    session: Session,
40    app_stream_id: u32,
41    app_flags: u8,
42}
43
44impl<T: Read + Write> SecureChannel<T> {
45    /// Constructs a secure channel from an active session.
46    ///
47    /// The session must already be in `Active` state with derived traffic keys.
48    pub fn from_active_session(io: T, session: Session) -> Result<Self, CoreError> {
49        let active_keys = session
50            .active_keys()
51            .ok_or(CoreError::InvalidSessionState)?;
52        let inbound = session.inbound_direction();
53        let outbound = session.outbound_direction();
54
55        Ok(Self {
56            io: SyncIo::new(io, active_keys, inbound, outbound),
57            session,
58            app_stream_id: 0,
59            app_flags: 0,
60        })
61    }
62
63    /// Sets the default stream ID for application-data frames.
64    pub fn with_app_stream_id(mut self, stream_id: u32) -> Self {
65        self.app_stream_id = stream_id;
66        self
67    }
68
69    /// Sets the default plaintext frame flags for application-data frames.
70    pub fn with_app_flags(mut self, flags: u8) -> Self {
71        self.app_flags = flags;
72        self
73    }
74
75    /// Returns immutable reference to the underlying `Session`.
76    pub fn session(&self) -> &Session {
77        &self.session
78    }
79
80    /// Returns mutable reference to the underlying `Session`.
81    pub fn session_mut(&mut self) -> &mut Session {
82        &mut self.session
83    }
84
85    /// Sends application data in an `APPLICATION_DATA` TLV with session-aware rekey handling.
86    pub fn send_data(&mut self, plaintext: &[u8]) -> Result<(), CoreError> {
87        self.io.send_data_with_session(
88            &mut self.session,
89            self.app_flags,
90            self.app_stream_id,
91            plaintext,
92        )
93    }
94
95    /// Sends explicit TLVs with session-aware rekey handling.
96    ///
97    /// This bypasses `APPLICATION_DATA` convenience framing.
98    pub fn send_tlvs(&mut self, tlvs: &[Tlv]) -> Result<(), CoreError> {
99        let payload = payload::encode_tlvs(tlvs)?;
100        self.io.send_data_with_session(
101            &mut self.session,
102            self.app_flags,
103            self.app_stream_id,
104            &payload,
105        )
106    }
107
108    /// Receives the next application-data payload.
109    ///
110    /// Control frames are handled automatically. The method loops internally until
111    /// it receives a non-control frame, then decodes TLVs and returns the first
112    /// `APPLICATION_DATA` value.
113    pub fn recv_application(&mut self) -> Result<Vec<u8>, CoreError> {
114        loop {
115            let Some(plaintext) = self.io.recv_application_with_session(&mut self.session)? else {
116                continue;
117            };
118
119            let tlvs = payload::decode_tlvs(&plaintext)?;
120            let app = tlvs
121                .iter()
122                .find(|t| t.typ == tlv_type::APPLICATION_DATA)
123                .ok_or(CoreError::InvalidTlv)?;
124            return Ok(app.value.clone());
125        }
126    }
127
128    /// Receives the next non-control frame and returns decoded TLVs.
129    pub fn recv_tlvs(&mut self) -> Result<Vec<Tlv>, CoreError> {
130        loop {
131            let Some(plaintext) = self.io.recv_application_with_session(&mut self.session)? else {
132                continue;
133            };
134            return payload::decode_tlvs(&plaintext);
135        }
136    }
137
138    /// Consumes the wrapper and returns `(io, session)`.
139    pub fn into_parts(self) -> (T, Session) {
140        (self.io.into_inner(), self.session)
141    }
142}
143
144impl<T> AsyncSecureChannel<T> {
145    /// Sets the default stream ID for application-data frames.
146    pub fn with_app_stream_id(mut self, stream_id: u32) -> Self {
147        self.app_stream_id = stream_id;
148        self
149    }
150
151    /// Sets the default plaintext frame flags for application-data frames.
152    pub fn with_app_flags(mut self, flags: u8) -> Self {
153        self.app_flags = flags;
154        self
155    }
156
157    /// Returns immutable reference to the underlying `Session`.
158    pub fn session(&self) -> &Session {
159        &self.session
160    }
161
162    /// Returns mutable reference to the underlying `Session`.
163    pub fn session_mut(&mut self) -> &mut Session {
164        &mut self.session
165    }
166
167    /// Returns immutable reference to inner framed transport.
168    pub fn framed_ref(&self) -> &FoctetFramed<T> {
169        &self.framed
170    }
171
172    /// Returns mutable reference to inner framed transport.
173    pub fn framed_mut(&mut self) -> &mut FoctetFramed<T> {
174        &mut self.framed
175    }
176
177    /// Consumes the wrapper and returns `(framed, session)`.
178    pub fn into_parts(self) -> (FoctetFramed<T>, Session) {
179        (self.framed, self.session)
180    }
181}
182
183impl<T: crate::io::PollIo + Unpin> AsyncSecureChannel<T> {
184    /// Constructs an async secure channel from an active session.
185    ///
186    /// The session must already be in `Active` state with derived traffic keys.
187    pub fn from_active_session(io: T, session: Session) -> Result<Self, CoreError> {
188        let active_keys = session
189            .active_keys()
190            .ok_or(CoreError::InvalidSessionState)?;
191        let inbound = session.inbound_direction();
192        let outbound = session.outbound_direction();
193        let framed = FoctetFramed::new(io, active_keys, inbound, outbound);
194
195        Ok(Self {
196            framed,
197            session,
198            app_stream_id: 0,
199            app_flags: 0,
200        })
201    }
202
203    /// Sends application data in an `APPLICATION_DATA` TLV with session-aware rekey handling.
204    pub async fn send_data(&mut self, plaintext: &[u8]) -> Result<(), CoreError> {
205        poll_fn(|cx| {
206            let mut framed = Pin::new(&mut self.framed);
207            match framed.as_mut().poll_ready(cx) {
208                std::task::Poll::Pending => return std::task::Poll::Pending,
209                std::task::Poll::Ready(Err(e)) => return std::task::Poll::Ready(Err(e)),
210                std::task::Poll::Ready(Ok(())) => {}
211            }
212
213            framed.as_mut().start_send_data_with_session(
214                &mut self.session,
215                self.app_flags,
216                self.app_stream_id,
217                plaintext,
218            )?;
219
220            framed.poll_flush(cx)
221        })
222        .await
223    }
224
225    /// Sends explicit TLVs with session-aware rekey handling.
226    ///
227    /// This bypasses `APPLICATION_DATA` convenience framing.
228    pub async fn send_tlvs(&mut self, tlvs: &[Tlv]) -> Result<(), CoreError> {
229        let payload = payload::encode_tlvs(tlvs)?;
230        self.send_data(&payload).await
231    }
232
233    /// Receives the next application-data payload.
234    ///
235    /// Control frames are handled automatically. The method loops internally until
236    /// it receives a non-control frame, then decodes TLVs and returns the first
237    /// `APPLICATION_DATA` value.
238    pub async fn recv_application(&mut self) -> Result<Vec<u8>, CoreError> {
239        loop {
240            let item = poll_fn(|cx| Pin::new(&mut self.framed).poll_next(cx)).await;
241            let decoded = match item {
242                Some(Ok(frame)) => frame,
243                Some(Err(e)) => return Err(e),
244                None => return Err(CoreError::UnexpectedEof),
245            };
246
247            if let Some(plaintext) = Pin::new(&mut self.framed)
248                .handle_incoming_with_session(&mut self.session, decoded)?
249            {
250                let tlvs = payload::decode_tlvs(&plaintext)?;
251                let app = tlvs
252                    .iter()
253                    .find(|t| t.typ == tlv_type::APPLICATION_DATA)
254                    .ok_or(CoreError::InvalidTlv)?;
255                return Ok(app.value.clone());
256            }
257        }
258    }
259
260    /// Receives the next non-control frame and returns decoded TLVs.
261    pub async fn recv_tlvs(&mut self) -> Result<Vec<Tlv>, CoreError> {
262        loop {
263            let item = poll_fn(|cx| Pin::new(&mut self.framed).poll_next(cx)).await;
264            let decoded = match item {
265                Some(Ok(frame)) => frame,
266                Some(Err(e)) => return Err(e),
267                None => return Err(CoreError::UnexpectedEof),
268            };
269
270            if let Some(plaintext) = Pin::new(&mut self.framed)
271                .handle_incoming_with_session(&mut self.session, decoded)?
272            {
273                return payload::decode_tlvs(&plaintext);
274            }
275        }
276    }
277}
278
279#[cfg(feature = "runtime-tokio")]
280impl<T> AsyncSecureChannel<crate::io::TokioIo<T>>
281where
282    T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
283{
284    /// Constructs async secure channel from a Tokio I/O object and an active session.
285    pub fn from_tokio(io: T, session: Session) -> Result<Self, CoreError> {
286        Self::from_active_session(crate::io::TokioIo::new(io), session)
287    }
288}
289
290#[cfg(feature = "runtime-futures")]
291impl<T> AsyncSecureChannel<crate::io::FuturesIo<T>>
292where
293    T: futures_io::AsyncRead + futures_io::AsyncWrite + Unpin,
294{
295    /// Constructs async secure channel from a futures-io object and an active session.
296    pub fn from_futures(io: T, session: Session) -> Result<Self, CoreError> {
297        Self::from_active_session(crate::io::FuturesIo::new(io), session)
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use std::{
304        collections::VecDeque,
305        io::{Read, Write},
306        sync::{Arc, Mutex},
307        time::Duration,
308    };
309
310    use crate::{ControlMessage, RekeyThresholds, Session};
311
312    use super::SecureChannel;
313
314    #[derive(Clone, Debug)]
315    struct MemPipe {
316        rx: Arc<Mutex<VecDeque<u8>>>,
317        tx: Arc<Mutex<VecDeque<u8>>>,
318    }
319
320    impl MemPipe {
321        fn pair() -> (Self, Self) {
322            let a_rx = Arc::new(Mutex::new(VecDeque::new()));
323            let b_rx = Arc::new(Mutex::new(VecDeque::new()));
324            (
325                Self {
326                    rx: Arc::clone(&a_rx),
327                    tx: Arc::clone(&b_rx),
328                },
329                Self { rx: b_rx, tx: a_rx },
330            )
331        }
332    }
333
334    impl Read for MemPipe {
335        fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
336            let mut rx = self.rx.lock().expect("lock rx");
337            let n = buf.len().min(rx.len());
338            for slot in buf.iter_mut().take(n) {
339                *slot = rx.pop_front().expect("rx byte");
340            }
341            Ok(n)
342        }
343    }
344
345    impl Write for MemPipe {
346        fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
347            let mut tx = self.tx.lock().expect("lock tx");
348            tx.extend(buf.iter().copied());
349            Ok(buf.len())
350        }
351
352        fn flush(&mut self) -> std::io::Result<()> {
353            Ok(())
354        }
355    }
356
357    fn make_session_pair() -> (Session, Session) {
358        let thresholds = RekeyThresholds {
359            max_frames: 1,
360            max_bytes: 1 << 30,
361            max_age: Duration::from_secs(3600),
362            max_previous_keys: 2,
363        };
364
365        let (mut initiator, hello) = Session::new_initiator(thresholds.clone());
366        let mut responder = Session::new_responder(thresholds);
367        let server_hello = responder
368            .handle_control(&hello)
369            .expect("responder handle client hello")
370            .expect("server hello");
371        let none = initiator
372            .handle_control(&server_hello)
373            .expect("initiator handle server hello");
374        assert!(none.is_none());
375        (initiator, responder)
376    }
377
378    #[test]
379    fn secure_channel_roundtrip_and_rekey() {
380        let (a_io, b_io) = MemPipe::pair();
381        let (a_session, b_session) = make_session_pair();
382
383        let mut client = SecureChannel::from_active_session(a_io, a_session)
384            .expect("client channel")
385            .with_app_stream_id(7);
386        let mut server = SecureChannel::from_active_session(b_io, b_session)
387            .expect("server channel")
388            .with_app_stream_id(7);
389
390        client.send_data(b"hello-1").expect("send 1");
391        let m1 = server.recv_application().expect("recv 1");
392        assert_eq!(m1, b"hello-1");
393
394        // max_frames=1 triggers rekey after first app payload.
395        client.send_data(b"hello-2").expect("send 2");
396        let m2 = server.recv_application().expect("recv 2");
397        assert_eq!(m2, b"hello-2");
398    }
399
400    #[test]
401    fn secure_channel_rejects_non_active_session() {
402        let (io, _peer) = MemPipe::pair();
403        let thresholds = RekeyThresholds::default();
404        let responder = Session::new_responder(thresholds);
405        let err = SecureChannel::from_active_session(io, responder)
406            .expect_err("must reject non-active session");
407        assert!(matches!(err, crate::CoreError::InvalidSessionState));
408    }
409
410    #[test]
411    fn handshake_exchange_is_control_messages() {
412        let thresholds = RekeyThresholds::default();
413        let (_initiator, hello) = Session::new_initiator(thresholds.clone());
414        let mut responder = Session::new_responder(thresholds);
415        let response = responder
416            .handle_control(&hello)
417            .expect("valid client hello")
418            .expect("server hello");
419        assert!(matches!(hello, ControlMessage::ClientHello { .. }));
420        assert!(matches!(response, ControlMessage::ServerHello { .. }));
421    }
422}