Skip to main content

dns_update/
http.rs

1/*
2 * Copyright Stalwart Labs LLC See the COPYING
3 * file at the top-level directory of this distribution.
4 *
5 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 * https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 * <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
8 * option. This file may not be copied, modified, or distributed
9 * except according to those terms.
10 */
11
12use std::time::Duration;
13
14use reqwest::{
15    Method,
16    header::{CONTENT_TYPE, HeaderMap, HeaderValue},
17};
18use serde::{Serialize, de::DeserializeOwned};
19
20use crate::Error;
21
22#[derive(Debug, Clone)]
23pub struct HttpClientBuilder {
24    timeout: Duration,
25    headers: HeaderMap<HeaderValue>,
26}
27
28#[derive(Debug, Default, Clone)]
29pub struct HttpClient {
30    method: Method,
31    timeout: Duration,
32    url: String,
33    headers: HeaderMap<HeaderValue>,
34    body: Option<String>,
35}
36
37impl Default for HttpClientBuilder {
38    fn default() -> Self {
39        let mut headers = HeaderMap::new();
40        headers.append(CONTENT_TYPE, HeaderValue::from_static("application/json"));
41
42        Self {
43            timeout: Duration::from_secs(30),
44            headers,
45        }
46    }
47}
48
49impl HttpClientBuilder {
50    pub fn build(&self, method: Method, url: impl Into<String>) -> HttpClient {
51        HttpClient {
52            method,
53            url: url.into(),
54            headers: self.headers.clone(),
55            body: None,
56            timeout: self.timeout,
57        }
58    }
59
60    pub fn get(&self, url: impl Into<String>) -> HttpClient {
61        self.build(Method::GET, url)
62    }
63
64    pub fn post(&self, url: impl Into<String>) -> HttpClient {
65        self.build(Method::POST, url)
66    }
67
68    pub fn put(&self, url: impl Into<String>) -> HttpClient {
69        self.build(Method::PUT, url)
70    }
71
72    pub fn delete(&self, url: impl Into<String>) -> HttpClient {
73        self.build(Method::DELETE, url)
74    }
75
76    pub fn patch(&self, url: impl Into<String>) -> HttpClient {
77        self.build(Method::PATCH, url)
78    }
79
80    pub fn with_header(mut self, name: &'static str, value: impl AsRef<str>) -> Self {
81        if let Ok(value) = HeaderValue::from_str(value.as_ref()) {
82            self.headers.append(name, value);
83        }
84        self
85    }
86
87    pub fn with_timeout(mut self, timeout: Option<Duration>) -> Self {
88        if let Some(timeout) = timeout {
89            self.timeout = timeout;
90        }
91        self
92    }
93}
94
95impl HttpClient {
96    pub fn with_header(mut self, name: &'static str, value: impl AsRef<str>) -> Self {
97        if let Ok(value) = HeaderValue::from_str(value.as_ref()) {
98            self.headers.append(name, value);
99        }
100        self
101    }
102
103    pub fn with_body<B: Serialize>(mut self, body: B) -> crate::Result<Self> {
104        match serde_json::to_string(&body) {
105            Ok(body) => {
106                self.body = Some(body);
107                Ok(self)
108            }
109            Err(err) => Err(Error::Serialize(format!(
110                "Failed to serialize request: {err}"
111            ))),
112        }
113    }
114
115    pub fn with_raw_body(mut self, body: String) -> Self {
116        self.body = Some(body);
117        self
118    }
119
120    pub async fn send<T>(self) -> crate::Result<T>
121    where
122        T: DeserializeOwned,
123    {
124        let response = self.send_raw().await?;
125        serde_json::from_slice::<T>(response.as_bytes())
126            .map_err(|err| Error::Serialize(format!("Failed to deserialize response: {err}")))
127    }
128
129    pub async fn send_raw(self) -> crate::Result<String> {
130        let mut request = reqwest::Client::builder()
131            .timeout(self.timeout)
132            .build()
133            .unwrap_or_default()
134            .request(self.method, &self.url)
135            .headers(self.headers);
136
137        if let Some(body) = self.body {
138            request = request.body(body);
139        }
140
141        let response = request
142            .send()
143            .await
144            .map_err(|err| Error::Api(format!("Failed to send request to {}: {err}", self.url)))?;
145
146        match response.status().as_u16() {
147            204 => Ok(String::new()),
148            200..=299 => response.text().await.map_err(|err| {
149                Error::Api(format!("Failed to read response from {}: {err}", self.url))
150            }),
151            400 => {
152                let text = response.text().await.map_err(|err| {
153                    Error::Api(format!("Failed to read response from {}: {err}", self.url))
154                })?;
155                Err(Error::Api(format!("BadRequest {}", text)))
156            }
157            401 => Err(Error::Unauthorized),
158            404 => Err(Error::NotFound),
159            code => Err(Error::Api(format!(
160                "Invalid HTTP response code {code}: {:?}",
161                response.error_for_status()
162            ))),
163        }
164    }
165
166    pub async fn send_with_retry<T>(self, max_retries: u32) -> crate::Result<T>
167    where
168        T: DeserializeOwned,
169    {
170        let mut attempts = 0;
171        let body = self.body;
172        loop {
173            let mut request = reqwest::Client::builder()
174                .timeout(self.timeout)
175                .build()
176                .unwrap_or_default()
177                .request(self.method.clone(), &self.url)
178                .headers(self.headers.clone());
179
180            if let Some(body) = body.as_ref() {
181                request = request.body(body.clone());
182            }
183
184            let response = request.send().await.map_err(|err| {
185                Error::Api(format!("Failed to send request to {}: {err}", self.url))
186            })?;
187
188            return match response.status().as_u16() {
189                204 => serde_json::from_str("{}").map_err(|err| {
190                    Error::Serialize(format!("Failed to create empty response: {err}"))
191                }),
192                200..=299 => {
193                    let text = response.text().await.map_err(|err| {
194                        Error::Api(format!("Failed to read response from {}: {err}", self.url))
195                    })?;
196                    serde_json::from_str(&text).map_err(|err| {
197                        Error::Serialize(format!("Failed to deserialize response: {err}"))
198                    })
199                }
200                429 if attempts < max_retries => {
201                    if let Some(retry_after) = response.headers().get("retry-after")
202                        && let Ok(seconds) = retry_after.to_str().unwrap_or("0").parse::<u64>()
203                    {
204                        tokio::time::sleep(Duration::from_secs(seconds)).await;
205                        attempts += 1;
206                        continue;
207                    }
208                    Err(Error::Api("Rate limit exceeded".to_string()))
209                }
210                400 => {
211                    let text = response.text().await.map_err(|err| {
212                        Error::Api(format!("Failed to read response from {}: {err}", self.url))
213                    })?;
214                    Err(Error::Api(format!("BadRequest {}", text)))
215                }
216                401 => Err(Error::Unauthorized),
217                404 => Err(Error::NotFound),
218                code => Err(Error::Api(format!(
219                    "Invalid HTTP response code {code}: {:?}",
220                    response.error_for_status()
221                ))),
222            };
223        }
224    }
225}