Skip to main content

foctet_core/
io.rs

1use std::{
2    io::{Read, Write},
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use crate::{
8    CoreError,
9    control::ControlMessage,
10    crypto::{Direction, TrafficKeys, decrypt_frame_with_key, encrypt_frame},
11    frame::{FRAME_HEADER_LEN, Frame, FrameHeader},
12    payload::{self, Tlv},
13    replay::{DEFAULT_REPLAY_WINDOW, ReplayProtector},
14    session::Session,
15};
16
17#[cfg(any(feature = "runtime-tokio", feature = "runtime-futures"))]
18use crate::frame::{FoctetFramed, FoctetStream};
19
20/// Minimal poll-based read trait used by Foctet runtime adapters.
21pub trait PollRead {
22    /// Attempts to read bytes into `buf`.
23    fn poll_read(
24        self: Pin<&mut Self>,
25        cx: &mut Context<'_>,
26        buf: &mut [u8],
27    ) -> Poll<std::io::Result<usize>>;
28}
29
30/// Minimal poll-based write trait used by Foctet runtime adapters.
31pub trait PollWrite {
32    /// Attempts to write bytes from `buf`.
33    fn poll_write(
34        self: Pin<&mut Self>,
35        cx: &mut Context<'_>,
36        buf: &[u8],
37    ) -> Poll<std::io::Result<usize>>;
38    /// Flushes pending writes.
39    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>>;
40    /// Closes the writer side.
41    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>>;
42}
43
44/// Combined poll-based I/O trait.
45pub trait PollIo: PollRead + PollWrite {}
46
47impl<T: PollRead + PollWrite> PollIo for T {}
48
49/// Tokio adapter implementing [`PollRead`] and [`PollWrite`].
50#[cfg(feature = "runtime-tokio")]
51#[derive(Debug, Clone)]
52pub struct TokioIo<T> {
53    inner: T,
54}
55
56#[cfg(feature = "runtime-tokio")]
57impl<T> TokioIo<T> {
58    /// Wraps a Tokio I/O object.
59    pub fn new(inner: T) -> Self {
60        Self { inner }
61    }
62
63    /// Unwraps and returns the inner Tokio I/O object.
64    pub fn into_inner(self) -> T {
65        self.inner
66    }
67}
68
69#[cfg(feature = "runtime-tokio")]
70impl<T> PollRead for TokioIo<T>
71where
72    T: tokio::io::AsyncRead + Unpin,
73{
74    fn poll_read(
75        mut self: Pin<&mut Self>,
76        cx: &mut Context<'_>,
77        buf: &mut [u8],
78    ) -> Poll<std::io::Result<usize>> {
79        let mut read_buf = tokio::io::ReadBuf::new(buf);
80        match Pin::new(&mut self.inner).poll_read(cx, &mut read_buf) {
81            Poll::Pending => Poll::Pending,
82            Poll::Ready(Ok(())) => Poll::Ready(Ok(read_buf.filled().len())),
83            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
84        }
85    }
86}
87
88#[cfg(feature = "runtime-tokio")]
89impl<T> PollWrite for TokioIo<T>
90where
91    T: tokio::io::AsyncWrite + Unpin,
92{
93    fn poll_write(
94        mut self: Pin<&mut Self>,
95        cx: &mut Context<'_>,
96        buf: &[u8],
97    ) -> Poll<std::io::Result<usize>> {
98        Pin::new(&mut self.inner).poll_write(cx, buf)
99    }
100
101    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
102        Pin::new(&mut self.inner).poll_flush(cx)
103    }
104
105    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
106        Pin::new(&mut self.inner).poll_shutdown(cx)
107    }
108}
109
110/// Futures-io adapter implementing [`PollRead`] and [`PollWrite`].
111#[cfg(feature = "runtime-futures")]
112#[derive(Debug, Clone)]
113pub struct FuturesIo<T> {
114    inner: T,
115}
116
117#[cfg(feature = "runtime-futures")]
118impl<T> FuturesIo<T> {
119    /// Wraps a futures-io object.
120    pub fn new(inner: T) -> Self {
121        Self { inner }
122    }
123
124    /// Unwraps and returns the inner futures-io object.
125    pub fn into_inner(self) -> T {
126        self.inner
127    }
128}
129
130#[cfg(feature = "runtime-futures")]
131impl<T> PollRead for FuturesIo<T>
132where
133    T: futures_io::AsyncRead + Unpin,
134{
135    fn poll_read(
136        mut self: Pin<&mut Self>,
137        cx: &mut Context<'_>,
138        buf: &mut [u8],
139    ) -> Poll<std::io::Result<usize>> {
140        Pin::new(&mut self.inner).poll_read(cx, buf)
141    }
142}
143
144#[cfg(feature = "runtime-futures")]
145impl<T> PollWrite for FuturesIo<T>
146where
147    T: futures_io::AsyncWrite + Unpin,
148{
149    fn poll_write(
150        mut self: Pin<&mut Self>,
151        cx: &mut Context<'_>,
152        buf: &[u8],
153    ) -> Poll<std::io::Result<usize>> {
154        Pin::new(&mut self.inner).poll_write(cx, buf)
155    }
156
157    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
158        Pin::new(&mut self.inner).poll_flush(cx)
159    }
160
161    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
162        Pin::new(&mut self.inner).poll_close(cx)
163    }
164}
165
166#[cfg(feature = "runtime-tokio")]
167impl<T> FoctetFramed<TokioIo<T>>
168where
169    T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
170{
171    /// Constructs [`FoctetFramed`] from Tokio async I/O.
172    pub fn from_tokio(
173        io: T,
174        keys: TrafficKeys,
175        inbound_direction: Direction,
176        outbound_direction: Direction,
177    ) -> Self {
178        Self::new(
179            TokioIo::new(io),
180            keys,
181            inbound_direction,
182            outbound_direction,
183        )
184    }
185}
186
187#[cfg(feature = "runtime-futures")]
188impl<T> FoctetFramed<FuturesIo<T>>
189where
190    T: futures_io::AsyncRead + futures_io::AsyncWrite + Unpin,
191{
192    /// Constructs [`FoctetFramed`] from futures-io async I/O.
193    pub fn from_futures(
194        io: T,
195        keys: TrafficKeys,
196        inbound_direction: Direction,
197        outbound_direction: Direction,
198    ) -> Self {
199        Self::new(
200            FuturesIo::new(io),
201            keys,
202            inbound_direction,
203            outbound_direction,
204        )
205    }
206}
207
208#[cfg(feature = "runtime-tokio")]
209impl<T> FoctetStream<TokioIo<T>>
210where
211    T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
212{
213    /// Constructs [`FoctetStream`] from Tokio async I/O.
214    pub fn from_tokio(
215        io: T,
216        keys: TrafficKeys,
217        inbound_direction: Direction,
218        outbound_direction: Direction,
219    ) -> Self {
220        let framed = FoctetFramed::from_tokio(io, keys, inbound_direction, outbound_direction);
221        Self::new(framed)
222    }
223}
224
225#[cfg(feature = "runtime-futures")]
226impl<T> FoctetStream<FuturesIo<T>>
227where
228    T: futures_io::AsyncRead + futures_io::AsyncWrite + Unpin,
229{
230    /// Constructs [`FoctetStream`] from futures-io async I/O.
231    pub fn from_futures(
232        io: T,
233        keys: TrafficKeys,
234        inbound_direction: Direction,
235        outbound_direction: Direction,
236    ) -> Self {
237        let framed = FoctetFramed::from_futures(io, keys, inbound_direction, outbound_direction);
238        Self::new(framed)
239    }
240}
241
242#[cfg(feature = "runtime-tokio")]
243impl<T> tokio::io::AsyncRead for FoctetStream<T>
244where
245    T: PollRead + PollWrite + Unpin,
246{
247    fn poll_read(
248        mut self: Pin<&mut Self>,
249        cx: &mut Context<'_>,
250        buf: &mut tokio::io::ReadBuf<'_>,
251    ) -> Poll<std::io::Result<()>> {
252        if buf.remaining() == 0 {
253            return Poll::Ready(Ok(()));
254        }
255        let dst = buf.initialize_unfilled();
256        match Pin::new(&mut *self).poll_read_plain(cx, dst) {
257            Poll::Pending => Poll::Pending,
258            Poll::Ready(Ok(n)) => {
259                buf.advance(n);
260                Poll::Ready(Ok(()))
261            }
262            Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e))),
263        }
264    }
265}
266
267#[cfg(feature = "runtime-tokio")]
268impl<T> tokio::io::AsyncWrite for FoctetStream<T>
269where
270    T: PollRead + PollWrite + Unpin,
271{
272    fn poll_write(
273        mut self: Pin<&mut Self>,
274        cx: &mut Context<'_>,
275        buf: &[u8],
276    ) -> Poll<std::io::Result<usize>> {
277        match Pin::new(&mut *self).poll_write_plain(cx, buf) {
278            Poll::Pending => Poll::Pending,
279            Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
280            Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e))),
281        }
282    }
283
284    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
285        match Pin::new(&mut *self).poll_flush_plain(cx) {
286            Poll::Pending => Poll::Pending,
287            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
288            Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e))),
289        }
290    }
291
292    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
293        match Pin::new(&mut *self).poll_close_plain(cx) {
294            Poll::Pending => Poll::Pending,
295            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
296            Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e))),
297        }
298    }
299}
300
301#[cfg(feature = "runtime-futures")]
302impl<T> futures_io::AsyncRead for FoctetStream<T>
303where
304    T: PollRead + PollWrite + Unpin,
305{
306    fn poll_read(
307        mut self: Pin<&mut Self>,
308        cx: &mut Context<'_>,
309        buf: &mut [u8],
310    ) -> Poll<std::io::Result<usize>> {
311        match Pin::new(&mut *self).poll_read_plain(cx, buf) {
312            Poll::Pending => Poll::Pending,
313            Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
314            Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e))),
315        }
316    }
317}
318
319#[cfg(feature = "runtime-futures")]
320impl<T> futures_io::AsyncWrite for FoctetStream<T>
321where
322    T: PollRead + PollWrite + Unpin,
323{
324    fn poll_write(
325        mut self: Pin<&mut Self>,
326        cx: &mut Context<'_>,
327        buf: &[u8],
328    ) -> Poll<std::io::Result<usize>> {
329        match Pin::new(&mut *self).poll_write_plain(cx, buf) {
330            Poll::Pending => Poll::Pending,
331            Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
332            Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e))),
333        }
334    }
335
336    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
337        match Pin::new(&mut *self).poll_flush_plain(cx) {
338            Poll::Pending => Poll::Pending,
339            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
340            Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e))),
341        }
342    }
343
344    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
345        match Pin::new(&mut *self).poll_close_plain(cx) {
346            Poll::Pending => Poll::Pending,
347            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
348            Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e))),
349        }
350    }
351}
352
353/// Blocking `Read + Write` adapter for Foctet framed transport.
354#[derive(Debug)]
355pub struct SyncIo<T> {
356    io: T,
357    keys: Vec<TrafficKeys>,
358    active_key_id: u8,
359    max_retained_keys: usize,
360    inbound_direction: Direction,
361    outbound_direction: Direction,
362    default_stream_id: u32,
363    default_flags: u8,
364    next_seq: u64,
365    max_ciphertext_len: usize,
366    replay: ReplayProtector,
367}
368
369impl<T> SyncIo<T> {
370    /// Creates a blocking Foctet transport wrapper.
371    pub fn new(
372        io: T,
373        keys: TrafficKeys,
374        inbound_direction: Direction,
375        outbound_direction: Direction,
376    ) -> Self {
377        Self {
378            io,
379            active_key_id: keys.key_id,
380            keys: vec![keys],
381            max_retained_keys: 2,
382            inbound_direction,
383            outbound_direction,
384            default_stream_id: 0,
385            default_flags: 0,
386            next_seq: 0,
387            max_ciphertext_len: 16 * 1024 * 1024,
388            replay: ReplayProtector::new(DEFAULT_REPLAY_WINDOW),
389        }
390    }
391
392    /// Sets default stream ID for [`SyncIo::send`].
393    pub fn with_stream_id(mut self, stream_id: u32) -> Self {
394        self.default_stream_id = stream_id;
395        self
396    }
397
398    /// Sets default frame flags for [`SyncIo::send`].
399    pub fn with_default_flags(mut self, flags: u8) -> Self {
400        self.default_flags = flags;
401        self
402    }
403
404    /// Sets inbound ciphertext size limit.
405    pub fn with_max_ciphertext_len(mut self, max_len: usize) -> Self {
406        self.max_ciphertext_len = max_len;
407        self
408    }
409
410    /// Sets number of retained previous keys.
411    pub fn with_max_retained_keys(mut self, max: usize) -> Self {
412        self.max_retained_keys = max.max(1);
413        self
414    }
415
416    /// Returns current active key ID.
417    pub fn active_key_id(&self) -> u8 {
418        self.active_key_id
419    }
420
421    /// Returns known key IDs, active first.
422    pub fn known_key_ids(&self) -> Vec<u8> {
423        self.keys.iter().map(|k| k.key_id).collect()
424    }
425
426    /// Installs new active keys and retains previous keys.
427    pub fn install_active_keys(&mut self, keys: TrafficKeys) {
428        self.keys.retain(|k| k.key_id != keys.key_id);
429        self.keys.insert(0, keys.clone());
430        self.active_key_id = keys.key_id;
431        let keep = self.max_retained_keys + 1;
432        if self.keys.len() > keep {
433            self.keys.truncate(keep);
434        }
435    }
436
437    /// Consumes wrapper and returns underlying I/O object.
438    pub fn into_inner(self) -> T {
439        self.io
440    }
441
442    fn active_keys(&self) -> Result<&TrafficKeys, CoreError> {
443        self.keys
444            .iter()
445            .find(|k| k.key_id == self.active_key_id)
446            .ok_or(CoreError::MissingSessionSecret)
447    }
448
449    fn key_for_id(&self, key_id: u8) -> Option<&TrafficKeys> {
450        self.keys.iter().find(|k| k.key_id == key_id)
451    }
452
453    fn set_key_ring_from_session(&mut self, session: &Session) -> Result<(), CoreError> {
454        let ring = session.key_ring()?;
455        self.keys = ring;
456        self.active_key_id = self
457            .keys
458            .first()
459            .map(|k| k.key_id)
460            .ok_or(CoreError::InvalidSessionState)?;
461        let keep = self.max_retained_keys + 1;
462        if self.keys.len() > keep {
463            self.keys.truncate(keep);
464        }
465        Ok(())
466    }
467}
468
469impl<T: Read + Write> SyncIo<T> {
470    fn send_with_key(
471        &mut self,
472        keys: &TrafficKeys,
473        flags: u8,
474        stream_id: u32,
475        plaintext: &[u8],
476    ) -> Result<(), CoreError> {
477        let frame = encrypt_frame(
478            keys,
479            self.outbound_direction,
480            flags,
481            stream_id,
482            self.next_seq,
483            plaintext,
484        )?;
485        self.next_seq = self.next_seq.wrapping_add(1);
486        self.io.write_all(&frame.to_bytes())?;
487        self.io.flush()?;
488        Ok(())
489    }
490
491    /// Sends plaintext using default flags and stream ID.
492    pub fn send(&mut self, plaintext: &[u8]) -> Result<(), CoreError> {
493        self.send_with(self.default_flags, self.default_stream_id, plaintext)
494    }
495
496    /// Sends plaintext with explicit frame flags and stream ID.
497    pub fn send_with(
498        &mut self,
499        flags: u8,
500        stream_id: u32,
501        plaintext: &[u8],
502    ) -> Result<(), CoreError> {
503        let active = self.active_keys()?.clone();
504        self.send_with_key(&active, flags, stream_id, plaintext)
505    }
506
507    /// Sends TLV payload records as a single encrypted frame payload.
508    pub fn send_tlvs_with(
509        &mut self,
510        flags: u8,
511        stream_id: u32,
512        tlvs: &[Tlv],
513    ) -> Result<(), CoreError> {
514        let payload = payload::encode_tlvs(tlvs)?;
515        self.send_with(flags, stream_id, &payload)
516    }
517
518    /// Receives and decrypts one frame payload.
519    pub fn recv(&mut self) -> Result<Vec<u8>, CoreError> {
520        let mut header_buf = [0u8; FRAME_HEADER_LEN];
521        self.io.read_exact(&mut header_buf)?;
522        let header = FrameHeader::decode(&header_buf)?;
523        header.validate_v0()?;
524
525        let ct_len = header.ct_len as usize;
526        if ct_len > self.max_ciphertext_len {
527            return Err(CoreError::FrameTooLarge);
528        }
529
530        let mut ciphertext = vec![0u8; ct_len];
531        self.io.read_exact(&mut ciphertext)?;
532
533        self.replay
534            .check_and_record(header.key_id, header.stream_id, header.seq)?;
535
536        let keys = self
537            .key_for_id(header.key_id)
538            .ok_or(CoreError::UnexpectedKeyId {
539                expected: self.active_key_id,
540                actual: header.key_id,
541            })?;
542
543        let frame = Frame { header, ciphertext };
544        decrypt_frame_with_key(keys, self.inbound_direction, &frame)
545    }
546
547    /// Sends one control message.
548    pub fn send_control(&mut self, stream_id: u32, msg: &ControlMessage) -> Result<(), CoreError> {
549        self.send_with(crate::frame::flags::IS_CONTROL, stream_id, &msg.encode())
550    }
551
552    /// Sends one control message using an explicit key ID.
553    pub fn send_control_with_key_id(
554        &mut self,
555        stream_id: u32,
556        key_id: u8,
557        msg: &ControlMessage,
558    ) -> Result<(), CoreError> {
559        let key = self
560            .key_for_id(key_id)
561            .ok_or(CoreError::UnexpectedKeyId {
562                expected: self.active_key_id,
563                actual: key_id,
564            })?
565            .clone();
566        self.send_with_key(
567            &key,
568            crate::frame::flags::IS_CONTROL,
569            stream_id,
570            &msg.encode(),
571        )
572    }
573
574    /// Receives and decodes one control message.
575    pub fn recv_control(&mut self) -> Result<ControlMessage, CoreError> {
576        let plaintext = self.recv()?;
577        ControlMessage::decode(&plaintext)
578    }
579
580    /// Receives and decodes TLV payload records.
581    pub fn recv_tlvs(&mut self) -> Result<Vec<Tlv>, CoreError> {
582        let plaintext = self.recv()?;
583        payload::decode_tlvs(&plaintext)
584    }
585
586    /// Sends application payload and auto-handles session rekey controls.
587    pub fn send_data_with_session(
588        &mut self,
589        session: &mut Session,
590        flags: u8,
591        stream_id: u32,
592        plaintext: &[u8],
593    ) -> Result<(), CoreError> {
594        self.set_key_ring_from_session(session)?;
595        let app_tlv = Tlv::application_data(plaintext)?;
596        self.send_tlvs_with(flags, stream_id, &[app_tlv])?;
597
598        if let Some(ctrl) = session.on_outbound_payload(plaintext.len())? {
599            let rekey_old = match &ctrl {
600                ControlMessage::Rekey { old_key_id, .. } => Some(*old_key_id),
601                _ => None,
602            };
603            if let Some(old_key_id) = rekey_old {
604                self.send_control_with_key_id(0, old_key_id, &ctrl)?;
605                self.set_key_ring_from_session(session)?;
606            } else {
607                self.send_control(0, &ctrl)?;
608            }
609        }
610        Ok(())
611    }
612
613    /// Receives next frame and applies session-aware control handling.
614    pub fn recv_application_with_session(
615        &mut self,
616        session: &mut Session,
617    ) -> Result<Option<Vec<u8>>, CoreError> {
618        let mut header_buf = [0u8; FRAME_HEADER_LEN];
619        self.io.read_exact(&mut header_buf)?;
620        let header = FrameHeader::decode(&header_buf)?;
621        header.validate_v0()?;
622
623        let ct_len = header.ct_len as usize;
624        if ct_len > self.max_ciphertext_len {
625            return Err(CoreError::FrameTooLarge);
626        }
627
628        let mut ciphertext = vec![0u8; ct_len];
629        self.io.read_exact(&mut ciphertext)?;
630
631        self.replay
632            .check_and_record(header.key_id, header.stream_id, header.seq)?;
633
634        let keys = self
635            .key_for_id(header.key_id)
636            .ok_or(CoreError::UnexpectedKeyId {
637                expected: self.active_key_id,
638                actual: header.key_id,
639            })?;
640
641        let frame = Frame { header, ciphertext };
642        let plaintext = decrypt_frame_with_key(keys, self.inbound_direction, &frame)?;
643
644        if frame.header.flags & crate::frame::flags::IS_CONTROL != 0 {
645            let msg = ControlMessage::decode(&plaintext)?;
646            let response = session.handle_control(&msg)?;
647            self.set_key_ring_from_session(session)?;
648            if let Some(resp) = response {
649                self.send_control(0, &resp)?;
650            }
651            return Ok(None);
652        }
653
654        Ok(Some(plaintext))
655    }
656}
657
658impl From<CoreError> for std::io::Error {
659    fn from(value: CoreError) -> Self {
660        std::io::Error::other(value)
661    }
662}