Skip to main content

icann_rdap_client/http/
wrapped.rs

1//! Wrapped Client.
2
3#![allow(mismatched_lifetime_syntaxes)] // TODO see if this can be removed with a buildstructor upgrade
4
5use std::collections::HashSet;
6
7use icann_rdap_common::prelude::ExtensionId;
8pub use reqwest::{header::HeaderValue, Client as ReqwestClient, Error as ReqwestError};
9use {
10    icann_rdap_common::httpdata::HttpData,
11    reqwest::header::{
12        ACCESS_CONTROL_ALLOW_ORIGIN, CACHE_CONTROL, CONTENT_TYPE, EXPIRES, LOCATION, RETRY_AFTER,
13        STRICT_TRANSPORT_SECURITY,
14    },
15};
16
17use {
18    super::{create_reqwest_client, ReqwestClientConfig},
19    crate::RdapClientError,
20};
21
22#[cfg(not(target_arch = "wasm32"))]
23use {
24    super::create_reqwest_client_with_addr, chrono::DateTime, chrono::Utc, reqwest::StatusCode,
25    std::net::SocketAddr, tracing::debug, tracing::info,
26};
27
28/// Used by the request functions.
29#[derive(Clone, Copy)]
30pub struct RequestOptions {
31    pub(crate) max_retry_secs: u32,
32    pub(crate) def_retry_secs: u32,
33    pub(crate) max_retries: u16,
34}
35
36impl Default for RequestOptions {
37    fn default() -> Self {
38        Self {
39            max_retry_secs: 120,
40            def_retry_secs: 60,
41            max_retries: 1,
42        }
43    }
44}
45
46/// Default set of extension IDs.
47///
48/// This is the default set of extension ids used in the RDAP
49/// media type "exts_list" parameter.
50pub fn default_exts_list() -> HashSet<ExtensionId> {
51    let mut exts_list = HashSet::new();
52    exts_list.insert(ExtensionId::Cidr0);
53    exts_list.insert(ExtensionId::Exts);
54    exts_list.insert(ExtensionId::JsContact);
55    exts_list.insert(ExtensionId::SimpleRedaction);
56    exts_list.insert(ExtensionId::Redacted);
57    exts_list.insert(ExtensionId::Ttl0);
58    exts_list
59}
60
61/// Configures the HTTP client.
62#[derive(Clone)]
63pub struct ClientConfig {
64    /// Config for the Reqwest client.
65    client_config: ReqwestClientConfig,
66
67    /// Request options.
68    request_options: RequestOptions,
69}
70
71impl Default for ClientConfig {
72    fn default() -> Self {
73        Self {
74            client_config: ReqwestClientConfig {
75                exts_list: default_exts_list(),
76                ..Default::default()
77            },
78            request_options: Default::default(),
79        }
80    }
81}
82
83#[buildstructor::buildstructor]
84impl ClientConfig {
85    #[builder]
86    pub fn new(
87        user_agent_suffix: Option<String>,
88        https_only: Option<bool>,
89        accept_invalid_host_names: Option<bool>,
90        accept_invalid_certificates: Option<bool>,
91        follow_redirects: Option<bool>,
92        host: Option<HeaderValue>,
93        origin: Option<HeaderValue>,
94        timeout_secs: Option<u64>,
95        max_retry_secs: Option<u32>,
96        def_retry_secs: Option<u32>,
97        max_retries: Option<u16>,
98        exts_list: Option<HashSet<ExtensionId>>,
99    ) -> Self {
100        let default_cc = ReqwestClientConfig::default();
101        let default_ro = RequestOptions::default();
102        Self {
103            client_config: ReqwestClientConfig {
104                user_agent_suffix: user_agent_suffix.unwrap_or(default_cc.user_agent_suffix),
105                https_only: https_only.unwrap_or(default_cc.https_only),
106                accept_invalid_host_names: accept_invalid_host_names
107                    .unwrap_or(default_cc.accept_invalid_host_names),
108                accept_invalid_certificates: accept_invalid_certificates
109                    .unwrap_or(default_cc.accept_invalid_certificates),
110                follow_redirects: follow_redirects.unwrap_or(default_cc.follow_redirects),
111                host,
112                origin,
113                timeout_secs: timeout_secs.unwrap_or(default_cc.timeout_secs),
114                exts_list: exts_list.unwrap_or(default_exts_list()),
115            },
116            request_options: RequestOptions {
117                max_retry_secs: max_retry_secs.unwrap_or(default_ro.max_retry_secs),
118                def_retry_secs: def_retry_secs.unwrap_or(default_ro.def_retry_secs),
119                max_retries: max_retries.unwrap_or(default_ro.max_retries),
120            },
121        }
122    }
123
124    #[builder(entry = "from_config", exit = "build")]
125    pub fn new_from_config(
126        &self,
127        user_agent_suffix: Option<String>,
128        https_only: Option<bool>,
129        accept_invalid_host_names: Option<bool>,
130        accept_invalid_certificates: Option<bool>,
131        follow_redirects: Option<bool>,
132        host: Option<HeaderValue>,
133        origin: Option<HeaderValue>,
134        timeout_secs: Option<u64>,
135        max_retry_secs: Option<u32>,
136        def_retry_secs: Option<u32>,
137        max_retries: Option<u16>,
138        exts_list: Option<HashSet<ExtensionId>>,
139    ) -> Self {
140        Self {
141            client_config: ReqwestClientConfig {
142                user_agent_suffix: user_agent_suffix
143                    .unwrap_or(self.client_config.user_agent_suffix.clone()),
144                https_only: https_only.unwrap_or(self.client_config.https_only),
145                accept_invalid_host_names: accept_invalid_host_names
146                    .unwrap_or(self.client_config.accept_invalid_host_names),
147                accept_invalid_certificates: accept_invalid_certificates
148                    .unwrap_or(self.client_config.accept_invalid_certificates),
149                follow_redirects: follow_redirects.unwrap_or(self.client_config.follow_redirects),
150                host: host.map_or(self.client_config.host.clone(), Some),
151                origin: origin.map_or(self.client_config.origin.clone(), Some),
152                timeout_secs: timeout_secs.unwrap_or(self.client_config.timeout_secs),
153                exts_list: exts_list.unwrap_or(self.client_config.exts_list.clone()),
154            },
155            request_options: RequestOptions {
156                max_retry_secs: max_retry_secs.unwrap_or(self.request_options.max_retry_secs),
157                def_retry_secs: def_retry_secs.unwrap_or(self.request_options.def_retry_secs),
158                max_retries: max_retries.unwrap_or(self.request_options.max_retries),
159            },
160        }
161    }
162}
163
164/// A wrapper around Reqwest client to give additional features when used with the request functions.
165pub struct Client {
166    /// The reqwest client.
167    pub(crate) reqwest_client: ReqwestClient,
168
169    /// Request options.
170    pub(crate) request_options: RequestOptions,
171}
172
173impl Client {
174    pub fn new(reqwest_client: ReqwestClient, request_options: RequestOptions) -> Self {
175        Self {
176            reqwest_client,
177            request_options,
178        }
179    }
180}
181
182/// Creates a wrapped HTTP client. The wrapped
183/// client holds its own connection pools, so in many
184/// uses cases creating only one client per process is
185/// necessary.
186pub fn create_client(config: &ClientConfig) -> Result<Client, RdapClientError> {
187    let client = create_reqwest_client(&config.client_config)?;
188    Ok(Client::new(client, config.request_options))
189}
190
191/// Creates a wrapped HTTP client.
192/// This will direct the underlying client to connect to a specific socket.
193#[cfg(not(target_arch = "wasm32"))]
194pub fn create_client_with_addr(
195    config: &ClientConfig,
196    domain: &str,
197    addr: SocketAddr,
198) -> Result<Client, RdapClientError> {
199    let client = create_reqwest_client_with_addr(&config.client_config, domain, addr)?;
200    Ok(Client::new(client, config.request_options))
201}
202
203pub(crate) struct WrappedResponse {
204    pub(crate) http_data: HttpData,
205    pub(crate) text: String,
206}
207
208pub(crate) async fn wrapped_request(
209    request_uri: &str,
210    client: &Client,
211) -> Result<WrappedResponse, ReqwestError> {
212    // send request and loop for possible retries
213    #[allow(unused_mut)] //because of wasm32 exclusion below
214    let mut response = client.reqwest_client.get(request_uri).send().await?;
215
216    // this doesn't work on wasm32 because tokio doesn't work on wasm
217    #[cfg(not(target_arch = "wasm32"))]
218    {
219        let mut tries: u16 = 0;
220        loop {
221            debug!("HTTP version: {:?}", response.version());
222            // don't repeat the request
223            if !matches!(response.status(), StatusCode::TOO_MANY_REQUESTS) {
224                break;
225            }
226            // loop if HTTP 429
227            let retry_after_header = response
228                .headers()
229                .get(RETRY_AFTER)
230                .map(|value| value.to_str().unwrap().to_string());
231            let retry_after = if let Some(rt) = retry_after_header {
232                info!("Server says too many requests and to retry-after '{rt}'.");
233                rt
234            } else {
235                info!("Server says too many requests but does not offer 'retry-after' value.");
236                client.request_options.def_retry_secs.to_string()
237            };
238            let mut wait_time_seconds = if let Ok(date) = DateTime::parse_from_rfc2822(&retry_after)
239            {
240                (date.with_timezone(&Utc) - Utc::now()).num_seconds() as u64
241            } else if let Ok(seconds) = retry_after.parse::<u64>() {
242                seconds
243            } else {
244                info!(
245                    "Unable to parse retry-after header value. Using {}",
246                    client.request_options.def_retry_secs
247                );
248                client.request_options.def_retry_secs.into()
249            };
250            if wait_time_seconds == 0 {
251                info!("Given {wait_time_seconds} for retry-after. Does not make sense.");
252                wait_time_seconds = client.request_options.def_retry_secs as u64;
253            }
254            if wait_time_seconds > client.request_options.max_retry_secs as u64 {
255                info!(
256                    "Server is asking to wait longer than configured max of {}.",
257                    client.request_options.max_retry_secs
258                );
259                wait_time_seconds = client.request_options.max_retry_secs as u64;
260            }
261            info!("Waiting {wait_time_seconds} seconds to retry.");
262            tokio::time::sleep(tokio::time::Duration::from_secs(wait_time_seconds + 1)).await;
263            tries += 1;
264            if tries > client.request_options.max_retries {
265                info!("Max query retries reached.");
266                break;
267            } else {
268                // send the query again
269                response = client.reqwest_client.get(request_uri).send().await?;
270            }
271        }
272    }
273
274    // throw an error if not 200 OK
275    //let response = response.error_for_status()?;
276
277    // get the response
278    let content_type = response
279        .headers()
280        .get(CONTENT_TYPE)
281        .map(|value| value.to_str().unwrap().to_string());
282    let expires = response
283        .headers()
284        .get(EXPIRES)
285        .map(|value| value.to_str().unwrap().to_string());
286    let cache_control = response
287        .headers()
288        .get(CACHE_CONTROL)
289        .map(|value| value.to_str().unwrap().to_string());
290    let location = response
291        .headers()
292        .get(LOCATION)
293        .map(|value| value.to_str().unwrap().to_string());
294    let access_control_allow_origin = response
295        .headers()
296        .get(ACCESS_CONTROL_ALLOW_ORIGIN)
297        .map(|value| value.to_str().unwrap().to_string());
298    let strict_transport_security = response
299        .headers()
300        .get(STRICT_TRANSPORT_SECURITY)
301        .map(|value| value.to_str().unwrap().to_string());
302    let retry_after = response
303        .headers()
304        .get(RETRY_AFTER)
305        .map(|value| value.to_str().unwrap().to_string());
306    let content_length = response.content_length();
307    let status_code = response.status().as_u16();
308    let url = response.url().to_owned();
309    let text = response.text().await?;
310
311    let http_data = HttpData::now()
312        .status_code(status_code)
313        .and_location(location)
314        .and_content_length(content_length)
315        .and_content_type(content_type)
316        .scheme(url.scheme())
317        .host(
318            url.host_str()
319                .expect("URL has no host. This shouldn't happen.")
320                .to_owned(),
321        )
322        .and_expires(expires)
323        .and_cache_control(cache_control)
324        .and_access_control_allow_origin(access_control_allow_origin)
325        .and_strict_transport_security(strict_transport_security)
326        .and_retry_after(retry_after)
327        .request_uri(request_uri)
328        .build();
329
330    Ok(WrappedResponse { http_data, text })
331}