1use std::future::Future;
2use std::pin::Pin;
3
4use bytes::Bytes;
5use futures_util::{Stream, StreamExt};
6
7use crate::error::Result;
8
9pub type BoxStream<'a, T> = Pin<Box<dyn Stream<Item = T> + Send + 'a>>;
14
15pub trait HttpClient: Send + Sync {
26 type Error: std::error::Error + Send + 'static;
28
29 fn stream(
46 &self,
47 url: &str,
48 headers: &[(String, String)],
49 ) -> impl Future<
50 Output = std::result::Result<
51 BoxStream<'static, std::result::Result<Bytes, Self::Error>>,
52 Self::Error,
53 >,
54 > + Send;
55
56 fn head(
70 &self,
71 url: &str,
72 ) -> impl Future<Output = std::result::Result<Option<u64>, Self::Error>> + Send;
73}
74
75#[cfg(feature = "reqwest")]
76mod reqwest_impl {
77 use super::*;
78 use reqwest;
79
80 pub struct ReqwestClient {
82 client: reqwest::Client,
83 }
84
85 impl ReqwestClient {
86 pub fn new() -> Result<Self> {
88 let client = reqwest::Client::new();
89 Ok(Self { client })
90 }
91 }
92
93 impl HttpClient for ReqwestClient {
94 type Error = reqwest::Error;
95
96 async fn stream(
97 &self,
98 url: &str,
99 headers: &[(String, String)],
100 ) -> std::result::Result<
101 BoxStream<'static, std::result::Result<Bytes, Self::Error>>,
102 Self::Error,
103 > {
104 let mut request = self.client.get(url);
105
106 for (key, value) in headers {
107 request = request.header(key, value);
108 }
109
110 let response = request.send().await?;
111 let stream = response.bytes_stream().map(|result| result);
112
113 Ok(Box::pin(stream))
114 }
115
116 async fn head(&self, url: &str) -> std::result::Result<Option<u64>, Self::Error> {
117 let response = self.client.head(url).send().await?;
118 let content_length = response
119 .headers()
120 .get(reqwest::header::CONTENT_LENGTH)
121 .and_then(|v| v.to_str().ok())
122 .and_then(|s| s.parse::<u64>().ok());
123
124 Ok(content_length)
125 }
126 }
127}
128
129#[cfg(feature = "reqwest")]
130pub use reqwest_impl::ReqwestClient;
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use futures_util::stream::{self, StreamExt};
136
137 struct MockHttpClient {
139 should_fail: bool,
140 content_length: Option<u64>,
141 }
142
143 impl MockHttpClient {
144 fn new() -> Self {
145 Self {
146 should_fail: false,
147 content_length: Some(1024),
148 }
149 }
150
151 fn with_error() -> Self {
152 Self {
153 should_fail: true,
154 content_length: None,
155 }
156 }
157
158 fn with_content_length(length: u64) -> Self {
159 Self {
160 should_fail: false,
161 content_length: Some(length),
162 }
163 }
164
165 fn without_content_length() -> Self {
166 Self {
167 should_fail: false,
168 content_length: None,
169 }
170 }
171 }
172
173 #[derive(Debug)]
174 struct MockError(String);
175
176 impl std::fmt::Display for MockError {
177 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178 write!(f, "{}", self.0)
179 }
180 }
181
182 impl std::error::Error for MockError {}
183
184 impl HttpClient for MockHttpClient {
185 type Error = MockError;
186
187 async fn stream(
188 &self,
189 _url: &str,
190 _headers: &[(String, String)],
191 ) -> std::result::Result<
192 BoxStream<'static, std::result::Result<Bytes, Self::Error>>,
193 Self::Error,
194 > {
195 if self.should_fail {
196 Err(MockError("Stream failed".to_string()))
197 } else {
198 let data = vec![Bytes::from("test data")];
199 let stream = stream::iter(data).map(Ok);
200 Ok(Box::pin(stream) as BoxStream<'static, _>)
201 }
202 }
203
204 async fn head(&self, _url: &str) -> std::result::Result<Option<u64>, Self::Error> {
205 if self.should_fail {
206 Err(MockError("HEAD request failed".to_string()))
207 } else {
208 Ok(self.content_length)
209 }
210 }
211 }
212
213 #[tokio::test]
214 async fn test_mock_http_client_stream_success() {
215 let client = MockHttpClient::new();
216 let result = client.stream("http://example.com", &[]).await;
217 assert!(result.is_ok());
218
219 let mut stream = result.unwrap();
220 if let Some(Ok(bytes)) = stream.next().await {
222 assert_eq!(bytes, Bytes::from("test data"));
223 } else {
224 panic!("Expected data");
225 }
226 }
227
228 #[tokio::test]
229 async fn test_mock_http_client_stream_error() {
230 let client = MockHttpClient::with_error();
231 let result = client.stream("http://example.com", &[]).await;
232 assert!(result.is_err());
233 match result {
234 Err(e) => assert_eq!(e.to_string(), "Stream failed"),
235 _ => panic!("Expected error"),
236 }
237 }
238
239 #[tokio::test]
240 async fn test_mock_http_client_head_with_content_length() {
241 let client = MockHttpClient::with_content_length(2048);
242 let result = client.head("http://example.com").await;
243 assert!(result.is_ok());
244 assert_eq!(result.unwrap(), Some(2048));
245 }
246
247 #[tokio::test]
248 async fn test_mock_http_client_head_without_content_length() {
249 let client = MockHttpClient::without_content_length();
250 let result = client.head("http://example.com").await;
251 assert!(result.is_ok());
252 assert_eq!(result.unwrap(), None);
253 }
254
255 #[tokio::test]
256 async fn test_mock_http_client_head_error() {
257 let client = MockHttpClient::with_error();
258 let result = client.head("http://example.com").await;
259 assert!(result.is_err());
260 match result {
261 Err(e) => assert_eq!(e.to_string(), "HEAD request failed"),
262 _ => panic!("Expected error"),
263 }
264 }
265
266 #[test]
267 fn test_box_stream_type_alias() {
268 let _stream: BoxStream<'static, std::result::Result<Bytes, MockError>> =
270 Box::pin(stream::empty());
271 }
272
273 #[cfg(feature = "reqwest")]
274 #[tokio::test]
275 async fn test_reqwest_client_creation() {
276 let result = ReqwestClient::new();
278 assert!(result.is_ok());
279
280 let client = result.unwrap();
281 let _client: ReqwestClient = client;
283 }
284}