1use std::{
2 pin::Pin,
3 task::{Context, Poll, ready},
4};
5
6use bytes::{Buf, BytesMut};
7use futures_core::Stream;
8use futures_sink::Sink;
9
10use crate::{
11 CoreError,
12 control::ControlMessage,
13 crypto::{Direction, TrafficKeys, decrypt_frame_with_key, encrypt_frame},
14 io::PollIo,
15 payload::{self, Tlv},
16 replay::{DEFAULT_REPLAY_WINDOW, ReplayProtector},
17 session::Session,
18};
19
20pub const WIRE_VERSION_V0: u8 = 0x00;
22pub const PROFILE_X25519_HKDF_XCHACHA20POLY1305: u8 = 0x01;
24pub const FRAME_HEADER_LEN: usize = 22;
26
27pub const DRAFT_MAGIC: [u8; 2] = [0xF0, 0xC7];
29
30pub mod flags {
32 pub const HAS_ROUTING: u8 = 1 << 0;
34 pub const IS_CONTROL: u8 = 1 << 1;
36 pub const ACK_REQUIRED: u8 = 1 << 2;
38 pub const PADDING: u8 = 1 << 3;
40 pub const ALL_KNOWN_BITS: u8 = HAS_ROUTING | IS_CONTROL | ACK_REQUIRED | PADDING;
42}
43
44#[derive(Clone, Debug, Eq, PartialEq)]
46pub struct FrameHeader {
47 pub magic: [u8; 2],
49 pub version: u8,
51 pub flags: u8,
53 pub profile_id: u8,
55 pub key_id: u8,
57 pub stream_id: u32,
59 pub seq: u64,
61 pub ct_len: u32,
63}
64
65impl FrameHeader {
66 pub fn new(
68 flags: u8,
69 profile_id: u8,
70 key_id: u8,
71 stream_id: u32,
72 seq: u64,
73 ct_len: u32,
74 ) -> Self {
75 Self {
76 magic: DRAFT_MAGIC,
77 version: WIRE_VERSION_V0,
78 flags,
79 profile_id,
80 key_id,
81 stream_id,
82 seq,
83 ct_len,
84 }
85 }
86
87 pub fn encode(&self) -> [u8; FRAME_HEADER_LEN] {
89 let mut out = [0u8; FRAME_HEADER_LEN];
90 out[0..2].copy_from_slice(&self.magic);
91 out[2] = self.version;
92 out[3] = self.flags;
93 out[4] = self.profile_id;
94 out[5] = self.key_id;
95 out[6..10].copy_from_slice(&self.stream_id.to_be_bytes());
96 out[10..18].copy_from_slice(&self.seq.to_be_bytes());
97 out[18..22].copy_from_slice(&self.ct_len.to_be_bytes());
98 out
99 }
100
101 pub fn decode(buf: &[u8]) -> Result<Self, CoreError> {
103 if buf.len() != FRAME_HEADER_LEN {
104 return Err(CoreError::InvalidHeaderLength(buf.len()));
105 }
106
107 let mut magic = [0u8; 2];
108 magic.copy_from_slice(&buf[0..2]);
109 let version = buf[2];
110 let flags = buf[3];
111 let profile_id = buf[4];
112 let key_id = buf[5];
113
114 let mut stream_id_bytes = [0u8; 4];
115 stream_id_bytes.copy_from_slice(&buf[6..10]);
116 let stream_id = u32::from_be_bytes(stream_id_bytes);
117
118 let mut seq_bytes = [0u8; 8];
119 seq_bytes.copy_from_slice(&buf[10..18]);
120 let seq = u64::from_be_bytes(seq_bytes);
121
122 let mut ct_len_bytes = [0u8; 4];
123 ct_len_bytes.copy_from_slice(&buf[18..22]);
124 let ct_len = u32::from_be_bytes(ct_len_bytes);
125
126 Ok(Self {
127 magic,
128 version,
129 flags,
130 profile_id,
131 key_id,
132 stream_id,
133 seq,
134 ct_len,
135 })
136 }
137
138 pub fn validate_v0(&self) -> Result<(), CoreError> {
140 if self.magic != DRAFT_MAGIC {
141 return Err(CoreError::InvalidMagic);
142 }
143 if self.version != WIRE_VERSION_V0 {
144 return Err(CoreError::UnsupportedVersion(self.version));
145 }
146 if self.profile_id != PROFILE_X25519_HKDF_XCHACHA20POLY1305 {
147 return Err(CoreError::UnsupportedProfile(self.profile_id));
148 }
149 if self.flags & !flags::ALL_KNOWN_BITS != 0 {
150 return Err(CoreError::UnknownFlags(self.flags));
151 }
152 Ok(())
153 }
154}
155
156#[derive(Clone, Debug, Eq, PartialEq)]
158pub struct Frame {
159 pub header: FrameHeader,
161 pub ciphertext: Vec<u8>,
163}
164
165impl Frame {
166 pub fn to_bytes(&self) -> Vec<u8> {
168 let mut out = Vec::with_capacity(FRAME_HEADER_LEN + self.ciphertext.len());
169 out.extend_from_slice(&self.header.encode());
170 out.extend_from_slice(&self.ciphertext);
171 out
172 }
173
174 pub fn from_bytes(buf: &[u8]) -> Result<Self, CoreError> {
176 if buf.len() < FRAME_HEADER_LEN {
177 return Err(CoreError::InvalidHeaderLength(buf.len()));
178 }
179 let header = FrameHeader::decode(&buf[..FRAME_HEADER_LEN])?;
180 let ciphertext = buf[FRAME_HEADER_LEN..].to_vec();
181 if ciphertext.len() != header.ct_len as usize {
182 return Err(CoreError::CiphertextLengthMismatch {
183 expected: header.ct_len as usize,
184 actual: ciphertext.len(),
185 });
186 }
187 Ok(Self { header, ciphertext })
188 }
189}
190
191#[derive(Clone, Debug, Eq, PartialEq)]
193pub struct DecodedFrame {
194 pub header: FrameHeader,
196 pub plaintext: Vec<u8>,
198}
199
200#[derive(Clone, Debug)]
202pub struct FoctetFramed<T> {
203 io: T,
204 keys: Vec<TrafficKeys>,
205 active_key_id: u8,
206 max_retained_keys: usize,
207 inbound_direction: Direction,
208 outbound_direction: Direction,
209 default_stream_id: u32,
210 default_flags: u8,
211 next_seq: u64,
212 max_ciphertext_len: usize,
213 rx: BytesMut,
214 tx: BytesMut,
215 replay: ReplayProtector,
216 eof: bool,
217}
218
219impl<T> FoctetFramed<T> {
220 pub fn new(
222 io: T,
223 keys: TrafficKeys,
224 inbound_direction: Direction,
225 outbound_direction: Direction,
226 ) -> Self {
227 Self {
228 io,
229 active_key_id: keys.key_id,
230 keys: vec![keys],
231 max_retained_keys: 2,
232 inbound_direction,
233 outbound_direction,
234 default_stream_id: 0,
235 default_flags: 0,
236 next_seq: 0,
237 max_ciphertext_len: 16 * 1024 * 1024,
238 rx: BytesMut::with_capacity(8 * 1024),
239 tx: BytesMut::new(),
240 replay: ReplayProtector::new(DEFAULT_REPLAY_WINDOW),
241 eof: false,
242 }
243 }
244
245 pub fn with_stream_id(mut self, stream_id: u32) -> Self {
247 self.default_stream_id = stream_id;
248 self
249 }
250
251 pub fn with_default_flags(mut self, flags: u8) -> Self {
253 self.default_flags = flags;
254 self
255 }
256
257 pub fn with_max_ciphertext_len(mut self, max_len: usize) -> Self {
259 self.max_ciphertext_len = max_len;
260 self
261 }
262
263 pub fn with_max_retained_keys(mut self, max: usize) -> Self {
265 self.max_retained_keys = max.max(1);
266 self
267 }
268
269 pub fn get_ref(&self) -> &T {
271 &self.io
272 }
273
274 pub fn get_mut(&mut self) -> &mut T {
276 &mut self.io
277 }
278
279 pub fn into_inner(self) -> T {
281 self.io
282 }
283
284 pub fn known_key_ids(&self) -> Vec<u8> {
286 self.keys.iter().map(|k| k.key_id).collect()
287 }
288
289 pub fn active_key_id(&self) -> u8 {
291 self.active_key_id
292 }
293
294 pub fn install_active_keys(&mut self, keys: TrafficKeys) {
296 self.keys.retain(|k| k.key_id != keys.key_id);
297 self.keys.insert(0, keys.clone());
298 self.active_key_id = keys.key_id;
299 let keep = self.max_retained_keys + 1;
300 if self.keys.len() > keep {
301 self.keys.truncate(keep);
302 }
303 }
304
305 fn active_keys(&self) -> Result<&TrafficKeys, CoreError> {
306 self.keys
307 .iter()
308 .find(|k| k.key_id == self.active_key_id)
309 .ok_or(CoreError::MissingSessionSecret)
310 }
311
312 fn key_for_id(&self, key_id: u8) -> Option<&TrafficKeys> {
313 self.keys.iter().find(|k| k.key_id == key_id)
314 }
315
316 fn set_key_ring_from_session(&mut self, session: &Session) -> Result<(), CoreError> {
317 let ring = session.key_ring()?;
318 self.keys = ring;
319 self.active_key_id = self
320 .keys
321 .first()
322 .map(|k| k.key_id)
323 .ok_or(CoreError::InvalidSessionState)?;
324 let keep = self.max_retained_keys + 1;
325 if self.keys.len() > keep {
326 self.keys.truncate(keep);
327 }
328 Ok(())
329 }
330
331 fn enqueue_with_specific_key(
332 &mut self,
333 key_id: u8,
334 flags: u8,
335 stream_id: u32,
336 plaintext: &[u8],
337 ) -> Result<(), CoreError> {
338 let keys = self
339 .key_for_id(key_id)
340 .ok_or(CoreError::UnexpectedKeyId {
341 expected: self.active_key_id,
342 actual: key_id,
343 })?
344 .clone();
345 let frame = encrypt_frame(
346 &keys,
347 self.outbound_direction,
348 flags,
349 stream_id,
350 self.next_seq,
351 plaintext,
352 )?;
353 self.next_seq = self
354 .next_seq
355 .checked_add(1)
356 .ok_or(CoreError::SequenceExhausted)?;
357 self.tx.extend_from_slice(&frame.to_bytes());
358 Ok(())
359 }
360}
361
362impl<T: PollIo + Unpin> FoctetFramed<T> {
363 pub fn poll_send_frame(
365 self: Pin<&mut Self>,
366 cx: &mut Context<'_>,
367 flags: u8,
368 stream_id: u32,
369 plaintext: &[u8],
370 ) -> Poll<Result<(), CoreError>> {
371 let this = self.get_mut();
372 ready!(Pin::new(&mut *this).poll_ready(cx))?;
373 Pin::new(&mut *this).start_send_with(flags, stream_id, plaintext)?;
374 Poll::Ready(Ok(()))
375 }
376
377 pub fn start_send_with(
379 self: Pin<&mut Self>,
380 flags: u8,
381 stream_id: u32,
382 plaintext: &[u8],
383 ) -> Result<(), CoreError> {
384 let this = self.get_mut();
385 let active = this.active_keys()?.clone();
386 let frame = encrypt_frame(
387 &active,
388 this.outbound_direction,
389 flags,
390 stream_id,
391 this.next_seq,
392 plaintext,
393 )?;
394 this.next_seq = this
395 .next_seq
396 .checked_add(1)
397 .ok_or(CoreError::SequenceExhausted)?;
398 this.tx.extend_from_slice(&frame.to_bytes());
399 Ok(())
400 }
401
402 pub fn start_send_control(
404 self: Pin<&mut Self>,
405 stream_id: u32,
406 msg: &ControlMessage,
407 ) -> Result<(), CoreError> {
408 self.start_send_with(flags::IS_CONTROL, stream_id, &msg.encode())
409 }
410
411 pub fn start_send_tlvs_with(
413 self: Pin<&mut Self>,
414 flags: u8,
415 stream_id: u32,
416 tlvs: &[Tlv],
417 ) -> Result<(), CoreError> {
418 let payload = payload::encode_tlvs(tlvs)?;
419 self.start_send_with(flags, stream_id, &payload)
420 }
421
422 pub fn start_send_control_with_key_id(
424 self: Pin<&mut Self>,
425 stream_id: u32,
426 key_id: u8,
427 msg: &ControlMessage,
428 ) -> Result<(), CoreError> {
429 let this = self.get_mut();
430 this.enqueue_with_specific_key(key_id, flags::IS_CONTROL, stream_id, &msg.encode())
431 }
432
433 pub fn decode_control(frame: &DecodedFrame) -> Result<ControlMessage, CoreError> {
435 if frame.header.flags & flags::IS_CONTROL == 0 {
436 return Err(CoreError::UnexpectedControlMessage);
437 }
438 ControlMessage::decode(&frame.plaintext)
439 }
440
441 pub fn decode_tlvs(frame: &DecodedFrame) -> Result<Vec<Tlv>, CoreError> {
443 payload::decode_tlvs(&frame.plaintext)
444 }
445
446 pub fn start_send_data_with_session(
448 self: Pin<&mut Self>,
449 session: &mut Session,
450 flags: u8,
451 stream_id: u32,
452 plaintext: &[u8],
453 ) -> Result<(), CoreError> {
454 let this = self.get_mut();
455 this.set_key_ring_from_session(session)?;
456 let app_tlv = Tlv::application_data(plaintext)?;
457 let app_payload = payload::encode_tlvs(&[app_tlv])?;
458 this.enqueue_with_specific_key(this.active_key_id, flags, stream_id, &app_payload)?;
459
460 if let Some(ctrl) = session.on_outbound_payload(plaintext.len())? {
461 let ctrl_bytes = ctrl.encode();
462 let rekey_old = match ctrl {
463 ControlMessage::Rekey { old_key_id, .. } => Some(old_key_id),
464 _ => None,
465 };
466 if let Some(old_key_id) = rekey_old {
467 this.enqueue_with_specific_key(old_key_id, flags::IS_CONTROL, 0, &ctrl_bytes)?;
468 this.set_key_ring_from_session(session)?;
469 } else {
470 this.enqueue_with_specific_key(
471 this.active_key_id,
472 flags::IS_CONTROL,
473 0,
474 &ctrl_bytes,
475 )?;
476 }
477 }
478 Ok(())
479 }
480
481 pub fn handle_incoming_with_session(
483 self: Pin<&mut Self>,
484 session: &mut Session,
485 frame: DecodedFrame,
486 ) -> Result<Option<Vec<u8>>, CoreError> {
487 let this = self.get_mut();
488 if frame.header.flags & flags::IS_CONTROL != 0 {
489 let msg = ControlMessage::decode(&frame.plaintext)?;
490 let response = session.handle_control(&msg)?;
491 this.set_key_ring_from_session(session)?;
492 if let Some(resp) = response {
493 this.enqueue_with_specific_key(
494 this.active_key_id,
495 flags::IS_CONTROL,
496 0,
497 &resp.encode(),
498 )?;
499 }
500 return Ok(None);
501 }
502 Ok(Some(frame.plaintext))
503 }
504
505 fn try_decode(&mut self) -> Result<Option<DecodedFrame>, CoreError> {
506 if self.rx.len() < FRAME_HEADER_LEN {
507 return Ok(None);
508 }
509
510 let header = FrameHeader::decode(&self.rx[..FRAME_HEADER_LEN])?;
511 header.validate_v0()?;
512
513 let ct_len = header.ct_len as usize;
514 if ct_len > self.max_ciphertext_len {
515 return Err(CoreError::FrameTooLarge);
516 }
517
518 let total = FRAME_HEADER_LEN + ct_len;
519 if self.rx.len() < total {
520 return Ok(None);
521 }
522
523 let frame_bytes = self.rx.split_to(total);
524 let frame = Frame::from_bytes(&frame_bytes)?;
525
526 self.replay.check_and_record(
527 frame.header.key_id,
528 frame.header.stream_id,
529 frame.header.seq,
530 )?;
531
532 let keys = self
533 .key_for_id(frame.header.key_id)
534 .ok_or(CoreError::UnexpectedKeyId {
535 expected: self.active_key_id,
536 actual: frame.header.key_id,
537 })?;
538 let plaintext = decrypt_frame_with_key(keys, self.inbound_direction, &frame)?;
539
540 Ok(Some(DecodedFrame {
541 header: frame.header,
542 plaintext,
543 }))
544 }
545
546 fn poll_fill_rx(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), CoreError>> {
547 let mut tmp = [0u8; 8192];
548 match Pin::new(&mut self.io).poll_read(cx, &mut tmp) {
549 Poll::Pending => Poll::Pending,
550 Poll::Ready(Ok(0)) => {
551 self.eof = true;
552 Poll::Ready(Ok(()))
553 }
554 Poll::Ready(Ok(n)) => {
555 self.rx.extend_from_slice(&tmp[..n]);
556 Poll::Ready(Ok(()))
557 }
558 Poll::Ready(Err(e)) => Poll::Ready(Err(CoreError::Io(e))),
559 }
560 }
561
562 fn poll_drain_tx(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), CoreError>> {
563 while !self.tx.is_empty() {
564 let n = ready!(Pin::new(&mut self.io).poll_write(cx, &self.tx))?;
565 if n == 0 {
566 return Poll::Ready(Err(CoreError::UnexpectedEof));
567 }
568 self.tx.advance(n);
569 }
570 Poll::Ready(Ok(()))
571 }
572}
573
574impl<T: PollIo + Unpin> Stream for FoctetFramed<T> {
575 type Item = Result<DecodedFrame, CoreError>;
576
577 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
578 let this = self.get_mut();
579
580 loop {
581 match this.try_decode() {
582 Ok(Some(frame)) => return Poll::Ready(Some(Ok(frame))),
583 Ok(None) => {}
584 Err(e) => return Poll::Ready(Some(Err(e))),
585 }
586
587 if this.eof {
588 if this.rx.is_empty() {
589 return Poll::Ready(None);
590 }
591 return Poll::Ready(Some(Err(CoreError::UnexpectedEof)));
592 }
593
594 ready!(this.poll_fill_rx(cx))?;
595 }
596 }
597}
598
599impl<T: PollIo + Unpin> Sink<Vec<u8>> for FoctetFramed<T> {
600 type Error = CoreError;
601
602 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
603 let this = self.get_mut();
604 if this.tx.is_empty() {
605 return Poll::Ready(Ok(()));
606 }
607 this.poll_drain_tx(cx)
608 }
609
610 fn start_send(self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
611 let this = self.get_mut();
612 let active = this.active_keys()?.clone();
613 let frame = encrypt_frame(
614 &active,
615 this.outbound_direction,
616 this.default_flags,
617 this.default_stream_id,
618 this.next_seq,
619 &item,
620 )?;
621 this.next_seq = this
622 .next_seq
623 .checked_add(1)
624 .ok_or(CoreError::SequenceExhausted)?;
625 this.tx.extend_from_slice(&frame.to_bytes());
626 Ok(())
627 }
628
629 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
630 let this = self.get_mut();
631 ready!(this.poll_drain_tx(cx))?;
632 Pin::new(&mut this.io).poll_flush(cx).map_err(CoreError::Io)
633 }
634
635 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
636 let this = self.get_mut();
637 ready!(this.poll_drain_tx(cx))?;
638 ready!(Pin::new(&mut this.io).poll_flush(cx)).map_err(CoreError::Io)?;
639 Pin::new(&mut this.io).poll_close(cx).map_err(CoreError::Io)
640 }
641}
642
643#[derive(Clone, Debug)]
645pub struct FoctetStream<T> {
646 framed: FoctetFramed<T>,
647 read_buf: BytesMut,
648 max_write_frame: usize,
649}
650
651impl<T> FoctetStream<T> {
652 pub fn new(framed: FoctetFramed<T>) -> Self {
654 Self {
655 framed,
656 read_buf: BytesMut::new(),
657 max_write_frame: 64 * 1024,
658 }
659 }
660
661 pub fn with_max_write_frame(mut self, max: usize) -> Self {
663 self.max_write_frame = max.max(1);
664 self
665 }
666
667 pub fn into_inner(self) -> FoctetFramed<T> {
669 self.framed
670 }
671
672 pub fn framed_ref(&self) -> &FoctetFramed<T> {
674 &self.framed
675 }
676
677 pub fn framed_mut(&mut self) -> &mut FoctetFramed<T> {
679 &mut self.framed
680 }
681}
682
683impl<T: PollIo + Unpin> FoctetStream<T> {
684 pub fn poll_read_plain(
686 self: Pin<&mut Self>,
687 cx: &mut Context<'_>,
688 out: &mut [u8],
689 ) -> Poll<Result<usize, CoreError>> {
690 let this = self.get_mut();
691
692 if !this.read_buf.is_empty() {
693 let n = out.len().min(this.read_buf.len());
694 out[..n].copy_from_slice(&this.read_buf.split_to(n));
695 return Poll::Ready(Ok(n));
696 }
697
698 match ready!(Pin::new(&mut this.framed).poll_next(cx)) {
699 Some(Ok(frame)) => {
700 this.read_buf.extend_from_slice(&frame.plaintext);
701 let n = out.len().min(this.read_buf.len());
702 out[..n].copy_from_slice(&this.read_buf.split_to(n));
703 Poll::Ready(Ok(n))
704 }
705 Some(Err(e)) => Poll::Ready(Err(e)),
706 None => Poll::Ready(Ok(0)),
707 }
708 }
709
710 pub fn poll_write_plain(
712 self: Pin<&mut Self>,
713 cx: &mut Context<'_>,
714 buf: &[u8],
715 ) -> Poll<Result<usize, CoreError>> {
716 let this = self.get_mut();
717 let n = buf.len().min(this.max_write_frame);
718 if n == 0 {
719 return Poll::Ready(Ok(0));
720 }
721
722 ready!(Pin::new(&mut this.framed).poll_ready(cx))?;
723 Pin::new(&mut this.framed).start_send(buf[..n].to_vec())?;
724 Poll::Ready(Ok(n))
725 }
726
727 pub fn poll_flush_plain(
729 self: Pin<&mut Self>,
730 cx: &mut Context<'_>,
731 ) -> Poll<Result<(), CoreError>> {
732 Pin::new(&mut self.get_mut().framed).poll_flush(cx)
733 }
734
735 pub fn poll_close_plain(
737 self: Pin<&mut Self>,
738 cx: &mut Context<'_>,
739 ) -> Poll<Result<(), CoreError>> {
740 Pin::new(&mut self.get_mut().framed).poll_close(cx)
741 }
742}
743
744#[cfg(test)]
745mod tests {
746 use std::{
747 collections::VecDeque,
748 pin::Pin,
749 task::{Context, Poll, Waker},
750 };
751
752 use futures_core::Stream;
753 use futures_sink::Sink;
754
755 use crate::{
756 crypto::{Direction, EphemeralKeyPair, derive_traffic_keys, random_session_salt},
757 io::{PollRead, PollWrite},
758 };
759
760 use super::{FoctetFramed, flags};
761
762 #[derive(Default, Debug)]
763 struct MemoryIo {
764 inbound: VecDeque<u8>,
765 outbound: Vec<u8>,
766 }
767
768 impl MemoryIo {
769 fn push_inbound(&mut self, bytes: &[u8]) {
770 self.inbound.extend(bytes.iter().copied());
771 }
772 }
773
774 impl PollRead for MemoryIo {
775 fn poll_read(
776 mut self: Pin<&mut Self>,
777 _cx: &mut Context<'_>,
778 buf: &mut [u8],
779 ) -> Poll<std::io::Result<usize>> {
780 let n = buf.len().min(self.inbound.len());
781 for slot in buf.iter_mut().take(n) {
782 *slot = self.inbound.pop_front().expect("inbound byte");
783 }
784 Poll::Ready(Ok(n))
785 }
786 }
787
788 impl PollWrite for MemoryIo {
789 fn poll_write(
790 mut self: Pin<&mut Self>,
791 _cx: &mut Context<'_>,
792 buf: &[u8],
793 ) -> Poll<std::io::Result<usize>> {
794 self.outbound.extend_from_slice(buf);
795 Poll::Ready(Ok(buf.len()))
796 }
797
798 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
799 Poll::Ready(Ok(()))
800 }
801
802 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
803 Poll::Ready(Ok(()))
804 }
805 }
806
807 fn noop_waker() -> Waker {
808 Waker::noop().clone()
809 }
810
811 #[test]
812 fn framed_sink_stream_roundtrip() {
813 let eph_a = EphemeralKeyPair::generate();
814 let eph_b = EphemeralKeyPair::generate();
815 let ss = eph_a.shared_secret(eph_b.public).expect("shared secret");
816 let salt = random_session_salt();
817 let keys = derive_traffic_keys(&ss, &salt, 1).expect("traffic keys");
818
819 let io = MemoryIo::default();
820 let mut framed = FoctetFramed::new(io, keys.clone(), Direction::C2S, Direction::C2S)
821 .with_stream_id(9)
822 .with_default_flags(flags::IS_CONTROL);
823
824 let waker = noop_waker();
825 let mut cx = Context::from_waker(&waker);
826
827 Pin::new(&mut framed)
828 .start_send(b"hello framed".to_vec())
829 .expect("queue");
830 match Pin::new(&mut framed).poll_flush(&mut cx) {
831 Poll::Ready(Ok(())) => {}
832 _ => panic!("flush failed"),
833 }
834
835 let outbound = framed.get_ref().outbound.clone();
836 framed.get_mut().push_inbound(&outbound);
837
838 let item = match Pin::new(&mut framed).poll_next(&mut cx) {
839 Poll::Ready(Some(Ok(frame))) => frame,
840 other => panic!("unexpected poll_next: {other:?}"),
841 };
842 assert_eq!(item.plaintext, b"hello framed");
843 assert_eq!(item.header.stream_id, 9);
844 assert_eq!(item.header.flags, flags::IS_CONTROL);
845 }
846}