use crate::{
crypto::noise::NoiseContext,
primitives::{RouterId, RouterInfo},
runtime::{Instant, Runtime},
transport::{
ssu2::{relay::types::RelayTagRequested, session::active::Ssu2SessionContext},
EncryptionKind, TerminationReason,
},
};
use bytes::{Bytes, BytesMut};
use futures::FutureExt;
use ml_kem::{
array::Array, DecapsulationKey, Encapsulate, EncapsulationKey, MlKem1024, MlKem512, MlKem768,
};
use alloc::{boxed::Box, collections::VecDeque, vec::Vec};
use core::{
fmt,
future::Future,
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
pub mod inbound;
pub mod outbound;
const MAX_CLOCK_SKEW: Duration = Duration::from_secs(60);
pub enum PendingSsu2SessionStatus<R: Runtime> {
NewInboundSession {
context: Ssu2SessionContext,
dst_id: u64,
k_header_2: [u8; 32],
pkt: BytesMut,
router_info: Box<RouterInfo>,
serialized: Bytes,
started: R::Instant,
target: SocketAddr,
relay_tag_request: RelayTagRequested,
encryption: EncryptionKind,
},
NewOutboundSession {
context: Ssu2SessionContext,
external_address: Option<SocketAddr>,
relay_tag: Option<u32>,
src_id: u64,
started: R::Instant,
encryption: EncryptionKind,
},
SessionTerminated {
address: Option<SocketAddr>,
connection_id: u64,
router_id: Option<RouterId>,
started: R::Instant,
relay_tag: Option<u32>,
reason: TerminationReason,
},
Timeout {
connection_id: u64,
router_id: Option<RouterId>,
started: R::Instant,
address: Option<SocketAddr>,
},
SocketClosed {
started: R::Instant,
},
}
impl<R: Runtime> fmt::Debug for PendingSsu2SessionStatus<R> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PendingSsu2SessionStatus::NewInboundSession {
dst_id,
target,
started,
..
} => f
.debug_struct("PendingSsu2SessionStatus::NewInboundSession")
.field("dst_id", &dst_id)
.field("target", &target)
.field("started", &started)
.finish_non_exhaustive(),
PendingSsu2SessionStatus::NewOutboundSession {
src_id, started, ..
} => f
.debug_struct("PendingSsu2SessionStatus::NewOutboundSession")
.field("src_id", &src_id)
.field("started", &started)
.finish_non_exhaustive(),
PendingSsu2SessionStatus::SessionTerminated {
address,
connection_id,
router_id,
started,
..
} => f
.debug_struct("PendingSsu2SessionStatus::SessionTerminated")
.field("address", &address)
.field("connection_id", &connection_id)
.field("router_id", &router_id)
.field("started", &started)
.finish_non_exhaustive(),
PendingSsu2SessionStatus::Timeout {
connection_id,
router_id,
started,
address,
} => f
.debug_struct("PendingSsu2SessionStatus::Timeout")
.field("connection_id", &connection_id)
.field("router_id", &router_id)
.field("started", &started)
.field("address", &address)
.finish_non_exhaustive(),
PendingSsu2SessionStatus::SocketClosed { started } => f
.debug_struct("PendingSsu2SessionStatus::SocketClosed")
.field("started", &started)
.finish(),
}
}
}
impl<R: Runtime> PendingSsu2SessionStatus<R> {
pub fn duration(&self) -> f64 {
match self {
Self::NewInboundSession { started, .. } => started.elapsed().as_millis() as f64,
Self::NewOutboundSession { started, .. } => started.elapsed().as_millis() as f64,
Self::SessionTerminated { started, .. } => started.elapsed().as_millis() as f64,
Self::Timeout { started, .. } => started.elapsed().as_millis() as f64,
Self::SocketClosed { started, .. } => started.elapsed().as_millis() as f64,
}
}
}
#[derive(Clone)]
pub enum PacketKind {
Single(Vec<u8>),
Multi(Vec<Vec<u8>>),
}
pub enum PacketRetransmitterEvent {
Retransmit {
pkt: PacketKind,
},
Timeout,
}
pub struct PacketRetransmitter<R: Runtime> {
pkt: Option<PacketKind>,
timeouts: VecDeque<Duration>,
timer: R::Timer,
}
impl<R: Runtime> PacketRetransmitter<R> {
pub fn inactive(timeout: Duration) -> Self {
Self {
pkt: None,
timeouts: VecDeque::new(),
timer: R::timer(timeout),
}
}
pub fn token_request(pkt: Vec<u8>) -> Self {
Self {
pkt: Some(PacketKind::Single(pkt)),
timeouts: VecDeque::from_iter([Duration::from_secs(6), Duration::from_secs(6)]),
timer: R::timer(Duration::from_secs(3)),
}
}
pub fn session_request(pkt: Vec<u8>) -> Self {
Self {
pkt: Some(PacketKind::Single(pkt)),
timeouts: VecDeque::from_iter([
Duration::from_millis(2500),
Duration::from_millis(5000),
Duration::from_millis(6250),
]),
timer: R::timer(Duration::from_millis(1250)),
}
}
pub fn session_created(pkt: Vec<u8>) -> Self {
Self {
pkt: Some(PacketKind::Single(pkt)),
timeouts: VecDeque::from_iter([
Duration::from_secs(2),
Duration::from_secs(4),
Duration::from_secs(5),
]),
timer: R::timer(Duration::from_secs(1)),
}
}
pub fn session_confirmed(pkts: Vec<Vec<u8>>) -> Self {
Self {
pkt: Some(PacketKind::Multi(pkts)),
timeouts: VecDeque::from_iter([
Duration::from_millis(2500),
Duration::from_millis(5000),
Duration::from_millis(6250),
]),
timer: R::timer(Duration::from_millis(1250)),
}
}
}
impl<R: Runtime> Future for PacketRetransmitter<R> {
type Output = PacketRetransmitterEvent;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
futures::ready!(self.timer.poll_unpin(cx));
match self.timeouts.pop_front() {
Some(timeout) => {
self.timer = R::timer(timeout);
let _ = self.timer.poll_unpin(cx);
match self.pkt {
None => Poll::Pending,
Some(ref pkt) =>
Poll::Ready(PacketRetransmitterEvent::Retransmit { pkt: pkt.clone() }),
}
}
None => Poll::Ready(PacketRetransmitterEvent::Timeout),
}
}
}
pub enum EncryptionContext {
X25519(NoiseContext),
MlKem512X25519(NoiseContext),
MlKem768X25519(NoiseContext),
#[allow(unused)]
MlKem1024X25519(NoiseContext),
}
impl EncryptionContext {
pub fn noise_ctx(&mut self) -> &mut NoiseContext {
match self {
Self::X25519(ctx) => ctx,
Self::MlKem512X25519(ctx) => ctx,
Self::MlKem768X25519(ctx) => ctx,
Self::MlKem1024X25519(ctx) => ctx,
}
}
pub fn encapsulation_key_size(&self) -> usize {
match self {
Self::X25519(_) => 0,
Self::MlKem512X25519(_) => 800,
Self::MlKem768X25519(_) => 1184,
Self::MlKem1024X25519(_) => 1568,
}
}
pub fn kem_ciphertext_size(&self) -> usize {
match self {
Self::X25519(_) => unreachable!(),
Self::MlKem512X25519(_) => 768,
Self::MlKem768X25519(_) => 1088,
Self::MlKem1024X25519(_) => 1568,
}
}
pub fn version(&self) -> u8 {
match self {
Self::X25519(_) => 2u8,
Self::MlKem512X25519(_) => 3u8,
Self::MlKem768X25519(_) => 4u8,
Self::MlKem1024X25519(_) => unreachable!(),
}
}
pub fn encapsulate<R: Runtime>(&self, encapsulation_key: &[u8]) -> Option<(Vec<u8>, Vec<u8>)> {
match self {
Self::X25519(_) => unreachable!(),
Self::MlKem512X25519(_) => {
let key = Array::try_from(encapsulation_key).ok()?;
let key = EncapsulationKey::<MlKem512>::new(&key).ok()?;
let (ciphertext, shared_key) = key.encapsulate_with_rng(&mut R::rng());
Some((ciphertext.to_vec(), shared_key.to_vec()))
}
Self::MlKem768X25519(_) => {
let key = Array::try_from(encapsulation_key).ok()?;
let key = EncapsulationKey::<MlKem768>::new(&key).ok()?;
let (ciphertext, shared_key) = key.encapsulate_with_rng(&mut R::rng());
Some((ciphertext.to_vec(), shared_key.to_vec()))
}
Self::MlKem1024X25519(_) => {
let key = Array::try_from(encapsulation_key).ok()?;
let key = EncapsulationKey::<MlKem1024>::new(&key).ok()?;
let (ciphertext, shared_key) = key.encapsulate_with_rng(&mut R::rng());
Some((ciphertext.to_vec(), shared_key.to_vec()))
}
}
}
}
impl fmt::Display for EncryptionContext {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::X25519(_) => write!(f, "x25519"),
Self::MlKem512X25519(_) => write!(f, "ml-kem-512"),
Self::MlKem768X25519(_) => write!(f, "ml-kem-768"),
Self::MlKem1024X25519(_) => write!(f, "ml-kem-1024"),
}
}
}
pub enum MlKemContext {
MlKem512X25519(Box<DecapsulationKey<MlKem512>>),
MlKem768X25519(Box<DecapsulationKey<MlKem768>>),
}