1use axum::{
2 body::{to_bytes, Body},
3 extract::Request,
4 http::{header::CONTENT_TYPE, HeaderValue, Method, StatusCode},
5 response::{IntoResponse, Response},
6};
7use nestforge_core::{framework_log_event, Container, Interceptor, NextFn, NextFuture, RequestContext};
8use nestforge_data::CacheStore;
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12struct CachedHttpResponse {
13 status: u16,
14 body: String,
15 content_type: Option<String>,
16}
17
18impl CachedHttpResponse {
19 fn into_response(self) -> Response {
20 let status = StatusCode::from_u16(self.status).unwrap_or(StatusCode::OK);
21 let mut response = (status, self.body).into_response();
22
23 if let Some(content_type) = self.content_type {
24 if let Ok(header) = HeaderValue::from_str(&content_type) {
25 response.headers_mut().insert(CONTENT_TYPE, header);
26 }
27 }
28
29 response
30 }
31}
32
33pub trait CachePolicy: Default + Clone + Send + Sync + 'static {
34 type Store: CacheStore + Send + Sync + 'static;
35
36 fn cache_key(&self, ctx: &RequestContext, req: &Request<Body>) -> Option<String> {
37 if ctx.method != Method::GET {
38 return None;
39 }
40
41 Some(format!(
42 "{}:{}",
43 std::any::type_name::<Self>(),
44 req.uri()
45 ))
46 }
47
48 fn ttl_seconds(&self) -> Option<u64> {
49 None
50 }
51
52 fn should_cache_response(&self, response: &Response) -> bool {
53 response.status() == StatusCode::OK
54 }
55}
56
57#[derive(Debug, Clone)]
58pub struct CacheInterceptor<P>
59where
60 P: CachePolicy,
61{
62 policy: P,
63}
64
65impl<P> Default for CacheInterceptor<P>
66where
67 P: CachePolicy,
68{
69 fn default() -> Self {
70 Self {
71 policy: P::default(),
72 }
73 }
74}
75
76impl<P> CacheInterceptor<P>
77where
78 P: CachePolicy,
79{
80 pub fn new(policy: P) -> Self {
81 Self { policy }
82 }
83}
84
85impl<P> Interceptor for CacheInterceptor<P>
86where
87 P: CachePolicy,
88{
89 fn around(&self, ctx: RequestContext, req: Request<Body>, next: NextFn) -> NextFuture {
90 let policy = self.policy.clone();
91
92 Box::pin(async move {
93 let Some(container) = req.extensions().get::<Container>().cloned() else {
94 return (next)(req).await;
95 };
96
97 let Some(cache_key) = policy.cache_key(&ctx, &req) else {
98 return (next)(req).await;
99 };
100
101 let Ok(store) = container.resolve::<P::Store>() else {
102 return (next)(req).await;
103 };
104
105 if let Ok(Some(cached)) = store.get(&cache_key).await {
106 match serde_json::from_str::<CachedHttpResponse>(&cached) {
107 Ok(response) => {
108 framework_log_event(
109 "response_cache_hit",
110 &[("key", cache_key.clone())],
111 );
112 return response.into_response();
113 }
114 Err(err) => {
115 framework_log_event(
116 "response_cache_deserialize_failed",
117 &[
118 ("key", cache_key.clone()),
119 ("error", err.to_string()),
120 ],
121 );
122 }
123 }
124 }
125
126 let response = (next)(req).await;
127 if !policy.should_cache_response(&response) {
128 return response;
129 }
130
131 let (parts, body) = response.into_parts();
132 let bytes = match to_bytes(body, usize::MAX).await {
133 Ok(bytes) => bytes,
134 Err(err) => {
135 framework_log_event(
136 "response_cache_read_failed",
137 &[
138 ("key", cache_key),
139 ("error", err.to_string()),
140 ],
141 );
142 return nestforge_core::HttpException::internal_server_error(
143 "Failed to read response body for caching",
144 )
145 .into_response();
146 }
147 };
148
149 let content_type = parts
150 .headers
151 .get(CONTENT_TYPE)
152 .and_then(|value| value.to_str().ok())
153 .map(str::to_string);
154
155 let response_for_client = Response::from_parts(parts, Body::from(bytes.clone()));
156
157 let Ok(body) = String::from_utf8(bytes.to_vec()) else {
158 return response_for_client;
159 };
160
161 let payload = CachedHttpResponse {
162 status: response_for_client.status().as_u16(),
163 body,
164 content_type,
165 };
166
167 match serde_json::to_string(&payload) {
168 Ok(serialized) => {
169 if let Err(err) = store
170 .set(&cache_key, &serialized, policy.ttl_seconds())
171 .await
172 {
173 framework_log_event(
174 "response_cache_store_failed",
175 &[
176 ("key", cache_key),
177 ("error", err.to_string()),
178 ],
179 );
180 }
181 }
182 Err(err) => {
183 framework_log_event(
184 "response_cache_serialize_failed",
185 &[
186 ("key", cache_key),
187 ("error", err.to_string()),
188 ],
189 );
190 }
191 }
192
193 response_for_client
194 })
195 }
196}
197
198#[derive(Debug)]
199pub struct DefaultCachePolicy<S>
200where
201 S: CacheStore + Send + Sync + 'static,
202{
203 _marker: std::marker::PhantomData<fn() -> S>,
204}
205
206impl<S> Clone for DefaultCachePolicy<S>
207where
208 S: CacheStore + Send + Sync + 'static,
209{
210 fn clone(&self) -> Self {
211 Self::default()
212 }
213}
214
215impl<S> Default for DefaultCachePolicy<S>
216where
217 S: CacheStore + Send + Sync + 'static,
218{
219 fn default() -> Self {
220 Self {
221 _marker: std::marker::PhantomData,
222 }
223 }
224}
225
226impl<S> CachePolicy for DefaultCachePolicy<S>
227where
228 S: CacheStore + Send + Sync + 'static,
229{
230 type Store = S;
231}
232
233pub fn cached_response_interceptor<S>() -> CacheInterceptor<DefaultCachePolicy<S>>
234where
235 S: CacheStore + Send + Sync + 'static,
236{
237 CacheInterceptor::default()
238}