Skip to main content

dns_update/
http.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use std::time::Duration;
13
14use reqwest::{
15    Method,
16    header::{CONTENT_TYPE, HeaderMap, HeaderValue},
17};
18use serde::{Serialize, de::DeserializeOwned};
19
20use crate::Error;
21
22#[derive(Debug, Clone)]
23pub struct HttpClientBuilder {
24    timeout: Duration,
25    headers: HeaderMap<HeaderValue>,
26}
27
28#[derive(Debug, Default, Clone)]
29pub struct HttpClient {
30    method: Method,
31    timeout: Duration,
32    url: String,
33    headers: HeaderMap<HeaderValue>,
34    body: Option<String>,
35}
36
37impl Default for HttpClientBuilder {
38    fn default() -> Self {
39        let mut headers = HeaderMap::new();
40        headers.append(CONTENT_TYPE, HeaderValue::from_static("application/json"));
41
42        Self {
43            timeout: Duration::from_secs(30),
44            headers,
45        }
46    }
47}
48
49impl HttpClientBuilder {
50    pub fn build(&self, method: Method, url: impl Into<String>) -> HttpClient {
51        HttpClient {
52            method,
53            url: url.into(),
54            headers: self.headers.clone(),
55            body: None,
56            timeout: self.timeout,
57        }
58    }
59
60    pub fn get(&self, url: impl Into<String>) -> HttpClient {
61        self.build(Method::GET, url)
62    }
63
64    pub fn post(&self, url: impl Into<String>) -> HttpClient {
65        self.build(Method::POST, url)
66    }
67
68    pub fn put(&self, url: impl Into<String>) -> HttpClient {
69        self.build(Method::PUT, url)
70    }
71
72    pub fn delete(&self, url: impl Into<String>) -> HttpClient {
73        self.build(Method::DELETE, url)
74    }
75
76    pub fn patch(&self, url: impl Into<String>) -> HttpClient {
77        self.build(Method::PATCH, url)
78    }
79
80    pub fn with_header(mut self, name: &'static str, value: impl AsRef<str>) -> Self {
81        if let Ok(value) = HeaderValue::from_str(value.as_ref()) {
82            self.headers.append(name, value);
83        }
84        self
85    }
86
87    pub fn set_header(mut self, name: &'static str, value: impl AsRef<str>) -> Self {
88        if let Ok(value) = HeaderValue::from_str(value.as_ref()) {
89            self.headers.insert(name, value);
90        }
91        self
92    }
93
94    pub fn with_timeout(mut self, timeout: Option<Duration>) -> Self {
95        if let Some(timeout) = timeout {
96            self.timeout = timeout;
97        }
98        self
99    }
100}
101
102impl HttpClient {
103    pub fn with_header(mut self, name: &'static str, value: impl AsRef<str>) -> Self {
104        if let Ok(value) = HeaderValue::from_str(value.as_ref()) {
105            self.headers.append(name, value);
106        }
107        self
108    }
109
110    pub fn with_body<B: Serialize>(mut self, body: B) -> crate::Result<Self> {
111        match serde_json::to_string(&body) {
112            Ok(body) => {
113                self.body = Some(body);
114                Ok(self)
115            }
116            Err(err) => Err(Error::Serialize(format!(
117                "Failed to serialize request: {err}"
118            ))),
119        }
120    }
121
122    pub fn with_raw_body(mut self, body: String) -> Self {
123        self.body = Some(body);
124        self
125    }
126
127    pub async fn send<T>(self) -> crate::Result<T>
128    where
129        T: DeserializeOwned,
130    {
131        let response = self.send_raw().await?;
132        serde_json::from_slice::<T>(response.as_bytes())
133            .map_err(|err| Error::Serialize(format!("Failed to deserialize response: {err}")))
134    }
135
136    pub async fn send_raw(self) -> crate::Result<String> {
137        let mut request = reqwest::Client::builder()
138            .timeout(self.timeout)
139            .build()
140            .unwrap_or_default()
141            .request(self.method, &self.url)
142            .headers(self.headers);
143
144        if let Some(body) = self.body {
145            request = request.body(body);
146        }
147
148        let response = request
149            .send()
150            .await
151            .map_err(|err| Error::Api(format!("Failed to send request to {}: {err}", self.url)))?;
152
153        match response.status().as_u16() {
154            204 => Ok(String::new()),
155            200..=299 => response.text().await.map_err(|err| {
156                Error::Api(format!("Failed to read response from {}: {err}", self.url))
157            }),
158            400 => {
159                let text = response.text().await.map_err(|err| {
160                    Error::Api(format!("Failed to read response from {}: {err}", self.url))
161                })?;
162                Err(Error::Api(format!("BadRequest {}", text)))
163            }
164            401 => Err(Error::Unauthorized),
165            404 => Err(Error::NotFound),
166            code => Err(Error::Api(format!(
167                "Invalid HTTP response code {code}: {:?}",
168                response.error_for_status()
169            ))),
170        }
171    }
172
173    pub async fn send_with_retry<T>(self, max_retries: u32) -> crate::Result<T>
174    where
175        T: DeserializeOwned,
176    {
177        let mut attempts = 0;
178        let body = self.body;
179        loop {
180            let mut request = reqwest::Client::builder()
181                .timeout(self.timeout)
182                .build()
183                .unwrap_or_default()
184                .request(self.method.clone(), &self.url)
185                .headers(self.headers.clone());
186
187            if let Some(body) = body.as_ref() {
188                request = request.body(body.clone());
189            }
190
191            let response = request.send().await.map_err(|err| {
192                Error::Api(format!("Failed to send request to {}: {err}", self.url))
193            })?;
194
195            return match response.status().as_u16() {
196                204 => serde_json::from_str("{}").map_err(|err| {
197                    Error::Serialize(format!("Failed to create empty response: {err}"))
198                }),
199                200..=299 => {
200                    let text = response.text().await.map_err(|err| {
201                        Error::Api(format!("Failed to read response from {}: {err}", self.url))
202                    })?;
203                    serde_json::from_str(&text).map_err(|err| {
204                        Error::Serialize(format!("Failed to deserialize response: {err}"))
205                    })
206                }
207                429 if attempts < max_retries => {
208                    if let Some(retry_after) = response.headers().get("retry-after")
209                        && let Ok(seconds) = retry_after.to_str().unwrap_or("0").parse::<u64>()
210                    {
211                        tokio::time::sleep(Duration::from_secs(seconds)).await;
212                        attempts += 1;
213                        continue;
214                    }
215                    Err(Error::Api("Rate limit exceeded".to_string()))
216                }
217                400 => {
218                    let text = response.text().await.map_err(|err| {
219                        Error::Api(format!("Failed to read response from {}: {err}", self.url))
220                    })?;
221                    Err(Error::Api(format!("BadRequest {}", text)))
222                }
223                401 => Err(Error::Unauthorized),
224                404 => Err(Error::NotFound),
225                code => Err(Error::Api(format!(
226                    "Invalid HTTP response code {code}: {:?}",
227                    response.error_for_status()
228                ))),
229            };
230        }
231    }
232}