Skip to main content

foctet_core/
frame.rs

1use std::{
2    pin::Pin,
3    task::{Context, Poll, ready},
4};
5
6use bytes::{Buf, BytesMut};
7use futures_core::Stream;
8use futures_sink::Sink;
9
10use crate::{
11    CoreError,
12    control::ControlMessage,
13    crypto::{Direction, TrafficKeys, decrypt_frame_with_key, encrypt_frame},
14    io::PollIo,
15    payload::{self, Tlv},
16    replay::{DEFAULT_REPLAY_WINDOW, ReplayProtector},
17    session::Session,
18};
19
20/// Draft v0 wire version identifier.
21pub const WIRE_VERSION_V0: u8 = 0x00;
22/// Mandatory profile identifier for Draft v0.
23pub const PROFILE_X25519_HKDF_XCHACHA20POLY1305: u8 = 0x01;
24/// Serialized frame-header length in bytes.
25pub const FRAME_HEADER_LEN: usize = 22;
26
27/// Draft v0 frame magic marker (`0xF0 0xC7`).
28pub const DRAFT_MAGIC: [u8; 2] = [0xF0, 0xC7];
29
30/// Bit flags carried in [`FrameHeader::flags`].
31pub mod flags {
32    /// Routing information exists at an outer layer.
33    pub const HAS_ROUTING: u8 = 1 << 0;
34    /// Frame carries a control payload.
35    pub const IS_CONTROL: u8 = 1 << 1;
36    /// Delivery acknowledgement hint.
37    pub const ACK_REQUIRED: u8 = 1 << 2;
38    /// Ciphertext includes semantic padding.
39    pub const PADDING: u8 = 1 << 3;
40    /// Bitmask of all known flags for Draft v0.
41    pub const ALL_KNOWN_BITS: u8 = HAS_ROUTING | IS_CONTROL | ACK_REQUIRED | PADDING;
42}
43
44/// Plaintext wire header authenticated as AEAD AAD.
45#[derive(Clone, Debug, Eq, PartialEq)]
46pub struct FrameHeader {
47    /// Draft magic marker.
48    pub magic: [u8; 2],
49    /// Protocol version.
50    pub version: u8,
51    /// Frame flags bitfield.
52    pub flags: u8,
53    /// Cryptographic profile identifier.
54    pub profile_id: u8,
55    /// Active traffic-key identifier.
56    pub key_id: u8,
57    /// Logical stream identifier.
58    pub stream_id: u32,
59    /// Sequence number per stream/direction/key.
60    pub seq: u64,
61    /// Ciphertext length in bytes.
62    pub ct_len: u32,
63}
64
65impl FrameHeader {
66    /// Creates a header value for a frame.
67    pub fn new(
68        flags: u8,
69        profile_id: u8,
70        key_id: u8,
71        stream_id: u32,
72        seq: u64,
73        ct_len: u32,
74    ) -> Self {
75        Self {
76            magic: DRAFT_MAGIC,
77            version: WIRE_VERSION_V0,
78            flags,
79            profile_id,
80            key_id,
81            stream_id,
82            seq,
83            ct_len,
84        }
85    }
86
87    /// Serializes header into fixed-width wire bytes.
88    pub fn encode(&self) -> [u8; FRAME_HEADER_LEN] {
89        let mut out = [0u8; FRAME_HEADER_LEN];
90        out[0..2].copy_from_slice(&self.magic);
91        out[2] = self.version;
92        out[3] = self.flags;
93        out[4] = self.profile_id;
94        out[5] = self.key_id;
95        out[6..10].copy_from_slice(&self.stream_id.to_be_bytes());
96        out[10..18].copy_from_slice(&self.seq.to_be_bytes());
97        out[18..22].copy_from_slice(&self.ct_len.to_be_bytes());
98        out
99    }
100
101    /// Parses a fixed-width frame header from bytes.
102    pub fn decode(buf: &[u8]) -> Result<Self, CoreError> {
103        if buf.len() != FRAME_HEADER_LEN {
104            return Err(CoreError::InvalidHeaderLength(buf.len()));
105        }
106
107        let mut magic = [0u8; 2];
108        magic.copy_from_slice(&buf[0..2]);
109        let version = buf[2];
110        let flags = buf[3];
111        let profile_id = buf[4];
112        let key_id = buf[5];
113
114        let mut stream_id_bytes = [0u8; 4];
115        stream_id_bytes.copy_from_slice(&buf[6..10]);
116        let stream_id = u32::from_be_bytes(stream_id_bytes);
117
118        let mut seq_bytes = [0u8; 8];
119        seq_bytes.copy_from_slice(&buf[10..18]);
120        let seq = u64::from_be_bytes(seq_bytes);
121
122        let mut ct_len_bytes = [0u8; 4];
123        ct_len_bytes.copy_from_slice(&buf[18..22]);
124        let ct_len = u32::from_be_bytes(ct_len_bytes);
125
126        Ok(Self {
127            magic,
128            version,
129            flags,
130            profile_id,
131            key_id,
132            stream_id,
133            seq,
134            ct_len,
135        })
136    }
137
138    /// Validates version/profile/flags according to Draft v0.
139    pub fn validate_v0(&self) -> Result<(), CoreError> {
140        if self.magic != DRAFT_MAGIC {
141            return Err(CoreError::InvalidMagic);
142        }
143        if self.version != WIRE_VERSION_V0 {
144            return Err(CoreError::UnsupportedVersion(self.version));
145        }
146        if self.profile_id != PROFILE_X25519_HKDF_XCHACHA20POLY1305 {
147            return Err(CoreError::UnsupportedProfile(self.profile_id));
148        }
149        if self.flags & !flags::ALL_KNOWN_BITS != 0 {
150            return Err(CoreError::UnknownFlags(self.flags));
151        }
152        Ok(())
153    }
154}
155
156/// Complete encrypted frame (header + ciphertext).
157#[derive(Clone, Debug, Eq, PartialEq)]
158pub struct Frame {
159    /// Plaintext authenticated header.
160    pub header: FrameHeader,
161    /// AEAD ciphertext (including tag).
162    pub ciphertext: Vec<u8>,
163}
164
165impl Frame {
166    /// Serializes frame to wire bytes.
167    pub fn to_bytes(&self) -> Vec<u8> {
168        let mut out = Vec::with_capacity(FRAME_HEADER_LEN + self.ciphertext.len());
169        out.extend_from_slice(&self.header.encode());
170        out.extend_from_slice(&self.ciphertext);
171        out
172    }
173
174    /// Parses a complete frame from wire bytes.
175    pub fn from_bytes(buf: &[u8]) -> Result<Self, CoreError> {
176        if buf.len() < FRAME_HEADER_LEN {
177            return Err(CoreError::InvalidHeaderLength(buf.len()));
178        }
179        let header = FrameHeader::decode(&buf[..FRAME_HEADER_LEN])?;
180        let ciphertext = buf[FRAME_HEADER_LEN..].to_vec();
181        if ciphertext.len() != header.ct_len as usize {
182            return Err(CoreError::CiphertextLengthMismatch {
183                expected: header.ct_len as usize,
184                actual: ciphertext.len(),
185            });
186        }
187        Ok(Self { header, ciphertext })
188    }
189}
190
191/// Frame decoded by transport layer with plaintext payload bytes.
192#[derive(Clone, Debug, Eq, PartialEq)]
193pub struct DecodedFrame {
194    /// Parsed authenticated header.
195    pub header: FrameHeader,
196    /// Decrypted frame payload bytes.
197    pub plaintext: Vec<u8>,
198}
199
200/// Framed Foctet transport over a poll-based I/O backend.
201#[derive(Clone, Debug)]
202pub struct FoctetFramed<T> {
203    io: T,
204    keys: Vec<TrafficKeys>,
205    active_key_id: u8,
206    max_retained_keys: usize,
207    inbound_direction: Direction,
208    outbound_direction: Direction,
209    default_stream_id: u32,
210    default_flags: u8,
211    next_seq: u64,
212    max_ciphertext_len: usize,
213    rx: BytesMut,
214    tx: BytesMut,
215    replay: ReplayProtector,
216    eof: bool,
217}
218
219impl<T> FoctetFramed<T> {
220    /// Creates a framed transport with initial traffic keys.
221    pub fn new(
222        io: T,
223        keys: TrafficKeys,
224        inbound_direction: Direction,
225        outbound_direction: Direction,
226    ) -> Self {
227        Self {
228            io,
229            active_key_id: keys.key_id,
230            keys: vec![keys],
231            max_retained_keys: 2,
232            inbound_direction,
233            outbound_direction,
234            default_stream_id: 0,
235            default_flags: 0,
236            next_seq: 0,
237            max_ciphertext_len: 16 * 1024 * 1024,
238            rx: BytesMut::with_capacity(8 * 1024),
239            tx: BytesMut::new(),
240            replay: ReplayProtector::new(DEFAULT_REPLAY_WINDOW),
241            eof: false,
242        }
243    }
244
245    /// Sets default stream ID for sink-based sending.
246    pub fn with_stream_id(mut self, stream_id: u32) -> Self {
247        self.default_stream_id = stream_id;
248        self
249    }
250
251    /// Sets default frame flags for sink-based sending.
252    pub fn with_default_flags(mut self, flags: u8) -> Self {
253        self.default_flags = flags;
254        self
255    }
256
257    /// Sets inbound ciphertext size limit.
258    pub fn with_max_ciphertext_len(mut self, max_len: usize) -> Self {
259        self.max_ciphertext_len = max_len;
260        self
261    }
262
263    /// Sets number of retained previous keys.
264    pub fn with_max_retained_keys(mut self, max: usize) -> Self {
265        self.max_retained_keys = max.max(1);
266        self
267    }
268
269    /// Returns immutable reference to underlying I/O object.
270    pub fn get_ref(&self) -> &T {
271        &self.io
272    }
273
274    /// Returns mutable reference to underlying I/O object.
275    pub fn get_mut(&mut self) -> &mut T {
276        &mut self.io
277    }
278
279    /// Consumes wrapper and returns underlying I/O object.
280    pub fn into_inner(self) -> T {
281        self.io
282    }
283
284    /// Returns known key IDs, active first.
285    pub fn known_key_ids(&self) -> Vec<u8> {
286        self.keys.iter().map(|k| k.key_id).collect()
287    }
288
289    /// Returns active key ID.
290    pub fn active_key_id(&self) -> u8 {
291        self.active_key_id
292    }
293
294    /// Installs new active keys and retains previous keys.
295    pub fn install_active_keys(&mut self, keys: TrafficKeys) {
296        self.keys.retain(|k| k.key_id != keys.key_id);
297        self.keys.insert(0, keys.clone());
298        self.active_key_id = keys.key_id;
299        let keep = self.max_retained_keys + 1;
300        if self.keys.len() > keep {
301            self.keys.truncate(keep);
302        }
303    }
304
305    fn active_keys(&self) -> Result<&TrafficKeys, CoreError> {
306        self.keys
307            .iter()
308            .find(|k| k.key_id == self.active_key_id)
309            .ok_or(CoreError::MissingSessionSecret)
310    }
311
312    fn key_for_id(&self, key_id: u8) -> Option<&TrafficKeys> {
313        self.keys.iter().find(|k| k.key_id == key_id)
314    }
315
316    fn set_key_ring_from_session(&mut self, session: &Session) -> Result<(), CoreError> {
317        let ring = session.key_ring()?;
318        self.keys = ring;
319        self.active_key_id = self
320            .keys
321            .first()
322            .map(|k| k.key_id)
323            .ok_or(CoreError::InvalidSessionState)?;
324        let keep = self.max_retained_keys + 1;
325        if self.keys.len() > keep {
326            self.keys.truncate(keep);
327        }
328        Ok(())
329    }
330
331    fn enqueue_with_specific_key(
332        &mut self,
333        key_id: u8,
334        flags: u8,
335        stream_id: u32,
336        plaintext: &[u8],
337    ) -> Result<(), CoreError> {
338        let keys = self
339            .key_for_id(key_id)
340            .ok_or(CoreError::UnexpectedKeyId {
341                expected: self.active_key_id,
342                actual: key_id,
343            })?
344            .clone();
345        let frame = encrypt_frame(
346            &keys,
347            self.outbound_direction,
348            flags,
349            stream_id,
350            self.next_seq,
351            plaintext,
352        )?;
353        self.next_seq = self
354            .next_seq
355            .checked_add(1)
356            .ok_or(CoreError::SequenceExhausted)?;
357        self.tx.extend_from_slice(&frame.to_bytes());
358        Ok(())
359    }
360}
361
362impl<T: PollIo + Unpin> FoctetFramed<T> {
363    /// Poll-based helper to send one payload frame.
364    pub fn poll_send_frame(
365        self: Pin<&mut Self>,
366        cx: &mut Context<'_>,
367        flags: u8,
368        stream_id: u32,
369        plaintext: &[u8],
370    ) -> Poll<Result<(), CoreError>> {
371        let this = self.get_mut();
372        ready!(Pin::new(&mut *this).poll_ready(cx))?;
373        Pin::new(&mut *this).start_send_with(flags, stream_id, plaintext)?;
374        Poll::Ready(Ok(()))
375    }
376
377    /// Enqueues one encrypted frame into outbound buffer.
378    pub fn start_send_with(
379        self: Pin<&mut Self>,
380        flags: u8,
381        stream_id: u32,
382        plaintext: &[u8],
383    ) -> Result<(), CoreError> {
384        let this = self.get_mut();
385        let active = this.active_keys()?.clone();
386        let frame = encrypt_frame(
387            &active,
388            this.outbound_direction,
389            flags,
390            stream_id,
391            this.next_seq,
392            plaintext,
393        )?;
394        this.next_seq = this
395            .next_seq
396            .checked_add(1)
397            .ok_or(CoreError::SequenceExhausted)?;
398        this.tx.extend_from_slice(&frame.to_bytes());
399        Ok(())
400    }
401
402    /// Enqueues a control payload frame.
403    pub fn start_send_control(
404        self: Pin<&mut Self>,
405        stream_id: u32,
406        msg: &ControlMessage,
407    ) -> Result<(), CoreError> {
408        self.start_send_with(flags::IS_CONTROL, stream_id, &msg.encode())
409    }
410
411    /// Enqueues TLV payload bytes in a data frame.
412    pub fn start_send_tlvs_with(
413        self: Pin<&mut Self>,
414        flags: u8,
415        stream_id: u32,
416        tlvs: &[Tlv],
417    ) -> Result<(), CoreError> {
418        let payload = payload::encode_tlvs(tlvs)?;
419        self.start_send_with(flags, stream_id, &payload)
420    }
421
422    /// Enqueues a control payload frame using explicit key ID.
423    pub fn start_send_control_with_key_id(
424        self: Pin<&mut Self>,
425        stream_id: u32,
426        key_id: u8,
427        msg: &ControlMessage,
428    ) -> Result<(), CoreError> {
429        let this = self.get_mut();
430        this.enqueue_with_specific_key(key_id, flags::IS_CONTROL, stream_id, &msg.encode())
431    }
432
433    /// Decodes a control message from a decoded control frame.
434    pub fn decode_control(frame: &DecodedFrame) -> Result<ControlMessage, CoreError> {
435        if frame.header.flags & flags::IS_CONTROL == 0 {
436            return Err(CoreError::UnexpectedControlMessage);
437        }
438        ControlMessage::decode(&frame.plaintext)
439    }
440
441    /// Decodes TLV records from a decoded frame payload.
442    pub fn decode_tlvs(frame: &DecodedFrame) -> Result<Vec<Tlv>, CoreError> {
443        payload::decode_tlvs(&frame.plaintext)
444    }
445
446    /// Sends application payload with session-aware automatic rekey handling.
447    pub fn start_send_data_with_session(
448        self: Pin<&mut Self>,
449        session: &mut Session,
450        flags: u8,
451        stream_id: u32,
452        plaintext: &[u8],
453    ) -> Result<(), CoreError> {
454        let this = self.get_mut();
455        this.set_key_ring_from_session(session)?;
456        let app_tlv = Tlv::application_data(plaintext)?;
457        let app_payload = payload::encode_tlvs(&[app_tlv])?;
458        this.enqueue_with_specific_key(this.active_key_id, flags, stream_id, &app_payload)?;
459
460        if let Some(ctrl) = session.on_outbound_payload(plaintext.len())? {
461            let ctrl_bytes = ctrl.encode();
462            let rekey_old = match ctrl {
463                ControlMessage::Rekey { old_key_id, .. } => Some(old_key_id),
464                _ => None,
465            };
466            if let Some(old_key_id) = rekey_old {
467                this.enqueue_with_specific_key(old_key_id, flags::IS_CONTROL, 0, &ctrl_bytes)?;
468                this.set_key_ring_from_session(session)?;
469            } else {
470                this.enqueue_with_specific_key(
471                    this.active_key_id,
472                    flags::IS_CONTROL,
473                    0,
474                    &ctrl_bytes,
475                )?;
476            }
477        }
478        Ok(())
479    }
480
481    /// Processes one incoming decoded frame with session-aware control handling.
482    pub fn handle_incoming_with_session(
483        self: Pin<&mut Self>,
484        session: &mut Session,
485        frame: DecodedFrame,
486    ) -> Result<Option<Vec<u8>>, CoreError> {
487        let this = self.get_mut();
488        if frame.header.flags & flags::IS_CONTROL != 0 {
489            let msg = ControlMessage::decode(&frame.plaintext)?;
490            let response = session.handle_control(&msg)?;
491            this.set_key_ring_from_session(session)?;
492            if let Some(resp) = response {
493                this.enqueue_with_specific_key(
494                    this.active_key_id,
495                    flags::IS_CONTROL,
496                    0,
497                    &resp.encode(),
498                )?;
499            }
500            return Ok(None);
501        }
502        Ok(Some(frame.plaintext))
503    }
504
505    fn try_decode(&mut self) -> Result<Option<DecodedFrame>, CoreError> {
506        if self.rx.len() < FRAME_HEADER_LEN {
507            return Ok(None);
508        }
509
510        let header = FrameHeader::decode(&self.rx[..FRAME_HEADER_LEN])?;
511        header.validate_v0()?;
512
513        let ct_len = header.ct_len as usize;
514        if ct_len > self.max_ciphertext_len {
515            return Err(CoreError::FrameTooLarge);
516        }
517
518        let total = FRAME_HEADER_LEN + ct_len;
519        if self.rx.len() < total {
520            return Ok(None);
521        }
522
523        let frame_bytes = self.rx.split_to(total);
524        let frame = Frame::from_bytes(&frame_bytes)?;
525
526        self.replay.check_and_record(
527            frame.header.key_id,
528            frame.header.stream_id,
529            frame.header.seq,
530        )?;
531
532        let keys = self
533            .key_for_id(frame.header.key_id)
534            .ok_or(CoreError::UnexpectedKeyId {
535                expected: self.active_key_id,
536                actual: frame.header.key_id,
537            })?;
538        let plaintext = decrypt_frame_with_key(keys, self.inbound_direction, &frame)?;
539
540        Ok(Some(DecodedFrame {
541            header: frame.header,
542            plaintext,
543        }))
544    }
545
546    fn poll_fill_rx(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), CoreError>> {
547        let mut tmp = [0u8; 8192];
548        match Pin::new(&mut self.io).poll_read(cx, &mut tmp) {
549            Poll::Pending => Poll::Pending,
550            Poll::Ready(Ok(0)) => {
551                self.eof = true;
552                Poll::Ready(Ok(()))
553            }
554            Poll::Ready(Ok(n)) => {
555                self.rx.extend_from_slice(&tmp[..n]);
556                Poll::Ready(Ok(()))
557            }
558            Poll::Ready(Err(e)) => Poll::Ready(Err(CoreError::Io(e))),
559        }
560    }
561
562    fn poll_drain_tx(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), CoreError>> {
563        while !self.tx.is_empty() {
564            let n = ready!(Pin::new(&mut self.io).poll_write(cx, &self.tx))?;
565            if n == 0 {
566                return Poll::Ready(Err(CoreError::UnexpectedEof));
567            }
568            self.tx.advance(n);
569        }
570        Poll::Ready(Ok(()))
571    }
572}
573
574impl<T: PollIo + Unpin> Stream for FoctetFramed<T> {
575    type Item = Result<DecodedFrame, CoreError>;
576
577    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
578        let this = self.get_mut();
579
580        loop {
581            match this.try_decode() {
582                Ok(Some(frame)) => return Poll::Ready(Some(Ok(frame))),
583                Ok(None) => {}
584                Err(e) => return Poll::Ready(Some(Err(e))),
585            }
586
587            if this.eof {
588                if this.rx.is_empty() {
589                    return Poll::Ready(None);
590                }
591                return Poll::Ready(Some(Err(CoreError::UnexpectedEof)));
592            }
593
594            ready!(this.poll_fill_rx(cx))?;
595        }
596    }
597}
598
599impl<T: PollIo + Unpin> Sink<Vec<u8>> for FoctetFramed<T> {
600    type Error = CoreError;
601
602    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
603        let this = self.get_mut();
604        if this.tx.is_empty() {
605            return Poll::Ready(Ok(()));
606        }
607        this.poll_drain_tx(cx)
608    }
609
610    fn start_send(self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
611        let this = self.get_mut();
612        let active = this.active_keys()?.clone();
613        let frame = encrypt_frame(
614            &active,
615            this.outbound_direction,
616            this.default_flags,
617            this.default_stream_id,
618            this.next_seq,
619            &item,
620        )?;
621        this.next_seq = this
622            .next_seq
623            .checked_add(1)
624            .ok_or(CoreError::SequenceExhausted)?;
625        this.tx.extend_from_slice(&frame.to_bytes());
626        Ok(())
627    }
628
629    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
630        let this = self.get_mut();
631        ready!(this.poll_drain_tx(cx))?;
632        Pin::new(&mut this.io).poll_flush(cx).map_err(CoreError::Io)
633    }
634
635    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
636        let this = self.get_mut();
637        ready!(this.poll_drain_tx(cx))?;
638        ready!(Pin::new(&mut this.io).poll_flush(cx)).map_err(CoreError::Io)?;
639        Pin::new(&mut this.io).poll_close(cx).map_err(CoreError::Io)
640    }
641}
642
643/// Plain byte-stream convenience wrapper over [`FoctetFramed`].
644#[derive(Clone, Debug)]
645pub struct FoctetStream<T> {
646    framed: FoctetFramed<T>,
647    read_buf: BytesMut,
648    max_write_frame: usize,
649}
650
651impl<T> FoctetStream<T> {
652    /// Creates a stream wrapper from a framed transport.
653    pub fn new(framed: FoctetFramed<T>) -> Self {
654        Self {
655            framed,
656            read_buf: BytesMut::new(),
657            max_write_frame: 64 * 1024,
658        }
659    }
660
661    /// Sets max plaintext bytes packed per write call.
662    pub fn with_max_write_frame(mut self, max: usize) -> Self {
663        self.max_write_frame = max.max(1);
664        self
665    }
666
667    /// Returns inner framed transport.
668    pub fn into_inner(self) -> FoctetFramed<T> {
669        self.framed
670    }
671
672    /// Returns immutable reference to inner framed transport.
673    pub fn framed_ref(&self) -> &FoctetFramed<T> {
674        &self.framed
675    }
676
677    /// Returns mutable reference to inner framed transport.
678    pub fn framed_mut(&mut self) -> &mut FoctetFramed<T> {
679        &mut self.framed
680    }
681}
682
683impl<T: PollIo + Unpin> FoctetStream<T> {
684    /// Poll-based plaintext read from encrypted framed transport.
685    pub fn poll_read_plain(
686        self: Pin<&mut Self>,
687        cx: &mut Context<'_>,
688        out: &mut [u8],
689    ) -> Poll<Result<usize, CoreError>> {
690        let this = self.get_mut();
691
692        if !this.read_buf.is_empty() {
693            let n = out.len().min(this.read_buf.len());
694            out[..n].copy_from_slice(&this.read_buf.split_to(n));
695            return Poll::Ready(Ok(n));
696        }
697
698        match ready!(Pin::new(&mut this.framed).poll_next(cx)) {
699            Some(Ok(frame)) => {
700                this.read_buf.extend_from_slice(&frame.plaintext);
701                let n = out.len().min(this.read_buf.len());
702                out[..n].copy_from_slice(&this.read_buf.split_to(n));
703                Poll::Ready(Ok(n))
704            }
705            Some(Err(e)) => Poll::Ready(Err(e)),
706            None => Poll::Ready(Ok(0)),
707        }
708    }
709
710    /// Poll-based plaintext write into encrypted framed transport.
711    pub fn poll_write_plain(
712        self: Pin<&mut Self>,
713        cx: &mut Context<'_>,
714        buf: &[u8],
715    ) -> Poll<Result<usize, CoreError>> {
716        let this = self.get_mut();
717        let n = buf.len().min(this.max_write_frame);
718        if n == 0 {
719            return Poll::Ready(Ok(0));
720        }
721
722        ready!(Pin::new(&mut this.framed).poll_ready(cx))?;
723        Pin::new(&mut this.framed).start_send(buf[..n].to_vec())?;
724        Poll::Ready(Ok(n))
725    }
726
727    /// Poll-based flush for pending encrypted writes.
728    pub fn poll_flush_plain(
729        self: Pin<&mut Self>,
730        cx: &mut Context<'_>,
731    ) -> Poll<Result<(), CoreError>> {
732        Pin::new(&mut self.get_mut().framed).poll_flush(cx)
733    }
734
735    /// Poll-based close for encrypted transport.
736    pub fn poll_close_plain(
737        self: Pin<&mut Self>,
738        cx: &mut Context<'_>,
739    ) -> Poll<Result<(), CoreError>> {
740        Pin::new(&mut self.get_mut().framed).poll_close(cx)
741    }
742}
743
744#[cfg(test)]
745mod tests {
746    use std::{
747        collections::VecDeque,
748        pin::Pin,
749        task::{Context, Poll, Waker},
750    };
751
752    use futures_core::Stream;
753    use futures_sink::Sink;
754
755    use crate::{
756        crypto::{Direction, EphemeralKeyPair, derive_traffic_keys, random_session_salt},
757        io::{PollRead, PollWrite},
758    };
759
760    use super::{FoctetFramed, flags};
761
762    #[derive(Default, Debug)]
763    struct MemoryIo {
764        inbound: VecDeque<u8>,
765        outbound: Vec<u8>,
766    }
767
768    impl MemoryIo {
769        fn push_inbound(&mut self, bytes: &[u8]) {
770            self.inbound.extend(bytes.iter().copied());
771        }
772    }
773
774    impl PollRead for MemoryIo {
775        fn poll_read(
776            mut self: Pin<&mut Self>,
777            _cx: &mut Context<'_>,
778            buf: &mut [u8],
779        ) -> Poll<std::io::Result<usize>> {
780            let n = buf.len().min(self.inbound.len());
781            for slot in buf.iter_mut().take(n) {
782                *slot = self.inbound.pop_front().expect("inbound byte");
783            }
784            Poll::Ready(Ok(n))
785        }
786    }
787
788    impl PollWrite for MemoryIo {
789        fn poll_write(
790            mut self: Pin<&mut Self>,
791            _cx: &mut Context<'_>,
792            buf: &[u8],
793        ) -> Poll<std::io::Result<usize>> {
794            self.outbound.extend_from_slice(buf);
795            Poll::Ready(Ok(buf.len()))
796        }
797
798        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
799            Poll::Ready(Ok(()))
800        }
801
802        fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
803            Poll::Ready(Ok(()))
804        }
805    }
806
807    fn noop_waker() -> Waker {
808        Waker::noop().clone()
809    }
810
811    #[test]
812    fn framed_sink_stream_roundtrip() {
813        let eph_a = EphemeralKeyPair::generate();
814        let eph_b = EphemeralKeyPair::generate();
815        let ss = eph_a.shared_secret(eph_b.public).expect("shared secret");
816        let salt = random_session_salt();
817        let keys = derive_traffic_keys(&ss, &salt, 1).expect("traffic keys");
818
819        let io = MemoryIo::default();
820        let mut framed = FoctetFramed::new(io, keys.clone(), Direction::C2S, Direction::C2S)
821            .with_stream_id(9)
822            .with_default_flags(flags::IS_CONTROL);
823
824        let waker = noop_waker();
825        let mut cx = Context::from_waker(&waker);
826
827        Pin::new(&mut framed)
828            .start_send(b"hello framed".to_vec())
829            .expect("queue");
830        match Pin::new(&mut framed).poll_flush(&mut cx) {
831            Poll::Ready(Ok(())) => {}
832            _ => panic!("flush failed"),
833        }
834
835        let outbound = framed.get_ref().outbound.clone();
836        framed.get_mut().push_inbound(&outbound);
837
838        let item = match Pin::new(&mut framed).poll_next(&mut cx) {
839            Poll::Ready(Some(Ok(frame))) => frame,
840            other => panic!("unexpected poll_next: {other:?}"),
841        };
842        assert_eq!(item.plaintext, b"hello framed");
843        assert_eq!(item.header.stream_id, 9);
844        assert_eq!(item.header.flags, flags::IS_CONTROL);
845    }
846}