Skip to main content

nestforge_cache/
lib.rs

1use axum::{
2    body::{to_bytes, Body},
3    extract::Request,
4    http::{header::CONTENT_TYPE, HeaderValue, Method, StatusCode},
5    response::{IntoResponse, Response},
6};
7use nestforge_core::{
8    framework_log_event, Container, Interceptor, NextFn, NextFuture, RequestContext,
9};
10use nestforge_data::CacheStore;
11use serde::{Deserialize, Serialize};
12
13/**
14 * CachedHttpResponse
15 *
16 * Internal representation of a cached HTTP response.
17 * Stores the status code, body, and content type.
18 */
19#[derive(Debug, Clone, Serialize, Deserialize)]
20struct CachedHttpResponse {
21    status: u16,
22    body: String,
23    content_type: Option<String>,
24}
25
26impl CachedHttpResponse {
27    fn into_response(self) -> Response {
28        let status = StatusCode::from_u16(self.status).unwrap_or(StatusCode::OK);
29        let mut response = (status, self.body).into_response();
30
31        if let Some(content_type) = self.content_type {
32            if let Ok(header) = HeaderValue::from_str(&content_type) {
33                response.headers_mut().insert(CONTENT_TYPE, header);
34            }
35        }
36
37        response
38    }
39}
40
41/**
42 * CachePolicy Trait
43 *
44 * Defines the caching strategy for responses.
45 *
46 * # Type Parameters
47 * - `Self`: The policy type
48 * - `Store`: The cache store implementation
49 *
50 * # Default Behavior
51 * By default, only GET requests are cached, with no TTL,
52 * and only successful (200 OK) responses are cached.
53 */
54pub trait CachePolicy: Default + Clone + Send + Sync + 'static {
55    /** The cache store type used by this policy */
56    type Store: CacheStore + Send + Sync + 'static;
57
58    /**
59     * Generates a cache key for the given request.
60     *
61     * Return `None` to skip caching for this request.
62     * Default implementation caches all GET requests.
63     */
64    fn cache_key(&self, ctx: &RequestContext, req: &Request<Body>) -> Option<String> {
65        if ctx.method != Method::GET {
66            return None;
67        }
68
69        Some(format!("{}:{}", std::any::type_name::<Self>(), req.uri()))
70    }
71
72    /**
73     * Returns the time-to-live in seconds for cached responses.
74     *
75     * Return `None` for no expiration.
76     */
77    fn ttl_seconds(&self) -> Option<u64> {
78        None
79    }
80
81    /**
82     * Determines whether a response should be cached.
83     *
84     * Default implementation caches only 200 OK responses.
85     */
86    fn should_cache_response(&self, response: &Response) -> bool {
87        response.status() == StatusCode::OK
88    }
89}
90
91/**
92 * CacheInterceptor
93 *
94 * An interceptor that implements HTTP response caching based on
95 * a configurable CachePolicy.
96 *
97 * # Type Parameters
98 * - `P`: The cache policy implementing CachePolicy
99 */
100#[derive(Debug, Clone)]
101pub struct CacheInterceptor<P>
102where
103    P: CachePolicy,
104{
105    policy: P,
106}
107
108impl<P> Default for CacheInterceptor<P>
109where
110    P: CachePolicy,
111{
112    fn default() -> Self {
113        Self {
114            policy: P::default(),
115        }
116    }
117}
118
119impl<P> CacheInterceptor<P>
120where
121    P: CachePolicy,
122{
123    /**
124     * Creates a new CacheInterceptor with the given policy.
125     */
126    pub fn new(policy: P) -> Self {
127        Self { policy }
128    }
129}
130
131impl<P> Interceptor for CacheInterceptor<P>
132where
133    P: CachePolicy,
134{
135    fn around(&self, ctx: RequestContext, req: Request<Body>, next: NextFn) -> NextFuture {
136        let policy = self.policy.clone();
137
138        Box::pin(async move {
139            let Some(container) = req.extensions().get::<Container>().cloned() else {
140                return (next)(req).await;
141            };
142
143            let Some(cache_key) = policy.cache_key(&ctx, &req) else {
144                return (next)(req).await;
145            };
146
147            let Ok(store) = container.resolve::<P::Store>() else {
148                return (next)(req).await;
149            };
150
151            if let Ok(Some(cached)) = store.get(&cache_key).await {
152                match serde_json::from_str::<CachedHttpResponse>(&cached) {
153                    Ok(response) => {
154                        framework_log_event("response_cache_hit", &[("key", cache_key.clone())]);
155                        return response.into_response();
156                    }
157                    Err(err) => {
158                        framework_log_event(
159                            "response_cache_deserialize_failed",
160                            &[("key", cache_key.clone()), ("error", err.to_string())],
161                        );
162                    }
163                }
164            }
165
166            let response = (next)(req).await;
167            if !policy.should_cache_response(&response) {
168                return response;
169            }
170
171            let (parts, body) = response.into_parts();
172            let bytes = match to_bytes(body, usize::MAX).await {
173                Ok(bytes) => bytes,
174                Err(err) => {
175                    framework_log_event(
176                        "response_cache_read_failed",
177                        &[("key", cache_key), ("error", err.to_string())],
178                    );
179                    return nestforge_core::HttpException::internal_server_error(
180                        "Failed to read response body for caching",
181                    )
182                    .into_response();
183                }
184            };
185
186            let content_type = parts
187                .headers
188                .get(CONTENT_TYPE)
189                .and_then(|value| value.to_str().ok())
190                .map(str::to_string);
191
192            let response_for_client = Response::from_parts(parts, Body::from(bytes.clone()));
193
194            let Ok(body) = String::from_utf8(bytes.to_vec()) else {
195                return response_for_client;
196            };
197
198            let payload = CachedHttpResponse {
199                status: response_for_client.status().as_u16(),
200                body,
201                content_type,
202            };
203
204            match serde_json::to_string(&payload) {
205                Ok(serialized) => {
206                    if let Err(err) = store
207                        .set(&cache_key, &serialized, policy.ttl_seconds())
208                        .await
209                    {
210                        framework_log_event(
211                            "response_cache_store_failed",
212                            &[("key", cache_key), ("error", err.to_string())],
213                        );
214                    }
215                }
216                Err(err) => {
217                    framework_log_event(
218                        "response_cache_serialize_failed",
219                        &[("key", cache_key), ("error", err.to_string())],
220                    );
221                }
222            }
223
224            response_for_client
225        })
226    }
227}
228
229#[derive(Debug)]
230pub struct DefaultCachePolicy<S>
231where
232    S: CacheStore + Send + Sync + 'static,
233{
234    _marker: std::marker::PhantomData<fn() -> S>,
235}
236
237impl<S> Clone for DefaultCachePolicy<S>
238where
239    S: CacheStore + Send + Sync + 'static,
240{
241    fn clone(&self) -> Self {
242        Self::default()
243    }
244}
245
246impl<S> Default for DefaultCachePolicy<S>
247where
248    S: CacheStore + Send + Sync + 'static,
249{
250    fn default() -> Self {
251        Self {
252            _marker: std::marker::PhantomData,
253        }
254    }
255}
256
257impl<S> CachePolicy for DefaultCachePolicy<S>
258where
259    S: CacheStore + Send + Sync + 'static,
260{
261    type Store = S;
262}
263
264pub fn cached_response_interceptor<S>() -> CacheInterceptor<DefaultCachePolicy<S>>
265where
266    S: CacheStore + Send + Sync + 'static,
267{
268    CacheInterceptor::default()
269}