Skip to main content

oxidite_middleware/
cache.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3use std::time::{Duration, SystemTime};
4use http::{Request, Response, Method};
5use http_body_util::Full;
6use bytes::Bytes;
7use tower::{Layer, Service};
8use std::task::{Context, Poll};
9use std::future::Future;
10use std::pin::Pin;
11
12/// Configuration for the caching middleware
13#[derive(Clone)]
14pub struct CacheConfig {
15    /// Maximum cache size (in number of entries)
16    pub max_entries: usize,
17    /// Default TTL for cached responses
18    pub default_ttl: Duration,
19    /// Whether to cache responses for GET requests by default
20    pub cache_get: bool,
21    /// Whether to cache responses for POST requests
22    pub cache_post: bool,
23}
24
25impl Default for CacheConfig {
26    fn default() -> Self {
27        Self {
28            max_entries: 1000,
29            default_ttl: Duration::from_secs(300), // 5 minutes
30            cache_get: true,
31            cache_post: false,
32        }
33    }
34}
35
36/// Cache layer that wraps services with caching functionality
37#[derive(Clone)]
38pub struct CacheLayer {
39    config: CacheConfig,
40}
41
42impl CacheLayer {
43    pub fn new(config: CacheConfig) -> Self {
44        Self {
45            config,
46        }
47    }
48
49    pub fn builder() -> CacheLayerBuilder {
50        CacheLayerBuilder::new()
51    }
52}
53
54impl<S> Layer<S> for CacheLayer {
55    type Service = CacheMiddleware<S>;
56
57    fn layer(&self, inner: S) -> Self::Service {
58        CacheMiddleware {
59            inner,
60            config: self.config.clone(),
61        }
62    }
63}
64
65/// Cache middleware service
66pub struct CacheMiddleware<S> {
67    inner: S,
68    config: CacheConfig,
69}
70
71impl<S> CacheMiddleware<S> {
72    fn should_cache_method(&self, method: &Method) -> bool {
73        match *method {
74            Method::GET => self.config.cache_get,
75            Method::POST => self.config.cache_post,
76            _ => false,
77        }
78    }
79}
80
81impl<S, ReqBody> Service<Request<ReqBody>> for CacheMiddleware<S>
82where
83    S: Service<Request<ReqBody>> + Clone,
84    S::Error: std::error::Error + Send + Sync + 'static,
85{
86    type Response = S::Response;
87    type Error = S::Error;
88    type Future = S::Future;
89
90    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
91        self.inner.poll_ready(cx)
92    }
93
94    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
95        // Just pass through to the inner service for now
96        // Proper caching implementation would require more complex async handling
97        self.inner.call(req)
98    }
99}
100
101/// Builder for CacheLayer
102pub struct CacheLayerBuilder {
103    config: CacheConfig,
104}
105
106impl CacheLayerBuilder {
107    pub fn new() -> Self {
108        Self {
109            config: CacheConfig::default(),
110        }
111    }
112
113    pub fn max_entries(mut self, max: usize) -> Self {
114        self.config.max_entries = max;
115        self
116    }
117
118    pub fn default_ttl(mut self, ttl: Duration) -> Self {
119        self.config.default_ttl = ttl;
120        self
121    }
122
123    pub fn cache_get(mut self, enable: bool) -> Self {
124        self.config.cache_get = enable;
125        self
126    }
127
128    pub fn cache_post(mut self, enable: bool) -> Self {
129        self.config.cache_post = enable;
130        self
131    }
132
133    pub fn build(self) -> CacheLayer {
134        CacheLayer::new(self.config)
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use http::{Request, StatusCode};
142    use tower::{Service, ServiceExt};
143
144    #[tokio::test]
145    async fn test_cache_middleware() {
146        let config = CacheConfig {
147            max_entries: 100,
148            default_ttl: Duration::from_secs(3600), // 1 hour
149            cache_get: true,
150            cache_post: false,
151        };
152        
153        let layer = CacheLayer::new(config);
154        
155        // Simple service that always returns the same response
156        let svc = tower::service_fn(|_req: Request<String>| async {
157            Ok::<_, Box<dyn std::error::Error + Send + Sync>>(
158                Response::builder()
159                    .status(StatusCode::OK)
160                    .body("Hello, world!".to_string())
161                    .unwrap()
162            )
163        });
164
165        let mut cached_svc = layer.layer(svc);
166
167        // First request
168        let req1 = Request::get("/test").body("".to_string()).unwrap();
169        let resp1 = cached_svc.ready().await.unwrap().call(req1).await.unwrap();
170        assert_eq!(resp1.status(), StatusCode::OK);
171
172        // Second request to same endpoint should work
173        let req2 = Request::get("/test").body("".to_string()).unwrap();
174        let resp2 = cached_svc.ready().await.unwrap().call(req2).await.unwrap();
175        assert_eq!(resp2.status(), StatusCode::OK);
176    }
177}