use std::io::{BufReader, BufWriter};
use std::net::{SocketAddr, TcpListener, TcpStream};
use comp_cat_rs::effect::io::Io;
use crate::codec;
use crate::error::Error;
use crate::protocol::{Envelope, RequestId};
use crate::serve::Serve;
#[derive(Debug, Clone, Copy)]
pub struct ListenAddr(SocketAddr);
impl ListenAddr {
#[must_use]
pub fn new(addr: SocketAddr) -> Self {
Self(addr)
}
#[must_use]
pub fn addr(self) -> SocketAddr {
self.0
}
}
#[must_use]
pub fn serve<S: Serve>(addr: ListenAddr, service: S) -> Io<Error, core::convert::Infallible> {
Io::suspend(move || {
let listener = TcpListener::bind(addr.addr())?;
listener
.incoming()
.try_for_each(|stream_result| {
let stream = stream_result?;
let svc = service.clone();
std::thread::Builder::new()
.spawn(move || {
let _: Result<(), Error> = handle_connection(stream, &svc);
})
.map(|_handle| ())
})
.map_err(Error::from)
.and(Err(Error::ConnectionClosed))
})
}
fn handle_connection<S: Serve>(stream: TcpStream, service: &S) -> Result<(), Error> {
let read_stream = stream.try_clone()?;
let mut reader = BufReader::new(read_stream);
let mut writer = BufWriter::new(stream);
std::iter::from_fn(|| Some(process_one_request(&mut reader, &mut writer, service)))
.try_for_each(|result| match result {
Ok(()) => Ok(()),
Err(Error::ConnectionClosed) => Err(Error::ConnectionClosed),
Err(e) => Err(e),
})
.or_else(|e| match e {
Error::ConnectionClosed => Ok(()),
other => Err(other),
})
}
fn process_one_request<S: Serve>(
reader: &mut impl std::io::Read,
writer: &mut impl std::io::Write,
service: &S,
) -> Result<(), Error> {
let envelope: Envelope = codec::decode(reader)?;
match envelope {
Envelope::Request { id, payload } => {
let response_envelope = deserialize_and_handle(id, &payload, service)?;
codec::encode(writer, &response_envelope)
}
Envelope::Response { .. } | Envelope::Error { .. } => Err(Error::Server {
message: "unexpected non-request envelope from client".to_owned(),
}),
}
}
fn deserialize_and_handle<S: Serve>(
id: RequestId,
payload: &str,
service: &S,
) -> Result<Envelope, Error> {
serde_json::from_str::<S::Request>(payload)
.map_err(Error::from_deserialize)
.and_then(|request| service.handle(request).run())
.and_then(|response| {
serde_json::to_string(&response)
.map_err(Error::from_serialize)
.map(|resp_payload| Envelope::Response {
id,
payload: resp_payload,
})
})
.or_else(|e| {
Ok(Envelope::Error {
id,
message: e.to_string(),
})
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client;
use crate::serve::Serve;
use serde::{Deserialize, Serialize};
#[derive(Clone)]
struct EchoService;
#[derive(Serialize, Deserialize, Debug, PartialEq)]
struct EchoRequest {
message: String,
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
struct EchoResponse {
echo: String,
}
impl Serve for EchoService {
type Request = EchoRequest;
type Response = EchoResponse;
fn handle(&self, request: EchoRequest) -> Io<Error, EchoResponse> {
Io::pure(EchoResponse {
echo: request.message,
})
}
}
#[test]
fn end_to_end_echo() -> Result<(), Error> {
let listener = TcpListener::bind("127.0.0.1:0")?;
let addr = listener.local_addr()?;
drop(listener);
let listen_addr = ListenAddr::new(addr);
let server_addr = client::ServerAddr::new(addr);
let server_handle = std::thread::spawn(move || {
let _: Result<core::convert::Infallible, Error> = serve(listen_addr, EchoService).run();
});
std::thread::sleep(std::time::Duration::from_millis(50));
let request = EchoRequest {
message: "hello".to_owned(),
};
let response: EchoResponse = client::call(server_addr, request).run()?;
assert_eq!(response.echo, "hello");
let request2 = EchoRequest {
message: "world".to_owned(),
};
let response2: EchoResponse = client::call(server_addr, request2).run()?;
assert_eq!(response2.echo, "world");
drop(server_handle);
Ok(())
}
}