oxidite_middleware/
cache.rs1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3use std::time::{Duration, SystemTime};
4use http::{Request, Response, Method};
5use http_body_util::Full;
6use bytes::Bytes;
7use tower::{Layer, Service};
8use std::task::{Context, Poll};
9use std::future::Future;
10use std::pin::Pin;
11
12#[derive(Clone)]
14pub struct CacheConfig {
15 pub max_entries: usize,
17 pub default_ttl: Duration,
19 pub cache_get: bool,
21 pub cache_post: bool,
23}
24
25impl Default for CacheConfig {
26 fn default() -> Self {
27 Self {
28 max_entries: 1000,
29 default_ttl: Duration::from_secs(300), cache_get: true,
31 cache_post: false,
32 }
33 }
34}
35
36#[derive(Clone)]
38pub struct CacheLayer {
39 config: CacheConfig,
40}
41
42impl CacheLayer {
43 pub fn new(config: CacheConfig) -> Self {
44 Self {
45 config,
46 }
47 }
48
49 pub fn builder() -> CacheLayerBuilder {
50 CacheLayerBuilder::new()
51 }
52}
53
54impl<S> Layer<S> for CacheLayer {
55 type Service = CacheMiddleware<S>;
56
57 fn layer(&self, inner: S) -> Self::Service {
58 CacheMiddleware {
59 inner,
60 config: self.config.clone(),
61 }
62 }
63}
64
65pub struct CacheMiddleware<S> {
67 inner: S,
68 config: CacheConfig,
69}
70
71impl<S> CacheMiddleware<S> {
72 fn should_cache_method(&self, method: &Method) -> bool {
73 match *method {
74 Method::GET => self.config.cache_get,
75 Method::POST => self.config.cache_post,
76 _ => false,
77 }
78 }
79}
80
81impl<S, ReqBody> Service<Request<ReqBody>> for CacheMiddleware<S>
82where
83 S: Service<Request<ReqBody>> + Clone,
84 S::Error: std::error::Error + Send + Sync + 'static,
85{
86 type Response = S::Response;
87 type Error = S::Error;
88 type Future = S::Future;
89
90 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
91 self.inner.poll_ready(cx)
92 }
93
94 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
95 self.inner.call(req)
98 }
99}
100
101pub struct CacheLayerBuilder {
103 config: CacheConfig,
104}
105
106impl CacheLayerBuilder {
107 pub fn new() -> Self {
108 Self {
109 config: CacheConfig::default(),
110 }
111 }
112
113 pub fn max_entries(mut self, max: usize) -> Self {
114 self.config.max_entries = max;
115 self
116 }
117
118 pub fn default_ttl(mut self, ttl: Duration) -> Self {
119 self.config.default_ttl = ttl;
120 self
121 }
122
123 pub fn cache_get(mut self, enable: bool) -> Self {
124 self.config.cache_get = enable;
125 self
126 }
127
128 pub fn cache_post(mut self, enable: bool) -> Self {
129 self.config.cache_post = enable;
130 self
131 }
132
133 pub fn build(self) -> CacheLayer {
134 CacheLayer::new(self.config)
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use http::{Request, StatusCode};
142 use tower::{Service, ServiceExt};
143
144 #[tokio::test]
145 async fn test_cache_middleware() {
146 let config = CacheConfig {
147 max_entries: 100,
148 default_ttl: Duration::from_secs(3600), cache_get: true,
150 cache_post: false,
151 };
152
153 let layer = CacheLayer::new(config);
154
155 let svc = tower::service_fn(|_req: Request<String>| async {
157 Ok::<_, Box<dyn std::error::Error + Send + Sync>>(
158 Response::builder()
159 .status(StatusCode::OK)
160 .body("Hello, world!".to_string())
161 .unwrap()
162 )
163 });
164
165 let mut cached_svc = layer.layer(svc);
166
167 let req1 = Request::get("/test").body("".to_string()).unwrap();
169 let resp1 = cached_svc.ready().await.unwrap().call(req1).await.unwrap();
170 assert_eq!(resp1.status(), StatusCode::OK);
171
172 let req2 = Request::get("/test").body("".to_string()).unwrap();
174 let resp2 = cached_svc.ready().await.unwrap().call(req2).await.unwrap();
175 assert_eq!(resp2.status(), StatusCode::OK);
176 }
177}