use crate::{
crypto::chachapoly::ChaChaPoly,
error::Ssu2Error,
primitives::RouterId,
runtime::{Runtime, UdpSocket},
transport::{
ssu2::{
message::{data::DataMessageBuilder, Block, HeaderKind, HeaderReader},
session::KeyContext,
Packet,
},
TerminationReason,
},
};
use bytes::Bytes;
use futures::FutureExt;
use thingbuf::mpsc::Receiver;
use alloc::{collections::VecDeque, vec::Vec};
use core::{
future::Future,
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
const LOG_TARGET: &str = "emissary::ssu2::terminating";
const TERMINATION_TIMEOUT: Duration = Duration::from_secs(60);
pub struct TerminationContext<R: Runtime> {
pub address: SocketAddr,
pub dst_id: u64,
pub duration: Duration,
pub intro_key: [u8; 32],
pub k_session_confirmed: Option<[u8; 32]>,
pub next_pkt_num: u32,
pub reason: TerminationReason,
pub recv_key_ctx: KeyContext,
pub router_id: RouterId,
pub rx: Receiver<Packet>,
pub send_key_ctx: KeyContext,
pub socket: R::UdpSocket,
}
pub struct TerminatingSsu2Session<R: Runtime> {
address: SocketAddr,
dst_id: u64,
intro_key: [u8; 32],
k_session_confirmed: Option<[u8; 32]>,
pkt: Bytes,
recv_key_ctx: KeyContext,
router_id: RouterId,
rx: Receiver<Packet>,
socket: R::UdpSocket,
timer: R::Timer,
write_buffer: VecDeque<Bytes>,
}
impl<R: Runtime> TerminatingSsu2Session<R> {
pub fn new(ctx: TerminationContext<R>) -> Self {
let pkt = DataMessageBuilder::default()
.with_dst_id(ctx.dst_id)
.with_pkt_num(ctx.next_pkt_num)
.with_key_context(ctx.intro_key, &ctx.send_key_ctx)
.with_termination(ctx.reason)
.build::<R>()
.freeze();
let write_buffer = if !core::matches!(ctx.reason, TerminationReason::TerminationReceived) {
VecDeque::from([pkt.clone()])
} else {
VecDeque::new()
};
Self {
address: ctx.address,
dst_id: ctx.dst_id,
intro_key: ctx.intro_key,
k_session_confirmed: ctx.k_session_confirmed,
pkt,
recv_key_ctx: ctx.recv_key_ctx,
router_id: ctx.router_id,
rx: ctx.rx,
socket: ctx.socket,
timer: R::timer(TERMINATION_TIMEOUT),
write_buffer,
}
}
fn on_packet(&mut self, mut pkt: Vec<u8>) -> Result<(), Ssu2Error> {
let pkt_num = match HeaderReader::new(self.intro_key, &mut pkt)?
.parse(self.recv_key_ctx.k_header_2)?
{
HeaderKind::Data { pkt_num, .. } => pkt_num,
kind => {
tracing::trace!(
target: LOG_TARGET,
router_id = %self.router_id,
dst_id = ?self.dst_id,
?kind,
"invalid message, expected `Data`",
);
return Err(Ssu2Error::UnexpectedMessage);
}
};
let mut payload = pkt[16..].to_vec();
ChaChaPoly::with_nonce(&self.recv_key_ctx.k_data, pkt_num as u64)
.decrypt_with_ad(&pkt[..16], &mut payload)?;
if !Block::parse::<R>(&payload)
.map_err(|error| {
tracing::warn!(
target: LOG_TARGET,
?error,
"failed to parse message block",
);
Ssu2Error::Malformed
})?
.iter()
.any(|message| core::matches!(message, Block::Termination { .. }))
{
self.write_buffer.push_back(self.pkt.clone());
}
Ok(())
}
}
impl<R: Runtime> Future for TerminatingSsu2Session<R> {
type Output = (RouterId, u64);
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
match self.rx.poll_recv(cx) {
Poll::Pending => break,
Poll::Ready(None) => return Poll::Ready((self.router_id.clone(), self.dst_id)),
Poll::Ready(Some(Packet { mut pkt, .. })) => match self.k_session_confirmed {
Some(key) =>
if let Ok(mut reader) = HeaderReader::new(self.intro_key, &mut pkt) {
match reader.parse(key) {
Ok(HeaderKind::SessionConfirmed { .. }) => {
let pkt = self.pkt.clone();
self.write_buffer.push_back(pkt);
}
Ok(pkt) => tracing::debug!(
target: LOG_TARGET,
router_id = %self.router_id,
dst_id = ?self.dst_id,
?pkt,
"unexpected packet, expected SessionConfirmed",
),
Err(error) => tracing::debug!(
target: LOG_TARGET,
router_id = %self.router_id,
dst_id = ?self.dst_id,
?error,
"failed to parse packet with key meant for SessionConfirmed",
),
}
},
None =>
if let Err(error) = self.on_packet(pkt) {
tracing::debug!(
target: LOG_TARGET,
router_id = %self.router_id,
dst_id = ?self.dst_id,
?error,
"failed to handle packet",
);
},
},
}
}
let address = self.address;
while let Some(pkt) = self.write_buffer.pop_back() {
match Pin::new(&mut self.socket).poll_send_to(cx, &pkt, address) {
Poll::Pending => {
self.write_buffer.push_front(pkt);
break;
}
Poll::Ready(None) => return Poll::Ready((self.router_id.clone(), self.dst_id)),
Poll::Ready(Some(_)) => {}
}
}
if self.timer.poll_unpin(cx).is_ready() {
tracing::trace!(
target: LOG_TARGET,
router_id = %self.router_id,
dst_id = ?self.dst_id,
"shutting down session",
);
return Poll::Ready((self.router_id.clone(), self.dst_id));
}
Poll::Pending
}
}