use std::{marker::PhantomData, pin::Pin, sync::LazyLock};
use log::debug;
pub use reqwest::{Client, Method, Request, Url};
use serde::de::DeserializeOwned;
use tower::{Layer, Service, ServiceBuilder, layer::util::Stack, util::MapRequestLayer};
#[derive(Debug, thiserror::Error)]
pub enum DownloaderError<E> {
#[error("HTTP client error: {0}")]
HttpClient(reqwest::Error),
#[error("Deserialization error: {0}")]
Deserialization(serde_json::Error),
#[error("Error polling inner service: {0}")]
InnerPoll(E),
#[error("Error calling inner service: {0}")]
InnerCall(E),
}
pub struct BodyDownloaderLayer;
impl<S> Layer<S> for BodyDownloaderLayer
where S: Service<String> {
type Service = BodyDownloader<S>;
fn layer(&self, inner: S) -> Self::Service {
BodyDownloader::new(inner)
}
}
pub struct BodyDownloader<S>
where
S: Service<String> {
client: Client,
inner: S
}
const HTTP_CLIENT: LazyLock<Client> = LazyLock::new(|| Client::new());
impl<S> BodyDownloader<S>
where
S: Service<String> {
pub fn new(inner: S) -> Self {
Self {
client: HTTP_CLIENT.clone(),
inner,
}
}
}
impl<S> Service<Request> for BodyDownloader<S>
where
S: Service<String> + Send + Clone + 'static,
<S as Service<String>>::Future: Send,
<S as Service<String>>::Error: Send {
type Response = ();
type Error = DownloaderError<S::Error>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
.map_err(|e| DownloaderError::InnerPoll(e))
}
fn call(&mut self, request: Request) -> Self::Future {
let client = self.client.clone();
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);
Box::pin(async move {
debug!("BodyDownloader received {} request for {}", request.method(), request.url());
let text = client.execute(request).await
.map_err(|e| DownloaderError::HttpClient(e))?
.text().await
.map_err(|e| DownloaderError::HttpClient(e))?;
inner.call(text).await
.map_err(|e| DownloaderError::InnerCall(e))?;
Ok(())
})
}
}
pub struct JsonDownloaderLayer<T> {
_t: PhantomData<T>,
}
impl<T> JsonDownloaderLayer<T> {
pub fn new() -> Self {
Self { _t: PhantomData::default(), }
}
}
impl<S, T> Layer<S> for JsonDownloaderLayer<T>
where
S: Service<T>,
T: DeserializeOwned {
type Service = JsonDownloader<S, T>;
fn layer(&self, inner: S) -> Self::Service {
JsonDownloader::new(inner)
}
}
pub struct JsonDownloader<S, T>
where
S: Service<T>,
T: DeserializeOwned {
client: Client,
inner: S,
_t: PhantomData<T>,
}
impl<S, T> JsonDownloader<S, T>
where
S: Service<T>,
T: DeserializeOwned {
pub fn new(inner: S) -> Self {
Self {
client: HTTP_CLIENT.clone(),
inner,
_t: PhantomData::default(),
}
}
}
impl<S, T> Service<Request> for JsonDownloader<S, T>
where
S: Service<T> + Send + Clone + 'static,
<S as Service<T>>::Future: Send,
<S as Service<T>>::Error: Send,
T: DeserializeOwned {
type Response = ();
type Error = DownloaderError<S::Error>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
.map_err(|e| DownloaderError::InnerPoll(e))
}
fn call(&mut self, request: Request) -> Self::Future {
let client = self.client.clone();
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);
Box::pin(async move {
debug!("JsonDownloader received {} request for {}", request.method(), request.url());
let text = client.execute(request).await
.map_err(|e| DownloaderError::HttpClient(e))?
.text().await
.map_err(|e| DownloaderError::HttpClient(e))?;
let obj = serde_json::from_str(&text)
.map_err(|e| DownloaderError::Deserialization(e))?;
inner.call(obj).await
.map_err(|e| DownloaderError::InnerCall(e))?;
Ok(())
})
}
}
pub fn string_to_get_reqwest(url: String) -> Result<reqwest::Request, url::ParseError> {
let url = Url::parse(&url)?;
Ok(Request::new(Method::GET, url))
}
pub trait ServiceBuilderReqwestExt {
type Output;
fn map_string_to_reqwest_get(self) -> Self::Output;
}
impl<L> ServiceBuilderReqwestExt for ServiceBuilder<L> {
type Output = ServiceBuilder<Stack<MapRequestLayer<fn(String) -> Result<reqwest::Request, url::ParseError>>, L>>;
fn map_string_to_reqwest_get(self) -> Self::Output {
self.map_request(string_to_get_reqwest)
}
}