use std::sync::Arc;
use reqwest::Client;
use super::transport::{Transport, TransportError, TransportRequest, TransportResponse};
const CONNECT_TIMEOUT_SECS: u64 = 10;
#[derive(Clone)]
pub struct HttpTransport {
inner: Arc<HttpTransportInner>,
}
struct HttpTransportInner {
strict_client: Client,
trust_all_client: Client,
}
impl HttpTransport {
pub fn new() -> Result<Self, TransportError> {
let strict_client = Client::builder()
.connect_timeout(std::time::Duration::from_secs(CONNECT_TIMEOUT_SECS))
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(|e| TransportError::Other(format!("build strict client: {e}")))?;
let trust_all_client = Client::builder()
.connect_timeout(std::time::Duration::from_secs(CONNECT_TIMEOUT_SECS))
.redirect(reqwest::redirect::Policy::none())
.danger_accept_invalid_certs(true)
.danger_accept_invalid_hostnames(true)
.build()
.map_err(|e| TransportError::Other(format!("build trust-all client: {e}")))?;
Ok(Self {
inner: Arc::new(HttpTransportInner {
strict_client,
trust_all_client,
}),
})
}
}
impl Transport for HttpTransport {
async fn send_request(
&self,
request: &TransportRequest,
) -> Result<TransportResponse, TransportError> {
let client = if request.accept_invalid_certs {
&self.inner.trust_all_client
} else {
&self.inner.strict_client
};
let method = reqwest::Method::from_bytes(request.method.as_bytes())
.map_err(|e| TransportError::Other(format!("invalid method: {e}")))?;
let mut builder = client.request(method, &request.url).timeout(request.timeout);
for (k, v) in &request.headers {
builder = builder.header(k, v);
}
if !request.body.is_empty() {
builder = builder.body(request.body.clone());
}
let response = match builder.send().await {
Ok(r) => r,
Err(e) if e.is_timeout() => return Err(TransportError::Timeout),
Err(e) if e.is_connect() => {
return Err(TransportError::ConnectionFailed(e.to_string()))
}
Err(e) => return Err(TransportError::SendFailed(e.to_string())),
};
let status_code = response.status().as_u16();
let headers: Vec<(String, String)> = response
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|s| (k.to_string(), s.to_string())))
.collect();
let body = response
.bytes()
.await
.map_err(|e| TransportError::ReceiveFailed(e.to_string()))?
.to_vec();
Ok(TransportResponse {
status_code,
body,
headers,
})
}
}