use crate::{
crypto::{chachapoly::ChaChaPoly, SigningPublicKey},
error::Ssu2Error,
events::EventHandle,
i2np::Message,
primitives::{RouterId, RouterInfo},
router::context::RouterContext,
runtime::{Counter, Instant, MetricsHandle, Runtime, UdpSocket},
subsystem::{OutboundMessage, OutboundMessageRecycle, SubsystemEvent},
transport::{
ssu2::{
duplicate::DuplicateFilter,
message::{Block, HeaderKind, HeaderReader},
metrics::*,
peer_test::types::PeerTestHandle,
relay::types::RelayHandle,
session::{
active::{
ack::RemoteAckManager, fragment::FragmentHandler,
path_validation::PathValidationState, transmission::TransmissionManager,
},
terminating::TerminationContext,
KeyContext,
},
Packet,
},
TerminationReason,
},
};
use bytes::BytesMut;
use futures::{FutureExt, StreamExt};
use thingbuf::mpsc::{with_recycle, Receiver, Sender};
use alloc::{boxed::Box, collections::VecDeque, sync::Arc, vec};
use core::{
cmp::min,
future::Future,
net::SocketAddr,
pin::Pin,
sync::atomic::AtomicU32,
task::{Context, Poll},
time::Duration,
};
mod ack;
mod fragment;
mod path_validation;
mod peer_test;
mod relay;
mod transmission;
const LOG_TARGET: &str = "emissary::ssu2::active";
const CMD_CHANNEL_SIZE: usize = 512;
const SSU2_RESEND_TIMEOUT: Duration = Duration::from_millis(40);
const MAX_IMMEDIATE_ACK_TIMEOUT: Duration = Duration::from_millis(5);
const MAX_ACK_TIMEOUT: Duration = Duration::from_millis(150);
struct AckTimer<R: Runtime> {
immediate: Option<R::Timer>,
normal: Option<R::Timer>,
}
impl<R: Runtime> AckTimer<R> {
fn new() -> Self {
Self {
immediate: None,
normal: None,
}
}
fn schedule_immediate_ack(&mut self, rtt: Duration) {
if self.immediate.is_none() {
self.immediate = Some(R::timer(min(rtt / 16, MAX_IMMEDIATE_ACK_TIMEOUT)));
}
}
fn schedule_ack(&mut self, rtt: Duration) {
if self.immediate.is_none() && self.normal.is_none() {
self.normal = Some(R::timer(min(rtt / 6, MAX_ACK_TIMEOUT)));
}
}
}
impl<R: Runtime> Future for AckTimer<R> {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(timer) = &mut self.immediate {
if timer.poll_unpin(cx).is_ready() {
self.immediate = None;
self.normal = None;
return Poll::Ready(());
}
}
if let Some(timer) = &mut self.normal {
if timer.poll_unpin(cx).is_ready() {
self.immediate = None;
self.normal = None;
return Poll::Ready(());
}
}
Poll::Pending
}
}
pub struct Ssu2SessionContext {
pub address: SocketAddr,
pub dst_id: u64,
pub intro_key: [u8; 32],
pub max_payload_size: usize,
pub pkt_rx: Receiver<Packet>,
pub recv_key_ctx: KeyContext,
pub router_id: RouterId,
pub send_key_ctx: KeyContext,
pub verifying_key: SigningPublicKey,
}
pub struct Ssu2Session<R: Runtime> {
ack_timer: AckTimer<R>,
address: SocketAddr,
dst_id: u64,
duplicate_filter: DuplicateFilter<R>,
event_handle: EventHandle<R>,
fragment_handler: FragmentHandler<R>,
inbound_bandwidth: usize,
intro_key: [u8; 32],
msg_rx: Receiver<OutboundMessage, OutboundMessageRecycle>,
msg_tx: Sender<OutboundMessage, OutboundMessageRecycle>,
outbound_bandwidth: usize,
path_validation_state: PathValidationState<R>,
peer_test_handle: PeerTestHandle<R>,
pending_router_info: Option<Box<RouterInfo>>,
pkt_rx: Receiver<Packet>,
recv_key_ctx: KeyContext,
relay_handle: RelayHandle<R>,
remote_ack: RemoteAckManager,
resend_timer: Option<R::Timer>,
router_ctx: RouterContext<R>,
router_id: RouterId,
send_key_ctx: KeyContext,
socket: R::UdpSocket,
started: R::Instant,
transmission: TransmissionManager<R>,
transport_tx: Sender<SubsystemEvent>,
verifying_key: SigningPublicKey,
write_buffer: VecDeque<(BytesMut, Option<SocketAddr>)>,
}
impl<R: Runtime> Ssu2Session<R> {
pub fn new(
context: Ssu2SessionContext,
socket: R::UdpSocket,
transport_tx: Sender<SubsystemEvent>,
router_ctx: RouterContext<R>,
peer_test_handle: PeerTestHandle<R>,
relay_handle: RelayHandle<R>,
) -> Self {
let (msg_tx, msg_rx) = with_recycle(CMD_CHANNEL_SIZE, OutboundMessageRecycle::default());
let metrics = router_ctx.metrics_handle().clone();
let pkt_num = Arc::new(AtomicU32::new(1u32));
let event_handle = router_ctx.event_handle().clone();
tracing::debug!(
target: LOG_TARGET,
dst_id = ?context.dst_id,
address = ?context.address,
"starting active session",
);
Self {
ack_timer: AckTimer::<R>::new(),
address: context.address,
dst_id: context.dst_id,
duplicate_filter: DuplicateFilter::new(),
event_handle,
fragment_handler: FragmentHandler::<R>::new(metrics.clone()),
inbound_bandwidth: 0usize,
intro_key: context.intro_key,
msg_rx,
msg_tx,
path_validation_state: PathValidationState::default(),
outbound_bandwidth: 0usize,
peer_test_handle,
pending_router_info: None,
pkt_rx: context.pkt_rx,
recv_key_ctx: context.recv_key_ctx,
relay_handle,
remote_ack: RemoteAckManager::new(),
resend_timer: None,
router_ctx,
router_id: context.router_id.clone(),
send_key_ctx: context.send_key_ctx.clone(),
started: R::now(),
transmission: TransmissionManager::<R>::new(
context.dst_id,
context.router_id,
context.intro_key,
context.send_key_ctx,
pkt_num,
metrics,
context.max_payload_size,
),
socket,
transport_tx,
verifying_key: context.verifying_key,
write_buffer: VecDeque::new(),
}
}
fn handle_message(&mut self, message: Message) {
if message.is_expired::<R>() {
tracing::trace!(
target: LOG_TARGET,
router_id = %self.router_id,
message_type = ?message.message_type,
message_id = ?message.message_id,
expiration = ?message.expiration,
"discarding expired message",
);
self.router_ctx
.metrics_handle()
.counter(DROPPED_PKTS)
.increment_with_label(1, "reason", "expired");
return;
}
if !self.duplicate_filter.insert(message.message_id) {
tracing::debug!(
target: LOG_TARGET,
router_id = %self.router_id,
message_id = ?message.message_id,
message_type = ?message.message_type,
"ignoring duplicate message",
);
self.router_ctx
.metrics_handle()
.counter(DUPLICATE_PKTS)
.increment_with_label(1, "kind", "i2np");
return;
}
if let Err(error) = self.transport_tx.try_send(SubsystemEvent::Message {
messages: vec![(self.router_id.clone(), message.clone())],
}) {
tracing::warn!(
target: LOG_TARGET,
router_id = %self.router_id,
?error,
"failed to dispatch messages to subsystems",
);
}
}
fn handle_packet(&mut self, pkt: Packet) -> Result<(), Ssu2Error> {
let Packet { mut pkt, address } = pkt;
let (pkt_num, immediate_ack) = match HeaderReader::new(self.intro_key, &mut pkt)?
.parse(self.recv_key_ctx.k_header_2)?
{
HeaderKind::Data {
immediate_ack,
pkt_num,
} => (pkt_num, immediate_ack),
kind => {
tracing::warn!(
target: LOG_TARGET,
router_id = %self.router_id,
?kind,
"unexpected packet",
);
return Err(Ssu2Error::UnexpectedMessage);
}
};
tracing::trace!(
target: LOG_TARGET,
router_id = %self.router_id,
?pkt_num,
pkt_len = ?pkt.len(),
?immediate_ack,
"handle packet",
);
let mut payload = pkt[16..].to_vec();
ChaChaPoly::with_nonce(&self.recv_key_ctx.k_data, pkt_num as u64)
.decrypt_with_ad(&pkt[..16], &mut payload)?;
if immediate_ack {
self.ack_timer.schedule_immediate_ack(self.transmission.round_trip_time());
}
let blocks = Block::parse::<R>(&payload).map_err(Ssu2Error::Parse)?;
self.validate_pkt_address(address, pkt_num);
for block in blocks {
match block {
Block::Termination {
reason,
num_valid_pkts,
} => {
self.transmission.register_remote_pkt(pkt_num);
tracing::debug!(
target: LOG_TARGET,
router_id = %self.router_id,
?reason,
?num_valid_pkts,
"session terminated by remote router",
);
return Err(Ssu2Error::SessionTerminated(TerminationReason::ssu2(
reason,
)));
}
Block::I2Np { message } => {
self.handle_message(message);
self.transmission.register_remote_pkt(pkt_num);
self.ack_timer.schedule_ack(self.transmission.round_trip_time());
}
Block::FirstFragment {
message_type,
message_id,
expiration,
fragment,
} => {
self.transmission.register_remote_pkt(pkt_num);
self.ack_timer.schedule_ack(self.transmission.round_trip_time());
if let Some(message) = self.fragment_handler.first_fragment(
message_type,
message_id,
expiration,
fragment,
) {
self.handle_message(message);
}
}
Block::FollowOnFragment {
last,
message_id,
fragment_num,
fragment,
} => {
self.transmission.register_remote_pkt(pkt_num);
self.ack_timer.schedule_ack(self.transmission.round_trip_time());
if let Some(message) = self.fragment_handler.follow_on_fragment(
message_id,
fragment_num,
last,
fragment,
) {
self.handle_message(message);
}
}
Block::Ack {
ack_through,
num_acks,
ranges,
} => {
self.transmission.register_remote_pkt(pkt_num);
self.transmission.register_ack(ack_through, num_acks, &ranges);
}
Block::Address { .. } | Block::DateTime { .. } | Block::Padding { .. } => {
self.transmission.register_remote_pkt(pkt_num);
}
Block::PeerTest { message } => {
self.transmission.register_remote_pkt(pkt_num);
self.ack_timer.schedule_ack(self.transmission.round_trip_time());
self.handle_peer_test_message(message);
}
Block::RouterInfo { router_info, .. } => {
tracing::trace!(
target: LOG_TARGET,
router_id = %self.router_id,
received_router_id = %router_info.identity.id(),
"received an in-session router info",
);
self.pending_router_info = Some(router_info);
}
Block::RelayRequest {
nonce,
relay_tag,
address,
message,
signature,
} => {
self.remote_ack.register_pkt(pkt_num);
self.ack_timer.schedule_ack(self.transmission.round_trip_time());
self.handle_relay_request(nonce, relay_tag, address, message, signature);
}
Block::RelayIntro {
router_id,
nonce,
relay_tag,
address,
message,
signature,
} => {
self.remote_ack.register_pkt(pkt_num);
self.ack_timer.schedule_ack(self.transmission.round_trip_time());
self.handle_relay_intro(
router_id, nonce, relay_tag, address, message, signature,
);
}
Block::RelayResponse {
nonce,
address,
token,
rejection,
message,
signature,
} => {
self.remote_ack.register_pkt(pkt_num);
self.ack_timer.schedule_ack(self.transmission.round_trip_time());
self.handle_relay_response(
nonce, address, token, rejection, message, signature,
);
}
Block::PathChallenge { challenge } => {
self.remote_ack.register_pkt(pkt_num);
self.ack_timer.schedule_ack(self.transmission.round_trip_time());
self.handle_path_challenge(address, challenge);
}
Block::PathResponse { response } => {
self.remote_ack.register_pkt(pkt_num);
self.ack_timer.schedule_ack(self.transmission.round_trip_time());
self.handle_path_response(response);
}
block => {
tracing::trace!(
target: LOG_TARGET,
router_id = %self.router_id,
?block,
"ignoring block",
);
self.transmission.register_remote_pkt(pkt_num);
}
}
}
self.pending_router_info.take();
Ok(())
}
pub async fn run(mut self) -> TerminationContext<R> {
self.transport_tx
.send(SubsystemEvent::ConnectionEstablished {
router_id: self.router_id.clone(),
tx: self.msg_tx.clone(),
})
.await
.expect("manager to stay alive");
let reason = (&mut self).await;
self.transport_tx
.send(SubsystemEvent::ConnectionClosed {
router_id: self.router_id.clone(),
})
.await
.expect("manager to stay alive");
TerminationContext {
address: self.address,
dst_id: self.dst_id,
intro_key: self.intro_key,
k_session_confirmed: None,
next_pkt_num: self.transmission.next_pkt_num(),
reason,
recv_key_ctx: self.recv_key_ctx,
router_id: self.router_id,
rx: self.pkt_rx,
send_key_ctx: self.send_key_ctx,
socket: self.socket,
duration: self.started.elapsed(),
}
}
}
impl<R: Runtime> Future for Ssu2Session<R> {
type Output = TerminationReason;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
match self.pkt_rx.poll_recv(cx) {
Poll::Pending => break,
Poll::Ready(None) => return Poll::Ready(TerminationReason::Unspecified),
Poll::Ready(Some(pkt)) => {
self.inbound_bandwidth += pkt.pkt.len();
match self.handle_packet(pkt) {
Ok(()) => {}
Err(Ssu2Error::SessionTerminated(reason)) => return Poll::Ready(reason),
Err(error) => {
tracing::debug!(
target: LOG_TARGET,
?error,
"failed to parse ssu2 message blocks",
);
debug_assert!(false);
self.router_ctx
.metrics_handle()
.counter(DROPPED_PKTS)
.increment_with_label(1, "reason", "invalid-active");
}
}
}
}
}
while self.transmission.has_capacity() {
match self.msg_rx.poll_recv(cx) {
Poll::Pending => break,
Poll::Ready(None) => return Poll::Ready(TerminationReason::RouterShutdown),
Poll::Ready(Some(OutboundMessage::Message(message))) => {
self.transmission.schedule(message);
}
Poll::Ready(Some(OutboundMessage::MessageWithFeedback(message, feedback_tx))) => {
self.transmission.schedule(message);
let _ = feedback_tx.send(());
}
Poll::Ready(Some(OutboundMessage::Messages(messages))) =>
for message in messages {
self.transmission.schedule(message);
},
Poll::Ready(Some(OutboundMessage::Dummy)) => {}
}
}
loop {
let Some(ref mut timer) = self.resend_timer else {
break;
};
let Poll::Ready(()) = timer.poll_unpin(cx) else {
break;
};
if let Some(pkts) = self.transmission.drain_expired() {
let address = self.address;
for (pkt, destination) in pkts.into_iter() {
match Pin::new(&mut self.socket).poll_send_to(
cx,
&pkt,
destination.unwrap_or(address),
) {
Poll::Pending => {}
Poll::Ready(None) => return Poll::Ready(TerminationReason::RouterShutdown),
Poll::Ready(Some(_)) => {}
}
}
}
self.resend_timer = Some(R::timer(SSU2_RESEND_TIMEOUT));
}
loop {
match self.peer_test_handle.poll_next_unpin(cx) {
Poll::Pending => break,
Poll::Ready(None) => return Poll::Ready(TerminationReason::RouterShutdown),
Poll::Ready(Some(command)) => self.handle_peer_test_command(command),
}
}
loop {
match self.relay_handle.poll_next_unpin(cx) {
Poll::Pending => break,
Poll::Ready(None) => return Poll::Ready(TerminationReason::RouterShutdown),
Poll::Ready(Some(command)) => self.handle_relay_command(command),
}
}
if self.ack_timer.poll_unpin(cx).is_ready() {
let pkt = self.transmission.build_explicit_ack();
let address = self.address;
match Pin::new(&mut self.socket).poll_send_to(cx, &pkt, address) {
Poll::Pending => {
self.write_buffer.push_front((pkt, None));
}
Poll::Ready(None) => return Poll::Ready(TerminationReason::RouterShutdown),
Poll::Ready(Some(_)) => {}
}
}
{
let address = self.address;
while let Some((pkt, destination)) = self.write_buffer.pop_front() {
match Pin::new(&mut self.socket).poll_send_to(
cx,
&pkt,
destination.unwrap_or(address),
) {
Poll::Pending => {
self.write_buffer.push_front((pkt, destination));
break;
}
Poll::Ready(None) => return Poll::Ready(TerminationReason::RouterShutdown),
Poll::Ready(Some(nwritten)) => {
self.router_ctx
.metrics_handle()
.counter(OUTBOUND_BANDWIDTH)
.increment(nwritten);
self.router_ctx.metrics_handle().counter(OUTBOUND_PKTS).increment(1);
self.outbound_bandwidth += nwritten;
}
}
}
if self.write_buffer.is_empty() {
if let Some(pkts) = self.transmission.drain() {
for (pkt, destination) in pkts {
match Pin::new(&mut self.socket).poll_send_to(
cx,
&pkt,
destination.unwrap_or(address),
) {
Poll::Pending => {
self.write_buffer.push_back((pkt, destination));
}
Poll::Ready(None) =>
return Poll::Ready(TerminationReason::RouterShutdown),
Poll::Ready(Some(nwritten)) => {
self.router_ctx
.metrics_handle()
.counter(OUTBOUND_BANDWIDTH)
.increment(nwritten);
self.router_ctx
.metrics_handle()
.counter(OUTBOUND_PKTS)
.increment(1);
self.outbound_bandwidth += nwritten;
}
}
}
self.resend_timer = {
let mut timer = R::timer(SSU2_RESEND_TIMEOUT);
let _ = timer.poll_unpin(cx);
Some(timer)
};
}
}
}
if self.event_handle.poll_unpin(cx).is_ready() {
self.event_handle.transport_inbound_bandwidth(self.inbound_bandwidth);
self.event_handle.transport_outbound_bandwidth(self.outbound_bandwidth);
self.inbound_bandwidth = 0;
self.outbound_bandwidth = 0;
}
loop {
match self.path_validation_state.poll_unpin(cx) {
Poll::Pending => break,
Poll::Ready(()) => self.handle_path_response_timeout(),
}
}
let _ = self.duplicate_filter.poll_unpin(cx);
let _ = self.fragment_handler.poll_unpin(cx);
Poll::Pending
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
crypto::SigningPrivateKey,
events::EventManager,
i2np::{MessageType, I2NP_MESSAGE_EXPIRATION},
primitives::{MessageId, RouterInfoBuilder},
profile::ProfileStorage,
runtime::mock::MockRuntime,
transport::ssu2::{
message::data::DataMessageBuilder, peer_test::types::PeerTestEventRecycle,
},
};
use bytes::Bytes;
use thingbuf::mpsc::channel;
#[tokio::test]
async fn backpressure_works() {
let (from_socket_tx, from_socket_rx) = channel(128);
let socket = <MockRuntime as Runtime>::UdpSocket::bind("127.0.0.1:0".parse().unwrap())
.await
.unwrap();
let mut recv_socket =
<MockRuntime as Runtime>::UdpSocket::bind("127.0.0.1:0".parse().unwrap())
.await
.unwrap();
let remote_signing_key = SigningPrivateKey::random(rand::rng());
let (router_info, static_key, signing_key) = RouterInfoBuilder::default().build();
let (_event_mgr, _event_subscriber, event_handle) =
EventManager::new(None, MockRuntime::register_metrics(vec![], None));
let router_ctx = RouterContext::new(
MockRuntime::register_metrics(vec![], None),
ProfileStorage::<MockRuntime>::new(&Vec::new(), &Vec::new()),
router_info.identity.id(),
Bytes::from(router_info.serialize(&signing_key)),
static_key.clone(),
signing_key,
2u8,
event_handle.clone(),
);
let ctx = Ssu2SessionContext {
address: recv_socket.local_address().unwrap(),
dst_id: 1337u64,
intro_key: [1u8; 32],
max_payload_size: 1472,
pkt_rx: from_socket_rx,
recv_key_ctx: KeyContext {
k_data: [2u8; 32],
k_header_2: [3u8; 32],
},
router_id: RouterId::random(),
send_key_ctx: KeyContext {
k_data: [3u8; 32],
k_header_2: [2u8; 32],
},
verifying_key: remote_signing_key.public(),
};
let (peer_test_event_tx, _peer_test_event_rx) =
with_recycle(16, PeerTestEventRecycle::default());
let peer_test_handle = PeerTestHandle::new(peer_test_event_tx);
let (relay_event_tx, _relay_event_rx) = channel(16);
let relay_handle = RelayHandle::new(relay_event_tx);
let cmd_tx = {
let (transport_tx, transport_rx) = channel(16);
tokio::spawn(
Ssu2Session::<MockRuntime>::new(
ctx,
socket,
transport_tx,
router_ctx,
peer_test_handle,
relay_handle,
)
.run(),
);
match transport_rx.recv().await.unwrap() {
SubsystemEvent::ConnectionEstablished { tx, .. } => tx,
_ => panic!("invalid event"),
}
};
for _ in 0..CMD_CHANNEL_SIZE {
cmd_tx
.try_send(OutboundMessage::Message(Message {
message_type: MessageType::Data,
message_id: *MessageId::random(),
expiration: MockRuntime::time_since_epoch() + I2NP_MESSAGE_EXPIRATION,
payload: vec![1, 2, 3, 4],
}))
.unwrap();
}
assert!(cmd_tx
.try_send(OutboundMessage::Message(Message {
message_type: MessageType::Data,
message_id: *MessageId::random(),
expiration: MockRuntime::time_since_epoch() + I2NP_MESSAGE_EXPIRATION,
payload: vec![1, 2, 3, 4],
}))
.is_err());
let mut buffer = vec![0u8; 0xffff];
for _ in 0..16 {
let (nread, _from) = recv_socket.recv_from(&mut buffer).await.unwrap();
let pkt = &mut buffer[..nread];
match HeaderReader::new([1u8; 32], pkt).unwrap().parse([2u8; 32]).unwrap() {
HeaderKind::Data { .. } => {}
_ => panic!("invalid packet"),
}
}
for _ in 0..16 {
assert!(cmd_tx
.try_send(OutboundMessage::Message(Message {
message_type: MessageType::Data,
message_id: *MessageId::random(),
expiration: MockRuntime::time_since_epoch() + I2NP_MESSAGE_EXPIRATION,
payload: vec![1, 2, 3, 4],
}))
.is_ok());
}
assert!(cmd_tx
.try_send(OutboundMessage::Message(Message {
message_type: MessageType::Data,
message_id: *MessageId::random(),
expiration: MockRuntime::time_since_epoch() + I2NP_MESSAGE_EXPIRATION,
payload: vec![1, 2, 3, 4],
}))
.is_err());
let mut pkt = DataMessageBuilder::default()
.with_dst_id(1337u64)
.with_pkt_num(1)
.with_key_context(
[1u8; 32],
&KeyContext {
k_data: [2u8; 32],
k_header_2: [3u8; 32],
},
)
.with_ack(16, 5, None)
.build::<MockRuntime>()
.to_vec();
let mut reader = HeaderReader::new([1u8; 32], &mut pkt).unwrap();
let _dst_id = reader.dst_id();
from_socket_tx
.try_send(Packet {
pkt,
address: "127.0.0.1:8888".parse().unwrap(),
})
.unwrap();
let future = async move {
for _ in 0..6 {
cmd_tx
.send(OutboundMessage::Message(Message {
message_type: MessageType::Data,
message_id: *MessageId::random(),
expiration: MockRuntime::time_since_epoch() + I2NP_MESSAGE_EXPIRATION,
payload: vec![1, 2, 3, 4],
}))
.await
.unwrap();
}
};
let _ = tokio::time::timeout(Duration::from_secs(5), future).await.expect("no timeout");
}
}