1use ort_core::{Error, MakeOrt, Ort, Reply, Spec};
2use std::convert::TryFrom;
3use tokio::time::Duration;
4
5#[derive(Clone)]
6pub struct MakeHttp {
7 concurrency: Option<usize>,
8 connect_timeout: Duration,
9}
10
11#[derive(Clone)]
12pub struct Http {
13 client: hyper::Client<hyper::client::HttpConnector>,
14 target: http::Uri,
15}
16
17impl MakeHttp {
18 pub fn new(concurrency: Option<usize>, connect_timeout: Duration) -> Self {
19 Self {
20 concurrency,
21 connect_timeout,
22 }
23 }
24}
25
26#[async_trait::async_trait]
27impl MakeOrt<http::Uri> for MakeHttp {
28 type Ort = Http;
29
30 async fn make_ort(&mut self, target: http::Uri) -> Result<Http, Error> {
31 let mut connect = hyper::client::HttpConnector::new();
32 connect.set_connect_timeout(Some(self.connect_timeout));
33 connect.set_nodelay(true);
34 connect.set_reuse_address(true);
35
36 let mut builder = hyper::Client::builder();
37 if let Some(c) = self.concurrency {
38 builder.pool_max_idle_per_host(c);
39 }
40 let client = builder.build(connect);
41
42 Ok(Http { client, target })
43 }
44}
45
46#[async_trait::async_trait]
47impl Ort for Http {
48 async fn ort(
49 &mut self,
50 Spec {
51 latency,
52 response_size,
53 }: Spec,
54 ) -> Result<Reply, Error> {
55 let mut uri = http::Uri::builder();
56 if let Some(s) = self.target.scheme() {
57 uri = uri.scheme(s.clone());
58 }
59
60 if let Some(a) = self.target.authority() {
61 uri = uri.authority(a.clone());
62 }
63
64 uri = {
65 let latency_ms = latency.as_millis() as i64;
66
67 tracing::trace!(latency_ms, response_size);
68 uri.path_and_query(
69 http::uri::PathAndQuery::try_from(
70 format!("/?latency_ms={}&size={}", latency_ms, response_size).as_str(),
71 )
72 .expect("query must be valid"),
73 )
74 };
75
76 let rsp = self
77 .client
78 .request(
79 http::Request::builder()
80 .uri(uri.build().unwrap())
81 .body(hyper::Body::default())
82 .unwrap(),
83 )
84 .await?;
85
86 let data = hyper::body::to_bytes(rsp.into_body()).await?;
87
88 Ok(Reply { data })
89 }
90}