use crate::error::{AbortError, FetchError, NetworkError, Result};
use crate::{Headers, ReadableStream, Request, RequestInit, Response};
use hyper_util::client::legacy::Client;
use hyper_util::rt::TokioExecutor;
use std::sync::OnceLock;
static CLIENT: OnceLock<
Client<
hyper_tls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>,
http_body_util::Full<bytes::Bytes>,
>,
> = OnceLock::new();
fn get_client() -> &'static Client<
hyper_tls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>,
http_body_util::Full<bytes::Bytes>,
> {
CLIENT.get_or_init(|| {
let https = hyper_tls::HttpsConnector::new();
Client::builder(TokioExecutor::new()).build(https)
})
}
pub async fn fetch(input: &str, init: Option<RequestInit>) -> Result<Response> {
let mut request = Request::new(input, init)?;
if let Some(signal) = request.signal() {
if signal.aborted() {
return Err(FetchError::Abort(AbortError::new(
"The operation was aborted",
)));
}
}
let client = get_client();
let method = http::Method::from_bytes(request.method().as_bytes())
.map_err(|_| FetchError::Network(NetworkError::new("Invalid method")))?;
let mut http_request = http::Request::builder()
.method(method)
.uri(request.get_url().as_str());
let header_map = request.headers().to_http_headers()?;
for (name, value) in header_map {
if let Some(header_name) = name {
http_request = http_request.header(header_name, value);
}
}
let body = match request.take_body() {
Some(body) => {
let bytes = body.to_bytes().await?;
http_body_util::Full::new(bytes)
}
None => http_body_util::Full::new(bytes::Bytes::new()),
};
let http_request = http_request.body(body)?;
let http_response = client.request(http_request).await?;
let (parts, incoming) = http_response.into_parts();
let headers = Headers::from_http_headers(&parts.headers);
let status_text = parts.status.canonical_reason().unwrap_or("").to_string();
let mut response = Response::from_parts(
parts.status.as_u16(),
status_text,
headers,
request.get_url().to_string(),
false, );
let body_bytes = http_body_util::BodyExt::collect(incoming)
.await
.map_err(|e| FetchError::Network(NetworkError::new(&e.to_string())))?
.to_bytes();
if !body_bytes.is_empty() {
response.set_body(ReadableStream::from_bytes(body_bytes));
}
Ok(response)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_initialization() {
let _client = get_client();
}
#[tokio::test]
async fn test_fetch_invalid_url() {
let result = fetch("not-a-url", None).await;
assert!(result.is_err());
}
}