use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tracing::Instrument as _;
use crate::conf::DataStore;
use crate::io::reactor::{ConnRole, TcpTransport};
use crate::net::client::{client_loop, ClientHandler};
use crate::net::conn::Conn;
use crate::net::dispatcher::Dispatcher;
use crate::net::listener::{bind_dual_stack, BindOptions};
use crate::net::NetError;
pub struct Proxy {
listener: TcpListener,
dispatcher: Arc<dyn Dispatcher>,
data_store: DataStore,
response_capacity: usize,
}
impl Proxy {
pub fn bind<A: Into<SocketAddr>>(
addr: A,
dispatcher: Arc<dyn Dispatcher>,
) -> Result<Self, NetError> {
let listener = bind_dual_stack(addr.into(), BindOptions::default())?;
Ok(Self {
listener,
dispatcher,
data_store: DataStore::Redis,
response_capacity: 64,
})
}
#[must_use]
pub fn with_data_store(mut self, ds: DataStore) -> Self {
self.data_store = ds;
self
}
#[must_use]
pub fn with_response_capacity(mut self, n: usize) -> Self {
self.response_capacity = n.max(1);
self
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.listener.local_addr()
}
pub fn listener(&self) -> &TcpListener {
&self.listener
}
#[tracing::instrument(
name = "proxy.run",
skip_all,
fields(
local = self.listener.local_addr().map_or_else(|_| String::from("?"), |a| a.to_string()),
),
)]
pub async fn run(
self,
cancel: std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>,
) -> Result<(), NetError> {
let mut cancel = cancel;
let mut clients: Vec<JoinHandle<Result<(), NetError>>> = Vec::new();
loop {
let accept = self.listener.accept();
tokio::select! {
() = &mut cancel => break,
res = accept => {
let (sock, peer) = res?;
let _ = sock.set_nodelay(true);
let role = ConnRole::Client;
let transport = Box::new(TcpTransport::new(sock, role));
let conn = Conn::new(transport, role);
let dispatcher = Arc::clone(&self.dispatcher);
let cap = self.response_capacity;
let ds = self.data_store;
tracing::debug!(?peer, "proxy accepted client");
let accept_span = tracing::info_span!(
"client.accept",
peer = %peer,
);
let handle = tokio::spawn(
async move {
let (tx, rx) = mpsc::channel(cap);
let handler = ClientHandler::new(dispatcher, tx, ds);
client_loop(conn, handler, rx).await
}
.instrument(accept_span),
);
clients.push(handle);
}
}
clients.retain(|h| !h.is_finished());
}
for h in clients {
let _ = tokio::time::timeout(std::time::Duration::from_millis(250), h).await;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn bind_returns_local_addr() {
let proxy = Proxy::bind(
"127.0.0.1:0".parse::<SocketAddr>().unwrap(),
Arc::new(crate::net::NoopDispatcher),
)
.unwrap();
assert!(proxy.local_addr().unwrap().ip().is_loopback());
}
#[tokio::test]
async fn run_exits_on_cancel() {
let proxy = Proxy::bind(
"127.0.0.1:0".parse::<SocketAddr>().unwrap(),
Arc::new(crate::net::NoopDispatcher),
)
.unwrap();
proxy.run(Box::pin(async {})).await.unwrap();
}
}