ort_http/
server.rs

1use drain::Watch as Drain;
2use futures::prelude::*;
3use ort_core::{Error, Ort, Reply, Spec};
4use std::{convert::Infallible, net::SocketAddr};
5use tokio::time;
6
7#[derive(Clone, Debug)]
8pub struct Server<O> {
9    inner: O,
10}
11
12impl<O: Ort> Server<O> {
13    pub fn new(inner: O) -> Self {
14        Self { inner }
15    }
16
17    async fn handle(
18        mut self,
19        req: http::Request<hyper::Body>,
20    ) -> Result<http::Response<hyper::Body>, Error> {
21        if req.method() == http::Method::GET {
22            let mut spec = Spec::default();
23
24            if let Some(q) = req.uri().query() {
25                for kv in q.split('&') {
26                    let mut kv = kv.splitn(2, '=');
27                    match kv.next() {
28                        Some("latency_ms") => {
29                            if let Some(ms) = kv.next().and_then(|v| v.parse::<u64>().ok()) {
30                                spec.latency = time::Duration::from_millis(ms);
31                            }
32                        }
33                        Some("size") => {
34                            if let Some(sz) = kv.next().and_then(|v| v.parse::<usize>().ok()) {
35                                spec.response_size = sz;
36                            }
37                        }
38                        Some(_) | None => {}
39                    }
40                }
41            }
42
43            let Reply { data } = self.inner.ort(spec).await?;
44            return http::Response::builder()
45                .status(http::StatusCode::OK)
46                .header(http::header::CONTENT_TYPE, "application/octet-stream")
47                .body(data.into())
48                .map_err(Into::into);
49        }
50
51        http::Response::builder()
52            .status(http::StatusCode::BAD_REQUEST)
53            .body(hyper::Body::default())
54            .map_err(Into::into)
55    }
56
57    pub async fn serve(self, addr: SocketAddr, drain: Drain) -> Result<(), Error> {
58        let svc = hyper::service::make_service_fn(move |_| {
59            let handler = self.clone();
60            async move {
61                Ok::<_, Infallible>(hyper::service::service_fn(
62                    move |req: http::Request<hyper::Body>| handler.clone().handle(req),
63                ))
64            }
65        });
66
67        let (close, closed) = tokio::sync::oneshot::channel();
68        tokio::pin! {
69            let srv = hyper::Server::bind(&addr)
70                .serve(svc)
71                .with_graceful_shutdown(closed.map(|_| ()));
72        }
73
74        tokio::select! {
75            _ = (&mut srv) => {}
76            handle = drain.signaled() => {
77                let _ = close.send(());
78                handle.release_after(srv).await?;
79            }
80        }
81
82        Ok(())
83    }
84}