use crate::{
crypto::SigningPrivateKey,
error::StreamingError,
primitives::{Destination, DestinationId},
runtime::Runtime,
sam::protocol::streaming::packet::{Packet, PacketBuilder},
};
use rand::Rng;
use alloc::{collections::VecDeque, vec::Vec};
const LOG_TARGET: &str = "emissary::streaming::pending";
const INITIAL_WINDOW_SIZE: usize = 6usize;
pub enum PendingStreamResult {
DoNothing,
Send {
packet: Vec<u8>,
},
SendAndDestroy {
packet: Vec<u8>,
},
Destroy,
}
pub struct PendingStream<R: Runtime> {
pub destination_id: DestinationId,
pub remote_destination: Destination,
pub established: R::Instant,
pub packets: VecDeque<Vec<u8>>,
pub recv_stream_id: u32,
pub send_stream_id: u32,
pub seq_nro: u32,
}
impl<R: Runtime> PendingStream<R> {
pub fn new(
destination: &Destination,
remote_destination: Destination,
recv_stream_id: u32,
syn_payload: Vec<u8>,
signing_key: &SigningPrivateKey,
) -> (Self, Vec<u8>) {
let send_stream_id = R::rng().next_u32();
let packet = PacketBuilder::new(send_stream_id)
.with_send_stream_id(recv_stream_id)
.with_seq_nro(0)
.with_from_included(destination)
.with_synchronize()
.with_signature()
.build_and_sign(signing_key)
.to_vec();
(
Self {
destination_id: remote_destination.id(),
remote_destination,
established: R::now(),
packets: match syn_payload.is_empty() {
true => VecDeque::new(),
false => VecDeque::from_iter([syn_payload]),
},
recv_stream_id,
send_stream_id,
seq_nro: 0u32,
},
packet,
)
}
fn on_packet_inner(&mut self, packet: Vec<u8>) -> Result<Option<Vec<u8>>, StreamingError> {
let Packet {
send_stream_id,
recv_stream_id,
seq_nro,
flags,
payload,
..
} = Packet::parse::<R>(&packet)?;
tracing::trace!(
target: LOG_TARGET,
remote = %self.destination_id,
?send_stream_id,
?recv_stream_id,
payload_len = ?payload.len(),
"inbound message",
);
if flags.reset() || flags.close() {
return Err(StreamingError::Closed);
}
if payload.is_empty() || seq_nro <= self.seq_nro {
return Ok(None);
}
if self.packets.len() == INITIAL_WINDOW_SIZE {
return Err(StreamingError::ReceiveWindowFull);
}
self.packets.push_back(payload.to_vec());
self.seq_nro = seq_nro;
Ok(Some(
PacketBuilder::new(self.send_stream_id)
.with_send_stream_id(self.recv_stream_id)
.with_ack_through(seq_nro)
.build()
.to_vec(),
))
}
pub fn on_packet(&mut self, packet: Vec<u8>) -> PendingStreamResult {
match self.on_packet_inner(packet) {
Ok(None) => PendingStreamResult::DoNothing,
Ok(Some(packet)) => PendingStreamResult::Send { packet },
Err(StreamingError::ReceiveWindowFull) => PendingStreamResult::SendAndDestroy {
packet: PacketBuilder::new(self.send_stream_id)
.with_send_stream_id(self.recv_stream_id)
.with_reset()
.build()
.to_vec(),
},
Err(StreamingError::Closed) => PendingStreamResult::Destroy,
Err(StreamingError::Malformed(_)) => PendingStreamResult::DoNothing,
Err(_) => unreachable!(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::{mock::MockRuntime, noop::NoopRuntime};
#[test]
fn ignore_duplicate_ack() {
let signing_key = SigningPrivateKey::from_bytes(&[0u8; 32]).unwrap();
let (mut stream, _) = PendingStream::<NoopRuntime>::new(
&Destination::new::<NoopRuntime>(signing_key.public()),
Destination::random().0,
1337u32,
vec![],
&SigningPrivateKey::random(NoopRuntime::rng()),
);
let packet = PacketBuilder::new(stream.send_stream_id)
.with_send_stream_id(stream.recv_stream_id)
.with_seq_nro(0u32)
.build()
.to_vec();
match stream.on_packet(packet) {
PendingStreamResult::DoNothing => {}
_ => panic!("invalid result"),
}
}
#[test]
fn destroy_stream_on_close() {
let signing_key = SigningPrivateKey::from_bytes(&[0u8; 32]).unwrap();
let (mut stream, _) = PendingStream::<NoopRuntime>::new(
&Destination::new::<NoopRuntime>(signing_key.public()),
Destination::random().0,
1337u32,
vec![],
&SigningPrivateKey::random(NoopRuntime::rng()),
);
let packet = PacketBuilder::new(stream.send_stream_id)
.with_send_stream_id(stream.recv_stream_id)
.with_seq_nro(0u32)
.with_close()
.build()
.to_vec();
match stream.on_packet(packet) {
PendingStreamResult::Destroy => {}
_ => panic!("invalid result"),
}
}
#[test]
fn destroy_stream_on_reset() {
let signing_key = SigningPrivateKey::from_bytes(&[0u8; 32]).unwrap();
let (mut stream, _) = PendingStream::<NoopRuntime>::new(
&Destination::new::<NoopRuntime>(signing_key.public()),
Destination::random().0,
1337u32,
vec![],
&SigningPrivateKey::random(NoopRuntime::rng()),
);
let packet = PacketBuilder::new(stream.send_stream_id)
.with_send_stream_id(stream.recv_stream_id)
.with_seq_nro(0u32)
.with_reset()
.build()
.to_vec();
match stream.on_packet(packet) {
PendingStreamResult::Destroy => {}
_ => panic!("invalid result"),
}
}
#[test]
fn buffer_data_correctly() {
let signing_key = SigningPrivateKey::from_bytes(&[0u8; 32]).unwrap();
let (mut stream, _) = PendingStream::<NoopRuntime>::new(
&Destination::new::<NoopRuntime>(signing_key.public()),
Destination::random().0,
1337u32,
vec![],
&SigningPrivateKey::random(NoopRuntime::rng()),
);
for i in 1..=3 {
let packet = PacketBuilder::new(stream.send_stream_id)
.with_send_stream_id(stream.recv_stream_id)
.with_seq_nro(i as u32)
.with_payload(b"hello, world")
.build()
.to_vec();
match stream.on_packet(packet) {
PendingStreamResult::Send { packet } => {
let Packet { ack_through, .. } = Packet::parse::<MockRuntime>(&packet).unwrap();
assert_eq!(ack_through, i as u32);
}
_ => panic!("invalid result"),
}
}
assert_eq!(stream.packets.len(), 3);
for packet in &stream.packets {
assert_eq!(packet, b"hello, world");
}
let packet = PacketBuilder::new(stream.send_stream_id)
.with_send_stream_id(stream.recv_stream_id)
.with_seq_nro(3u32)
.build()
.to_vec();
match stream.on_packet(packet) {
PendingStreamResult::DoNothing => {}
_ => panic!("invalid result"),
}
}
#[test]
fn ignore_invalid_packets() {
let signing_key = SigningPrivateKey::from_bytes(&[0u8; 32]).unwrap();
let (mut stream, _) = PendingStream::<NoopRuntime>::new(
&Destination::new::<NoopRuntime>(signing_key.public()),
Destination::random().0,
1337u32,
vec![],
&SigningPrivateKey::random(NoopRuntime::rng()),
);
match stream.on_packet(vec![1, 2, 3, 4]) {
PendingStreamResult::DoNothing => {}
_ => panic!("invalid result"),
}
}
#[test]
fn receive_window_full() {
let signing_key = SigningPrivateKey::from_bytes(&[0u8; 32]).unwrap();
let (mut stream, _) = PendingStream::<NoopRuntime>::new(
&Destination::new::<NoopRuntime>(signing_key.public()),
Destination::random().0,
1337u32,
vec![],
&SigningPrivateKey::random(NoopRuntime::rng()),
);
for i in 1..=INITIAL_WINDOW_SIZE {
let packet = PacketBuilder::new(stream.send_stream_id)
.with_send_stream_id(stream.recv_stream_id)
.with_seq_nro(i as u32)
.with_payload(b"hello, world")
.build()
.to_vec();
match stream.on_packet(packet) {
PendingStreamResult::Send { packet } => {
let Packet { ack_through, .. } = Packet::parse::<MockRuntime>(&packet).unwrap();
assert_eq!(ack_through, i as u32);
}
_ => panic!("invalid result"),
}
}
assert_eq!(stream.packets.len(), INITIAL_WINDOW_SIZE);
for packet in &stream.packets {
assert_eq!(packet, b"hello, world");
}
let packet = PacketBuilder::new(stream.send_stream_id)
.with_send_stream_id(stream.recv_stream_id)
.with_seq_nro(7 as u32)
.with_payload(b"hello, world")
.build()
.to_vec();
match stream.on_packet(packet) {
PendingStreamResult::SendAndDestroy { packet } => {
assert!(Packet::parse::<MockRuntime>(&packet).unwrap().flags.reset());
}
_ => panic!("invalid result"),
}
}
#[test]
fn syn_payload_not_empty() {
let signing_key = SigningPrivateKey::from_bytes(&[0u8; 32]).unwrap();
let (mut stream, _) = PendingStream::<NoopRuntime>::new(
&Destination::new::<NoopRuntime>(signing_key.public()),
Destination::random().0,
1337u32,
vec![1, 2, 3, 4],
&SigningPrivateKey::random(NoopRuntime::rng()),
);
match stream.packets.pop_front() {
Some(payload) => {
assert_eq!(payload, vec![1, 2, 3, 4]);
}
_ => panic!("expected payload"),
}
}
}