1use axum::{
2 body::{to_bytes, Body},
3 extract::Request,
4 http::{header::CONTENT_TYPE, HeaderValue, Method, StatusCode},
5 response::{IntoResponse, Response},
6};
7use nestforge_core::{
8 framework_log_event, Container, Interceptor, NextFn, NextFuture, RequestContext,
9};
10use nestforge_data::CacheStore;
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14struct CachedHttpResponse {
15 status: u16,
16 body: String,
17 content_type: Option<String>,
18}
19
20impl CachedHttpResponse {
21 fn into_response(self) -> Response {
22 let status = StatusCode::from_u16(self.status).unwrap_or(StatusCode::OK);
23 let mut response = (status, self.body).into_response();
24
25 if let Some(content_type) = self.content_type {
26 if let Ok(header) = HeaderValue::from_str(&content_type) {
27 response.headers_mut().insert(CONTENT_TYPE, header);
28 }
29 }
30
31 response
32 }
33}
34
35pub trait CachePolicy: Default + Clone + Send + Sync + 'static {
36 type Store: CacheStore + Send + Sync + 'static;
37
38 fn cache_key(&self, ctx: &RequestContext, req: &Request<Body>) -> Option<String> {
39 if ctx.method != Method::GET {
40 return None;
41 }
42
43 Some(format!("{}:{}", std::any::type_name::<Self>(), req.uri()))
44 }
45
46 fn ttl_seconds(&self) -> Option<u64> {
47 None
48 }
49
50 fn should_cache_response(&self, response: &Response) -> bool {
51 response.status() == StatusCode::OK
52 }
53}
54
55#[derive(Debug, Clone)]
56pub struct CacheInterceptor<P>
57where
58 P: CachePolicy,
59{
60 policy: P,
61}
62
63impl<P> Default for CacheInterceptor<P>
64where
65 P: CachePolicy,
66{
67 fn default() -> Self {
68 Self {
69 policy: P::default(),
70 }
71 }
72}
73
74impl<P> CacheInterceptor<P>
75where
76 P: CachePolicy,
77{
78 pub fn new(policy: P) -> Self {
79 Self { policy }
80 }
81}
82
83impl<P> Interceptor for CacheInterceptor<P>
84where
85 P: CachePolicy,
86{
87 fn around(&self, ctx: RequestContext, req: Request<Body>, next: NextFn) -> NextFuture {
88 let policy = self.policy.clone();
89
90 Box::pin(async move {
91 let Some(container) = req.extensions().get::<Container>().cloned() else {
92 return (next)(req).await;
93 };
94
95 let Some(cache_key) = policy.cache_key(&ctx, &req) else {
96 return (next)(req).await;
97 };
98
99 let Ok(store) = container.resolve::<P::Store>() else {
100 return (next)(req).await;
101 };
102
103 if let Ok(Some(cached)) = store.get(&cache_key).await {
104 match serde_json::from_str::<CachedHttpResponse>(&cached) {
105 Ok(response) => {
106 framework_log_event("response_cache_hit", &[("key", cache_key.clone())]);
107 return response.into_response();
108 }
109 Err(err) => {
110 framework_log_event(
111 "response_cache_deserialize_failed",
112 &[("key", cache_key.clone()), ("error", err.to_string())],
113 );
114 }
115 }
116 }
117
118 let response = (next)(req).await;
119 if !policy.should_cache_response(&response) {
120 return response;
121 }
122
123 let (parts, body) = response.into_parts();
124 let bytes = match to_bytes(body, usize::MAX).await {
125 Ok(bytes) => bytes,
126 Err(err) => {
127 framework_log_event(
128 "response_cache_read_failed",
129 &[("key", cache_key), ("error", err.to_string())],
130 );
131 return nestforge_core::HttpException::internal_server_error(
132 "Failed to read response body for caching",
133 )
134 .into_response();
135 }
136 };
137
138 let content_type = parts
139 .headers
140 .get(CONTENT_TYPE)
141 .and_then(|value| value.to_str().ok())
142 .map(str::to_string);
143
144 let response_for_client = Response::from_parts(parts, Body::from(bytes.clone()));
145
146 let Ok(body) = String::from_utf8(bytes.to_vec()) else {
147 return response_for_client;
148 };
149
150 let payload = CachedHttpResponse {
151 status: response_for_client.status().as_u16(),
152 body,
153 content_type,
154 };
155
156 match serde_json::to_string(&payload) {
157 Ok(serialized) => {
158 if let Err(err) = store
159 .set(&cache_key, &serialized, policy.ttl_seconds())
160 .await
161 {
162 framework_log_event(
163 "response_cache_store_failed",
164 &[("key", cache_key), ("error", err.to_string())],
165 );
166 }
167 }
168 Err(err) => {
169 framework_log_event(
170 "response_cache_serialize_failed",
171 &[("key", cache_key), ("error", err.to_string())],
172 );
173 }
174 }
175
176 response_for_client
177 })
178 }
179}
180
181#[derive(Debug)]
182pub struct DefaultCachePolicy<S>
183where
184 S: CacheStore + Send + Sync + 'static,
185{
186 _marker: std::marker::PhantomData<fn() -> S>,
187}
188
189impl<S> Clone for DefaultCachePolicy<S>
190where
191 S: CacheStore + Send + Sync + 'static,
192{
193 fn clone(&self) -> Self {
194 Self::default()
195 }
196}
197
198impl<S> Default for DefaultCachePolicy<S>
199where
200 S: CacheStore + Send + Sync + 'static,
201{
202 fn default() -> Self {
203 Self {
204 _marker: std::marker::PhantomData,
205 }
206 }
207}
208
209impl<S> CachePolicy for DefaultCachePolicy<S>
210where
211 S: CacheStore + Send + Sync + 'static,
212{
213 type Store = S;
214}
215
216pub fn cached_response_interceptor<S>() -> CacheInterceptor<DefaultCachePolicy<S>>
217where
218 S: CacheStore + Send + Sync + 'static,
219{
220 CacheInterceptor::default()
221}