1use crate::cache::CacheHandle;
2use axum::{
3 extract::State,
4 http::{header, HeaderMap, StatusCode},
5 response::IntoResponse,
6 routing::post,
7 Json, Router,
8};
9use serde::Deserialize;
10use std::sync::Arc;
11
12#[derive(Clone)]
13pub struct ControlState {
14 handles: Vec<(String, CacheHandle)>,
16 auth_token: Option<String>,
17}
18
19impl ControlState {
20 pub fn new(handles: Vec<(String, CacheHandle)>, auth_token: Option<String>) -> Self {
21 Self { handles, auth_token }
22 }
23
24 fn resolve_handles(
27 &self,
28 server: Option<&str>,
29 ) -> Result<Vec<&CacheHandle>, (StatusCode, String)> {
30 match server {
31 None => Ok(self.handles.iter().map(|(_, h)| h).collect()),
32 Some(name) => {
33 let matched: Vec<&CacheHandle> = self
34 .handles
35 .iter()
36 .filter(|(n, _)| n == name)
37 .map(|(_, h)| h)
38 .collect();
39 if matched.is_empty() {
40 Err((
41 StatusCode::NOT_FOUND,
42 format!("No server named '{}' found", name),
43 ))
44 } else {
45 Ok(matched)
46 }
47 }
48 }
49 }
50
51 fn resolve_snapshot_handles(
57 &self,
58 server: Option<&str>,
59 ) -> Result<Vec<&CacheHandle>, (StatusCode, String)> {
60 match server {
61 None => {
62 let handles: Vec<&CacheHandle> = self
63 .handles
64 .iter()
65 .filter(|(_, h)| h.is_snapshot_capable())
66 .map(|(_, h)| h)
67 .collect();
68 if handles.is_empty() {
69 return Err((
70 StatusCode::BAD_REQUEST,
71 "No servers running in PreGenerate mode — snapshot operations are not available".to_string(),
72 ));
73 }
74 Ok(handles)
75 }
76 Some(name) => {
77 let matched: Vec<&CacheHandle> = self
78 .handles
79 .iter()
80 .filter(|(n, _)| n == name)
81 .map(|(_, h)| h)
82 .collect();
83 if matched.is_empty() {
84 Err((
85 StatusCode::NOT_FOUND,
86 format!("No server named '{}' found", name),
87 ))
88 } else {
89 Ok(matched)
90 }
91 }
92 }
93 }
94}
95
96#[derive(Deserialize)]
97struct PatternBody {
98 pattern: String,
99 server: Option<String>,
101}
102
103#[derive(Deserialize)]
104struct PathBody {
105 path: String,
106 server: Option<String>,
109}
110
111fn check_auth(state: &ControlState, headers: &HeaderMap) -> Result<(), StatusCode> {
113 if let Some(required_token) = &state.auth_token {
114 let auth_header = headers
115 .get(header::AUTHORIZATION)
116 .and_then(|h| h.to_str().ok());
117 let expected = format!("Bearer {}", required_token);
118 if auth_header != Some(expected.as_str()) {
119 tracing::warn!("Unauthorized control endpoint attempt");
120 return Err(StatusCode::UNAUTHORIZED);
121 }
122 }
123 Ok(())
124}
125
126async fn invalidate_all_handler(
128 State(state): State<Arc<ControlState>>,
129 headers: HeaderMap,
130) -> Result<impl IntoResponse, StatusCode> {
131 check_auth(&state, &headers)?;
132
133 for (_, handle) in &state.handles {
134 handle.invalidate_all();
135 }
136 tracing::info!(
137 "invalidate_all triggered via control endpoint ({} server(s))",
138 state.handles.len()
139 );
140 Ok((StatusCode::OK, "Cache invalidated"))
141}
142
143async fn invalidate_handler(
147 State(state): State<Arc<ControlState>>,
148 headers: HeaderMap,
149 Json(body): Json<PatternBody>,
150) -> Result<impl IntoResponse, (StatusCode, String)> {
151 check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
152
153 let handles = state.resolve_handles(body.server.as_deref())?;
154 for handle in handles {
155 handle.invalidate(&body.pattern);
156 }
157 tracing::info!(
158 "invalidate('{}') triggered via control endpoint (server={:?})",
159 body.pattern,
160 body.server
161 );
162 Ok((StatusCode::OK, "Pattern invalidation triggered".to_string()))
163}
164
165async fn add_snapshot_handler(
170 State(state): State<Arc<ControlState>>,
171 headers: HeaderMap,
172 Json(body): Json<PathBody>,
173) -> Result<impl IntoResponse, (StatusCode, String)> {
174 check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
175
176 let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
177 for handle in handles {
178 handle
179 .add_snapshot(&body.path)
180 .await
181 .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
182 }
183 tracing::info!(
184 "add_snapshot('{}') triggered via control endpoint (server={:?})",
185 body.path, body.server
186 );
187 Ok((StatusCode::OK, "Snapshot added".to_string()))
188}
189
190async fn refresh_snapshot_handler(
195 State(state): State<Arc<ControlState>>,
196 headers: HeaderMap,
197 Json(body): Json<PathBody>,
198) -> Result<impl IntoResponse, (StatusCode, String)> {
199 check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
200
201 let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
202 for handle in handles {
203 handle
204 .refresh_snapshot(&body.path)
205 .await
206 .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
207 }
208 tracing::info!(
209 "refresh_snapshot('{}') triggered via control endpoint (server={:?})",
210 body.path, body.server
211 );
212 Ok((StatusCode::OK, "Snapshot refreshed".to_string()))
213}
214
215async fn remove_snapshot_handler(
220 State(state): State<Arc<ControlState>>,
221 headers: HeaderMap,
222 Json(body): Json<PathBody>,
223) -> Result<impl IntoResponse, (StatusCode, String)> {
224 check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
225
226 let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
227 for handle in handles {
228 handle
229 .remove_snapshot(&body.path)
230 .await
231 .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
232 }
233 tracing::info!(
234 "remove_snapshot('{}') triggered via control endpoint (server={:?})",
235 body.path, body.server
236 );
237 Ok((StatusCode::OK, "Snapshot removed".to_string()))
238}
239
240async fn refresh_all_snapshots_handler(
245 State(state): State<Arc<ControlState>>,
246 headers: HeaderMap,
247 body: Option<Json<serde_json::Value>>,
248) -> Result<impl IntoResponse, (StatusCode, String)> {
249 check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
250
251 let server_filter = body
252 .as_ref()
253 .and_then(|Json(v)| v.get("server"))
254 .and_then(|v| v.as_str())
255 .map(|s| s.to_string());
256
257 let handles = state.resolve_snapshot_handles(server_filter.as_deref())?;
258 for handle in handles {
259 handle
260 .refresh_all_snapshots()
261 .await
262 .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
263 }
264 tracing::info!(
265 "refresh_all_snapshots triggered via control endpoint (server={:?})",
266 server_filter
267 );
268 Ok((StatusCode::OK, "All snapshots refreshed".to_string()))
269}
270
271pub fn create_control_router(
275 handles: Vec<(String, CacheHandle)>,
276 auth_token: Option<String>,
277) -> Router {
278 let state = Arc::new(ControlState::new(handles, auth_token));
279
280 Router::new()
281 .route("/invalidate_all", post(invalidate_all_handler))
282 .route("/invalidate", post(invalidate_handler))
283 .route("/add_snapshot", post(add_snapshot_handler))
284 .route("/refresh_snapshot", post(refresh_snapshot_handler))
285 .route("/remove_snapshot", post(remove_snapshot_handler))
286 .route("/refresh_all_snapshots", post(refresh_all_snapshots_handler))
287 .with_state(state)
288}