use std::{
future::Future,
hash::RandomState,
io,
net::SocketAddr,
pin::pin,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
time::Duration,
};
use http::{Request, Response, StatusCode, header};
use http_body_util::{BodyExt, Full};
use hyper::body::{Bytes, Incoming};
use papaya::{HashMapRef, LocalGuard};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::TcpStream,
sync::{mpsc, oneshot},
time::{sleep, timeout},
};
use uuid::Uuid;
const CONNECTION_TIMEOUT: Duration = Duration::from_secs(30);
const READ_TIMEOUT: Duration = Duration::from_millis(50);
pub trait UpstreamConnector: Clone + Send + Sync + 'static {
type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
fn connect(&self, addr: SocketAddr) -> impl Future<Output = io::Result<Self::Stream>> + Send;
}
#[derive(Clone, Default)]
pub struct TcpConnector;
impl UpstreamConnector for TcpConnector {
type Stream = TcpStream;
async fn connect(&self, addr: SocketAddr) -> io::Result<TcpStream> {
TcpStream::connect(addr).await
}
}
pub struct Sessions<C: UpstreamConnector = TcpConnector> {
sessions: papaya::HashMap<Uuid, mpsc::Sender<SessionCommand>>,
configuration: Configuration,
connector: C,
successful_transfers: AtomicU64,
}
impl<C: UpstreamConnector> std::fmt::Debug for Sessions<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sessions")
.field("sessions", &self.sessions)
.field("configuration", &self.configuration)
.finish_non_exhaustive()
}
}
#[derive(Debug)]
pub struct Configuration {
pub upstream: SocketAddr,
pub session_header_key: String,
}
impl Sessions<TcpConnector> {
pub fn new(upstream: SocketAddr, session_header_key: String) -> Arc<Self> {
Self::with_connector(upstream, session_header_key, TcpConnector)
}
}
impl<C: UpstreamConnector> Sessions<C> {
pub fn with_connector(
upstream: SocketAddr,
session_header_key: String,
connector: C,
) -> Arc<Self> {
let sessions = Sessions {
configuration: Configuration {
upstream,
session_header_key,
},
sessions: Default::default(),
connector,
successful_transfers: AtomicU64::new(0),
};
Arc::new(sessions)
}
pub async fn handle_request(
self: Arc<Self>,
request: Request<Incoming>,
) -> Response<Full<Bytes>> {
let Some(session_id) = request
.headers()
.get(&self.configuration.session_header_key)
.and_then(|value| Uuid::try_parse_ascii(value.as_ref()).ok())
else {
return Self::handle_session_error();
};
let Ok(body) = request.collect().await.map(|b| b.to_bytes()) else {
return Self::handle_session_error();
};
self.handle_request_inner(session_id, body).await
}
async fn handle_request_inner(
self: Arc<Self>,
session: Uuid,
data: Bytes,
) -> Response<Full<Bytes>> {
let cmd_tx = {
let map = self.sessions.pin();
match map.get(&session) {
Some(tx) => tx.clone(),
None => self.clone().handle_new_session(session, map),
}
};
return self
.clone()
.handle_existing_session_request(&cmd_tx, data)
.await;
}
async fn handle_existing_session_request(
self: Arc<Self>,
cmd_tx: &mpsc::Sender<SessionCommand>,
data: Bytes,
) -> Response<Full<Bytes>> {
let Ok(body) = SessionCommand::send(data, cmd_tx).await else {
log::error!("Failed to send command to session");
return Self::handle_session_error();
};
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/octet-stream")
.body(Full::new(body))
.unwrap()
}
fn handle_new_session(
self: Arc<Self>,
new_session: Uuid,
session_map: HashMapRef<
'_,
Uuid,
mpsc::Sender<SessionCommand>,
RandomState,
LocalGuard<'_>,
>,
) -> mpsc::Sender<SessionCommand> {
let sessions = self.clone();
let session_id = new_session;
let (cmd_tx, cmd_rx) = mpsc::channel(1);
session_map.insert(new_session, cmd_tx.clone());
tokio::spawn(async move {
let Ok(mut session) = Session::connect(cmd_rx, session_id, sessions).await else {
return;
};
session.run().await;
});
cmd_tx
}
fn handle_session_error() -> Response<Full<Bytes>> {
Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::new()))
.unwrap()
}
pub fn take_successful_transfers(&self) -> u64 {
self.successful_transfers.swap(0, Ordering::Relaxed)
}
pub fn remove_session(self: Arc<Self>, session: &Uuid) {
log::debug!("Removing session {}", session);
let _ = self.sessions.pin().remove(session);
}
}
struct Session<C: UpstreamConnector> {
connection: C::Stream,
cmd_rx: mpsc::Receiver<SessionCommand>,
session_id: Uuid,
sessions: Arc<Sessions<C>>,
counted_transfer: bool,
}
impl<C: UpstreamConnector> Session<C> {
pub async fn connect(
cmd_rx: mpsc::Receiver<SessionCommand>,
session_id: Uuid,
sessions: Arc<Sessions<C>>,
) -> io::Result<Self> {
let connection = match sessions
.connector
.connect(sessions.configuration.upstream)
.await
{
Ok(conn) => conn,
Err(err) => {
log::error!("Failed to connect to upstream server: {}", err);
sessions.remove_session(&session_id);
return Err(err);
}
};
Ok(Self {
connection,
session_id,
cmd_rx,
sessions,
counted_transfer: false,
})
}
pub async fn run(&mut self) {
let Self {
connection,
cmd_rx,
sessions,
session_id,
counted_transfer,
} = self;
let mut deadline = pin!(sleep(CONNECTION_TIMEOUT));
let mut read_buffer = vec![0u8; 1024 * 64];
loop {
let deadline_ref = deadline.as_mut();
tokio::select! {
maybe_cmd = cmd_rx.recv() => {
let Some(mut cmd) = maybe_cmd else {
return;
};
if let Some(tx_bytes) = cmd.take_payload() {
log::debug!("Received {} bytes for session {}", tx_bytes.len(), session_id);
if let Err(err) = connection.write_all(&tx_bytes).await {
log::error!("Failed to send data to upstream: {err}");
}
}
let response_bytes = match timeout(READ_TIMEOUT, connection.read(&mut read_buffer)).await {
Ok(Ok(bytes_read)) => {
deadline.set(sleep(CONNECTION_TIMEOUT));
if bytes_read > 0 && !*counted_transfer {
*counted_transfer = true;
sessions.successful_transfers.fetch_add(1, Ordering::Relaxed);
}
Bytes::copy_from_slice(&read_buffer[..bytes_read])
},
Ok(Err(connection_error)) => {
log::error!("Failed to receive data from upstream {connection_error}");
return;
},
Err(_timeout) => Bytes::new(),
};
cmd.respond_with(response_bytes);
},
_ = deadline_ref => {
return;
}
}
}
}
}
impl<C: UpstreamConnector> Drop for Session<C> {
fn drop(&mut self) {
self.sessions.clone().remove_session(&self.session_id);
}
}
#[derive(Debug)]
struct SessionCommand {
tx_payload: Option<Bytes>,
return_tx: oneshot::Sender<Bytes>,
}
impl SessionCommand {
async fn send(payload: Bytes, cmd_tx: &mpsc::Sender<SessionCommand>) -> anyhow::Result<Bytes> {
let (cmd, rx) = Self::new(payload);
cmd_tx.send(cmd).await?;
let payload = rx.await?;
Ok(payload)
}
fn new(tx_payload: Bytes) -> (Self, oneshot::Receiver<Bytes>) {
let (return_tx, rx) = oneshot::channel();
(
Self {
tx_payload: Some(tx_payload),
return_tx,
},
rx,
)
}
fn take_payload(&mut self) -> Option<Bytes> {
self.tx_payload.take()
}
fn respond_with(self, received_bytes: Bytes) {
let _ = self.return_tx.send(received_bytes);
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::DuplexStream;
#[derive(Clone)]
struct MockConnector {
streams: Arc<tokio::sync::Mutex<Vec<DuplexStream>>>,
}
impl MockConnector {
fn new(streams: Vec<DuplexStream>) -> Self {
Self {
streams: Arc::new(tokio::sync::Mutex::new(streams)),
}
}
}
impl UpstreamConnector for MockConnector {
type Stream = DuplexStream;
async fn connect(&self, _addr: SocketAddr) -> io::Result<DuplexStream> {
self.streams.lock().await.pop().ok_or_else(|| {
io::Error::new(io::ErrorKind::ConnectionRefused, "no streams available")
})
}
}
fn dummy_addr() -> SocketAddr {
"127.0.0.1:1234".parse().unwrap()
}
#[tokio::test(start_paused = true)]
async fn session_removed_after_connection_timeout() {
let (upstream, _upstream_remote) = tokio::io::duplex(8192);
let connector = MockConnector::new(vec![upstream]);
let sessions = Sessions::with_connector(dummy_addr(), "X-Session".to_string(), connector);
let session_id = Uuid::new_v4();
let response = sessions
.clone()
.handle_request_inner(session_id, Bytes::from("hello"))
.await;
assert_eq!(response.status(), StatusCode::OK);
assert!(
sessions.sessions.pin().get(&session_id).is_some(),
"Session should exist after first request"
);
tokio::time::advance(CONNECTION_TIMEOUT + Duration::from_secs(1)).await;
tokio::time::sleep(Duration::from_millis(1)).await;
assert!(
sessions.sessions.pin().get(&session_id).is_none(),
"Session should be removed after connection timeout"
);
}
#[tokio::test(start_paused = true)]
async fn read_timeout_returns_empty_body() {
let (upstream, _upstream_remote) = tokio::io::duplex(8192);
let connector = MockConnector::new(vec![upstream]);
let sessions = Sessions::with_connector(dummy_addr(), "X-Session".to_string(), connector);
let session_id = Uuid::new_v4();
let response = sessions
.clone()
.handle_request_inner(session_id, Bytes::from("ping"))
.await;
assert_eq!(response.status(), StatusCode::OK);
let body = response.into_body().collect().await.unwrap().to_bytes();
assert!(
body.is_empty(),
"Body should be empty when upstream does not respond within read timeout"
);
}
#[tokio::test(start_paused = true)]
async fn successful_transfer_counter_incremented() {
let (upstream, mut upstream_remote) = tokio::io::duplex(8192);
let connector = MockConnector::new(vec![upstream]);
let sessions = Sessions::with_connector(dummy_addr(), "X-Session".to_string(), connector);
let session_id = Uuid::new_v4();
assert_eq!(sessions.take_successful_transfers(), 0);
tokio::spawn(async move {
let mut buf = [0u8; 64];
let _ = upstream_remote.read(&mut buf).await;
upstream_remote.write_all(b"response").await.unwrap();
loop {
match upstream_remote.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(_) => {
upstream_remote.write_all(b"response2").await.unwrap();
}
}
}
});
let response = sessions
.clone()
.handle_request_inner(session_id, Bytes::from("hello"))
.await;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(sessions.take_successful_transfers(), 1);
let response = sessions
.clone()
.handle_request_inner(session_id, Bytes::from("hello again"))
.await;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(sessions.take_successful_transfers(), 0);
}
}