Skip to main content

ic_bn_lib/http/client/
clients_reqwest.rs

1use std::{
2    sync::{
3        Arc,
4        atomic::{AtomicUsize, Ordering},
5    },
6    time::{Duration, Instant},
7};
8
9use ahash::RandomState;
10use anyhow::Context;
11use async_trait::async_trait;
12use ic_bn_lib_common::{
13    traits::dns::CloneableDnsResolver,
14    types::http::{ClientOptions, Error, HttpVersion},
15};
16use moka::sync::{Cache, CacheBuilder};
17use prometheus::Registry;
18use reqwest::{Request, Response, dns::Resolve};
19use scopeguard::defer;
20use url::Url;
21
22use super::{Client, ClientStats, ClientWithStats, Metrics, Stats};
23
24// Extracts host:port from the URL
25fn extract_host(url: &Url) -> String {
26    format!(
27        "{}:{}",
28        url.host_str()
29            .and_then(|x| x.split('@').next_back())
30            .unwrap_or_default(),
31        url.port_or_known_default().unwrap_or_default()
32    )
33}
34
35/// Create a new `reqwest::Client` from `ClientOptions` and a resolver
36pub fn new<R: Resolve + 'static>(
37    opts: ClientOptions,
38    resolver: Option<R>,
39) -> Result<reqwest::Client, Error> {
40    let mut client = reqwest::Client::builder()
41        .connect_timeout(opts.timeout_connect)
42        .read_timeout(opts.timeout_read)
43        .timeout(opts.timeout)
44        .pool_idle_timeout(opts.pool_idle_timeout)
45        .tcp_nodelay(true)
46        .tcp_keepalive(opts.tcp_keepalive_delay)
47        .tcp_keepalive_interval(opts.tcp_keepalive_interval)
48        .tcp_keepalive_retries(opts.tcp_keepalive_retries)
49        .http2_keep_alive_interval(opts.http2_keepalive)
50        .http2_keep_alive_while_idle(opts.http2_keepalive_idle)
51        .http2_adaptive_window(true)
52        .user_agent(opts.user_agent)
53        .redirect(reqwest::redirect::Policy::none())
54        .no_proxy();
55
56    for (host, addr) in &opts.dns_overrides {
57        client = client.resolve(host, *addr);
58    }
59
60    match opts.http_version {
61        HttpVersion::Http1 => {
62            client = client.http1_only();
63        }
64        HttpVersion::Http2 => {
65            client = client.http2_prior_knowledge();
66        }
67        _ => {}
68    }
69
70    if let Some(v) = opts.http2_keepalive_timeout {
71        client = client.http2_keep_alive_timeout(v);
72    }
73
74    if let Some(v) = opts.pool_idle_max {
75        client = client.pool_max_idle_per_host(v);
76    }
77
78    if let Some(v) = opts.tls_config {
79        client = client.use_preconfigured_tls(v);
80    }
81
82    if let Some(v) = resolver {
83        client = client.dns_resolver(Arc::new(v));
84    }
85
86    Ok(client.build().context("unable to create reqwest client")?)
87}
88
89/// Reqwest-based HTTP client
90#[derive(Clone, Debug)]
91pub struct ReqwestClient(reqwest::Client);
92
93impl ReqwestClient {
94    pub fn new<R: Resolve + 'static>(
95        opts: ClientOptions,
96        resolver: Option<R>,
97    ) -> Result<Self, Error> {
98        Ok(Self(new(opts, resolver)?))
99    }
100}
101
102#[async_trait]
103impl Client for ReqwestClient {
104    async fn execute(&self, req: Request) -> Result<Response, reqwest::Error> {
105        self.0.execute(req).await
106    }
107}
108
109/// Reqwest-based HTTP client that pools a number of `Client`s and uses them in a round-robin fashion
110#[derive(Clone, Debug)]
111pub struct ReqwestClientRoundRobin {
112    inner: Arc<ReqwestClientRoundRobinInner>,
113}
114
115#[derive(Debug)]
116struct ReqwestClientRoundRobinInner {
117    cli: Vec<reqwest::Client>,
118    next: AtomicUsize,
119}
120
121impl ReqwestClientRoundRobin {
122    pub fn new<R: CloneableDnsResolver>(
123        opts: ClientOptions,
124        resolver: Option<R>,
125        count: usize,
126    ) -> Result<Self, Error> {
127        let inner = ReqwestClientRoundRobinInner {
128            cli: (0..count)
129                .map(|_| new(opts.clone(), resolver.clone()))
130                .collect::<Result<Vec<_>, _>>()?,
131            next: AtomicUsize::new(0),
132        };
133
134        Ok(Self {
135            inner: Arc::new(inner),
136        })
137    }
138}
139
140#[async_trait]
141impl Client for ReqwestClientRoundRobin {
142    async fn execute(&self, req: Request) -> Result<Response, reqwest::Error> {
143        let next = self.inner.next.fetch_add(1, Ordering::SeqCst) % self.inner.cli.len();
144        self.inner.cli[next].execute(req).await
145    }
146}
147
148/// Client that pools a defined number of `Client`s and picks the least loaded one for the next request.
149#[derive(Clone, Debug)]
150pub struct ReqwestClientLeastLoaded {
151    inner: Arc<Vec<ReqwestClientLeastLoadedInner>>,
152    metrics: Option<Metrics>,
153}
154
155#[derive(Debug)]
156struct ReqwestClientLeastLoadedInner {
157    cli: reqwest::Client,
158    outstanding: Cache<String, Arc<AtomicUsize>, RandomState>,
159}
160
161impl ReqwestClientLeastLoaded {
162    pub fn new<R: CloneableDnsResolver>(
163        opts: ClientOptions,
164        resolver: Option<R>,
165        count: usize,
166        registry: Option<&Registry>,
167    ) -> Result<Self, Error> {
168        let inner = (0..count)
169            .map(|_| -> Result<_, _> {
170                Ok::<_, Error>(ReqwestClientLeastLoadedInner {
171                    cli: new(opts.clone(), resolver.clone())?,
172                    // Creates a cache with some sensible max capacity to hold target hosts.
173                    // If the host isn't contacted in 10min then we remove it.
174                    // TODO should we make this configurable? Probably ok like this
175                    outstanding: CacheBuilder::new(16384)
176                        .time_to_idle(Duration::from_secs(600))
177                        .build_with_hasher(RandomState::default()),
178                })
179            })
180            .collect::<Result<Vec<_>, _>>()?;
181
182        Ok(Self {
183            inner: Arc::new(inner),
184            metrics: registry.map(Metrics::new),
185        })
186    }
187}
188
189#[async_trait]
190impl Client for ReqwestClientLeastLoaded {
191    async fn execute(&self, req: Request) -> Result<Response, reqwest::Error> {
192        // Extract host:port from the request URL
193        let host = extract_host(req.url());
194        let labels = &[&host];
195
196        self.metrics
197            .as_ref()
198            .inspect(|x| x.requests.with_label_values(labels).inc());
199
200        // Select the client with least outstanding requests for the given host
201        let (cli, counter) = self
202            .inner
203            .iter()
204            .map(|x| {
205                (
206                    &x.cli,
207                    // Get an atomic counter for the given host or create a new one
208                    x.outstanding
209                        .get_with_by_ref(&host, || Arc::new(AtomicUsize::new(0))),
210                )
211            })
212            .min_by_key(|x| x.1.load(Ordering::SeqCst))
213            .unwrap();
214
215        // The future can be cancelled so we have to use defer to make sure the counter is decreased
216        defer! {
217            counter.fetch_sub(1, Ordering::SeqCst);
218            self.metrics
219                .as_ref()
220                .inspect(|x| x.requests_inflight.with_label_values(labels).dec());
221        }
222
223        counter.fetch_add(1, Ordering::SeqCst);
224        self.metrics
225            .as_ref()
226            .inspect(|x| x.requests_inflight.with_label_values(labels).inc());
227
228        // Execute the request & observe duration
229        let start = Instant::now();
230        let result = cli.execute(req).await;
231        self.metrics.as_ref().inspect(|x| {
232            x.request_duration
233                .with_label_values(labels)
234                .observe(start.elapsed().as_secs_f64())
235        });
236
237        result
238    }
239}
240
241impl Stats for ReqwestClientLeastLoaded {
242    fn stats(&self) -> ClientStats {
243        ClientStats {
244            pool_size: self.inner.len(),
245            outstanding: self
246                .inner
247                .iter()
248                .flat_map(|x| x.outstanding.iter().map(|x| x.1.load(Ordering::SeqCst)))
249                .sum(),
250        }
251    }
252}
253
254impl ClientWithStats for ReqwestClientLeastLoaded {
255    fn to_client(self: Arc<Self>) -> Arc<dyn Client> {
256        self
257    }
258}
259
260#[cfg(test)]
261mod test {
262    use super::*;
263
264    #[test]
265    fn test_extract_host() {
266        assert_eq!(
267            extract_host(&Url::parse("https://foo:123/bar/beef").unwrap()),
268            "foo:123"
269        );
270
271        assert_eq!(
272            extract_host(&Url::parse("https://foo:443/bar/beef").unwrap()),
273            "foo:443"
274        );
275
276        assert_eq!(
277            extract_host(&Url::parse("https://foo/bar/beef").unwrap()),
278            "foo:443"
279        );
280
281        assert_eq!(
282            extract_host(&Url::parse("http://foo:80/bar/beef").unwrap()),
283            "foo:80"
284        );
285
286        assert_eq!(
287            extract_host(&Url::parse("http://foo/bar/beef").unwrap()),
288            "foo:80"
289        );
290
291        assert_eq!(
292            extract_host(&Url::parse("https://top:secret@foo:123/bar/beef").unwrap()),
293            "foo:123"
294        );
295    }
296}