use crate::{
crypto::{base64_encode, SigningPrivateKey},
error::{ConnectionError, Error},
primitives::Destination,
runtime::Runtime,
sam::{
parser::{DestinationContext, HostKind, SamCommand, SamVersion, SessionKind},
socket::SamSocket,
},
};
use bytes::{BufMut, BytesMut};
use futures::{FutureExt, StreamExt};
use hashbrown::HashMap;
use rand::Rng;
use alloc::{boxed::Box, format, string::String, sync::Arc};
use core::{
fmt,
future::Future,
mem,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
const LOG_TARGET: &str = "emissary::sam::pending::connection";
const KEEP_ALIVE_TIMEOUT: Duration = Duration::from_secs(10);
const ELGAMAL_KEY_LEN: usize = 256usize;
pub enum ConnectionKind<R: Runtime> {
Session {
session_id: Arc<str>,
socket: Box<SamSocket<R>>,
destination: Box<DestinationContext>,
version: SamVersion,
session_kind: SessionKind,
options: HashMap<String, String>,
},
Stream {
session_id: Arc<str>,
socket: Box<SamSocket<R>>,
version: SamVersion,
host: HostKind,
options: HashMap<String, String>,
},
Accept {
session_id: Arc<str>,
socket: Box<SamSocket<R>>,
version: SamVersion,
options: HashMap<String, String>,
},
Forward {
session_id: Arc<str>,
socket: Box<SamSocket<R>>,
version: SamVersion,
port: u16,
options: HashMap<String, String>,
},
}
impl<R: Runtime> fmt::Debug for ConnectionKind<R> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Session {
version,
session_id,
session_kind,
options,
..
} => f
.debug_struct("ConnectionKind::Session")
.field("session_id", &session_id)
.field("version", &version)
.field("session_kind", &session_kind)
.field("options", &options)
.finish_non_exhaustive(),
Self::Stream {
session_id,
version,
..
} => f
.debug_struct("ConnectionKind::Stream")
.field("session_id", &session_id)
.field("version", &version)
.finish_non_exhaustive(),
Self::Accept {
session_id,
version,
..
} => f
.debug_struct("ConnectionKind::Accept")
.field("session_id", &session_id)
.field("version", &version)
.finish_non_exhaustive(),
Self::Forward {
session_id,
version,
..
} => f
.debug_struct("ConnectionKind::Forward")
.field("session_id", &session_id)
.field("version", &version)
.finish_non_exhaustive(),
}
}
}
enum PendingConnectionState<R: Runtime> {
AwaitingHandshake {
socket: Box<SamSocket<R>>,
},
Handshaked {
socket: Box<SamSocket<R>>,
version: SamVersion,
},
Poisoned,
}
pub struct PendingSamConnection<R: Runtime> {
state: PendingConnectionState<R>,
keep_alive_timer: R::Timer,
}
impl<R: Runtime> PendingSamConnection<R> {
pub fn new(stream: R::TcpStream) -> Self {
Self {
state: PendingConnectionState::AwaitingHandshake {
socket: Box::new(SamSocket::new(stream)),
},
keep_alive_timer: R::timer(KEEP_ALIVE_TIMEOUT),
}
}
}
impl<R: Runtime> Future for PendingSamConnection<R> {
type Output = crate::Result<ConnectionKind<R>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
match mem::replace(&mut self.state, PendingConnectionState::Poisoned) {
PendingConnectionState::AwaitingHandshake { mut socket } => match socket
.poll_next_unpin(cx)
{
Poll::Pending => {
self.state = PendingConnectionState::AwaitingHandshake { socket };
break;
}
Poll::Ready(None) => {
tracing::debug!(
target: LOG_TARGET,
"client closed socket",
);
return Poll::Ready(Err(Error::Connection(ConnectionError::SocketClosed)));
}
Poll::Ready(Some(SamCommand::Hello { max, .. })) => {
let version = match max {
Some(SamVersion::V33) => {
tracing::debug!(
target: LOG_TARGET,
"v3.3 not supported",
);
SamVersion::V32
}
Some(max) => max,
None => SamVersion::V32,
};
tracing::debug!(
target: LOG_TARGET,
?version,
"client connected"
);
socket.send_message(
format!("HELLO REPLY RESULT=OK VERSION={version}\n")
.as_bytes()
.to_vec(),
);
self.state = PendingConnectionState::Handshaked { version, socket };
self.keep_alive_timer = R::timer(KEEP_ALIVE_TIMEOUT);
}
Poll::Ready(Some(command)) => {
tracing::debug!(
target: LOG_TARGET,
?command,
"received an unexpected command, expected `HELLO`",
);
return Poll::Ready(Err(Error::InvalidState));
}
},
PendingConnectionState::Handshaked {
mut socket,
version,
} => match socket.poll_next_unpin(cx) {
Poll::Pending => {
self.state = PendingConnectionState::Handshaked { socket, version };
break;
}
Poll::Ready(None) => {
tracing::debug!(
target: LOG_TARGET,
"client closed socket",
);
return Poll::Ready(Err(Error::Connection(ConnectionError::SocketClosed)));
}
Poll::Ready(Some(SamCommand::CreateSession {
session_id,
session_kind,
destination,
options,
})) => {
tracing::info!(
target: LOG_TARGET,
%session_id,
?session_kind,
?destination,
"create session"
);
return Poll::Ready(Ok(ConnectionKind::Session {
session_id: Arc::from(session_id),
destination,
socket,
version,
session_kind,
options,
}));
}
Poll::Ready(Some(SamCommand::Connect {
session_id,
host,
options,
})) => {
tracing::info!(
target: LOG_TARGET,
%session_id,
"connect to destination"
);
return Poll::Ready(Ok(ConnectionKind::Stream {
session_id: Arc::from(session_id),
socket,
version,
host,
options,
}));
}
Poll::Ready(Some(SamCommand::Accept {
session_id,
options,
})) => {
tracing::info!(
target: LOG_TARGET,
%session_id,
"accept inbound connection"
);
return Poll::Ready(Ok(ConnectionKind::Accept {
session_id: Arc::from(session_id),
socket,
version,
options,
}));
}
Poll::Ready(Some(SamCommand::Forward {
session_id,
port,
options,
})) => {
tracing::info!(
target: LOG_TARGET,
%session_id,
?port,
"forward inbound connections",
);
return Poll::Ready(Ok(ConnectionKind::Forward {
session_id: Arc::from(session_id),
socket,
port,
version,
options,
}));
}
Poll::Ready(Some(SamCommand::NamingLookup { name })) => {
tracing::debug!(
target: LOG_TARGET,
?version,
?name,
"destination lookup",
);
socket.send_message(
format!("NAMING REPLY RESULT=KEY_NOT_FOUND NAME={name}\n")
.as_bytes()
.to_vec(),
);
self.state = PendingConnectionState::Handshaked { version, socket };
}
Poll::Ready(Some(SamCommand::GenerateDestination)) => {
tracing::debug!(
target: LOG_TARGET,
?version,
"generate destination",
);
let (signing_key, destination) = {
let signing_key = SigningPrivateKey::random(R::rng());
let destination = Destination::new::<R>(signing_key.public());
(signing_key, destination)
};
let (privkey, destination) = {
let mut out =
BytesMut::with_capacity(destination.serialized_len() + 2 * 32);
let destination = destination.serialize();
out.put_slice(&destination);
out.put_slice(&{
{
let mut bytes = [0u8; ELGAMAL_KEY_LEN];
R::rng().fill_bytes(&mut bytes);
bytes
}
});
out.put_slice(signing_key.as_ref());
(base64_encode(out), base64_encode(&destination))
};
socket.send_message(
format!("DEST REPLY PUB={destination} PRIV={privkey}\n")
.as_bytes()
.to_vec(),
);
self.state = PendingConnectionState::Handshaked { version, socket };
}
Poll::Ready(Some(command)) => {
tracing::debug!(
target: LOG_TARGET,
?command,
"received an unexpected command, expected `SESSION`/`STREAM`",
);
return Poll::Ready(Err(Error::InvalidState));
}
},
PendingConnectionState::Poisoned => {
tracing::warn!(
target: LOG_TARGET,
"pending connection state has been poisoned",
);
debug_assert!(false);
return Poll::Ready(Err(Error::InvalidState));
}
}
}
if self.keep_alive_timer.poll_unpin(cx).is_ready() {
tracing::debug!(
target: LOG_TARGET,
"keep-alive timer expired, closing connection",
);
return Poll::Ready(Err(Error::Connection(ConnectionError::KeepAliveTimeout)));
}
Poll::Pending
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::{
mock::{MockRuntime, MockTcpStream},
TcpStream as _,
};
use std::time::Duration;
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::TcpListener,
};
#[tokio::test]
async fn client_closes_socket() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let address = listener.local_addr().unwrap();
let (stream1, stream2) = tokio::join!(listener.accept(), MockTcpStream::connect(address));
stream1.unwrap().0.shutdown().await.unwrap();
match PendingSamConnection::<MockRuntime>::new(stream2.unwrap()).await {
Err(Error::Connection(ConnectionError::SocketClosed)) => {}
_ => panic!("invalid result"),
}
}
#[tokio::test(start_paused = true)]
async fn keep_alive_timeout() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let address = listener.local_addr().unwrap();
let (_stream1, stream2) = tokio::join!(listener.accept(), MockTcpStream::connect(address));
match PendingSamConnection::<MockRuntime>::new(stream2.unwrap()).await {
Err(Error::Connection(ConnectionError::KeepAliveTimeout)) => {}
_ => panic!("invalid result"),
}
}
#[tokio::test(start_paused = true)]
async fn keep_alive_timeout_after_handshake() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let address = listener.local_addr().unwrap();
let (stream1, stream2) = tokio::join!(listener.accept(), MockTcpStream::connect(address));
let mut connection = PendingSamConnection::<MockRuntime>::new(stream2.unwrap());
let mut stream = stream1.unwrap().0;
stream.write_all(b"HELLO VERSION\n").await.unwrap();
loop {
futures::future::poll_fn(|cx| match connection.poll_unpin(cx) {
Poll::Pending => Poll::Ready(()),
_ => panic!("invalid return value"),
})
.await;
match connection.state {
PendingConnectionState::Handshaked {
version: SamVersion::V32,
..
} => break,
_ => {}
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
let mut reader = BufReader::new(stream);
let mut response = String::new();
reader.read_line(&mut response).await.unwrap();
assert_eq!(response, "HELLO REPLY RESULT=OK VERSION=3.2\n");
match connection.await {
Err(Error::Connection(ConnectionError::KeepAliveTimeout)) => {}
_ => panic!("invalid result"),
}
}
#[tokio::test(start_paused = true)]
async fn client_requests_no_version() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let address = listener.local_addr().unwrap();
let (stream1, stream2) = tokio::join!(listener.accept(), MockTcpStream::connect(address));
let mut connection = PendingSamConnection::<MockRuntime>::new(stream2.unwrap());
let mut stream = stream1.unwrap().0;
stream.write_all(b"HELLO VERSION\n").await.unwrap();
loop {
futures::future::poll_fn(|cx| match connection.poll_unpin(cx) {
Poll::Pending => Poll::Ready(()),
_ => panic!("invalid return value"),
})
.await;
match connection.state {
PendingConnectionState::Handshaked {
version: SamVersion::V32,
..
} => break,
_ => {}
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
let mut reader = BufReader::new(stream);
let mut response = String::new();
reader.read_line(&mut response).await.unwrap();
assert_eq!(response, "HELLO REPLY RESULT=OK VERSION=3.2\n");
}
#[tokio::test(start_paused = true)]
async fn client_requests_max_version() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let address = listener.local_addr().unwrap();
let (stream1, stream2) = tokio::join!(listener.accept(), MockTcpStream::connect(address));
let mut connection = PendingSamConnection::<MockRuntime>::new(stream2.unwrap());
let mut stream = stream1.unwrap().0;
stream.write_all(b"HELLO VERSION MAX=3.1\n").await.unwrap();
loop {
futures::future::poll_fn(|cx| match connection.poll_unpin(cx) {
Poll::Pending => Poll::Ready(()),
_ => panic!("invalid return value"),
})
.await;
match connection.state {
PendingConnectionState::Handshaked {
version: SamVersion::V31,
..
} => break,
_ => {}
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
let mut reader = BufReader::new(stream);
let mut response = String::new();
reader.read_line(&mut response).await.unwrap();
assert_eq!(response, "HELLO REPLY RESULT=OK VERSION=3.1\n");
}
#[tokio::test(start_paused = true)]
async fn client_requests_min_version() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let address = listener.local_addr().unwrap();
let (stream1, stream2) = tokio::join!(listener.accept(), MockTcpStream::connect(address));
let mut connection = PendingSamConnection::<MockRuntime>::new(stream2.unwrap());
let mut stream = stream1.unwrap().0;
stream.write_all(b"HELLO VERSION MIN=3.1\n").await.unwrap();
loop {
futures::future::poll_fn(|cx| match connection.poll_unpin(cx) {
Poll::Pending => Poll::Ready(()),
_ => panic!("invalid return value"),
})
.await;
match connection.state {
PendingConnectionState::Handshaked {
version: SamVersion::V32,
..
} => break,
_ => {}
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
let mut reader = BufReader::new(stream);
let mut response = String::new();
reader.read_line(&mut response).await.unwrap();
assert_eq!(response, "HELLO REPLY RESULT=OK VERSION=3.2\n");
}
#[tokio::test(start_paused = true)]
async fn session_create() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let address = listener.local_addr().unwrap();
let (stream1, stream2) = tokio::join!(listener.accept(), MockTcpStream::connect(address));
let mut connection = PendingSamConnection::<MockRuntime>::new(stream2.unwrap());
let mut stream = stream1.unwrap().0;
stream.write_all(b"HELLO VERSION\n").await.unwrap();
loop {
futures::future::poll_fn(|cx| match connection.poll_unpin(cx) {
Poll::Pending => Poll::Ready(()),
_ => panic!("invalid return value"),
})
.await;
match connection.state {
PendingConnectionState::Handshaked {
version: SamVersion::V32,
..
} => break,
_ => {}
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
let mut reader = BufReader::new(stream);
let mut response = String::new();
reader.read_line(&mut response).await.unwrap();
assert_eq!(response, "HELLO REPLY RESULT=OK VERSION=3.2\n");
let mut stream = reader.into_inner();
stream
.write_all(b"SESSION CREATE STYLE=STREAM ID=test DESTINATION=TRANSIENT\n")
.await
.unwrap();
match tokio::time::timeout(Duration::from_secs(5), connection).await.unwrap() {
Ok(ConnectionKind::Session {
session_id,
version: SamVersion::V32,
session_kind: SessionKind::Stream,
..
}) => {
assert_eq!(&*session_id, "test");
}
Ok(kind) => panic!("invalid connection kind: {kind:?}"),
Err(error) => panic!("failed to create session: {error:?}"),
}
}
#[tokio::test]
async fn send_sesssion_create_before_handshake() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let address = listener.local_addr().unwrap();
let (stream1, stream2) = tokio::join!(listener.accept(), MockTcpStream::connect(address));
let connection = PendingSamConnection::<MockRuntime>::new(stream2.unwrap());
let mut stream = stream1.unwrap().0;
stream
.write_all(b"SESSION CREATE STYLE=STREAM ID=test DESTINATION=TRANSIENT\n")
.await
.unwrap();
match tokio::time::timeout(Duration::from_secs(5), connection).await.unwrap() {
Err(Error::InvalidState) => {}
Ok(kind) => panic!("session succeeded: {kind:?}"),
Err(error) => panic!("invalid error: {error:?}"),
}
}
}