use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use axum::http::Request;
use tower::{Layer, Service};
use crate::{
cache::scope_cache,
driver::{CacheHandle, Driver},
};
#[derive(Clone)]
pub struct CacheLayer {
handle: Arc<CacheHandle>,
}
impl CacheLayer {
pub fn new(driver: Driver, prefix: impl Into<String>) -> Self {
Self {
handle: Arc::new(CacheHandle::new(driver, prefix)),
}
}
}
impl<S> Layer<S> for CacheLayer {
type Service = CacheService<S>;
fn layer(&self, inner: S) -> Self::Service {
CacheService {
inner,
handle: self.handle.clone(),
}
}
}
#[derive(Clone)]
pub struct CacheService<S> {
inner: S,
handle: Arc<CacheHandle>,
}
impl<S, ReqBody> Service<Request<ReqBody>> for CacheService<S>
where
S: Service<Request<ReqBody>> + Send + Clone + 'static,
S::Future: Send + 'static,
ReqBody: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<S::Response, S::Error>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let handle = self.handle.clone();
let future = self.inner.call(req);
Box::pin(scope_cache(handle, future))
}
}