use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::RwLock;
use tokio::io::{self, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::task::JoinSet;
use tracing::{debug, error, info};
pub struct NlbFrontendBuilder {
listen_addr: Option<SocketAddr>,
backends: Vec<SocketAddr>,
}
impl NlbFrontendBuilder {
pub fn listen_addr(mut self, addr: SocketAddr) -> Self {
self.listen_addr = Some(addr);
self
}
pub fn backend(mut self, addr: SocketAddr) -> Self {
self.backends.push(addr);
self
}
pub fn backends(mut self, addrs: impl IntoIterator<Item = SocketAddr>) -> Self {
self.backends.extend(addrs);
self
}
pub fn build(self) -> NlbFrontend {
assert!(
self.listen_addr.is_some(),
"NlbFrontend requires a listen address"
);
assert!(
!self.backends.is_empty(),
"NlbFrontend requires at least one backend"
);
NlbFrontend {
listen_addr: self.listen_addr.unwrap(),
backends: self.backends,
}
}
}
pub struct NlbFrontend {
listen_addr: SocketAddr,
backends: Vec<SocketAddr>,
}
impl NlbFrontend {
pub fn builder() -> NlbFrontendBuilder {
NlbFrontendBuilder {
listen_addr: None,
backends: vec![],
}
}
pub async fn run(self) -> Result<RunningNlbFrontend, io::Error> {
let listener = TcpListener::bind(self.listen_addr).await?;
let listen_addr = listener.local_addr()?;
info!("NLB frontend listening on {}", listen_addr);
let shared = Arc::new(NlbShared {
backends: RwLock::new(self.backends),
connection_tasks: tokio::sync::Mutex::new(JoinSet::new()),
});
let shared_for_handle = shared.clone();
let handle = tokio::task::spawn(async move {
loop {
match listener.accept().await {
Ok((stream, peer_addr)) => {
let connection_shared = shared.clone();
shared.connection_tasks.lock().await.spawn(async move {
let result =
Self::handle_connection(stream, peer_addr, &connection_shared)
.await;
if let Err(e) = result {
debug!("NLB connection from {} ended: {}", peer_addr, e);
}
});
}
Err(e) => {
error!("NLB accept error: {}", e);
break;
}
}
}
});
Ok(RunningNlbFrontend {
listen_addr,
handle,
shared: shared_for_handle,
})
}
async fn handle_connection(
driver_stream: TcpStream,
driver_addr: SocketAddr,
shared: &NlbShared,
) -> Result<(), io::Error> {
let backend_addr = {
let backends = shared
.backends
.read()
.expect("NlbShared backends RwLock poisoned");
if backends.is_empty() {
return Err(io::Error::new(
io::ErrorKind::NotConnected,
"NLB has no backends configured",
));
}
let idx = rand::random_range(0..backends.len());
backends[idx]
};
debug!(
"NLB routing connection from {} to backend {}",
driver_addr, backend_addr
);
let (mut driver_read, mut driver_write) = driver_stream.into_split();
let backend_tcp = TcpStream::connect(backend_addr).await?;
let (mut backend_read, mut backend_write) = backend_tcp.into_split();
let d2b = async {
let r = io::copy(&mut driver_read, &mut backend_write).await;
let _ = backend_write.shutdown().await;
r
};
let b2d = async {
let r = io::copy(&mut backend_read, &mut driver_write).await;
let _ = driver_write.shutdown().await;
r
};
match tokio::try_join!(d2b, b2d) {
Ok((to_backend, to_driver)) => {
debug!(
"NLB connection {} -> {} finished: {} bytes to backend, {} bytes to driver",
driver_addr, backend_addr, to_backend, to_driver
);
}
Err(e) => {
debug!(
"NLB connection {} -> {} error: {}",
driver_addr, backend_addr, e
);
}
}
Ok(())
}
}
struct NlbShared {
backends: RwLock<Vec<SocketAddr>>,
connection_tasks: tokio::sync::Mutex<JoinSet<()>>,
}
pub struct RunningNlbFrontend {
listen_addr: SocketAddr,
handle: tokio::task::JoinHandle<()>,
shared: Arc<NlbShared>,
}
impl RunningNlbFrontend {
pub fn listen_addr(&self) -> SocketAddr {
self.listen_addr
}
pub fn set_backends(&self, backends: Vec<SocketAddr>) {
info!(
"NLB on {} updating backends to: {:?}",
self.listen_addr, backends
);
let mut guard = self
.shared
.backends
.write()
.expect("NlbShared backends RwLock poisoned");
*guard = backends;
}
pub async fn finish(self) {
self.handle.abort();
let _ = self.handle.await;
self.shared.connection_tasks.lock().await.shutdown().await;
info!("NLB frontend on {} has shut down", self.listen_addr);
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
#[tokio::test]
async fn test_nlb_random_balancing() {
let backend1 = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend2 = TcpListener::bind("127.0.0.1:0").await.unwrap();
let b1_addr = backend1.local_addr().unwrap();
let b2_addr = backend2.local_addr().unwrap();
fn spawn_echo(listener: TcpListener, tag: u8) {
tokio::spawn(async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(s) => s,
Err(_) => break,
};
let tag = tag;
tokio::spawn(async move {
let mut buf = [0u8; 64];
loop {
let n = match stream.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
let mut resp = vec![tag];
resp.extend_from_slice(&buf[..n]);
if stream.write_all(&resp).await.is_err() {
break;
}
}
});
}
});
}
spawn_echo(backend1, b'1');
spawn_echo(backend2, b'2');
let nlb = NlbFrontend::builder()
.listen_addr("127.0.0.1:0".parse().unwrap())
.backend(b1_addr)
.backend(b2_addr)
.build();
let running = nlb.run().await.unwrap();
let nlb_addr = running.listen_addr();
let mut saw_backend = [false; 2];
for _ in 0..20 {
let mut conn = TcpStream::connect(nlb_addr).await.unwrap();
conn.write_all(b"hi").await.unwrap();
let mut buf = [0u8; 3];
conn.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf[1..], b"hi");
match buf[0] {
b'1' => saw_backend[0] = true,
b'2' => saw_backend[1] = true,
other => panic!("Unexpected tag byte: {}", other),
}
}
assert!(
saw_backend[0] && saw_backend[1],
"Expected both backends to be hit, but saw_backend = {:?}",
saw_backend
);
running.finish().await;
}
}