#[cfg(not(feature = "std"))]
use alloc::format;
pub mod framing;
use crate::errors::CoreError;
use async_lock::Mutex;
use bytes::Bytes;
use embedded_io_async::{Error, Read, Write};
pub struct EmbeddedLeg<R, W, const N: usize> {
rx: Mutex<(R, [u8; N])>,
tx: Mutex<W>,
}
impl<R, W, const N: usize> EmbeddedLeg<R, W, N> {
pub fn new(reader: R, writer: W) -> Self {
Self {
rx: Mutex::new((reader, [0u8; N])),
tx: Mutex::new(writer),
}
}
pub fn into_inner(self) -> (R, W) {
let (r, _buf) = self.rx.into_inner();
let w = self.tx.into_inner();
(r, w)
}
}
impl<R, W, const N: usize> EmbeddedLeg<R, W, N>
where
R: Read,
W: Write,
{
pub async fn send_frame(&self, data: &[u8]) -> Result<(), CoreError> {
let header = framing::encode_header(data.len(), N)
.map_err(|e| CoreError::NetworkError(format!("framing: {:?}", e)))?;
let mut w = self.tx.lock().await;
w.write_all(&header)
.await
.map_err(|e| CoreError::NetworkError(format!("write header: {:?}", e.kind())))?;
w.write_all(data)
.await
.map_err(|e| CoreError::NetworkError(format!("write payload: {:?}", e.kind())))?;
w.flush()
.await
.map_err(|e| CoreError::NetworkError(format!("flush: {:?}", e.kind())))?;
Ok(())
}
pub async fn recv_frame(&self) -> Result<Bytes, CoreError> {
let mut header = [0u8; framing::HEADER_LEN];
let mut guard = self.rx.lock().await;
let (r, buf) = &mut *guard;
r.read_exact(&mut header)
.await
.map_err(|_| CoreError::NetworkError("read header".into()))?;
let len = framing::decode_header(&header, N)
.map_err(|e| CoreError::NetworkError(format!("framing: {:?}", e)))?;
r.read_exact(&mut buf[..len])
.await
.map_err(|_| CoreError::NetworkError("read payload".into()))?;
Ok(Bytes::copy_from_slice(&buf[..len]))
}
}
#[macro_export]
macro_rules! impl_embedded_session_transport {
($reader:ty, $writer:ty, $n:expr) => {
impl $crate::transport::session_transport::SessionTransport
for $crate::transport::legs::embedded::EmbeddedLeg<$reader, $writer, $n>
{
fn send_bytes(
&self,
data: &[u8],
) -> impl core::future::Future<Output = Result<(), $crate::errors::CoreError>> + Send
{
self.send_frame(data)
}
fn recv_bytes(
&self,
) -> impl core::future::Future<
Output = Result<bytes::Bytes, $crate::errors::CoreError>,
> + Send {
self.recv_frame()
}
}
};
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::session_transport::SessionTransport;
use core::convert::Infallible;
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Mutex as TokioMutex, Notify};
crate::impl_embedded_session_transport!(MockReader, MockWriter, 1024);
struct Pipe {
buf: VecDeque<u8>,
closed: bool,
}
struct MockReader {
read_from: Arc<TokioMutex<Pipe>>,
read_notify: Arc<Notify>,
max_read: usize,
}
struct MockWriter {
write_to: Arc<TokioMutex<Pipe>>,
write_notify: Arc<Notify>,
}
fn duplex_pair() -> ((MockReader, MockWriter), (MockReader, MockWriter)) {
duplex_pair_with_chunk(usize::MAX)
}
fn duplex_pair_with_chunk(
max_read: usize,
) -> ((MockReader, MockWriter), (MockReader, MockWriter)) {
let ab = Arc::new(TokioMutex::new(Pipe {
buf: VecDeque::new(),
closed: false,
}));
let ba = Arc::new(TokioMutex::new(Pipe {
buf: VecDeque::new(),
closed: false,
}));
let n_ab = Arc::new(Notify::new());
let n_ba = Arc::new(Notify::new());
let a = (
MockReader {
read_from: ba.clone(),
read_notify: n_ba.clone(),
max_read,
},
MockWriter {
write_to: ab.clone(),
write_notify: n_ab.clone(),
},
);
let b = (
MockReader {
read_from: ab,
read_notify: n_ab,
max_read,
},
MockWriter {
write_to: ba,
write_notify: n_ba,
},
);
(a, b)
}
impl embedded_io_async::ErrorType for MockReader {
type Error = Infallible;
}
impl embedded_io_async::ErrorType for MockWriter {
type Error = Infallible;
}
impl Read for MockReader {
async fn read(&mut self, out: &mut [u8]) -> Result<usize, Infallible> {
if out.is_empty() {
return Ok(0);
}
loop {
let notified = self.read_notify.notified();
tokio::pin!(notified);
notified.as_mut().enable();
{
let mut p = self.read_from.lock().await;
if !p.buf.is_empty() {
let n = out.len().min(p.buf.len()).min(self.max_read);
for slot in out.iter_mut().take(n) {
*slot = p.buf.pop_front().expect("checked non-empty");
}
return Ok(n);
}
if p.closed {
return Ok(0);
}
}
notified.await;
}
}
}
impl Write for MockWriter {
async fn write(&mut self, data: &[u8]) -> Result<usize, Infallible> {
let mut p = self.write_to.lock().await;
p.buf.extend(data.iter().copied());
drop(p);
self.write_notify.notify_waiters();
Ok(data.len())
}
}
#[tokio::test]
async fn send_frame_writes_length_prefixed_payload() {
let ((a_r, a_w), (mut b_r, _b_w)) = duplex_pair();
let leg: EmbeddedLeg<MockReader, MockWriter, 1024> = EmbeddedLeg::new(a_r, a_w);
leg.send_frame(b"hello").await.expect("send_frame");
let mut buf = vec![0u8; 4 + 5];
tokio::time::timeout(Duration::from_secs(1), b_r.read_exact(&mut buf))
.await
.expect("peer read should not hang")
.expect("peer read_exact");
assert_eq!(&buf[..4], &[0x00, 0x00, 0x00, 0x05], "length prefix");
assert_eq!(&buf[4..], b"hello", "payload");
}
#[tokio::test]
async fn recv_frame_reads_length_prefixed_payload() {
let ((a_r, a_w), (_b_r, mut b_w)) = duplex_pair();
let leg: EmbeddedLeg<MockReader, MockWriter, 1024> = EmbeddedLeg::new(a_r, a_w);
b_w.write_all(&[0x00, 0x00, 0x00, 0x05]).await.unwrap();
b_w.write_all(b"world").await.unwrap();
let frame = tokio::time::timeout(Duration::from_secs(1), leg.recv_frame())
.await
.expect("recv should not hang")
.expect("recv_frame");
assert_eq!(&frame[..], b"world");
}
#[tokio::test]
async fn recv_frame_reassembles_under_adversarial_chunking() {
let ((a_r, a_w), (_b_r, mut b_w)) = duplex_pair_with_chunk(1);
let leg: EmbeddedLeg<MockReader, MockWriter, 1024> = EmbeddedLeg::new(a_r, a_w);
let writer = tokio::spawn(async move {
for &b in &[0x00, 0x00, 0x00, 0x05] {
b_w.write_all(&[b]).await.expect("write header byte");
tokio::task::yield_now().await;
}
for &b in b"abcde" {
b_w.write_all(&[b]).await.expect("write payload byte");
tokio::task::yield_now().await;
}
});
let frame = tokio::time::timeout(Duration::from_secs(1), leg.recv_frame())
.await
.expect("recv should not hang under 1-byte chunking")
.expect("recv_frame");
writer.await.expect("writer task");
assert_eq!(&frame[..], b"abcde");
}
#[tokio::test]
async fn recv_frame_rejects_oversized_header() {
let ((a_r, a_w), (_b_r, mut b_w)) = duplex_pair();
let leg: EmbeddedLeg<MockReader, MockWriter, 8> = EmbeddedLeg::new(a_r, a_w);
b_w.write_all(&[0x00, 0x00, 0x00, 0x10])
.await
.expect("write bogus header");
let err = tokio::time::timeout(Duration::from_secs(1), leg.recv_frame())
.await
.expect("recv should error fast, not hang on payload");
match err {
Err(CoreError::NetworkError(msg)) => {
assert!(
msg.contains("framing"),
"expected framing error, got: {msg}"
);
}
other => panic!("expected NetworkError(framing), got {other:?}"),
}
}
#[tokio::test]
async fn recv_frame_returns_error_on_eof_mid_header() {
let ((a_r, a_w), (_b_r, b_w)) = duplex_pair();
let leg: EmbeddedLeg<MockReader, MockWriter, 1024> = EmbeddedLeg::new(a_r, a_w);
let target = b_w.write_to.clone();
let notify = b_w.write_notify.clone();
{
let mut p = target.lock().await;
p.buf.extend([0x00u8, 0x00]);
p.closed = true;
}
notify.notify_waiters();
let err = tokio::time::timeout(Duration::from_secs(1), leg.recv_frame())
.await
.expect("recv should error fast on EOF");
match err {
Err(CoreError::NetworkError(msg)) => {
assert!(
msg.contains("read header"),
"expected `read header` in error msg, got: {msg}"
);
}
other => panic!("expected NetworkError(read header), got {other:?}"),
}
}
#[tokio::test]
async fn send_recv_run_concurrently_without_blocking() {
let ((a_r, a_w), (b_r, b_w)) = duplex_pair();
let leg_a: Arc<EmbeddedLeg<MockReader, MockWriter, 1024>> =
Arc::new(EmbeddedLeg::new(a_r, a_w));
let leg_b: Arc<EmbeddedLeg<MockReader, MockWriter, 1024>> =
Arc::new(EmbeddedLeg::new(b_r, b_w));
let leg_a_send = Arc::clone(&leg_a);
let send = tokio::spawn(async move { leg_a_send.send_frame(b"ping").await });
let leg_b_recv = Arc::clone(&leg_b);
let recv = tokio::spawn(async move { leg_b_recv.recv_frame().await });
tokio::time::timeout(Duration::from_secs(1), async {
send.await.expect("send task").expect("send_frame result");
let frame = recv.await.expect("recv task").expect("recv_frame result");
assert_eq!(&frame[..], b"ping");
})
.await
.expect("concurrent send+recv should complete within 1s");
}
#[tokio::test]
async fn session_transport_round_trip() {
let ((a_r, a_w), (b_r, b_w)) = duplex_pair();
let leg_a: EmbeddedLeg<MockReader, MockWriter, 1024> = EmbeddedLeg::new(a_r, a_w);
let leg_b: EmbeddedLeg<MockReader, MockWriter, 1024> = EmbeddedLeg::new(b_r, b_w);
tokio::time::timeout(Duration::from_secs(1), async {
<EmbeddedLeg<MockReader, MockWriter, 1024> as SessionTransport>::send_bytes(
&leg_a,
b"hello-trait",
)
.await
.expect("send_bytes");
let frame =
<EmbeddedLeg<MockReader, MockWriter, 1024> as SessionTransport>::recv_bytes(&leg_b)
.await
.expect("recv_bytes");
assert_eq!(&frame[..], b"hello-trait");
})
.await
.expect("trait round-trip should complete within 1s");
}
use crate::api::session::{ConnectionState, PhantomSession};
use crate::transport::handshake::{ClientHello, HandshakeResponse, HandshakeServer};
use crate::transport::types::{
PacketFlags, PacketHeader, PhantomPacket, SessionId, StreamId as TransportStreamId,
};
fn decrypt_incoming_local(
server_session: &crate::transport::session::Session,
bytes: &[u8],
) -> Vec<u8> {
let pkt = PhantomPacket::from_wire(bytes).expect("deserialize PhantomPacket");
assert!(
pkt.header.flags.contains(PacketFlags::ENCRYPTED),
"expected ENCRYPTED flag on application data"
);
server_session
.decrypt_packet(&pkt.header, &pkt.payload)
.expect("decrypt application data")
}
fn encrypt_outgoing_local(
server_session: &crate::transport::session::Session,
session_id: SessionId,
stream_id: TransportStreamId,
sequence: u32,
payload: &[u8],
) -> Vec<u8> {
let flag_bits = PacketFlags::RELIABLE | PacketFlags::ENCRYPTED;
let header =
PacketHeader::new(session_id, stream_id, sequence, PacketFlags::new(flag_bits))
.with_epoch(server_session.current_epoch());
let ct = server_session
.encrypt_packet(&header, payload)
.expect("encrypt reply");
let packet = PhantomPacket::new(header, ct);
packet.to_wire()
}
struct TeeWriter {
inner: MockWriter,
recorder: Arc<TokioMutex<Vec<u8>>>,
}
impl embedded_io_async::ErrorType for TeeWriter {
type Error = Infallible;
}
impl Write for TeeWriter {
async fn write(&mut self, data: &[u8]) -> Result<usize, Infallible> {
self.recorder.lock().await.extend_from_slice(data);
self.inner.write(data).await
}
}
crate::impl_embedded_session_transport!(MockReader, TeeWriter, 16384);
crate::impl_embedded_session_transport!(MockReader, MockWriter, 16384);
#[tokio::test]
async fn test_phantom_session_handshake_via_embedded_leg() {
let ((client_r, client_w_inner), (server_r, server_w)) = duplex_pair();
let client_wire_recorder: Arc<TokioMutex<Vec<u8>>> = Arc::new(TokioMutex::new(Vec::new()));
let client_w = TeeWriter {
inner: client_w_inner,
recorder: Arc::clone(&client_wire_recorder),
};
let client_leg: EmbeddedLeg<MockReader, TeeWriter, 16384> =
EmbeddedLeg::new(client_r, client_w);
let server_leg: EmbeddedLeg<MockReader, MockWriter, 16384> =
EmbeddedLeg::new(server_r, server_w);
let server_hs = HandshakeServer::new().expect("HandshakeServer::new");
let server_pinned_key = server_hs.verifying_key().clone();
let session = PhantomSession::connect_with_transport(
"test-server:9000",
client_leg,
server_pinned_key,
);
session
.send(b"early-data".to_vec())
.await
.expect("queue early-data");
let server_handle = tokio::spawn(async move {
let client_ip = "127.0.0.1".parse().expect("parse loopback IP");
let client_hello_bytes =
tokio::time::timeout(Duration::from_secs(5), server_leg.recv_frame())
.await
.expect("recv ClientHello within 5s")
.expect("recv ClientHello frame");
let client_hello = borsh::from_slice::<ClientHello>(&client_hello_bytes)
.expect("deserialize ClientHello");
let server_session = loop {
let response = server_hs.process_client_hello(&client_hello, 0, client_ip);
match response {
HandshakeResponse::Retry(retry) => {
let retry_bytes =
borsh::to_vec(&retry).expect("serialize HelloRetryRequest");
tokio::time::timeout(
Duration::from_secs(5),
server_leg.send_frame(&retry_bytes),
)
.await
.expect("send retry within 5s")
.expect("send retry frame");
let next_bytes =
tokio::time::timeout(Duration::from_secs(5), server_leg.recv_frame())
.await
.expect("recv retried ClientHello within 5s")
.expect("recv retried ClientHello frame");
let next_hello = borsh::from_slice::<ClientHello>(&next_bytes)
.expect("deserialize retried ClientHello");
let resp2 = server_hs.process_client_hello(&next_hello, 0, client_ip);
match resp2 {
HandshakeResponse::Success(server_hello, session, _) => {
let server_hello_bytes =
borsh::to_vec(&server_hello).expect("serialize ServerHello");
tokio::time::timeout(
Duration::from_secs(5),
server_leg.send_frame(&server_hello_bytes),
)
.await
.expect("send ServerHello within 5s")
.expect("send ServerHello frame");
break session;
}
other => panic!("expected success after retry, got {other:?}"),
}
}
HandshakeResponse::Success(server_hello, session, _) => {
let server_hello_bytes =
borsh::to_vec(&server_hello).expect("serialize ServerHello");
tokio::time::timeout(
Duration::from_secs(5),
server_leg.send_frame(&server_hello_bytes),
)
.await
.expect("send ServerHello within 5s")
.expect("send ServerHello frame");
break session;
}
HandshakeResponse::Reject(r) => panic!("unexpected reject: {r:?}"),
HandshakeResponse::Fail(e) => panic!("handshake failed: {e:?}"),
}
};
let session_id = *server_session.id();
let early_frame = tokio::time::timeout(Duration::from_secs(5), server_leg.recv_frame())
.await
.expect("recv early-data within 5s")
.expect("recv early-data frame");
let early_plain = decrypt_incoming_local(&server_session, &early_frame);
assert_eq!(early_plain, b"early-data");
let post_frame = tokio::time::timeout(Duration::from_secs(5), server_leg.recv_frame())
.await
.expect("recv after-handshake within 5s")
.expect("recv after-handshake frame");
let post_plain = decrypt_incoming_local(&server_session, &post_frame);
assert_eq!(post_plain, b"after-handshake");
let reply = encrypt_outgoing_local(&server_session, session_id, 1, 1, b"server-reply");
tokio::time::timeout(Duration::from_secs(5), server_leg.send_frame(&reply))
.await
.expect("send reply within 5s")
.expect("send reply frame");
});
tokio::time::sleep(Duration::from_millis(500)).await;
assert_eq!(session.connection_state(), ConnectionState::Connected);
session
.send(b"after-handshake".to_vec())
.await
.expect("send after-handshake");
let reply = tokio::time::timeout(Duration::from_secs(5), session.recv())
.await
.expect("recv reply within 5s")
.expect("recv server-reply");
assert_eq!(reply, b"server-reply");
tokio::time::timeout(Duration::from_secs(5), server_handle)
.await
.expect("server task within 5s")
.expect("server task joined");
session.disconnect().await.expect("close session");
let wire = client_wire_recorder.lock().await;
assert!(
!wire
.windows(b"early-data".len())
.any(|w| w == b"early-data"),
"plaintext early-data leaked onto the embedded wire"
);
assert!(
!wire
.windows(b"after-handshake".len())
.any(|w| w == b"after-handshake"),
"plaintext after-handshake leaked onto the embedded wire"
);
}
}