1use std::{
2 io::{Read, Write},
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use crate::{
8 CoreError,
9 control::ControlMessage,
10 crypto::{Direction, TrafficKeys, decrypt_frame_with_key, encrypt_frame},
11 frame::{FRAME_HEADER_LEN, Frame, FrameHeader},
12 payload::{self, Tlv},
13 replay::{DEFAULT_REPLAY_WINDOW, ReplayProtector},
14 session::Session,
15};
16
17#[cfg(any(feature = "runtime-tokio", feature = "runtime-futures"))]
18use crate::frame::{FoctetFramed, FoctetStream};
19
20pub trait PollRead {
22 fn poll_read(
24 self: Pin<&mut Self>,
25 cx: &mut Context<'_>,
26 buf: &mut [u8],
27 ) -> Poll<std::io::Result<usize>>;
28}
29
30pub trait PollWrite {
32 fn poll_write(
34 self: Pin<&mut Self>,
35 cx: &mut Context<'_>,
36 buf: &[u8],
37 ) -> Poll<std::io::Result<usize>>;
38 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>>;
40 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>>;
42}
43
44pub trait PollIo: PollRead + PollWrite {}
46
47impl<T: PollRead + PollWrite> PollIo for T {}
48
49#[cfg(feature = "runtime-tokio")]
51#[derive(Debug, Clone)]
52pub struct TokioIo<T> {
53 inner: T,
54}
55
56#[cfg(feature = "runtime-tokio")]
57impl<T> TokioIo<T> {
58 pub fn new(inner: T) -> Self {
60 Self { inner }
61 }
62
63 pub fn into_inner(self) -> T {
65 self.inner
66 }
67}
68
69#[cfg(feature = "runtime-tokio")]
70impl<T> PollRead for TokioIo<T>
71where
72 T: tokio::io::AsyncRead + Unpin,
73{
74 fn poll_read(
75 mut self: Pin<&mut Self>,
76 cx: &mut Context<'_>,
77 buf: &mut [u8],
78 ) -> Poll<std::io::Result<usize>> {
79 let mut read_buf = tokio::io::ReadBuf::new(buf);
80 match Pin::new(&mut self.inner).poll_read(cx, &mut read_buf) {
81 Poll::Pending => Poll::Pending,
82 Poll::Ready(Ok(())) => Poll::Ready(Ok(read_buf.filled().len())),
83 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
84 }
85 }
86}
87
88#[cfg(feature = "runtime-tokio")]
89impl<T> PollWrite for TokioIo<T>
90where
91 T: tokio::io::AsyncWrite + Unpin,
92{
93 fn poll_write(
94 mut self: Pin<&mut Self>,
95 cx: &mut Context<'_>,
96 buf: &[u8],
97 ) -> Poll<std::io::Result<usize>> {
98 Pin::new(&mut self.inner).poll_write(cx, buf)
99 }
100
101 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
102 Pin::new(&mut self.inner).poll_flush(cx)
103 }
104
105 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
106 Pin::new(&mut self.inner).poll_shutdown(cx)
107 }
108}
109
110#[cfg(feature = "runtime-futures")]
112#[derive(Debug, Clone)]
113pub struct FuturesIo<T> {
114 inner: T,
115}
116
117#[cfg(feature = "runtime-futures")]
118impl<T> FuturesIo<T> {
119 pub fn new(inner: T) -> Self {
121 Self { inner }
122 }
123
124 pub fn into_inner(self) -> T {
126 self.inner
127 }
128}
129
130#[cfg(feature = "runtime-futures")]
131impl<T> PollRead for FuturesIo<T>
132where
133 T: futures_io::AsyncRead + Unpin,
134{
135 fn poll_read(
136 mut self: Pin<&mut Self>,
137 cx: &mut Context<'_>,
138 buf: &mut [u8],
139 ) -> Poll<std::io::Result<usize>> {
140 Pin::new(&mut self.inner).poll_read(cx, buf)
141 }
142}
143
144#[cfg(feature = "runtime-futures")]
145impl<T> PollWrite for FuturesIo<T>
146where
147 T: futures_io::AsyncWrite + Unpin,
148{
149 fn poll_write(
150 mut self: Pin<&mut Self>,
151 cx: &mut Context<'_>,
152 buf: &[u8],
153 ) -> Poll<std::io::Result<usize>> {
154 Pin::new(&mut self.inner).poll_write(cx, buf)
155 }
156
157 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
158 Pin::new(&mut self.inner).poll_flush(cx)
159 }
160
161 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
162 Pin::new(&mut self.inner).poll_close(cx)
163 }
164}
165
166#[cfg(feature = "runtime-tokio")]
167impl<T> FoctetFramed<TokioIo<T>>
168where
169 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
170{
171 pub fn from_tokio(
173 io: T,
174 keys: TrafficKeys,
175 inbound_direction: Direction,
176 outbound_direction: Direction,
177 ) -> Self {
178 Self::new(
179 TokioIo::new(io),
180 keys,
181 inbound_direction,
182 outbound_direction,
183 )
184 }
185}
186
187#[cfg(feature = "runtime-futures")]
188impl<T> FoctetFramed<FuturesIo<T>>
189where
190 T: futures_io::AsyncRead + futures_io::AsyncWrite + Unpin,
191{
192 pub fn from_futures(
194 io: T,
195 keys: TrafficKeys,
196 inbound_direction: Direction,
197 outbound_direction: Direction,
198 ) -> Self {
199 Self::new(
200 FuturesIo::new(io),
201 keys,
202 inbound_direction,
203 outbound_direction,
204 )
205 }
206}
207
208#[cfg(feature = "runtime-tokio")]
209impl<T> FoctetStream<TokioIo<T>>
210where
211 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
212{
213 pub fn from_tokio(
215 io: T,
216 keys: TrafficKeys,
217 inbound_direction: Direction,
218 outbound_direction: Direction,
219 ) -> Self {
220 let framed = FoctetFramed::from_tokio(io, keys, inbound_direction, outbound_direction);
221 Self::new(framed)
222 }
223}
224
225#[cfg(feature = "runtime-futures")]
226impl<T> FoctetStream<FuturesIo<T>>
227where
228 T: futures_io::AsyncRead + futures_io::AsyncWrite + Unpin,
229{
230 pub fn from_futures(
232 io: T,
233 keys: TrafficKeys,
234 inbound_direction: Direction,
235 outbound_direction: Direction,
236 ) -> Self {
237 let framed = FoctetFramed::from_futures(io, keys, inbound_direction, outbound_direction);
238 Self::new(framed)
239 }
240}
241
242#[cfg(feature = "runtime-tokio")]
243impl<T> tokio::io::AsyncRead for FoctetStream<T>
244where
245 T: PollRead + PollWrite + Unpin,
246{
247 fn poll_read(
248 mut self: Pin<&mut Self>,
249 cx: &mut Context<'_>,
250 buf: &mut tokio::io::ReadBuf<'_>,
251 ) -> Poll<std::io::Result<()>> {
252 if buf.remaining() == 0 {
253 return Poll::Ready(Ok(()));
254 }
255 let dst = buf.initialize_unfilled();
256 match Pin::new(&mut *self).poll_read_plain(cx, dst) {
257 Poll::Pending => Poll::Pending,
258 Poll::Ready(Ok(n)) => {
259 buf.advance(n);
260 Poll::Ready(Ok(()))
261 }
262 Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e))),
263 }
264 }
265}
266
267#[cfg(feature = "runtime-tokio")]
268impl<T> tokio::io::AsyncWrite for FoctetStream<T>
269where
270 T: PollRead + PollWrite + Unpin,
271{
272 fn poll_write(
273 mut self: Pin<&mut Self>,
274 cx: &mut Context<'_>,
275 buf: &[u8],
276 ) -> Poll<std::io::Result<usize>> {
277 match Pin::new(&mut *self).poll_write_plain(cx, buf) {
278 Poll::Pending => Poll::Pending,
279 Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
280 Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e))),
281 }
282 }
283
284 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
285 match Pin::new(&mut *self).poll_flush_plain(cx) {
286 Poll::Pending => Poll::Pending,
287 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
288 Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e))),
289 }
290 }
291
292 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
293 match Pin::new(&mut *self).poll_close_plain(cx) {
294 Poll::Pending => Poll::Pending,
295 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
296 Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e))),
297 }
298 }
299}
300
301#[cfg(feature = "runtime-futures")]
302impl<T> futures_io::AsyncRead for FoctetStream<T>
303where
304 T: PollRead + PollWrite + Unpin,
305{
306 fn poll_read(
307 mut self: Pin<&mut Self>,
308 cx: &mut Context<'_>,
309 buf: &mut [u8],
310 ) -> Poll<std::io::Result<usize>> {
311 match Pin::new(&mut *self).poll_read_plain(cx, buf) {
312 Poll::Pending => Poll::Pending,
313 Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
314 Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e))),
315 }
316 }
317}
318
319#[cfg(feature = "runtime-futures")]
320impl<T> futures_io::AsyncWrite for FoctetStream<T>
321where
322 T: PollRead + PollWrite + Unpin,
323{
324 fn poll_write(
325 mut self: Pin<&mut Self>,
326 cx: &mut Context<'_>,
327 buf: &[u8],
328 ) -> Poll<std::io::Result<usize>> {
329 match Pin::new(&mut *self).poll_write_plain(cx, buf) {
330 Poll::Pending => Poll::Pending,
331 Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
332 Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e))),
333 }
334 }
335
336 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
337 match Pin::new(&mut *self).poll_flush_plain(cx) {
338 Poll::Pending => Poll::Pending,
339 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
340 Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e))),
341 }
342 }
343
344 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
345 match Pin::new(&mut *self).poll_close_plain(cx) {
346 Poll::Pending => Poll::Pending,
347 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
348 Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e))),
349 }
350 }
351}
352
353#[derive(Debug)]
355pub struct SyncIo<T> {
356 io: T,
357 keys: Vec<TrafficKeys>,
358 active_key_id: u8,
359 max_retained_keys: usize,
360 inbound_direction: Direction,
361 outbound_direction: Direction,
362 default_stream_id: u32,
363 default_flags: u8,
364 next_seq: u64,
365 max_ciphertext_len: usize,
366 replay: ReplayProtector,
367}
368
369impl<T> SyncIo<T> {
370 pub fn new(
372 io: T,
373 keys: TrafficKeys,
374 inbound_direction: Direction,
375 outbound_direction: Direction,
376 ) -> Self {
377 Self {
378 io,
379 active_key_id: keys.key_id,
380 keys: vec![keys],
381 max_retained_keys: 2,
382 inbound_direction,
383 outbound_direction,
384 default_stream_id: 0,
385 default_flags: 0,
386 next_seq: 0,
387 max_ciphertext_len: 16 * 1024 * 1024,
388 replay: ReplayProtector::new(DEFAULT_REPLAY_WINDOW),
389 }
390 }
391
392 pub fn with_stream_id(mut self, stream_id: u32) -> Self {
394 self.default_stream_id = stream_id;
395 self
396 }
397
398 pub fn with_default_flags(mut self, flags: u8) -> Self {
400 self.default_flags = flags;
401 self
402 }
403
404 pub fn with_max_ciphertext_len(mut self, max_len: usize) -> Self {
406 self.max_ciphertext_len = max_len;
407 self
408 }
409
410 pub fn with_max_retained_keys(mut self, max: usize) -> Self {
412 self.max_retained_keys = max.max(1);
413 self
414 }
415
416 pub fn active_key_id(&self) -> u8 {
418 self.active_key_id
419 }
420
421 pub fn known_key_ids(&self) -> Vec<u8> {
423 self.keys.iter().map(|k| k.key_id).collect()
424 }
425
426 pub fn install_active_keys(&mut self, keys: TrafficKeys) {
428 self.keys.retain(|k| k.key_id != keys.key_id);
429 self.keys.insert(0, keys.clone());
430 self.active_key_id = keys.key_id;
431 let keep = self.max_retained_keys + 1;
432 if self.keys.len() > keep {
433 self.keys.truncate(keep);
434 }
435 }
436
437 pub fn into_inner(self) -> T {
439 self.io
440 }
441
442 fn active_keys(&self) -> Result<&TrafficKeys, CoreError> {
443 self.keys
444 .iter()
445 .find(|k| k.key_id == self.active_key_id)
446 .ok_or(CoreError::MissingSessionSecret)
447 }
448
449 fn key_for_id(&self, key_id: u8) -> Option<&TrafficKeys> {
450 self.keys.iter().find(|k| k.key_id == key_id)
451 }
452
453 fn set_key_ring_from_session(&mut self, session: &Session) -> Result<(), CoreError> {
454 let ring = session.key_ring()?;
455 self.keys = ring;
456 self.active_key_id = self
457 .keys
458 .first()
459 .map(|k| k.key_id)
460 .ok_or(CoreError::InvalidSessionState)?;
461 let keep = self.max_retained_keys + 1;
462 if self.keys.len() > keep {
463 self.keys.truncate(keep);
464 }
465 Ok(())
466 }
467}
468
469impl<T: Read + Write> SyncIo<T> {
470 fn send_with_key(
471 &mut self,
472 keys: &TrafficKeys,
473 flags: u8,
474 stream_id: u32,
475 plaintext: &[u8],
476 ) -> Result<(), CoreError> {
477 let frame = encrypt_frame(
478 keys,
479 self.outbound_direction,
480 flags,
481 stream_id,
482 self.next_seq,
483 plaintext,
484 )?;
485 self.next_seq = self.next_seq.wrapping_add(1);
486 self.io.write_all(&frame.to_bytes())?;
487 self.io.flush()?;
488 Ok(())
489 }
490
491 pub fn send(&mut self, plaintext: &[u8]) -> Result<(), CoreError> {
493 self.send_with(self.default_flags, self.default_stream_id, plaintext)
494 }
495
496 pub fn send_with(
498 &mut self,
499 flags: u8,
500 stream_id: u32,
501 plaintext: &[u8],
502 ) -> Result<(), CoreError> {
503 let active = self.active_keys()?.clone();
504 self.send_with_key(&active, flags, stream_id, plaintext)
505 }
506
507 pub fn send_tlvs_with(
509 &mut self,
510 flags: u8,
511 stream_id: u32,
512 tlvs: &[Tlv],
513 ) -> Result<(), CoreError> {
514 let payload = payload::encode_tlvs(tlvs)?;
515 self.send_with(flags, stream_id, &payload)
516 }
517
518 pub fn recv(&mut self) -> Result<Vec<u8>, CoreError> {
520 let mut header_buf = [0u8; FRAME_HEADER_LEN];
521 self.io.read_exact(&mut header_buf)?;
522 let header = FrameHeader::decode(&header_buf)?;
523 header.validate_v0()?;
524
525 let ct_len = header.ct_len as usize;
526 if ct_len > self.max_ciphertext_len {
527 return Err(CoreError::FrameTooLarge);
528 }
529
530 let mut ciphertext = vec![0u8; ct_len];
531 self.io.read_exact(&mut ciphertext)?;
532
533 self.replay
534 .check_and_record(header.key_id, header.stream_id, header.seq)?;
535
536 let keys = self
537 .key_for_id(header.key_id)
538 .ok_or(CoreError::UnexpectedKeyId {
539 expected: self.active_key_id,
540 actual: header.key_id,
541 })?;
542
543 let frame = Frame { header, ciphertext };
544 decrypt_frame_with_key(keys, self.inbound_direction, &frame)
545 }
546
547 pub fn send_control(&mut self, stream_id: u32, msg: &ControlMessage) -> Result<(), CoreError> {
549 self.send_with(crate::frame::flags::IS_CONTROL, stream_id, &msg.encode())
550 }
551
552 pub fn send_control_with_key_id(
554 &mut self,
555 stream_id: u32,
556 key_id: u8,
557 msg: &ControlMessage,
558 ) -> Result<(), CoreError> {
559 let key = self
560 .key_for_id(key_id)
561 .ok_or(CoreError::UnexpectedKeyId {
562 expected: self.active_key_id,
563 actual: key_id,
564 })?
565 .clone();
566 self.send_with_key(
567 &key,
568 crate::frame::flags::IS_CONTROL,
569 stream_id,
570 &msg.encode(),
571 )
572 }
573
574 pub fn recv_control(&mut self) -> Result<ControlMessage, CoreError> {
576 let plaintext = self.recv()?;
577 ControlMessage::decode(&plaintext)
578 }
579
580 pub fn recv_tlvs(&mut self) -> Result<Vec<Tlv>, CoreError> {
582 let plaintext = self.recv()?;
583 payload::decode_tlvs(&plaintext)
584 }
585
586 pub fn send_data_with_session(
588 &mut self,
589 session: &mut Session,
590 flags: u8,
591 stream_id: u32,
592 plaintext: &[u8],
593 ) -> Result<(), CoreError> {
594 self.set_key_ring_from_session(session)?;
595 let app_tlv = Tlv::application_data(plaintext)?;
596 self.send_tlvs_with(flags, stream_id, &[app_tlv])?;
597
598 if let Some(ctrl) = session.on_outbound_payload(plaintext.len())? {
599 let rekey_old = match &ctrl {
600 ControlMessage::Rekey { old_key_id, .. } => Some(*old_key_id),
601 _ => None,
602 };
603 if let Some(old_key_id) = rekey_old {
604 self.send_control_with_key_id(0, old_key_id, &ctrl)?;
605 self.set_key_ring_from_session(session)?;
606 } else {
607 self.send_control(0, &ctrl)?;
608 }
609 }
610 Ok(())
611 }
612
613 pub fn recv_application_with_session(
615 &mut self,
616 session: &mut Session,
617 ) -> Result<Option<Vec<u8>>, CoreError> {
618 let mut header_buf = [0u8; FRAME_HEADER_LEN];
619 self.io.read_exact(&mut header_buf)?;
620 let header = FrameHeader::decode(&header_buf)?;
621 header.validate_v0()?;
622
623 let ct_len = header.ct_len as usize;
624 if ct_len > self.max_ciphertext_len {
625 return Err(CoreError::FrameTooLarge);
626 }
627
628 let mut ciphertext = vec![0u8; ct_len];
629 self.io.read_exact(&mut ciphertext)?;
630
631 self.replay
632 .check_and_record(header.key_id, header.stream_id, header.seq)?;
633
634 let keys = self
635 .key_for_id(header.key_id)
636 .ok_or(CoreError::UnexpectedKeyId {
637 expected: self.active_key_id,
638 actual: header.key_id,
639 })?;
640
641 let frame = Frame { header, ciphertext };
642 let plaintext = decrypt_frame_with_key(keys, self.inbound_direction, &frame)?;
643
644 if frame.header.flags & crate::frame::flags::IS_CONTROL != 0 {
645 let msg = ControlMessage::decode(&plaintext)?;
646 let response = session.handle_control(&msg)?;
647 self.set_key_ring_from_session(session)?;
648 if let Some(resp) = response {
649 self.send_control(0, &resp)?;
650 }
651 return Ok(None);
652 }
653
654 Ok(Some(plaintext))
655 }
656}
657
658impl From<CoreError> for std::io::Error {
659 fn from(value: CoreError) -> Self {
660 std::io::Error::other(value)
661 }
662}