ort_http/
client.rs

1use ort_core::{Error, MakeOrt, Ort, Reply, Spec};
2use std::convert::TryFrom;
3use tokio::time::Duration;
4
5#[derive(Clone)]
6pub struct MakeHttp {
7    concurrency: Option<usize>,
8    connect_timeout: Duration,
9}
10
11#[derive(Clone)]
12pub struct Http {
13    client: hyper::Client<hyper::client::HttpConnector>,
14    target: http::Uri,
15}
16
17impl MakeHttp {
18    pub fn new(concurrency: Option<usize>, connect_timeout: Duration) -> Self {
19        Self {
20            concurrency,
21            connect_timeout,
22        }
23    }
24}
25
26#[async_trait::async_trait]
27impl MakeOrt<http::Uri> for MakeHttp {
28    type Ort = Http;
29
30    async fn make_ort(&mut self, target: http::Uri) -> Result<Http, Error> {
31        let mut connect = hyper::client::HttpConnector::new();
32        connect.set_connect_timeout(Some(self.connect_timeout));
33        connect.set_nodelay(true);
34        connect.set_reuse_address(true);
35
36        let mut builder = hyper::Client::builder();
37        if let Some(c) = self.concurrency {
38            builder.pool_max_idle_per_host(c);
39        }
40        let client = builder.build(connect);
41
42        Ok(Http { client, target })
43    }
44}
45
46#[async_trait::async_trait]
47impl Ort for Http {
48    async fn ort(
49        &mut self,
50        Spec {
51            latency,
52            response_size,
53        }: Spec,
54    ) -> Result<Reply, Error> {
55        let mut uri = http::Uri::builder();
56        if let Some(s) = self.target.scheme() {
57            uri = uri.scheme(s.clone());
58        }
59
60        if let Some(a) = self.target.authority() {
61            uri = uri.authority(a.clone());
62        }
63
64        uri = {
65            let latency_ms = latency.as_millis() as i64;
66
67            tracing::trace!(latency_ms, response_size);
68            uri.path_and_query(
69                http::uri::PathAndQuery::try_from(
70                    format!("/?latency_ms={}&size={}", latency_ms, response_size).as_str(),
71                )
72                .expect("query must be valid"),
73            )
74        };
75
76        let rsp = self
77            .client
78            .request(
79                http::Request::builder()
80                    .uri(uri.build().unwrap())
81                    .body(hyper::Body::default())
82                    .unwrap(),
83            )
84            .await?;
85
86        let data = hyper::body::to_bytes(rsp.into_body()).await?;
87
88        Ok(Reply { data })
89    }
90}