Skip to main content

dns_update/
http.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use std::time::Duration;
13
14use reqwest::{
15    Method,
16    header::{CONTENT_TYPE, HeaderMap, HeaderValue},
17};
18use serde::{Serialize, de::DeserializeOwned};
19
20use crate::Error;
21
22#[derive(Debug, Clone)]
23pub struct HttpClientBuilder {
24    timeout: Duration,
25    headers: HeaderMap<HeaderValue>,
26}
27
28#[derive(Debug, Clone)]
29pub struct HttpClient {
30    headers: HeaderMap<HeaderValue>,
31    client: reqwest::Client,
32}
33
34#[derive(Debug, Clone)]
35pub struct HttpRequest {
36    method: Method,
37    url: String,
38    headers: HeaderMap<HeaderValue>,
39    body: Option<String>,
40    client: reqwest::Client,
41}
42
43impl Default for HttpClientBuilder {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl HttpClientBuilder {
50    pub fn new() -> Self {
51        let mut headers = HeaderMap::new();
52        headers.append(CONTENT_TYPE, HeaderValue::from_static("application/json"));
53
54        Self {
55            timeout: Duration::from_secs(30),
56            headers,
57        }
58    }
59
60    pub fn with_header(mut self, name: &'static str, value: impl AsRef<str>) -> Self {
61        if let Ok(value) = HeaderValue::from_str(value.as_ref()) {
62            self.headers.append(name, value);
63        }
64        self
65    }
66
67    pub fn set_header(mut self, name: &'static str, value: impl AsRef<str>) -> Self {
68        if let Ok(value) = HeaderValue::from_str(value.as_ref()) {
69            self.headers.insert(name, value);
70        }
71        self
72    }
73
74    pub fn with_timeout(mut self, timeout: Option<Duration>) -> Self {
75        if let Some(timeout) = timeout {
76            self.timeout = timeout;
77        }
78        self
79    }
80
81    pub fn build(self) -> HttpClient {
82        let client = reqwest::Client::builder()
83            .timeout(self.timeout)
84            .build()
85            .unwrap_or_default();
86        HttpClient {
87            headers: self.headers,
88            client,
89        }
90    }
91}
92
93impl HttpClient {
94    pub fn request(&self, method: Method, url: impl Into<String>) -> HttpRequest {
95        HttpRequest {
96            method,
97            url: url.into(),
98            headers: self.headers.clone(),
99            body: None,
100            client: self.client.clone(),
101        }
102    }
103
104    pub fn get(&self, url: impl Into<String>) -> HttpRequest {
105        self.request(Method::GET, url)
106    }
107
108    pub fn post(&self, url: impl Into<String>) -> HttpRequest {
109        self.request(Method::POST, url)
110    }
111
112    pub fn put(&self, url: impl Into<String>) -> HttpRequest {
113        self.request(Method::PUT, url)
114    }
115
116    pub fn delete(&self, url: impl Into<String>) -> HttpRequest {
117        self.request(Method::DELETE, url)
118    }
119
120    pub fn patch(&self, url: impl Into<String>) -> HttpRequest {
121        self.request(Method::PATCH, url)
122    }
123}
124
125impl HttpRequest {
126    pub fn with_header(mut self, name: &'static str, value: impl AsRef<str>) -> Self {
127        if let Ok(value) = HeaderValue::from_str(value.as_ref()) {
128            self.headers.append(name, value);
129        }
130        self
131    }
132
133    pub fn set_header(mut self, name: &'static str, value: impl AsRef<str>) -> Self {
134        if let Ok(value) = HeaderValue::from_str(value.as_ref()) {
135            self.headers.insert(name, value);
136        }
137        self
138    }
139
140    pub fn with_body<B: Serialize>(mut self, body: B) -> crate::Result<Self> {
141        match serde_json::to_string(&body) {
142            Ok(body) => {
143                self.body = Some(body);
144                Ok(self)
145            }
146            Err(err) => Err(Error::Serialize(format!(
147                "Failed to serialize request: {err}"
148            ))),
149        }
150    }
151
152    pub fn with_raw_body(mut self, body: String) -> Self {
153        self.body = Some(body);
154        self
155    }
156
157    pub async fn send<T>(self) -> crate::Result<T>
158    where
159        T: DeserializeOwned,
160    {
161        let response = self.send_raw().await?;
162        serde_json::from_slice::<T>(response.as_bytes())
163            .map_err(|err| Error::Serialize(format!("Failed to deserialize response: {err}")))
164    }
165
166    pub async fn send_raw(self) -> crate::Result<String> {
167        self.send_raw_with_headers().await.map(|(body, _)| body)
168    }
169
170    pub async fn send_raw_with_headers(self) -> crate::Result<(String, HeaderMap<HeaderValue>)> {
171        let mut request = self
172            .client
173            .request(self.method, &self.url)
174            .headers(self.headers);
175
176        if let Some(body) = self.body {
177            request = request.body(body);
178        }
179
180        let response = request
181            .send()
182            .await
183            .map_err(|err| Error::Api(format!("Failed to send request to {}: {err}", self.url)))?;
184
185        let code = response.status().as_u16();
186        let headers = response.headers().clone();
187        match code {
188            204 => Ok((String::new(), headers)),
189            200..=299 => response
190                .text()
191                .await
192                .map(|body| (body, headers))
193                .map_err(|err| {
194                    Error::Api(format!("Failed to read response from {}: {err}", self.url))
195                }),
196            401 => Err(Error::Unauthorized),
197            404 => Err(Error::NotFound),
198            _ => {
199                let text = response.text().await.unwrap_or_default();
200                Err(Error::Api(http_status_message(code, &text)))
201            }
202        }
203    }
204
205    pub async fn send_with_retry<T>(self, max_retries: u32) -> crate::Result<T>
206    where
207        T: DeserializeOwned,
208    {
209        let mut attempts = 0;
210        let body = self.body;
211        loop {
212            let mut request = self
213                .client
214                .request(self.method.clone(), &self.url)
215                .headers(self.headers.clone());
216
217            if let Some(body) = body.as_ref() {
218                request = request.body(body.clone());
219            }
220
221            let response = request.send().await.map_err(|err| {
222                Error::Api(format!("Failed to send request to {}: {err}", self.url))
223            })?;
224
225            let code = response.status().as_u16();
226            return match code {
227                204 => serde_json::from_str("{}").map_err(|err| {
228                    Error::Serialize(format!("Failed to create empty response: {err}"))
229                }),
230                200..=299 => {
231                    let text = response.text().await.map_err(|err| {
232                        Error::Api(format!("Failed to read response from {}: {err}", self.url))
233                    })?;
234                    let parse_target = if text.trim().is_empty() { "{}" } else { &text };
235                    serde_json::from_str(parse_target).map_err(|err| {
236                        Error::Serialize(format!("Failed to deserialize response: {err}"))
237                    })
238                }
239                429 | 503 if attempts < max_retries => {
240                    let delay = retry_after(response.headers())
241                        .unwrap_or_else(|| Duration::from_secs(1u64 << attempts.min(6)));
242                    tokio::time::sleep(delay.min(MAX_RETRY_DELAY)).await;
243                    attempts += 1;
244                    continue;
245                }
246                401 => Err(Error::Unauthorized),
247                404 => Err(Error::NotFound),
248                _ => {
249                    let text = response.text().await.unwrap_or_default();
250                    Err(Error::Api(http_status_message(code, &text)))
251                }
252            };
253        }
254    }
255}
256
257const MAX_RETRY_DELAY: Duration = Duration::from_secs(60);
258
259fn retry_after(headers: &HeaderMap<HeaderValue>) -> Option<Duration> {
260    headers
261        .get("retry-after")?
262        .to_str()
263        .ok()?
264        .parse::<u64>()
265        .ok()
266        .map(Duration::from_secs)
267}
268
269fn http_status_message(code: u16, body: &str) -> String {
270    let trimmed = body.trim();
271    if code == 400 {
272        if trimmed.is_empty() {
273            "BadRequest".to_string()
274        } else {
275            format!("BadRequest {trimmed}")
276        }
277    } else if trimmed.is_empty() {
278        format!("HTTP {code}")
279    } else {
280        format!("HTTP {code}: {trimmed}")
281    }
282}