Skip to main content

nestforge_cache/
lib.rs

1use axum::{
2    body::{to_bytes, Body},
3    extract::Request,
4    http::{header::CONTENT_TYPE, HeaderValue, Method, StatusCode},
5    response::{IntoResponse, Response},
6};
7use nestforge_core::{
8    framework_log_event, Container, Interceptor, NextFn, NextFuture, RequestContext,
9};
10use nestforge_data::CacheStore;
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14struct CachedHttpResponse {
15    status: u16,
16    body: String,
17    content_type: Option<String>,
18}
19
20impl CachedHttpResponse {
21    fn into_response(self) -> Response {
22        let status = StatusCode::from_u16(self.status).unwrap_or(StatusCode::OK);
23        let mut response = (status, self.body).into_response();
24
25        if let Some(content_type) = self.content_type {
26            if let Ok(header) = HeaderValue::from_str(&content_type) {
27                response.headers_mut().insert(CONTENT_TYPE, header);
28            }
29        }
30
31        response
32    }
33}
34
35pub trait CachePolicy: Default + Clone + Send + Sync + 'static {
36    type Store: CacheStore + Send + Sync + 'static;
37
38    fn cache_key(&self, ctx: &RequestContext, req: &Request<Body>) -> Option<String> {
39        if ctx.method != Method::GET {
40            return None;
41        }
42
43        Some(format!("{}:{}", std::any::type_name::<Self>(), req.uri()))
44    }
45
46    fn ttl_seconds(&self) -> Option<u64> {
47        None
48    }
49
50    fn should_cache_response(&self, response: &Response) -> bool {
51        response.status() == StatusCode::OK
52    }
53}
54
55#[derive(Debug, Clone)]
56pub struct CacheInterceptor<P>
57where
58    P: CachePolicy,
59{
60    policy: P,
61}
62
63impl<P> Default for CacheInterceptor<P>
64where
65    P: CachePolicy,
66{
67    fn default() -> Self {
68        Self {
69            policy: P::default(),
70        }
71    }
72}
73
74impl<P> CacheInterceptor<P>
75where
76    P: CachePolicy,
77{
78    pub fn new(policy: P) -> Self {
79        Self { policy }
80    }
81}
82
83impl<P> Interceptor for CacheInterceptor<P>
84where
85    P: CachePolicy,
86{
87    fn around(&self, ctx: RequestContext, req: Request<Body>, next: NextFn) -> NextFuture {
88        let policy = self.policy.clone();
89
90        Box::pin(async move {
91            let Some(container) = req.extensions().get::<Container>().cloned() else {
92                return (next)(req).await;
93            };
94
95            let Some(cache_key) = policy.cache_key(&ctx, &req) else {
96                return (next)(req).await;
97            };
98
99            let Ok(store) = container.resolve::<P::Store>() else {
100                return (next)(req).await;
101            };
102
103            if let Ok(Some(cached)) = store.get(&cache_key).await {
104                match serde_json::from_str::<CachedHttpResponse>(&cached) {
105                    Ok(response) => {
106                        framework_log_event("response_cache_hit", &[("key", cache_key.clone())]);
107                        return response.into_response();
108                    }
109                    Err(err) => {
110                        framework_log_event(
111                            "response_cache_deserialize_failed",
112                            &[("key", cache_key.clone()), ("error", err.to_string())],
113                        );
114                    }
115                }
116            }
117
118            let response = (next)(req).await;
119            if !policy.should_cache_response(&response) {
120                return response;
121            }
122
123            let (parts, body) = response.into_parts();
124            let bytes = match to_bytes(body, usize::MAX).await {
125                Ok(bytes) => bytes,
126                Err(err) => {
127                    framework_log_event(
128                        "response_cache_read_failed",
129                        &[("key", cache_key), ("error", err.to_string())],
130                    );
131                    return nestforge_core::HttpException::internal_server_error(
132                        "Failed to read response body for caching",
133                    )
134                    .into_response();
135                }
136            };
137
138            let content_type = parts
139                .headers
140                .get(CONTENT_TYPE)
141                .and_then(|value| value.to_str().ok())
142                .map(str::to_string);
143
144            let response_for_client = Response::from_parts(parts, Body::from(bytes.clone()));
145
146            let Ok(body) = String::from_utf8(bytes.to_vec()) else {
147                return response_for_client;
148            };
149
150            let payload = CachedHttpResponse {
151                status: response_for_client.status().as_u16(),
152                body,
153                content_type,
154            };
155
156            match serde_json::to_string(&payload) {
157                Ok(serialized) => {
158                    if let Err(err) = store
159                        .set(&cache_key, &serialized, policy.ttl_seconds())
160                        .await
161                    {
162                        framework_log_event(
163                            "response_cache_store_failed",
164                            &[("key", cache_key), ("error", err.to_string())],
165                        );
166                    }
167                }
168                Err(err) => {
169                    framework_log_event(
170                        "response_cache_serialize_failed",
171                        &[("key", cache_key), ("error", err.to_string())],
172                    );
173                }
174            }
175
176            response_for_client
177        })
178    }
179}
180
181#[derive(Debug)]
182pub struct DefaultCachePolicy<S>
183where
184    S: CacheStore + Send + Sync + 'static,
185{
186    _marker: std::marker::PhantomData<fn() -> S>,
187}
188
189impl<S> Clone for DefaultCachePolicy<S>
190where
191    S: CacheStore + Send + Sync + 'static,
192{
193    fn clone(&self) -> Self {
194        Self::default()
195    }
196}
197
198impl<S> Default for DefaultCachePolicy<S>
199where
200    S: CacheStore + Send + Sync + 'static,
201{
202    fn default() -> Self {
203        Self {
204            _marker: std::marker::PhantomData,
205        }
206    }
207}
208
209impl<S> CachePolicy for DefaultCachePolicy<S>
210where
211    S: CacheStore + Send + Sync + 'static,
212{
213    type Store = S;
214}
215
216pub fn cached_response_interceptor<S>() -> CacheInterceptor<DefaultCachePolicy<S>>
217where
218    S: CacheStore + Send + Sync + 'static,
219{
220    CacheInterceptor::default()
221}