use std::future::Future;
use std::io;
use std::net::SocketAddr;
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use crate::runtime::{Context, ProcessHandle, Runtime};
impl Runtime {
pub async fn listen<F, Fut>(
&self,
addr: impl ToSocketAddrs,
handler: F,
) -> io::Result<(SocketAddr, ProcessHandle)>
where
F: Fn(Context, TcpStream) -> Fut + Clone + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let listener = TcpListener::bind(addr).await?;
let local = listener.local_addr()?;
let rt = self.clone();
let acceptor = self.spawn(move |_ctx| async move {
while let Ok((stream, _peer)) = listener.accept().await {
let handler = handler.clone();
rt.spawn(move |ctx| handler(ctx, stream));
}
});
Ok((local, acceptor))
}
pub async fn connect(&self, addr: impl ToSocketAddrs) -> io::Result<TcpStream> {
TcpStream::connect(addr).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[tokio::test]
async fn echoes_over_a_process_per_connection() {
let rt = Runtime::new();
let (addr, _acceptor) = rt
.listen("127.0.0.1:0", |_ctx, mut stream| async move {
let mut buf = [0u8; 5];
stream.read_exact(&mut buf).await.unwrap();
stream.write_all(&buf).await.unwrap();
})
.await
.unwrap();
let mut client = rt.connect(addr).await.unwrap();
client.write_all(b"hello").await.unwrap();
let mut buf = [0u8; 5];
client.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn every_connection_is_its_own_live_process() {
let rt = Runtime::new();
let (addr, acceptor) = rt
.listen("127.0.0.1:0", |mut ctx, _stream| async move {
let _ = ctx.recv().await; })
.await
.unwrap();
let mut clients = Vec::new();
for _ in 0..3 {
clients.push(rt.connect(addr).await.unwrap());
}
for _ in 0..1000 {
if rt.process_count() == 4 {
break;
}
tokio::task::yield_now().await;
}
assert_eq!(rt.process_count(), 4);
drop(clients);
acceptor.kill();
}
#[tokio::test]
async fn killing_the_acceptor_closes_the_port() {
let rt = Runtime::new();
let (addr, acceptor) = rt
.listen("127.0.0.1:0", |_ctx, _stream| async {})
.await
.unwrap();
acceptor.kill();
acceptor.join().await; assert!(rt.connect(addr).await.is_err());
}
}