use crate::{
crypto::{chachapoly::ChaChaPoly, siphash::SipHash},
events::EventHandle,
primitives::{RouterId, RouterInfo},
runtime::{AsyncRead, AsyncWrite, Counter, Histogram, Instant, MetricsHandle, Runtime},
subsystem::{OutboundMessage, OutboundMessageRecycle, SubsystemEvent},
transport::{
ntcp2::{
message::MessageBlock,
metrics::*,
session::{KeyContext, Role},
},
Direction, EncryptionKind, TerminationReason,
},
util::AsyncWriteExt,
};
use futures::FutureExt;
use futures_channel::oneshot;
use thingbuf::mpsc::{with_recycle, Receiver, Sender};
use alloc::{vec, vec::Vec};
use core::{
future::Future,
mem,
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
const LOG_TARGET: &str = "emissary::ntcp2::active";
#[cfg(not(test))]
const IDLE_TIMEOUT: Duration = Duration::from_secs(120);
#[cfg(test)]
const IDLE_TIMEOUT: Duration = Duration::from_secs(3);
#[cfg(not(test))]
const IDLE_TIMEOUT_CHECK_INTERVAL: Duration = Duration::from_secs(2);
#[cfg(test)]
const IDLE_TIMEOUT_CHECK_INTERVAL: Duration = Duration::from_secs(1);
enum ReadState {
ReadSize {
offset: usize,
},
ReadFrame {
size: usize,
offset: usize,
},
}
enum WriteState {
GetMessage,
SendSize {
offset: usize,
size: Vec<u8>,
message: Vec<u8>,
feedback_tx: Option<oneshot::Sender<()>>,
},
SendMessage {
offset: usize,
message: Vec<u8>,
feedback_tx: Option<oneshot::Sender<()>>,
},
Poisoned,
}
pub struct Ntcp2Session<R: Runtime> {
address: SocketAddr,
direction: Direction,
encryption: EncryptionKind,
event_handle: EventHandle<R>,
idle_timeout_timer: R::Timer,
inbound_activity: R::Instant,
inbound_bandwidth: usize,
metrics_handle: R::MetricsHandle,
msg_rx: Receiver<OutboundMessage, OutboundMessageRecycle>,
msg_tx: Sender<OutboundMessage, OutboundMessageRecycle>,
outbound_activity: R::Instant,
outbound_bandwidth: usize,
read_buffer: Vec<u8>,
read_state: ReadState,
recv_cipher: ChaChaPoly,
role: Role,
router: RouterId,
router_info: RouterInfo,
send_cipher: ChaChaPoly,
sip: SipHash,
started: R::Instant,
stream: R::TcpStream,
transport_tx: Sender<SubsystemEvent>,
write_state: WriteState,
}
impl<R: Runtime> Ntcp2Session<R> {
pub fn new(
role: Role,
address: SocketAddr,
router_info: RouterInfo,
stream: R::TcpStream,
key_context: KeyContext,
direction: Direction,
event_handle: EventHandle<R>,
transport_tx: Sender<SubsystemEvent>,
started: R::Instant,
metrics_handle: R::MetricsHandle,
encryption: EncryptionKind,
) -> Self {
let KeyContext {
send_key,
recv_key,
sip,
} = key_context;
let (msg_tx, msg_rx) = with_recycle(512, OutboundMessageRecycle::default());
Self {
address,
direction,
encryption,
event_handle,
idle_timeout_timer: R::timer(IDLE_TIMEOUT_CHECK_INTERVAL),
inbound_activity: R::now(),
inbound_bandwidth: 0usize,
metrics_handle,
msg_rx,
msg_tx,
outbound_activity: R::now(),
outbound_bandwidth: 0usize,
read_buffer: vec![0u8; 0xffff],
read_state: ReadState::ReadSize { offset: 0usize },
recv_cipher: ChaChaPoly::new(&recv_key),
role,
router: router_info.identity.id(),
router_info,
send_cipher: ChaChaPoly::new(&send_key),
sip,
started,
stream,
transport_tx,
write_state: WriteState::GetMessage,
}
}
pub fn direction(&self) -> Direction {
self.direction
}
pub fn role(&self) -> Role {
self.role
}
pub fn address(&self) -> SocketAddr {
self.address
}
pub fn encryption(&self) -> EncryptionKind {
self.encryption
}
pub fn started(&self) -> &R::Instant {
&self.started
}
pub fn router(&self) -> RouterInfo {
self.router_info.clone()
}
pub async fn run(mut self) -> (RouterId, TerminationReason) {
tracing::trace!(
target: LOG_TARGET,
router_id = %self.router,
"start ntcp2 event loop",
);
self.transport_tx
.send(SubsystemEvent::ConnectionEstablished {
router_id: self.router.clone(),
tx: self.msg_tx.clone(),
})
.await
.expect("manager to stay alive");
let reason = match (&mut self).await {
Some(reason) => reason,
None => {
let message = MessageBlock::new_termination(TerminationReason::IdleTimeout);
let message = self.send_cipher.encrypt(&message).unwrap();
let mut payload = self.sip.obfuscate(message.len() as u16).to_be_bytes().to_vec();
payload.extend(&message);
if let Err(error) = self.stream.write_all(&payload).await {
tracing::debug!(
target: LOG_TARGET,
router_id = %self.router,
?error,
"failed to send termination block",
);
}
TerminationReason::IdleTimeout
}
};
self.transport_tx
.send(SubsystemEvent::ConnectionClosed {
router_id: self.router.clone(),
})
.await
.expect("manager to stay alive");
(self.router, reason)
}
}
impl<R: Runtime> Future for Ntcp2Session<R> {
type Output = Option<TerminationReason>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = &mut *self;
let mut stream = Pin::new(&mut this.stream);
loop {
match this.read_state {
ReadState::ReadSize { offset } => {
match stream.as_mut().poll_read(cx, &mut this.read_buffer[offset..2]) {
Poll::Pending => break,
Poll::Ready(Err(error)) => {
tracing::debug!(
target: LOG_TARGET,
router_id = %this.router,
?error,
"socket error",
);
return Poll::Ready(Some(TerminationReason::IoError));
}
Poll::Ready(Ok(nread)) => {
if nread == 0 {
return Poll::Ready(Some(TerminationReason::IoError));
}
if offset + nread != 2 {
this.read_state = ReadState::ReadSize {
offset: offset + nread,
};
continue;
}
let size = ((this.read_buffer[0] as u16) << 8)
| (this.read_buffer[1] as u16) & 0xff;
this.read_state = ReadState::ReadFrame {
size: this.sip.deobfuscate(size) as usize,
offset: 0usize,
};
this.inbound_activity = R::now();
}
}
}
ReadState::ReadFrame { size, offset } => {
match stream.as_mut().poll_read(cx, &mut this.read_buffer[offset..size]) {
Poll::Pending => break,
Poll::Ready(Err(_)) =>
return Poll::Ready(Some(TerminationReason::IoError)),
Poll::Ready(Ok(nread)) => {
if nread == 0 {
return Poll::Ready(Some(TerminationReason::IoError));
}
if offset + nread < size {
this.read_state = ReadState::ReadFrame {
size,
offset: offset + nread,
};
continue;
}
this.inbound_bandwidth += this.read_buffer[..size].len();
this.metrics_handle
.histogram(MESSAGE_SIZES)
.record(this.read_buffer[..size].len() as f64);
this.metrics_handle
.counter(INBOUND_BW)
.increment(this.read_buffer[..size].len());
this.metrics_handle.counter(NUM_INBOUND_MESSAGES).increment(1);
let data_block =
match this.recv_cipher.decrypt(this.read_buffer[..size].to_vec()) {
Ok(data_block) => data_block,
Err(_) =>
return Poll::Ready(Some(TerminationReason::AeadFailure)),
};
let messages = match MessageBlock::parse_multiple(&data_block) {
Ok(messages) => messages,
Err(error) => {
tracing::warn!(
target: LOG_TARGET,
router_id = %this.router,
?data_block,
?error,
"failed to parse message(s)",
);
continue;
}
};
this.metrics_handle
.histogram(NUM_BLOCKS_PER_MSG)
.record(messages.len() as f64);
tracing::trace!(
target: LOG_TARGET,
router_id = %this.router,
?size,
num_messages = ?messages.len(),
"read ntcp2 frame",
);
if let Some(MessageBlock::Termination { reason, .. }) =
messages.iter().find(|message| {
core::matches!(message, MessageBlock::Termination { .. })
})
{
tracing::debug!(
target: LOG_TARGET,
router_id = %this.router,
?reason,
"session terminated by remote router",
);
return Poll::Ready(Some(TerminationReason::ntcp2(*reason)));
}
let messages = messages
.into_iter()
.filter_map(|message| match message {
MessageBlock::I2Np { message } =>
if message.is_expired::<R>() {
tracing::trace!(
target: LOG_TARGET,
router_id = %this.router,
message_type = ?message.message_type,
message_id = ?message.message_id,
expiration = ?message.expiration,
"discarding expired message",
);
None
} else {
Some(message)
},
MessageBlock::Padding { .. } => None,
message => {
tracing::debug!(
target: LOG_TARGET,
router_id = %this.router,
?message,
"ignoring message",
);
None
}
})
.collect::<Vec<_>>();
if let Err(error) =
this.transport_tx.try_send(SubsystemEvent::Message {
messages: messages
.iter()
.map(|message| (this.router.clone(), message.clone()))
.collect(),
})
{
tracing::warn!(
target: LOG_TARGET,
router_id = %this.router,
?error,
"failed to dispatch messages to subsystems",
);
}
this.read_state = ReadState::ReadSize { offset: 0usize };
this.inbound_activity = R::now();
}
}
}
}
}
loop {
match mem::replace(&mut this.write_state, WriteState::Poisoned) {
WriteState::GetMessage => match this.msg_rx.poll_recv(cx) {
Poll::Pending => {
this.write_state = WriteState::GetMessage;
break;
}
Poll::Ready(None) => return Poll::Ready(Some(TerminationReason::Unspecified)),
Poll::Ready(Some(OutboundMessage::Message(message))) => {
if message.is_expired::<R>() {
this.write_state = WriteState::GetMessage;
continue;
}
let message = message.serialize_short();
assert!(message.len() as u16 <= u16::MAX, "too large message");
let test = MessageBlock::new_i2np_message(&message);
let data_block = this.send_cipher.encrypt(&test).unwrap();
let size = this.sip.obfuscate(data_block.len() as u16);
this.write_state = WriteState::SendSize {
size: size.to_be_bytes().to_vec(),
offset: 0usize,
message: data_block,
feedback_tx: None,
};
}
Poll::Ready(Some(OutboundMessage::MessageWithFeedback(
message,
feedback_tx,
))) => {
if message.is_expired::<R>() {
this.write_state = WriteState::GetMessage;
continue;
}
let message = message.serialize_short();
assert!(message.len() as u16 <= u16::MAX, "too large message");
let test = MessageBlock::new_i2np_message(&message);
let data_block = this.send_cipher.encrypt(&test).unwrap();
let size = this.sip.obfuscate(data_block.len() as u16);
this.write_state = WriteState::SendSize {
size: size.to_be_bytes().to_vec(),
offset: 0usize,
message: data_block,
feedback_tx: Some(feedback_tx),
};
}
Poll::Ready(Some(OutboundMessage::Messages(mut messages))) => {
assert!(!messages.is_empty());
if messages.len() > 1 {
todo!("not implemented")
}
let message = messages.pop().expect("message to exist");
if message.is_expired::<R>() {
this.write_state = WriteState::GetMessage;
continue;
}
let message = message.serialize_short();
assert!(message.len() as u16 <= u16::MAX, "too large message");
let test = MessageBlock::new_i2np_message(&message);
let data_block = this.send_cipher.encrypt(&test).unwrap();
let size = this.sip.obfuscate(data_block.len() as u16);
this.write_state = WriteState::SendSize {
size: size.to_be_bytes().to_vec(),
offset: 0usize,
message: data_block,
feedback_tx: None,
};
}
Poll::Ready(Some(OutboundMessage::Dummy)) => unreachable!(),
},
WriteState::SendSize {
offset,
size,
message,
feedback_tx,
} => match stream.as_mut().poll_write(cx, &size[offset..]) {
Poll::Pending => {
this.write_state = WriteState::SendSize {
offset,
size,
message,
feedback_tx,
};
break;
}
Poll::Ready(Err(_)) => return Poll::Ready(Some(TerminationReason::IoError)),
Poll::Ready(Ok(0)) => return Poll::Ready(Some(TerminationReason::IoError)),
Poll::Ready(Ok(nwritten)) => {
this.outbound_bandwidth += nwritten;
match nwritten + offset == size.len() {
true => {
this.write_state = WriteState::SendMessage {
offset: 0usize,
message,
feedback_tx,
};
}
false => {
this.write_state = WriteState::SendSize {
size,
offset: offset + nwritten,
message,
feedback_tx,
};
}
}
this.outbound_activity = R::now();
}
},
WriteState::SendMessage {
offset,
message,
feedback_tx,
} => match stream.as_mut().poll_write(cx, &message[offset..]) {
Poll::Pending => {
this.write_state = WriteState::SendMessage {
offset,
message,
feedback_tx,
};
break;
}
Poll::Ready(Err(_)) => return Poll::Ready(Some(TerminationReason::IoError)),
Poll::Ready(Ok(0)) => return Poll::Ready(Some(TerminationReason::IoError)),
Poll::Ready(Ok(nwritten)) => {
this.outbound_bandwidth += nwritten;
this.metrics_handle.counter(OUTBOUND_BW).increment(nwritten);
match nwritten + offset == message.len() {
true => {
if let Some(tx) = feedback_tx {
let _ = tx.send(());
}
this.metrics_handle.counter(NUM_OUTBOUND_MESSAGES).increment(1);
this.write_state = WriteState::GetMessage;
}
false => {
this.write_state = WriteState::SendMessage {
offset: offset + nwritten,
message,
feedback_tx,
};
}
}
this.outbound_activity = R::now();
}
},
WriteState::Poisoned => {
tracing::warn!(
target: LOG_TARGET,
router = %this.router,
"write state is poisoned",
);
debug_assert!(false);
return Poll::Ready(Some(TerminationReason::Unspecified));
}
}
}
if this.idle_timeout_timer.poll_unpin(cx).is_ready() {
if this.inbound_activity.elapsed() > IDLE_TIMEOUT
&& this.outbound_activity.elapsed() > IDLE_TIMEOUT
{
return Poll::Ready(None);
}
this.idle_timeout_timer = R::timer(IDLE_TIMEOUT_CHECK_INTERVAL);
let _ = this.idle_timeout_timer.poll_unpin(cx);
}
if this.event_handle.poll_unpin(cx).is_ready() {
self.event_handle.transport_inbound_bandwidth(self.inbound_bandwidth);
self.event_handle.transport_outbound_bandwidth(self.outbound_bandwidth);
self.inbound_bandwidth = 0;
self.outbound_bandwidth = 0;
}
Poll::Pending
}
}