Skip to main content

phantom_frame/
control.rs

1use crate::cache::CacheHandle;
2use axum::{
3    body::Body,
4    extract::State,
5    http::{header, Request, StatusCode},
6    response::IntoResponse,
7    routing::post,
8    Router,
9};
10use std::sync::Arc;
11
12#[derive(Clone)]
13pub struct ControlState {
14    handles: Vec<CacheHandle>,
15    auth_token: Option<String>,
16}
17
18impl ControlState {
19    pub fn new(handles: Vec<CacheHandle>, auth_token: Option<String>) -> Self {
20        Self { handles, auth_token }
21    }
22}
23
24/// Handler for POST /refresh-cache endpoint
25async fn refresh_cache_handler(
26    State(state): State<Arc<ControlState>>,
27    req: Request<Body>,
28) -> Result<impl IntoResponse, StatusCode> {
29    // Check authorization if auth_token is set
30    if let Some(required_token) = &state.auth_token {
31        let auth_header = req
32            .headers()
33            .get(header::AUTHORIZATION)
34            .and_then(|h| h.to_str().ok());
35
36        let expected = format!("Bearer {}", required_token);
37
38        if auth_header != Some(expected.as_str()) {
39            tracing::warn!("Unauthorized refresh-cache attempt");
40            return Err(StatusCode::UNAUTHORIZED);
41        }
42    }
43
44    // Trigger cache invalidation on all registered server caches
45    for handle in &state.handles {
46        handle.invalidate_all();
47    }
48    tracing::info!(
49        "Cache invalidation triggered via control endpoint ({} server(s))",
50        state.handles.len()
51    );
52
53    Ok((StatusCode::OK, "Cache refresh triggered"))
54}
55
56/// Create the control server router.
57///
58/// `handles` contains one [`CacheHandle`] per named proxy server.
59/// A single `/refresh-cache` call invalidates all of them.
60pub fn create_control_router(handles: Vec<CacheHandle>, auth_token: Option<String>) -> Router {
61    let state = Arc::new(ControlState::new(handles, auth_token));
62
63    Router::new()
64        .route("/refresh-cache", post(refresh_cache_handler))
65        .with_state(state)
66}