use crate::{
crypto::{chachapoly::ChaChaPoly, StaticPublicKey},
destination::session::{
context::KeyContext,
inbound::InboundSession,
outbound::OutboundSession,
tag_set::{TagSet, TagSetEntry},
LOG_TARGET, NUM_TAGS_TO_GENERATE,
},
error::SessionError,
i2np::{
garlic::{
DeliveryInstructions as GarlicDeliveryInstructions, GarlicMessage, GarlicMessageBlock,
GarlicMessageBuilder, NextKeyKind,
},
MessageType, I2NP_MESSAGE_EXPIRATION,
},
primitives::{DestinationId, MessageId},
runtime::{Instant, Runtime},
};
use bytes::{BufMut, Bytes, BytesMut};
use hashbrown::{HashMap, HashSet};
use rand::Rng;
#[cfg(feature = "std")]
use parking_lot::RwLock;
#[cfg(feature = "no_std")]
use spin::rwlock::RwLock;
use alloc::{boxed::Box, collections::VecDeque, sync::Arc, vec::Vec};
use core::{fmt, mem, time::Duration};
const GARLIC_MESSAGE_OVERHEAD: usize = 24usize;
const PENDING_SESSION_MAX_AGE: Duration = Duration::from_secs(5 * 60);
const NSR_CONTEXT_MAX_AGE: Duration = Duration::from_secs(3 * 60);
const PREV_TAG_MAX_AGE: Duration = Duration::from_secs(3 * 60);
const NUM_EXTRA_TAGS_TO_GENERATE: usize = 128usize;
pub enum PendingSessionEvent<R: Runtime> {
DoNothing,
SendMessage {
message: Vec<u8>,
},
ReturnMessage {
message: Vec<u8>,
tag_set_id: u16,
tag_index: u16,
},
CreateSession {
message: Vec<u8>,
context: Box<SessionContext<R>>,
tag_set_id: u16,
tag_index: u16,
},
}
enum PendingSessionState<R: Runtime> {
InboundActive {
inbound: Box<InboundSession<R>>,
garlic_tags: Arc<RwLock<HashMap<u64, DestinationId>>>,
tag_set_entries: HashMap<u64, TagSetEntry>,
},
AwaitingNsr {
outbound: HashMap<usize, OutboundSession<R>>,
remote_public_key: StaticPublicKey,
garlic_tags: Arc<RwLock<HashMap<u64, DestinationId>>>,
nsr_tag_set_entries: HashMap<u64, (usize, TagSetEntry)>,
tag_set_entries: HashMap<u64, TagSetEntry>,
},
AwaitingEsTransmit {
outbound: HashMap<usize, OutboundSession<R>>,
send_tag_set: Box<TagSet>,
recv_tag_set: Box<TagSet>,
remote_public_key: StaticPublicKey,
garlic_tags: Arc<RwLock<HashMap<u64, DestinationId>>>,
nsr_tag_set_entries: HashMap<u64, (usize, TagSetEntry)>,
tag_set_entries: HashMap<u64, TagSetEntry>,
},
Poisoned,
}
impl<R: Runtime> fmt::Debug for PendingSessionState<R> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InboundActive { .. } =>
f.debug_struct("PendingSessionState::InboundActive").finish_non_exhaustive(),
Self::AwaitingNsr { .. } =>
f.debug_struct("PendingSessionState::AwaitingNsr").finish_non_exhaustive(),
Self::AwaitingEsTransmit { .. } => f
.debug_struct("PendingSessionState::AwaitingEsTransmit")
.finish_non_exhaustive(),
Self::Poisoned =>
f.debug_struct("PendingSessionState::Poisoned").finish_non_exhaustive(),
}
}
}
pub struct PendingSession<R: Runtime> {
created: R::Instant,
key_context: KeyContext<R>,
local: DestinationId,
remote: DestinationId,
state: PendingSessionState<R>,
}
impl<R: Runtime> PendingSession<R> {
pub fn new_inbound(
local: DestinationId,
remote: DestinationId,
inbound: InboundSession<R>,
garlic_tags: Arc<RwLock<HashMap<u64, DestinationId>>>,
key_context: KeyContext<R>,
) -> Self {
Self {
created: R::now(),
key_context,
local,
remote,
state: PendingSessionState::InboundActive {
inbound: Box::new(inbound),
tag_set_entries: HashMap::new(),
garlic_tags,
},
}
}
pub fn new_outbound(
local: DestinationId,
remote: DestinationId,
remote_public_key: StaticPublicKey,
outbound: OutboundSession<R>,
garlic_tags: Arc<RwLock<HashMap<u64, DestinationId>>>,
key_context: KeyContext<R>,
ratchet_threshold: u16,
) -> Self {
let nsr_tag_set_entries = {
let mut inner = garlic_tags.write();
outbound
.generate_new_session_reply_tags(ratchet_threshold)
.map(|tag_set| {
inner.insert(tag_set.tag, remote.clone());
(tag_set.tag, (0usize, tag_set))
})
.collect()
};
Self {
created: R::now(),
key_context,
local,
remote,
state: PendingSessionState::AwaitingNsr {
outbound: HashMap::from_iter([(0usize, outbound)]),
remote_public_key,
garlic_tags,
nsr_tag_set_entries,
tag_set_entries: HashMap::new(),
},
}
}
pub fn is_expired(&self) -> bool {
self.created.elapsed() > PENDING_SESSION_MAX_AGE
}
pub fn advance_outbound(
&mut self,
lease_set: Bytes,
message: Vec<u8>,
ratchet_threshold: u16,
) -> Result<PendingSessionEvent<R>, SessionError> {
match mem::replace(&mut self.state, PendingSessionState::Poisoned) {
PendingSessionState::InboundActive {
mut inbound,
mut tag_set_entries,
garlic_tags,
} => {
tracing::trace!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
"send NSR",
);
let hash = self.remote.to_vec();
let message = GarlicMessageBuilder::default()
.with_garlic_clove(
MessageType::Data,
MessageId::from(R::rng().next_u32()),
R::time_since_epoch() + I2NP_MESSAGE_EXPIRATION,
GarlicDeliveryInstructions::Destination { hash: &hash },
&{
let mut out = BytesMut::with_capacity(message.len() + 4);
out.put_u32(message.len() as u32);
out.put_slice(&message);
out
},
)
.build();
let (message, entries) =
inbound.create_new_session_reply(message, ratchet_threshold)?;
{
let mut inner = garlic_tags.write();
entries.into_iter().for_each(|entry| {
inner.insert(entry.tag, self.remote.clone());
tag_set_entries.insert(entry.tag, entry);
})
}
self.state = PendingSessionState::InboundActive {
inbound,
tag_set_entries,
garlic_tags,
};
Ok(PendingSessionEvent::SendMessage { message })
}
PendingSessionState::AwaitingNsr {
garlic_tags,
mut outbound,
mut nsr_tag_set_entries,
remote_public_key,
tag_set_entries,
} => {
tracing::debug!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
"send another NS",
);
let (session, message) = self.key_context.create_outbound_session(
self.local.clone(),
self.remote.clone(),
&remote_public_key,
lease_set,
&message,
);
{
let mut inner = garlic_tags.write();
session.generate_new_session_reply_tags(ratchet_threshold).for_each(
|tag_set| {
inner.insert(tag_set.tag, self.remote.clone());
nsr_tag_set_entries.insert(tag_set.tag, (outbound.len(), tag_set));
},
);
}
outbound.insert(outbound.len(), session);
self.state = PendingSessionState::AwaitingNsr {
outbound,
garlic_tags,
nsr_tag_set_entries,
remote_public_key,
tag_set_entries,
};
Ok(PendingSessionEvent::SendMessage { message })
}
PendingSessionState::AwaitingEsTransmit {
outbound,
mut send_tag_set,
recv_tag_set,
garlic_tags,
nsr_tag_set_entries,
tag_set_entries,
..
} => {
let TagSetEntry {
key,
tag,
tag_index,
tag_set_id,
} = send_tag_set.next_entry().ok_or_else(|| {
tracing::warn!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
"`TagSet` ran out of tags",
);
debug_assert!(false);
SessionError::InvalidState
})?;
tracing::trace!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
garlic_tag = ?tag,
"send first ES message",
);
let message = {
let mut out = BytesMut::with_capacity(message.len() + 4);
out.put_u32(message.len() as u32);
out.put_slice(&message);
out
};
let hash = self.remote.to_vec();
let builder = GarlicMessageBuilder::default().with_garlic_clove(
MessageType::Data,
MessageId::from(R::rng().next_u32()),
R::time_since_epoch() + I2NP_MESSAGE_EXPIRATION,
GarlicDeliveryInstructions::Destination { hash: &hash },
&message,
);
let mut message = builder.build();
let mut out = BytesMut::with_capacity(message.len() + GARLIC_MESSAGE_OVERHEAD);
ChaChaPoly::with_nonce(&key, tag_index as u64)
.encrypt_with_ad_new(&tag.to_le_bytes(), &mut message)?;
out.put_u64_le(tag);
out.put_slice(&message);
Ok(PendingSessionEvent::CreateSession {
message: out.freeze().to_vec(),
context: Box::new(SessionContext {
garlic_tags,
local: self.local.clone(),
recv_tag_set: *recv_tag_set,
remote: self.remote.clone(),
send_tag_set: *send_tag_set,
tag_set_entries,
nsr_context: NsrContext::new(nsr_tag_set_entries, outbound),
}),
tag_set_id,
tag_index,
})
}
state => {
tracing::warn!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
?state,
"invalid session state",
);
debug_assert!(false);
Err(SessionError::InvalidState)
}
}
}
pub fn advance_inbound(
&mut self,
garlic_tag: u64,
message: Vec<u8>,
ratchet_threshold: u16,
) -> Result<PendingSessionEvent<R>, SessionError> {
match mem::replace(&mut self.state, PendingSessionState::Poisoned) {
PendingSessionState::InboundActive {
mut inbound,
mut tag_set_entries,
garlic_tags,
} => {
tracing::trace!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
"ES received",
);
let tag_set_entry = tag_set_entries.remove(&garlic_tag).ok_or_else(|| {
tracing::warn!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
?garlic_tag,
"`TagSetEntry` doesn't exist for ES",
);
debug_assert!(false);
SessionError::InvalidState
})?;
let tag_set_id = tag_set_entry.tag_set_id;
let tag_index = tag_set_entry.tag_index;
let (message, send_tag_set, recv_tag_set) =
inbound.handle_existing_session(garlic_tag, tag_set_entry, message)?;
Ok(PendingSessionEvent::CreateSession {
message,
context: Box::new(SessionContext {
recv_tag_set,
send_tag_set,
tag_set_entries,
garlic_tags,
local: self.local.clone(),
remote: self.remote.clone(),
nsr_context: NsrContext::Inactive,
}),
tag_set_id,
tag_index,
})
}
PendingSessionState::AwaitingNsr {
garlic_tags,
mut outbound,
mut nsr_tag_set_entries,
mut tag_set_entries,
remote_public_key,
} => {
tracing::debug!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
"NSR received"
);
let (session_idx, tag_set_entry) =
nsr_tag_set_entries.remove(&garlic_tag).ok_or_else(|| {
tracing::warn!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
?garlic_tag,
"`TagSetEntry` doesn't exist for NSR",
);
debug_assert!(false);
SessionError::InvalidState
})?;
let tag_set_id = tag_set_entry.tag_set_id;
let tag_index = tag_set_entry.tag_index;
let (message, send_tag_set, mut recv_tag_set) = outbound
.get_mut(&session_idx)
.expect("to exist")
.handle_new_session_reply(tag_set_entry, message, ratchet_threshold)?;
{
let mut inner = garlic_tags.write();
(0..NUM_TAGS_TO_GENERATE).for_each(|_| {
let entry = recv_tag_set.next_entry().expect("to succeed");
inner.insert(entry.tag, self.remote.clone());
tag_set_entries.insert(entry.tag, entry);
});
};
self.state = PendingSessionState::AwaitingEsTransmit {
outbound,
send_tag_set: Box::new(send_tag_set),
recv_tag_set: Box::new(recv_tag_set),
garlic_tags,
tag_set_entries,
remote_public_key,
nsr_tag_set_entries,
};
Ok(PendingSessionEvent::ReturnMessage {
tag_set_id,
tag_index,
message,
})
}
PendingSessionState::AwaitingEsTransmit {
mut outbound,
send_tag_set,
recv_tag_set,
remote_public_key,
garlic_tags,
mut nsr_tag_set_entries,
tag_set_entries,
} => {
tracing::debug!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
"NSR received while waiting to send ES"
);
let (session_idx, tag_set_entry) =
nsr_tag_set_entries.remove(&garlic_tag).ok_or_else(|| {
tracing::warn!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
?garlic_tag,
"TagSetEntry doesn't exist for NSR",
);
debug_assert!(false);
SessionError::InvalidState
})?;
let tag_set_id = tag_set_entry.tag_set_id;
let tag_index = tag_set_entry.tag_index;
let message = match outbound
.get_mut(&session_idx)
.expect("to exist")
.handle_new_session_reply(tag_set_entry, message, ratchet_threshold)
{
Ok((message, ..)) => message,
Err(error) => {
tracing::warn!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
?garlic_tag,
?error,
"failed to handle NSR",
);
self.state = PendingSessionState::AwaitingEsTransmit {
outbound,
send_tag_set,
recv_tag_set,
garlic_tags,
tag_set_entries,
remote_public_key,
nsr_tag_set_entries,
};
return Ok(PendingSessionEvent::DoNothing);
}
};
self.state = PendingSessionState::AwaitingEsTransmit {
outbound,
send_tag_set,
recv_tag_set,
garlic_tags,
tag_set_entries,
remote_public_key,
nsr_tag_set_entries,
};
Ok(PendingSessionEvent::ReturnMessage {
tag_set_id,
tag_index,
message,
})
}
PendingSessionState::Poisoned => {
tracing::warn!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
"session state has been poisoned",
);
debug_assert!(false);
Err(SessionError::InvalidState)
}
}
}
}
enum NsrContext<R: Runtime> {
Inactive,
Active {
created: R::Instant,
tag_set_entries: HashMap<u64, (usize, TagSetEntry)>,
sessions: HashMap<usize, OutboundSession<R>>,
},
}
impl<R: Runtime> NsrContext<R> {
pub fn new(
tag_set_entries: HashMap<u64, (usize, TagSetEntry)>,
sessions: HashMap<usize, OutboundSession<R>>,
) -> Self {
Self::Active {
created: R::now(),
tag_set_entries,
sessions,
}
}
fn decrypt(
&mut self,
garlic_tag: u64,
message: Vec<u8>,
ratchet_threshold: u16,
) -> Result<(u16, u16, Vec<u8>), SessionError> {
let NsrContext::Active {
tag_set_entries,
sessions,
..
} = self
else {
return Err(SessionError::UnknownTag);
};
let (session_idx, tag_set_entry) =
tag_set_entries.remove(&garlic_tag).ok_or(SessionError::UnknownTag)?;
let session = sessions.get_mut(&session_idx).ok_or(SessionError::UnknownTag)?;
tracing::debug!(
target: LOG_TARGET,
?garlic_tag,
"late NSR message",
);
let tag_set_id = tag_set_entry.tag_set_id;
let tag_index = tag_set_entry.tag_index;
session
.handle_new_session_reply(tag_set_entry, message, ratchet_threshold)
.map(|(message, _, _)| (tag_set_id, tag_index, message))
}
fn try_expire(&mut self) -> Option<impl Iterator<Item = u64>> {
match self {
Self::Inactive => None,
Self::Active { created, .. } if created.elapsed() < NSR_CONTEXT_MAX_AGE => None,
Self::Active {
tag_set_entries, ..
} => {
let garlic_tags = tag_set_entries.keys().copied().collect::<Vec<_>>();
*self = Self::Inactive;
Some(garlic_tags.into_iter())
}
}
}
}
pub struct SessionContext<R: Runtime> {
garlic_tags: Arc<RwLock<HashMap<u64, DestinationId>>>,
local: DestinationId,
recv_tag_set: TagSet,
remote: DestinationId,
send_tag_set: TagSet,
tag_set_entries: HashMap<u64, TagSetEntry>,
nsr_context: NsrContext<R>,
}
pub struct Session<R: Runtime> {
expiring: VecDeque<(R::Instant, HashSet<u64>)>,
garlic_tags: Arc<RwLock<HashMap<u64, DestinationId>>>,
local: DestinationId,
nsr_context: NsrContext<R>,
pending_next_key: Option<NextKeyKind>,
recv_tag_set: TagSet,
remote: DestinationId,
send_tag_set: TagSet,
tag_set_entries: HashMap<u64, TagSetEntry>,
}
impl<R: Runtime> Session<R> {
pub fn new(context: SessionContext<R>) -> Self {
let SessionContext {
garlic_tags,
local,
recv_tag_set,
remote,
send_tag_set,
tag_set_entries,
nsr_context,
} = context;
Self {
expiring: VecDeque::new(),
garlic_tags,
local,
nsr_context,
pending_next_key: None,
recv_tag_set,
remote,
send_tag_set,
tag_set_entries,
}
}
pub fn has_pending_next_key(&self) -> bool {
self.pending_next_key.is_some()
}
pub fn decrypt(
&mut self,
garlic_tag: u64,
message: Vec<u8>,
ratchet_threshold: u16,
) -> Result<(u16, u16, Vec<u8>), SessionError> {
tracing::trace!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
?garlic_tag,
len = ?message.len(),
"received ES",
);
let Some(TagSetEntry {
key,
tag,
tag_index,
tag_set_id,
}) = self.tag_set_entries.remove(&garlic_tag)
else {
return self.nsr_context.decrypt(garlic_tag, message, ratchet_threshold).map_err(
|error| {
tracing::warn!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
?garlic_tag,
?error,
"`TagSetEntry` doesn't exist and failed to handle as NSR",
);
debug_assert!(false);
SessionError::InvalidState
},
);
};
{
match self.recv_tag_set.next_entry() {
None => tracing::debug!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
"receive tag set ran out of tags",
),
Some(entry) => {
self.garlic_tags.write().insert(entry.tag, self.remote.clone());
self.tag_set_entries.insert(entry.tag, entry);
}
}
}
let mut payload = message[12..].to_vec();
let payload = ChaChaPoly::with_nonce(&key, tag_index as u64)
.decrypt_with_ad(&tag.to_le_bytes(), &mut payload)
.map(|_| payload)?;
let message = GarlicMessage::parse(&payload).map_err(|error| {
tracing::warn!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
?error,
"malformed garlic message",
);
SessionError::Malformed
})?;
message.blocks.iter().try_for_each(|block| match block {
GarlicMessageBlock::NextKey { kind } => {
tracing::debug!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
?kind,
"handle `NextKey` block",
);
match **kind {
NextKeyKind::ForwardKey { .. } => {
self.pending_next_key = if self.recv_tag_set.can_ratchet(kind) {
tracing::trace!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
?kind,
"ratchet receive tag set",
);
{
let mut inner = self.garlic_tags.write();
for _ in 0..NUM_EXTRA_TAGS_TO_GENERATE {
if let Some(entry) = self.recv_tag_set.next_entry() {
inner.insert(entry.tag, self.remote.clone());
self.tag_set_entries.insert(entry.tag, entry);
}
}
let expiring_tags =
self.tag_set_entries.keys().copied().collect::<HashSet<_>>();
self.expiring.push_back((R::now(), expiring_tags));
}
let next_key = self.recv_tag_set.handle_next_key::<R>(kind)?;
{
let mut inner = self.garlic_tags.write();
(0..NUM_TAGS_TO_GENERATE).for_each(|_| {
let entry = self.recv_tag_set.next_entry().expect("to succeed");
inner.insert(entry.tag, self.remote.clone());
self.tag_set_entries.insert(entry.tag, entry);
});
}
next_key
} else {
tracing::trace!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
?kind,
"receive tag set already ratcheted",
);
self.recv_tag_set.handle_next_key::<R>(kind)?
};
}
NextKeyKind::ReverseKey { .. } => {
tracing::trace!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
?kind,
"handle reverse key",
);
self.pending_next_key = self.send_tag_set.handle_next_key::<R>(kind)?;
}
}
Ok::<_, SessionError>(())
}
_ => Ok::<_, SessionError>(()),
})?;
Ok((tag_set_id, tag_index, payload))
}
pub fn encrypt(
&mut self,
mut message_builder: GarlicMessageBuilder,
) -> Result<(u16, u16, Vec<u8>), SessionError> {
let TagSetEntry {
key,
tag,
tag_index,
tag_set_id,
} = self.send_tag_set.next_entry().ok_or_else(|| {
tracing::warn!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
"`TagSet` ran out of tags",
);
debug_assert!(false);
SessionError::InvalidState
})?;
tracing::trace!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
garlic_tag = ?tag,
"send ES",
);
message_builder = match self.send_tag_set.try_generate_next_key::<R>()? {
Some(kind) => {
tracing::debug!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
?kind,
"send forward `NextKey` block",
);
message_builder.with_next_key(kind)
}
None => message_builder,
};
message_builder = match self.pending_next_key.take() {
Some(kind) => {
tracing::debug!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
?kind,
"send reverse `NextKey` block",
);
message_builder.with_next_key(kind)
}
None => message_builder,
};
let mut message = message_builder.build();
let mut out = BytesMut::with_capacity(message.len() + GARLIC_MESSAGE_OVERHEAD);
ChaChaPoly::with_nonce(&key, tag_index as u64)
.encrypt_with_ad_new(&tag.to_le_bytes(), &mut message)?;
out.put_u64_le(tag);
out.put_slice(&message);
Ok((tag_set_id, tag_index, out.freeze().to_vec()))
}
pub fn maintain(&mut self) {
if let Some(tags) = self.nsr_context.try_expire() {
let mut inner = self.garlic_tags.write();
tags.for_each(|tag| {
inner.remove(&tag);
});
}
loop {
let Some((created, _)) = self.expiring.front() else {
break;
};
if created.elapsed() < PREV_TAG_MAX_AGE {
break;
}
let (_, tags) = self.expiring.pop_front().expect("to exist");
{
let mut inner = self.garlic_tags.write();
tags.into_iter().for_each(|tag| {
inner.remove(&tag);
self.tag_set_entries.remove(&tag);
});
}
}
}
pub fn destroy(self) {
tracing::info!(
target: LOG_TARGET,
local = %self.local,
remote = %self.remote,
"destroy session",
);
let mut inner = self.garlic_tags.write();
self.tag_set_entries.keys().for_each(|tag| {
inner.remove(tag);
});
}
}