use std::future::Future;
use std::pin::Pin;
use bytes::Bytes;
use futures_util::{Stream, StreamExt};
use crate::error::Result;
pub type BoxStream<'a, T> = Pin<Box<dyn Stream<Item = T> + Send + 'a>>;
pub trait HttpClient: Send + Sync {
type Error: std::error::Error + Send + 'static;
fn stream(
&self,
url: &str,
headers: &[(String, String)],
) -> impl Future<
Output = std::result::Result<
BoxStream<'static, std::result::Result<Bytes, Self::Error>>,
Self::Error,
>,
> + Send;
fn head(
&self,
url: &str,
) -> impl Future<Output = std::result::Result<Option<u64>, Self::Error>> + Send;
}
#[cfg(feature = "reqwest")]
mod reqwest_impl {
use super::*;
use reqwest;
pub struct ReqwestClient {
client: reqwest::Client,
}
impl ReqwestClient {
pub fn new() -> Result<Self> {
let client = reqwest::Client::new();
Ok(Self { client })
}
}
impl HttpClient for ReqwestClient {
type Error = reqwest::Error;
async fn stream(
&self,
url: &str,
headers: &[(String, String)],
) -> std::result::Result<
BoxStream<'static, std::result::Result<Bytes, Self::Error>>,
Self::Error,
> {
let mut request = self.client.get(url);
for (key, value) in headers {
request = request.header(key, value);
}
let response = request.send().await?;
let stream = response.bytes_stream().map(|result| result);
Ok(Box::pin(stream))
}
async fn head(&self, url: &str) -> std::result::Result<Option<u64>, Self::Error> {
let response = self.client.head(url).send().await?;
let content_length = response
.headers()
.get(reqwest::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
Ok(content_length)
}
}
}
#[cfg(feature = "reqwest")]
pub use reqwest_impl::ReqwestClient;
#[cfg(test)]
mod tests {
use super::*;
use futures_util::stream::{self, StreamExt};
struct MockHttpClient {
should_fail: bool,
content_length: Option<u64>,
}
impl MockHttpClient {
fn new() -> Self {
Self {
should_fail: false,
content_length: Some(1024),
}
}
fn with_error() -> Self {
Self {
should_fail: true,
content_length: None,
}
}
fn with_content_length(length: u64) -> Self {
Self {
should_fail: false,
content_length: Some(length),
}
}
fn without_content_length() -> Self {
Self {
should_fail: false,
content_length: None,
}
}
}
#[derive(Debug)]
struct MockError(String);
impl std::fmt::Display for MockError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for MockError {}
impl HttpClient for MockHttpClient {
type Error = MockError;
async fn stream(
&self,
_url: &str,
_headers: &[(String, String)],
) -> std::result::Result<
BoxStream<'static, std::result::Result<Bytes, Self::Error>>,
Self::Error,
> {
if self.should_fail {
Err(MockError("Stream failed".to_string()))
} else {
let data = vec![Bytes::from("test data")];
let stream = stream::iter(data).map(Ok);
Ok(Box::pin(stream) as BoxStream<'static, _>)
}
}
async fn head(&self, _url: &str) -> std::result::Result<Option<u64>, Self::Error> {
if self.should_fail {
Err(MockError("HEAD request failed".to_string()))
} else {
Ok(self.content_length)
}
}
}
#[tokio::test]
async fn test_mock_http_client_stream_success() {
let client = MockHttpClient::new();
let result = client.stream("http://example.com", &[]).await;
assert!(result.is_ok());
let mut stream = result.unwrap();
if let Some(Ok(bytes)) = stream.next().await {
assert_eq!(bytes, Bytes::from("test data"));
} else {
panic!("Expected data");
}
}
#[tokio::test]
async fn test_mock_http_client_stream_error() {
let client = MockHttpClient::with_error();
let result = client.stream("http://example.com", &[]).await;
assert!(result.is_err());
match result {
Err(e) => assert_eq!(e.to_string(), "Stream failed"),
_ => panic!("Expected error"),
}
}
#[tokio::test]
async fn test_mock_http_client_head_with_content_length() {
let client = MockHttpClient::with_content_length(2048);
let result = client.head("http://example.com").await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), Some(2048));
}
#[tokio::test]
async fn test_mock_http_client_head_without_content_length() {
let client = MockHttpClient::without_content_length();
let result = client.head("http://example.com").await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), None);
}
#[tokio::test]
async fn test_mock_http_client_head_error() {
let client = MockHttpClient::with_error();
let result = client.head("http://example.com").await;
assert!(result.is_err());
match result {
Err(e) => assert_eq!(e.to_string(), "HEAD request failed"),
_ => panic!("Expected error"),
}
}
#[test]
fn test_box_stream_type_alias() {
let _stream: BoxStream<'static, std::result::Result<Bytes, MockError>> =
Box::pin(stream::empty());
}
#[cfg(feature = "reqwest")]
#[tokio::test]
async fn test_reqwest_client_creation() {
let result = ReqwestClient::new();
assert!(result.is_ok());
let client = result.unwrap();
let _client: ReqwestClient = client;
}
}