1use crate::cache::CacheHandle;
2use axum::{
3 extract::State,
4 http::{header, HeaderMap, StatusCode},
5 response::IntoResponse,
6 routing::post,
7 Json, Router,
8};
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tokio::task::JoinHandle;
12
13#[derive(Clone)]
14pub struct ControlState {
15 handles: Vec<(String, CacheHandle)>,
17 auth_token: Option<String>,
18}
19
20impl ControlState {
21 pub fn new(handles: Vec<(String, CacheHandle)>, auth_token: Option<String>) -> Self {
22 Self {
23 handles,
24 auth_token,
25 }
26 }
27
28 fn resolve_handles(
31 &self,
32 server: Option<&str>,
33 ) -> Result<Vec<&CacheHandle>, (StatusCode, String)> {
34 match server {
35 None => Ok(self.handles.iter().map(|(_, h)| h).collect()),
36 Some(name) => {
37 let matched: Vec<&CacheHandle> = self
38 .handles
39 .iter()
40 .filter(|(n, _)| n == name)
41 .map(|(_, h)| h)
42 .collect();
43 if matched.is_empty() {
44 Err((
45 StatusCode::NOT_FOUND,
46 format!("No server named '{}' found", name),
47 ))
48 } else {
49 Ok(matched)
50 }
51 }
52 }
53 }
54
55 fn resolve_snapshot_handles(
61 &self,
62 server: Option<&str>,
63 ) -> Result<Vec<&CacheHandle>, (StatusCode, String)> {
64 match server {
65 None => {
66 let handles: Vec<&CacheHandle> = self
67 .handles
68 .iter()
69 .filter(|(_, h)| h.is_snapshot_capable())
70 .map(|(_, h)| h)
71 .collect();
72 if handles.is_empty() {
73 return Err((
74 StatusCode::BAD_REQUEST,
75 "No servers running in PreGenerate mode — snapshot operations are not available".to_string(),
76 ));
77 }
78 Ok(handles)
79 }
80 Some(name) => {
81 let matched: Vec<&CacheHandle> = self
82 .handles
83 .iter()
84 .filter(|(n, _)| n == name)
85 .map(|(_, h)| h)
86 .collect();
87 if matched.is_empty() {
88 Err((
89 StatusCode::NOT_FOUND,
90 format!("No server named '{}' found", name),
91 ))
92 } else {
93 Ok(matched)
94 }
95 }
96 }
97 }
98}
99
100#[derive(Deserialize)]
101struct PatternBody {
102 pattern: String,
103 server: Option<String>,
105}
106
107#[derive(Deserialize)]
108struct PathBody {
109 path: String,
110 server: Option<String>,
113}
114
115#[derive(Deserialize)]
116struct BulkPatternBody {
117 patterns: Vec<String>,
118 server: Option<String>,
120}
121
122#[derive(Deserialize)]
123struct BulkPathBody {
124 paths: Vec<String>,
125 server: Option<String>,
128}
129
130#[derive(Serialize)]
131struct BulkOperationItemResult {
132 item: String,
133 success: bool,
134 error: Option<String>,
135}
136
137#[derive(Serialize)]
138struct BulkOperationResponse {
139 operation: &'static str,
140 server: Option<String>,
141 requested: usize,
142 succeeded: usize,
143 failed: usize,
144 results: Vec<BulkOperationItemResult>,
145}
146
147#[derive(Clone, Copy)]
148enum BulkSnapshotAction {
149 Add,
150 Refresh,
151 Remove,
152}
153
154fn check_auth(state: &ControlState, headers: &HeaderMap) -> Result<(), StatusCode> {
156 if let Some(required_token) = &state.auth_token {
157 let auth_header = headers
158 .get(header::AUTHORIZATION)
159 .and_then(|h| h.to_str().ok());
160 let expected = format!("Bearer {}", required_token);
161 if auth_header != Some(expected.as_str()) {
162 tracing::warn!("Unauthorized control endpoint attempt");
163 return Err(StatusCode::UNAUTHORIZED);
164 }
165 }
166 Ok(())
167}
168
169fn validate_bulk_items<T>(items: &[T], field_name: &str) -> Result<(), (StatusCode, String)> {
170 if items.is_empty() {
171 return Err((
172 StatusCode::BAD_REQUEST,
173 format!("'{}' must contain at least one item", field_name),
174 ));
175 }
176 Ok(())
177}
178
179fn bulk_response(
180 operation: &'static str,
181 server: Option<String>,
182 results: Vec<BulkOperationItemResult>,
183) -> (StatusCode, Json<BulkOperationResponse>) {
184 let requested = results.len();
185 let succeeded = results.iter().filter(|result| result.success).count();
186 let failed = requested - succeeded;
187
188 (
189 StatusCode::OK,
190 Json(BulkOperationResponse {
191 operation,
192 server,
193 requested,
194 succeeded,
195 failed,
196 results,
197 }),
198 )
199}
200
201async fn run_bulk_snapshot_operation(
202 handles: Vec<&CacheHandle>,
203 paths: &[String],
204 action: BulkSnapshotAction,
205) -> Vec<BulkOperationItemResult> {
206 let handles: Arc<Vec<CacheHandle>> = Arc::new(handles.into_iter().cloned().collect());
207 let tasks: Vec<JoinHandle<BulkOperationItemResult>> = paths
208 .iter()
209 .cloned()
210 .map(|path| {
211 let handles = Arc::clone(&handles);
212 tokio::spawn(async move {
213 let error = run_snapshot_operation_for_path(handles.as_ref(), &path, action).await;
214
215 BulkOperationItemResult {
216 item: path,
217 success: error.is_none(),
218 error,
219 }
220 })
221 })
222 .collect();
223
224 let mut results = Vec::with_capacity(tasks.len());
225
226 for task in tasks {
227 match task.await {
228 Ok(result) => results.push(result),
229 Err(err) => {
230 tracing::error!("bulk snapshot task failed: {}", err);
231 results.push(BulkOperationItemResult {
232 item: "<unknown>".to_string(),
233 success: false,
234 error: Some("bulk snapshot task failed".to_string()),
235 });
236 }
237 }
238 }
239
240 results
241}
242
243async fn run_snapshot_operation_for_path(
244 handles: &[CacheHandle],
245 path: &str,
246 action: BulkSnapshotAction,
247) -> Option<String> {
248 for handle in handles {
249 let outcome = match action {
250 BulkSnapshotAction::Add => handle.add_snapshot(path).await,
251 BulkSnapshotAction::Refresh => handle.refresh_snapshot(path).await,
252 BulkSnapshotAction::Remove => handle.remove_snapshot(path).await,
253 };
254
255 if let Err(err) = outcome {
256 return Some(err.to_string());
257 }
258 }
259
260 None
261}
262
263async fn invalidate_all_handler(
265 State(state): State<Arc<ControlState>>,
266 headers: HeaderMap,
267) -> Result<impl IntoResponse, StatusCode> {
268 check_auth(&state, &headers)?;
269
270 for (_, handle) in &state.handles {
271 handle.invalidate_all();
272 }
273 tracing::info!(
274 "invalidate_all triggered via control endpoint ({} server(s))",
275 state.handles.len()
276 );
277 Ok((StatusCode::OK, "Cache invalidated"))
278}
279
280async fn invalidate_handler(
284 State(state): State<Arc<ControlState>>,
285 headers: HeaderMap,
286 Json(body): Json<PatternBody>,
287) -> Result<impl IntoResponse, (StatusCode, String)> {
288 check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
289
290 let handles = state.resolve_handles(body.server.as_deref())?;
291 for handle in handles {
292 handle.invalidate(&body.pattern);
293 }
294 tracing::info!(
295 "invalidate('{}') triggered via control endpoint (server={:?})",
296 body.pattern,
297 body.server
298 );
299 Ok((StatusCode::OK, "Pattern invalidation triggered".to_string()))
300}
301
302async fn bulk_invalidate_handler(
307 State(state): State<Arc<ControlState>>,
308 headers: HeaderMap,
309 Json(body): Json<BulkPatternBody>,
310) -> Result<impl IntoResponse, (StatusCode, String)> {
311 check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
312 validate_bulk_items(&body.patterns, "patterns")?;
313
314 let handles = state.resolve_handles(body.server.as_deref())?;
315 let mut results = Vec::with_capacity(body.patterns.len());
316
317 for pattern in &body.patterns {
318 for handle in &handles {
319 handle.invalidate(pattern);
320 }
321
322 results.push(BulkOperationItemResult {
323 item: pattern.clone(),
324 success: true,
325 error: None,
326 });
327 }
328
329 tracing::info!(
330 "bulk_invalidate(count={}) triggered via control endpoint (server={:?})",
331 body.patterns.len(),
332 body.server
333 );
334
335 Ok(bulk_response("bulk_invalidate", body.server, results))
336}
337
338async fn add_snapshot_handler(
343 State(state): State<Arc<ControlState>>,
344 headers: HeaderMap,
345 Json(body): Json<PathBody>,
346) -> Result<impl IntoResponse, (StatusCode, String)> {
347 check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
348
349 let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
350 for handle in handles {
351 handle
352 .add_snapshot(&body.path)
353 .await
354 .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
355 }
356 tracing::info!(
357 "add_snapshot('{}') triggered via control endpoint (server={:?})",
358 body.path,
359 body.server
360 );
361 Ok((StatusCode::OK, "Snapshot added".to_string()))
362}
363
364async fn bulk_add_snapshot_handler(
370 State(state): State<Arc<ControlState>>,
371 headers: HeaderMap,
372 Json(body): Json<BulkPathBody>,
373) -> Result<impl IntoResponse, (StatusCode, String)> {
374 check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
375 validate_bulk_items(&body.paths, "paths")?;
376
377 let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
378 let results = run_bulk_snapshot_operation(handles, &body.paths, BulkSnapshotAction::Add).await;
379
380 tracing::info!(
381 "bulk_add_snapshot(count={}) triggered via control endpoint (server={:?})",
382 body.paths.len(),
383 body.server
384 );
385
386 Ok(bulk_response("bulk_add_snapshot", body.server, results))
387}
388
389async fn refresh_snapshot_handler(
394 State(state): State<Arc<ControlState>>,
395 headers: HeaderMap,
396 Json(body): Json<PathBody>,
397) -> Result<impl IntoResponse, (StatusCode, String)> {
398 check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
399
400 let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
401 for handle in handles {
402 handle
403 .refresh_snapshot(&body.path)
404 .await
405 .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
406 }
407 tracing::info!(
408 "refresh_snapshot('{}') triggered via control endpoint (server={:?})",
409 body.path,
410 body.server
411 );
412 Ok((StatusCode::OK, "Snapshot refreshed".to_string()))
413}
414
415async fn bulk_refresh_snapshot_handler(
421 State(state): State<Arc<ControlState>>,
422 headers: HeaderMap,
423 Json(body): Json<BulkPathBody>,
424) -> Result<impl IntoResponse, (StatusCode, String)> {
425 check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
426 validate_bulk_items(&body.paths, "paths")?;
427
428 let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
429 let results =
430 run_bulk_snapshot_operation(handles, &body.paths, BulkSnapshotAction::Refresh).await;
431
432 tracing::info!(
433 "bulk_refresh_snapshot(count={}) triggered via control endpoint (server={:?})",
434 body.paths.len(),
435 body.server
436 );
437
438 Ok(bulk_response("bulk_refresh_snapshot", body.server, results))
439}
440
441async fn remove_snapshot_handler(
446 State(state): State<Arc<ControlState>>,
447 headers: HeaderMap,
448 Json(body): Json<PathBody>,
449) -> Result<impl IntoResponse, (StatusCode, String)> {
450 check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
451
452 let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
453 for handle in handles {
454 handle
455 .remove_snapshot(&body.path)
456 .await
457 .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
458 }
459 tracing::info!(
460 "remove_snapshot('{}') triggered via control endpoint (server={:?})",
461 body.path,
462 body.server
463 );
464 Ok((StatusCode::OK, "Snapshot removed".to_string()))
465}
466
467async fn bulk_remove_snapshot_handler(
473 State(state): State<Arc<ControlState>>,
474 headers: HeaderMap,
475 Json(body): Json<BulkPathBody>,
476) -> Result<impl IntoResponse, (StatusCode, String)> {
477 check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
478 validate_bulk_items(&body.paths, "paths")?;
479
480 let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
481 let results =
482 run_bulk_snapshot_operation(handles, &body.paths, BulkSnapshotAction::Remove).await;
483
484 tracing::info!(
485 "bulk_remove_snapshot(count={}) triggered via control endpoint (server={:?})",
486 body.paths.len(),
487 body.server
488 );
489
490 Ok(bulk_response("bulk_remove_snapshot", body.server, results))
491}
492
493async fn refresh_all_snapshots_handler(
498 State(state): State<Arc<ControlState>>,
499 headers: HeaderMap,
500 body: Option<Json<serde_json::Value>>,
501) -> Result<impl IntoResponse, (StatusCode, String)> {
502 check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
503
504 let server_filter = body
505 .as_ref()
506 .and_then(|Json(v)| v.get("server"))
507 .and_then(|v| v.as_str())
508 .map(|s| s.to_string());
509
510 let handles = state.resolve_snapshot_handles(server_filter.as_deref())?;
511 for handle in handles {
512 handle
513 .refresh_all_snapshots()
514 .await
515 .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
516 }
517 tracing::info!(
518 "refresh_all_snapshots triggered via control endpoint (server={:?})",
519 server_filter
520 );
521 Ok((StatusCode::OK, "All snapshots refreshed".to_string()))
522}
523
524pub fn create_control_router(
528 handles: Vec<(String, CacheHandle)>,
529 auth_token: Option<String>,
530) -> Router {
531 let state = Arc::new(ControlState::new(handles, auth_token));
532
533 Router::new()
534 .route("/invalidate_all", post(invalidate_all_handler))
535 .route("/invalidate", post(invalidate_handler))
536 .route("/bulk_invalidate", post(bulk_invalidate_handler))
537 .route("/add_snapshot", post(add_snapshot_handler))
538 .route("/bulk_add_snapshot", post(bulk_add_snapshot_handler))
539 .route("/refresh_snapshot", post(refresh_snapshot_handler))
540 .route(
541 "/bulk_refresh_snapshot",
542 post(bulk_refresh_snapshot_handler),
543 )
544 .route("/remove_snapshot", post(remove_snapshot_handler))
545 .route("/bulk_remove_snapshot", post(bulk_remove_snapshot_handler))
546 .route(
547 "/refresh_all_snapshots",
548 post(refresh_all_snapshots_handler),
549 )
550 .with_state(state)
551}