use crate::{
i2cp::message::{Message, MessageType, I2CP_HEADER_SIZE},
runtime::{AsyncRead, AsyncWrite, Runtime},
};
use bytes::BytesMut;
use futures::Stream;
use alloc::{collections::VecDeque, vec, vec::Vec};
use core::{
mem,
pin::Pin,
task::{Context, Poll, Waker},
};
const LOG_TARGET: &str = "emissary::i2cp::socket";
enum ReadState {
ReadHeader {
offset: usize,
},
ReadFrame {
size: usize,
msg_type: u8,
offset: usize,
},
}
enum WriteState {
GetMessage,
SendMessage {
offset: usize,
message: BytesMut,
},
Poisoned,
}
pub struct I2cpSocket<R: Runtime> {
pending_frames: VecDeque<BytesMut>,
read_buffer: Vec<u8>,
read_state: ReadState,
stream: R::TcpStream,
waker: Option<Waker>,
write_state: WriteState,
}
impl<R: Runtime> I2cpSocket<R> {
pub fn new(stream: R::TcpStream) -> Self {
Self {
pending_frames: VecDeque::new(),
read_buffer: vec![0u8; 0xffff],
read_state: ReadState::ReadHeader { offset: 0usize },
stream,
write_state: WriteState::GetMessage,
waker: None,
}
}
pub fn send_message(&mut self, message: BytesMut) {
self.pending_frames.push_back(message);
if let Some(waker) = self.waker.take() {
waker.wake_by_ref();
}
}
}
impl<R: Runtime> Stream for I2cpSocket<R> {
type Item = Message;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = &mut *self;
let mut stream = Pin::new(&mut this.stream);
loop {
match this.read_state {
ReadState::ReadHeader { offset } => {
match stream
.as_mut()
.poll_read(cx, &mut this.read_buffer[offset..I2CP_HEADER_SIZE])
{
Poll::Pending => break,
Poll::Ready(Err(error)) => {
tracing::debug!(
target: LOG_TARGET,
?error,
"socket read error",
);
return Poll::Ready(None);
}
Poll::Ready(Ok(nread)) => {
if nread == 0 {
tracing::debug!(
target: LOG_TARGET,
?offset,
"read zero bytes from socket (header)",
);
return Poll::Ready(None);
}
if offset + nread != I2CP_HEADER_SIZE {
this.read_state = ReadState::ReadHeader {
offset: offset + nread,
};
continue;
}
let size = u32::from_be_bytes(
TryInto::<[u8; 4]>::try_into(&this.read_buffer[..4])
.expect("to succeed"),
);
let msg_type = this.read_buffer[4];
if size == 0 {
this.read_state = ReadState::ReadHeader { offset: 0usize };
let Some(msg_type) = MessageType::from_u8(msg_type) else {
tracing::warn!(
target: LOG_TARGET,
?msg_type,
"invalid message type",
);
continue;
};
let Some(message) = Message::parse::<R>(msg_type, []) else {
tracing::warn!(
target: LOG_TARGET,
?msg_type,
"failed to parse i2cp message with no payload",
);
continue;
};
return Poll::Ready(Some(message));
}
this.read_state = ReadState::ReadFrame {
size: size as usize,
msg_type,
offset: 0usize,
};
}
}
}
ReadState::ReadFrame {
size,
msg_type,
offset,
} => {
match stream.as_mut().poll_read(cx, &mut this.read_buffer[offset..size]) {
Poll::Pending => break,
Poll::Ready(Err(error)) => {
tracing::debug!(
target: LOG_TARGET,
?error,
"socket read error",
);
return Poll::Ready(None);
}
Poll::Ready(Ok(nread)) => {
if nread == 0 {
tracing::debug!(
target: LOG_TARGET,
"read zero bytes from socket (payload)",
);
return Poll::Ready(None);
}
if offset + nread < size {
this.read_state = ReadState::ReadFrame {
size,
msg_type,
offset: offset + nread,
};
continue;
}
this.read_state = ReadState::ReadHeader { offset: 0usize };
let Some(msg_type) = MessageType::from_u8(msg_type) else {
tracing::warn!(
target: LOG_TARGET,
?msg_type,
"invalid message type",
);
continue;
};
let Some(message) =
Message::parse::<R>(msg_type, &this.read_buffer[..size])
else {
tracing::warn!(
target: LOG_TARGET,
?msg_type,
"failed to parse i2cp message",
);
continue;
};
return Poll::Ready(Some(message));
}
}
}
}
}
loop {
match mem::replace(&mut this.write_state, WriteState::Poisoned) {
WriteState::GetMessage => match this.pending_frames.pop_front() {
None => {
this.write_state = WriteState::GetMessage;
break;
}
Some(message) => {
this.write_state = WriteState::SendMessage {
offset: 0usize,
message,
};
}
},
WriteState::SendMessage { offset, message } =>
match stream.as_mut().poll_write(cx, &message[offset..]) {
Poll::Pending => {
this.write_state = WriteState::SendMessage { offset, message };
break;
}
Poll::Ready(Err(_)) => return Poll::Ready(None),
Poll::Ready(Ok(0)) => {
tracing::debug!(
target: LOG_TARGET,
"wrote zero bytes to socket",
);
return Poll::Ready(None);
}
Poll::Ready(Ok(nwritten)) => match nwritten + offset == message.len() {
true => {
this.write_state = WriteState::GetMessage;
}
false => {
this.write_state = WriteState::SendMessage {
offset: offset + nwritten,
message,
};
}
},
},
WriteState::Poisoned => {
tracing::warn!(
target: LOG_TARGET,
"write state is poisoned",
);
debug_assert!(false);
return Poll::Ready(None);
}
}
}
self.waker = Some(cx.waker().clone());
Poll::Pending
}
}