1use axum::{
2 body::{to_bytes, Body},
3 extract::Request,
4 http::{header::CONTENT_TYPE, HeaderValue, Method, StatusCode},
5 response::{IntoResponse, Response},
6};
7use nestforge_core::{
8 framework_log_event, Container, Interceptor, NextFn, NextFuture, RequestContext,
9};
10use nestforge_data::CacheStore;
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
20struct CachedHttpResponse {
21 status: u16,
22 body: String,
23 content_type: Option<String>,
24}
25
26impl CachedHttpResponse {
27 fn into_response(self) -> Response {
28 let status = StatusCode::from_u16(self.status).unwrap_or(StatusCode::OK);
29 let mut response = (status, self.body).into_response();
30
31 if let Some(content_type) = self.content_type {
32 if let Ok(header) = HeaderValue::from_str(&content_type) {
33 response.headers_mut().insert(CONTENT_TYPE, header);
34 }
35 }
36
37 response
38 }
39}
40
41pub trait CachePolicy: Default + Clone + Send + Sync + 'static {
55 type Store: CacheStore + Send + Sync + 'static;
57
58 fn cache_key(&self, ctx: &RequestContext, req: &Request<Body>) -> Option<String> {
65 if ctx.method != Method::GET {
66 return None;
67 }
68
69 Some(format!("{}:{}", std::any::type_name::<Self>(), req.uri()))
70 }
71
72 fn ttl_seconds(&self) -> Option<u64> {
78 None
79 }
80
81 fn should_cache_response(&self, response: &Response) -> bool {
87 response.status() == StatusCode::OK
88 }
89}
90
91#[derive(Debug, Clone)]
101pub struct CacheInterceptor<P>
102where
103 P: CachePolicy,
104{
105 policy: P,
106}
107
108impl<P> Default for CacheInterceptor<P>
109where
110 P: CachePolicy,
111{
112 fn default() -> Self {
113 Self {
114 policy: P::default(),
115 }
116 }
117}
118
119impl<P> CacheInterceptor<P>
120where
121 P: CachePolicy,
122{
123 pub fn new(policy: P) -> Self {
127 Self { policy }
128 }
129}
130
131impl<P> Interceptor for CacheInterceptor<P>
132where
133 P: CachePolicy,
134{
135 fn around(&self, ctx: RequestContext, req: Request<Body>, next: NextFn) -> NextFuture {
136 let policy = self.policy.clone();
137
138 Box::pin(async move {
139 let Some(container) = req.extensions().get::<Container>().cloned() else {
140 return (next)(req).await;
141 };
142
143 let Some(cache_key) = policy.cache_key(&ctx, &req) else {
144 return (next)(req).await;
145 };
146
147 let Ok(store) = container.resolve::<P::Store>() else {
148 return (next)(req).await;
149 };
150
151 if let Ok(Some(cached)) = store.get(&cache_key).await {
152 match serde_json::from_str::<CachedHttpResponse>(&cached) {
153 Ok(response) => {
154 framework_log_event("response_cache_hit", &[("key", cache_key.clone())]);
155 return response.into_response();
156 }
157 Err(err) => {
158 framework_log_event(
159 "response_cache_deserialize_failed",
160 &[("key", cache_key.clone()), ("error", err.to_string())],
161 );
162 }
163 }
164 }
165
166 let response = (next)(req).await;
167 if !policy.should_cache_response(&response) {
168 return response;
169 }
170
171 let (parts, body) = response.into_parts();
172 let bytes = match to_bytes(body, usize::MAX).await {
173 Ok(bytes) => bytes,
174 Err(err) => {
175 framework_log_event(
176 "response_cache_read_failed",
177 &[("key", cache_key), ("error", err.to_string())],
178 );
179 return nestforge_core::HttpException::internal_server_error(
180 "Failed to read response body for caching",
181 )
182 .into_response();
183 }
184 };
185
186 let content_type = parts
187 .headers
188 .get(CONTENT_TYPE)
189 .and_then(|value| value.to_str().ok())
190 .map(str::to_string);
191
192 let response_for_client = Response::from_parts(parts, Body::from(bytes.clone()));
193
194 let Ok(body) = String::from_utf8(bytes.to_vec()) else {
195 return response_for_client;
196 };
197
198 let payload = CachedHttpResponse {
199 status: response_for_client.status().as_u16(),
200 body,
201 content_type,
202 };
203
204 match serde_json::to_string(&payload) {
205 Ok(serialized) => {
206 if let Err(err) = store
207 .set(&cache_key, &serialized, policy.ttl_seconds())
208 .await
209 {
210 framework_log_event(
211 "response_cache_store_failed",
212 &[("key", cache_key), ("error", err.to_string())],
213 );
214 }
215 }
216 Err(err) => {
217 framework_log_event(
218 "response_cache_serialize_failed",
219 &[("key", cache_key), ("error", err.to_string())],
220 );
221 }
222 }
223
224 response_for_client
225 })
226 }
227}
228
229#[derive(Debug)]
230pub struct DefaultCachePolicy<S>
231where
232 S: CacheStore + Send + Sync + 'static,
233{
234 _marker: std::marker::PhantomData<fn() -> S>,
235}
236
237impl<S> Clone for DefaultCachePolicy<S>
238where
239 S: CacheStore + Send + Sync + 'static,
240{
241 fn clone(&self) -> Self {
242 Self::default()
243 }
244}
245
246impl<S> Default for DefaultCachePolicy<S>
247where
248 S: CacheStore + Send + Sync + 'static,
249{
250 fn default() -> Self {
251 Self {
252 _marker: std::marker::PhantomData,
253 }
254 }
255}
256
257impl<S> CachePolicy for DefaultCachePolicy<S>
258where
259 S: CacheStore + Send + Sync + 'static,
260{
261 type Store = S;
262}
263
264pub fn cached_response_interceptor<S>() -> CacheInterceptor<DefaultCachePolicy<S>>
265where
266 S: CacheStore + Send + Sync + 'static,
267{
268 CacheInterceptor::default()
269}