1use drain::Watch as Drain;
2use futures::prelude::*;
3use ort_core::{Error, Ort, Reply, Spec};
4use std::{convert::Infallible, net::SocketAddr};
5use tokio::time;
6
7#[derive(Clone, Debug)]
8pub struct Server<O> {
9 inner: O,
10}
11
12impl<O: Ort> Server<O> {
13 pub fn new(inner: O) -> Self {
14 Self { inner }
15 }
16
17 async fn handle(
18 mut self,
19 req: http::Request<hyper::Body>,
20 ) -> Result<http::Response<hyper::Body>, Error> {
21 if req.method() == http::Method::GET {
22 let mut spec = Spec::default();
23
24 if let Some(q) = req.uri().query() {
25 for kv in q.split('&') {
26 let mut kv = kv.splitn(2, '=');
27 match kv.next() {
28 Some("latency_ms") => {
29 if let Some(ms) = kv.next().and_then(|v| v.parse::<u64>().ok()) {
30 spec.latency = time::Duration::from_millis(ms);
31 }
32 }
33 Some("size") => {
34 if let Some(sz) = kv.next().and_then(|v| v.parse::<usize>().ok()) {
35 spec.response_size = sz;
36 }
37 }
38 Some(_) | None => {}
39 }
40 }
41 }
42
43 let Reply { data } = self.inner.ort(spec).await?;
44 return http::Response::builder()
45 .status(http::StatusCode::OK)
46 .header(http::header::CONTENT_TYPE, "application/octet-stream")
47 .body(data.into())
48 .map_err(Into::into);
49 }
50
51 http::Response::builder()
52 .status(http::StatusCode::BAD_REQUEST)
53 .body(hyper::Body::default())
54 .map_err(Into::into)
55 }
56
57 pub async fn serve(self, addr: SocketAddr, drain: Drain) -> Result<(), Error> {
58 let svc = hyper::service::make_service_fn(move |_| {
59 let handler = self.clone();
60 async move {
61 Ok::<_, Infallible>(hyper::service::service_fn(
62 move |req: http::Request<hyper::Body>| handler.clone().handle(req),
63 ))
64 }
65 });
66
67 let (close, closed) = tokio::sync::oneshot::channel();
68 tokio::pin! {
69 let srv = hyper::Server::bind(&addr)
70 .serve(svc)
71 .with_graceful_shutdown(closed.map(|_| ()));
72 }
73
74 tokio::select! {
75 _ = (&mut srv) => {}
76 handle = drain.signaled() => {
77 let _ = close.send(());
78 handle.release_after(srv).await?;
79 }
80 }
81
82 Ok(())
83 }
84}