use {
crate::{Error, body::Body, execute::ExecuteCtx},
futures::future::{self, Ready},
hyper::{
http::{Request, Response},
server::conn::AddrStream,
service::Service,
},
std::{
convert::Infallible,
future::Future,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{self, Poll},
},
tracing::{Level, event},
};
pub struct ViceroyService {
ctx: Arc<ExecuteCtx>,
}
impl ViceroyService {
pub fn new(ctx: Arc<ExecuteCtx>) -> Self {
Self { ctx }
}
fn make_service(&self, remote: &AddrStream) -> RequestService {
RequestService::new(self.ctx.clone(), remote)
}
pub async fn serve(self, addr: SocketAddr) -> Result<(), hyper::Error> {
let server = hyper::Server::bind(&addr).serve(self);
event!(Level::INFO, "Listening on http://{}", server.local_addr());
server.await?;
Ok(())
}
}
impl<'addr> Service<&'addr AddrStream> for ViceroyService {
type Response = RequestService;
type Error = Infallible;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, addr: &'addr AddrStream) -> Self::Future {
future::ok(self.make_service(addr))
}
}
#[derive(Clone)]
pub struct RequestService {
ctx: Arc<ExecuteCtx>,
local_addr: SocketAddr,
remote_addr: SocketAddr,
}
impl RequestService {
fn new(ctx: Arc<ExecuteCtx>, addr: &AddrStream) -> Self {
let local_addr = addr.local_addr();
let remote_addr = addr.remote_addr();
Self {
ctx,
local_addr,
remote_addr,
}
}
}
type ServiceFuture = dyn Future<Output = Result<Response<Body>, Error>> + Send;
impl Service<Request<hyper::Body>> for RequestService {
type Response = Response<Body>;
type Error = Error;
type Future = Pin<Box<ServiceFuture>>;
fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<hyper::Body>) -> Self::Future {
let ctx = self.ctx.clone();
let local = self.local_addr;
let remote = self.remote_addr;
Box::pin(async move {
ctx.handle_request(req, local, remote)
.await
.map(|result| result.0)
})
}
}