1use std::path::Path;
45use std::time::Duration;
46use reqwest::{Client, Response, header};
47use futures::StreamExt;
48use tokio::fs::File;
49use tokio::io::AsyncWriteExt;
50use bytes::Bytes;
51use thiserror::Error;
52use serde::{Serialize, de::DeserializeOwned};
53
54#[derive(Error, Debug)]
56pub enum HttpClientError {
57 #[error("Request error: {0}")]
58 RequestError(#[from] reqwest::Error),
59
60 #[error("IO error: {0}")]
61 IoError(#[from] std::io::Error),
62
63 #[error("Invalid URL: {0}")]
64 UrlError(String),
65
66 #[error("Timeout reached")]
67 TimeoutError,
68
69 #[error("Download failed: {0}")]
70 DownloadError(String),
71
72 #[error("Resume not supported by server")]
73 ResumeNotSupported,
74}
75
76#[derive(Debug, Clone)]
78pub struct HttpClient {
79 client: Client,
80 base_url: Option<String>,
81}
82
83impl Default for HttpClient {
84 fn default() -> Self {
85 Self::new()
86 }
87}
88
89impl HttpClient {
90 pub fn new() -> Self {
92 Self {
93 client: Client::new(),
94 base_url: None,
95 }
96 }
97
98 pub fn with_base_url(base_url: impl Into<String>) -> Self {
100 Self {
101 client: Client::new(),
102 base_url: Some(base_url.into()),
103 }
104 }
105
106 pub fn with_client(client: Client) -> Self {
108 Self {
109 client,
110 base_url: None,
111 }
112 }
113
114 pub fn with_timeout(mut self, timeout: Duration) -> Self {
116 self.client = Client::builder()
117 .timeout(timeout)
118 .build()
119 .expect("Failed to build client with timeout");
120 self
121 }
122
123 fn build_url(&self, endpoint: &str) -> Result<String, HttpClientError> {
124 match &self.base_url {
125 Some(base) => Ok(format!("{}{}", base, endpoint)),
126 None => Ok(endpoint.to_string()),
127 }
128 }
129
130 pub async fn get(&self, endpoint: &str) -> Result<Response, HttpClientError> {
132 let url = self.build_url(endpoint)?;
133 self.client.get(&url)
134 .send()
135 .await
136 .map_err(Into::into)
137 }
138
139 pub async fn get_with_query<T: Serialize + ?Sized>(
141 &self,
142 endpoint: &str,
143 query: &T,
144 ) -> Result<Response, HttpClientError> {
145 let url = self.build_url(endpoint)?;
146 self.client.get(&url)
147 .query(query)
148 .send()
149 .await
150 .map_err(Into::into)
151 }
152
153 pub async fn post<T: Serialize + ?Sized>(
155 &self,
156 endpoint: &str,
157 body: &T,
158 ) -> Result<Response, HttpClientError> {
159 let url = self.build_url(endpoint)?;
160 self.client.post(&url)
161 .json(body)
162 .send()
163 .await
164 .map_err(Into::into)
165 }
166
167 pub async fn post_raw(
169 &self,
170 endpoint: &str,
171 body: Vec<u8>,
172 content_type: &str,
173 ) -> Result<Response, HttpClientError> {
174 let url = self.build_url(endpoint)?;
175 self.client.post(&url)
176 .header("Content-Type", content_type)
177 .body(body)
178 .send()
179 .await
180 .map_err(Into::into)
181 }
182
183 pub async fn download_file(
185 &self,
186 url: &str,
187 destination: &Path,
188 mut progress_callback: impl FnMut(u64, u64),
189 ) -> Result<(), HttpClientError> {
190 let response = self.client.get(url)
191 .send()
192 .await?;
193
194 let total_size = response.content_length().unwrap_or(0);
195 let mut downloaded: u64 = 0;
196 let mut file = File::create(destination).await?;
197 let mut stream = response.bytes_stream();
198
199 while let Some(chunk) = stream.next().await {
200 let chunk = chunk?;
201 file.write_all(&chunk).await?;
202 downloaded += chunk.len() as u64;
203 progress_callback(downloaded, total_size);
204 }
205
206 Ok(())
207 }
208
209 pub async fn download_file_with_resume(
211 &self,
212 url: &str,
213 destination: &Path,
214 mut progress_callback: impl FnMut(u64, u64),
215 ) -> Result<(), HttpClientError> {
216 let file_exists = destination.exists();
218 let mut file = tokio::fs::OpenOptions::new()
219 .create(true)
220 .append(true)
221 .open(destination)
222 .await?;
223
224 let mut downloaded_bytes = if file_exists {
225 file.metadata().await?.len()
226 } else {
227 0
228 };
229
230 let mut request = self.client.get(url);
232 if downloaded_bytes > 0 {
233 request = request.header(header::RANGE, format!("bytes={}-", downloaded_bytes));
234 }
235
236 let response = request.send().await?;
238
239 let status = response.status();
241 if !status.is_success() && status != reqwest::StatusCode::PARTIAL_CONTENT {
242 return Err(HttpClientError::DownloadError(format!(
243 "Server returned error status: {}", status
244 )));
245 }
246
247 if downloaded_bytes > 0 && status != reqwest::StatusCode::PARTIAL_CONTENT {
249 return Err(HttpClientError::ResumeNotSupported);
250 }
251
252 let total_size = match status {
254 reqwest::StatusCode::PARTIAL_CONTENT => {
255 response.headers()
257 .get(header::CONTENT_RANGE)
258 .and_then(|h| h.to_str().ok())
259 .and_then(|s| {
260 s.split('/').last().and_then(|s| s.parse::<u64>().ok())
261 })
262 .unwrap_or(downloaded_bytes + response.content_length().unwrap_or(0))
263 }
264 _ => {
265 downloaded_bytes + response.content_length().unwrap_or(0)
266 }
267 };
268
269 let mut stream = response.bytes_stream();
271 while let Some(chunk) = stream.next().await {
272 let chunk = chunk?;
273 file.write_all(&chunk).await?;
274 downloaded_bytes += chunk.len() as u64;
275 progress_callback(downloaded_bytes, total_size);
276 }
277
278 Ok(())
279 }
280
281 pub async fn json<T: DeserializeOwned>(response: Response) -> Result<T, HttpClientError> {
283 response.json::<T>().await.map_err(Into::into)
284 }
285
286 pub async fn text(response: Response) -> Result<String, HttpClientError> {
288 response.text().await.map_err(Into::into)
289 }
290
291 pub async fn bytes(response: Response) -> Result<Bytes, HttpClientError> {
293 response.bytes().await.map_err(Into::into)
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use tokio::fs;
301 use std::sync::{Arc, Mutex};
302
303 #[tokio::test]
304 async fn test_get_request() {
305 let client = HttpClient::new();
306 let response = client.get("https://httpbin.org/get").await;
307 assert!(response.is_ok());
308 }
309
310 #[tokio::test]
311 async fn test_file_download() {
312 let temp_dir = std::env::temp_dir();
313 let dest = temp_dir.join("test_download.txt");
314
315 let _ = fs::remove_file(&dest).await;
317
318 let client = HttpClient::new();
319
320 let progress_values = Arc::new(Mutex::new(Vec::new()));
322 let progress_values_clone = progress_values.clone();
323
324 let result = client.download_file(
325 "https://httpbin.org/bytes/16", &dest,
327 move |downloaded, total| {
328 progress_values_clone.lock().unwrap().push((downloaded, total));
329 }
330 ).await;
331
332 assert!(result.is_ok());
333 assert!(!progress_values.lock().unwrap().is_empty());
334
335 let _ = fs::remove_file(&dest).await;
337 }
338}