1use crate::cache::CacheHandle;
2use axum::{
3 body::Body,
4 extract::State,
5 http::{header, Request, StatusCode},
6 response::IntoResponse,
7 routing::post,
8 Router,
9};
10use std::sync::Arc;
11
12#[derive(Clone)]
13pub struct ControlState {
14 handles: Vec<CacheHandle>,
15 auth_token: Option<String>,
16}
17
18impl ControlState {
19 pub fn new(handles: Vec<CacheHandle>, auth_token: Option<String>) -> Self {
20 Self { handles, auth_token }
21 }
22}
23
24async fn refresh_cache_handler(
26 State(state): State<Arc<ControlState>>,
27 req: Request<Body>,
28) -> Result<impl IntoResponse, StatusCode> {
29 if let Some(required_token) = &state.auth_token {
31 let auth_header = req
32 .headers()
33 .get(header::AUTHORIZATION)
34 .and_then(|h| h.to_str().ok());
35
36 let expected = format!("Bearer {}", required_token);
37
38 if auth_header != Some(expected.as_str()) {
39 tracing::warn!("Unauthorized refresh-cache attempt");
40 return Err(StatusCode::UNAUTHORIZED);
41 }
42 }
43
44 for handle in &state.handles {
46 handle.invalidate_all();
47 }
48 tracing::info!(
49 "Cache invalidation triggered via control endpoint ({} server(s))",
50 state.handles.len()
51 );
52
53 Ok((StatusCode::OK, "Cache refresh triggered"))
54}
55
56pub fn create_control_router(handles: Vec<CacheHandle>, auth_token: Option<String>) -> Router {
61 let state = Arc::new(ControlState::new(handles, auth_token));
62
63 Router::new()
64 .route("/refresh-cache", post(refresh_cache_handler))
65 .with_state(state)
66}