Skip to main content

nestforge_cache/
lib.rs

1use axum::{
2    body::{to_bytes, Body},
3    extract::Request,
4    http::{header::CONTENT_TYPE, HeaderValue, Method, StatusCode},
5    response::{IntoResponse, Response},
6};
7use nestforge_core::{framework_log_event, Container, Interceptor, NextFn, NextFuture, RequestContext};
8use nestforge_data::CacheStore;
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12struct CachedHttpResponse {
13    status: u16,
14    body: String,
15    content_type: Option<String>,
16}
17
18impl CachedHttpResponse {
19    fn into_response(self) -> Response {
20        let status = StatusCode::from_u16(self.status).unwrap_or(StatusCode::OK);
21        let mut response = (status, self.body).into_response();
22
23        if let Some(content_type) = self.content_type {
24            if let Ok(header) = HeaderValue::from_str(&content_type) {
25                response.headers_mut().insert(CONTENT_TYPE, header);
26            }
27        }
28
29        response
30    }
31}
32
33pub trait CachePolicy: Default + Clone + Send + Sync + 'static {
34    type Store: CacheStore + Send + Sync + 'static;
35
36    fn cache_key(&self, ctx: &RequestContext, req: &Request<Body>) -> Option<String> {
37        if ctx.method != Method::GET {
38            return None;
39        }
40
41        Some(format!(
42            "{}:{}",
43            std::any::type_name::<Self>(),
44            req.uri()
45        ))
46    }
47
48    fn ttl_seconds(&self) -> Option<u64> {
49        None
50    }
51
52    fn should_cache_response(&self, response: &Response) -> bool {
53        response.status() == StatusCode::OK
54    }
55}
56
57#[derive(Debug, Clone)]
58pub struct CacheInterceptor<P>
59where
60    P: CachePolicy,
61{
62    policy: P,
63}
64
65impl<P> Default for CacheInterceptor<P>
66where
67    P: CachePolicy,
68{
69    fn default() -> Self {
70        Self {
71            policy: P::default(),
72        }
73    }
74}
75
76impl<P> CacheInterceptor<P>
77where
78    P: CachePolicy,
79{
80    pub fn new(policy: P) -> Self {
81        Self { policy }
82    }
83}
84
85impl<P> Interceptor for CacheInterceptor<P>
86where
87    P: CachePolicy,
88{
89    fn around(&self, ctx: RequestContext, req: Request<Body>, next: NextFn) -> NextFuture {
90        let policy = self.policy.clone();
91
92        Box::pin(async move {
93            let Some(container) = req.extensions().get::<Container>().cloned() else {
94                return (next)(req).await;
95            };
96
97            let Some(cache_key) = policy.cache_key(&ctx, &req) else {
98                return (next)(req).await;
99            };
100
101            let Ok(store) = container.resolve::<P::Store>() else {
102                return (next)(req).await;
103            };
104
105            if let Ok(Some(cached)) = store.get(&cache_key).await {
106                match serde_json::from_str::<CachedHttpResponse>(&cached) {
107                    Ok(response) => {
108                        framework_log_event(
109                            "response_cache_hit",
110                            &[("key", cache_key.clone())],
111                        );
112                        return response.into_response();
113                    }
114                    Err(err) => {
115                        framework_log_event(
116                            "response_cache_deserialize_failed",
117                            &[
118                                ("key", cache_key.clone()),
119                                ("error", err.to_string()),
120                            ],
121                        );
122                    }
123                }
124            }
125
126            let response = (next)(req).await;
127            if !policy.should_cache_response(&response) {
128                return response;
129            }
130
131            let (parts, body) = response.into_parts();
132            let bytes = match to_bytes(body, usize::MAX).await {
133                Ok(bytes) => bytes,
134                Err(err) => {
135                    framework_log_event(
136                        "response_cache_read_failed",
137                        &[
138                            ("key", cache_key),
139                            ("error", err.to_string()),
140                        ],
141                    );
142                    return nestforge_core::HttpException::internal_server_error(
143                        "Failed to read response body for caching",
144                    )
145                    .into_response();
146                }
147            };
148
149            let content_type = parts
150                .headers
151                .get(CONTENT_TYPE)
152                .and_then(|value| value.to_str().ok())
153                .map(str::to_string);
154
155            let response_for_client = Response::from_parts(parts, Body::from(bytes.clone()));
156
157            let Ok(body) = String::from_utf8(bytes.to_vec()) else {
158                return response_for_client;
159            };
160
161            let payload = CachedHttpResponse {
162                status: response_for_client.status().as_u16(),
163                body,
164                content_type,
165            };
166
167            match serde_json::to_string(&payload) {
168                Ok(serialized) => {
169                    if let Err(err) = store
170                        .set(&cache_key, &serialized, policy.ttl_seconds())
171                        .await
172                    {
173                        framework_log_event(
174                            "response_cache_store_failed",
175                            &[
176                                ("key", cache_key),
177                                ("error", err.to_string()),
178                            ],
179                        );
180                    }
181                }
182                Err(err) => {
183                    framework_log_event(
184                        "response_cache_serialize_failed",
185                        &[
186                            ("key", cache_key),
187                            ("error", err.to_string()),
188                        ],
189                    );
190                }
191            }
192
193            response_for_client
194        })
195    }
196}
197
198#[derive(Debug)]
199pub struct DefaultCachePolicy<S>
200where
201    S: CacheStore + Send + Sync + 'static,
202{
203    _marker: std::marker::PhantomData<fn() -> S>,
204}
205
206impl<S> Clone for DefaultCachePolicy<S>
207where
208    S: CacheStore + Send + Sync + 'static,
209{
210    fn clone(&self) -> Self {
211        Self::default()
212    }
213}
214
215impl<S> Default for DefaultCachePolicy<S>
216where
217    S: CacheStore + Send + Sync + 'static,
218{
219    fn default() -> Self {
220        Self {
221            _marker: std::marker::PhantomData,
222        }
223    }
224}
225
226impl<S> CachePolicy for DefaultCachePolicy<S>
227where
228    S: CacheStore + Send + Sync + 'static,
229{
230    type Store = S;
231}
232
233pub fn cached_response_interceptor<S>() -> CacheInterceptor<DefaultCachePolicy<S>>
234where
235    S: CacheStore + Send + Sync + 'static,
236{
237    CacheInterceptor::default()
238}