Skip to main content

phantom_frame/
control.rs

1use crate::cache::CacheHandle;
2use axum::{
3    extract::State,
4    http::{header, HeaderMap, StatusCode},
5    response::IntoResponse,
6    routing::post,
7    Json, Router,
8};
9use serde::Deserialize;
10use std::sync::Arc;
11
12#[derive(Clone)]
13pub struct ControlState {
14    /// Named server handles — (server_name, handle) pairs.
15    handles: Vec<(String, CacheHandle)>,
16    auth_token: Option<String>,
17}
18
19impl ControlState {
20    pub fn new(handles: Vec<(String, CacheHandle)>, auth_token: Option<String>) -> Self {
21        Self { handles, auth_token }
22    }
23
24    /// Return handles matching `server` (if provided) or all handles.
25    /// Returns `Err` when a name was given but no server matched.
26    fn resolve_handles(
27        &self,
28        server: Option<&str>,
29    ) -> Result<Vec<&CacheHandle>, (StatusCode, String)> {
30        match server {
31            None => Ok(self.handles.iter().map(|(_, h)| h).collect()),
32            Some(name) => {
33                let matched: Vec<&CacheHandle> = self
34                    .handles
35                    .iter()
36                    .filter(|(n, _)| n == name)
37                    .map(|(_, h)| h)
38                    .collect();
39                if matched.is_empty() {
40                    Err((
41                        StatusCode::NOT_FOUND,
42                        format!("No server named '{}' found", name),
43                    ))
44                } else {
45                    Ok(matched)
46                }
47            }
48        }
49    }
50
51    /// Like `resolve_handles`, but for snapshot operations:
52    /// - When a specific server is named, return it even if it's in Dynamic mode
53    ///   (the operation will then fail with BAD_REQUEST from the handle itself).
54    /// - When broadcasting (no server specified), silently skip Dynamic-mode
55    ///   servers that don't support snapshots.
56    fn resolve_snapshot_handles(
57        &self,
58        server: Option<&str>,
59    ) -> Result<Vec<&CacheHandle>, (StatusCode, String)> {
60        match server {
61            None => {
62                let handles: Vec<&CacheHandle> = self
63                    .handles
64                    .iter()
65                    .filter(|(_, h)| h.is_snapshot_capable())
66                    .map(|(_, h)| h)
67                    .collect();
68                if handles.is_empty() {
69                    return Err((
70                        StatusCode::BAD_REQUEST,
71                        "No servers running in PreGenerate mode — snapshot operations are not available".to_string(),
72                    ));
73                }
74                Ok(handles)
75            }
76            Some(name) => {
77                let matched: Vec<&CacheHandle> = self
78                    .handles
79                    .iter()
80                    .filter(|(n, _)| n == name)
81                    .map(|(_, h)| h)
82                    .collect();
83                if matched.is_empty() {
84                    Err((
85                        StatusCode::NOT_FOUND,
86                        format!("No server named '{}' found", name),
87                    ))
88                } else {
89                    Ok(matched)
90                }
91            }
92        }
93    }
94}
95
96#[derive(Deserialize)]
97struct PatternBody {
98    pattern: String,
99    /// Optional: only invalidate this named server's cache.
100    server: Option<String>,
101}
102
103#[derive(Deserialize)]
104struct PathBody {
105    path: String,
106    /// Optional: only operate on this named server.
107    /// When omitted, the operation is broadcast to all servers.
108    server: Option<String>,
109}
110
111/// Returns `Err(UNAUTHORIZED)` when the request lacks a valid Bearer token.
112fn check_auth(state: &ControlState, headers: &HeaderMap) -> Result<(), StatusCode> {
113    if let Some(required_token) = &state.auth_token {
114        let auth_header = headers
115            .get(header::AUTHORIZATION)
116            .and_then(|h| h.to_str().ok());
117        let expected = format!("Bearer {}", required_token);
118        if auth_header != Some(expected.as_str()) {
119            tracing::warn!("Unauthorized control endpoint attempt");
120            return Err(StatusCode::UNAUTHORIZED);
121        }
122    }
123    Ok(())
124}
125
126/// POST /invalidate_all — invalidate every cached entry across all servers.
127async fn invalidate_all_handler(
128    State(state): State<Arc<ControlState>>,
129    headers: HeaderMap,
130) -> Result<impl IntoResponse, StatusCode> {
131    check_auth(&state, &headers)?;
132
133    for (_, handle) in &state.handles {
134        handle.invalidate_all();
135    }
136    tracing::info!(
137        "invalidate_all triggered via control endpoint ({} server(s))",
138        state.handles.len()
139    );
140    Ok((StatusCode::OK, "Cache invalidated"))
141}
142
143/// POST /invalidate — invalidate entries matching a wildcard pattern.
144///
145/// Body: `{ "pattern": "/api/*" }` or `{ "pattern": "/api/*", "server": "frontend" }`
146async fn invalidate_handler(
147    State(state): State<Arc<ControlState>>,
148    headers: HeaderMap,
149    Json(body): Json<PatternBody>,
150) -> Result<impl IntoResponse, (StatusCode, String)> {
151    check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
152
153    let handles = state.resolve_handles(body.server.as_deref())?;
154    for handle in handles {
155        handle.invalidate(&body.pattern);
156    }
157    tracing::info!(
158        "invalidate('{}') triggered via control endpoint (server={:?})",
159        body.pattern,
160        body.server
161    );
162    Ok((StatusCode::OK, "Pattern invalidation triggered".to_string()))
163}
164
165/// POST /add_snapshot — fetch a path from upstream, cache it, and track it.
166///
167/// Only available when the proxy is running in `PreGenerate` mode.
168/// Body: `{ "path": "/about" }` or `{ "path": "/about", "server": "frontend" }`
169async fn add_snapshot_handler(
170    State(state): State<Arc<ControlState>>,
171    headers: HeaderMap,
172    Json(body): Json<PathBody>,
173) -> Result<impl IntoResponse, (StatusCode, String)> {
174    check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
175
176    let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
177    for handle in handles {
178        handle
179            .add_snapshot(&body.path)
180            .await
181            .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
182    }
183    tracing::info!(
184        "add_snapshot('{}') triggered via control endpoint (server={:?})",
185        body.path, body.server
186    );
187    Ok((StatusCode::OK, "Snapshot added".to_string()))
188}
189
190/// POST /refresh_snapshot — re-fetch a cached snapshot path from upstream.
191///
192/// Only available when the proxy is running in `PreGenerate` mode.
193/// Body: `{ "path": "/about" }` or `{ "path": "/about", "server": "frontend" }`
194async fn refresh_snapshot_handler(
195    State(state): State<Arc<ControlState>>,
196    headers: HeaderMap,
197    Json(body): Json<PathBody>,
198) -> Result<impl IntoResponse, (StatusCode, String)> {
199    check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
200
201    let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
202    for handle in handles {
203        handle
204            .refresh_snapshot(&body.path)
205            .await
206            .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
207    }
208    tracing::info!(
209        "refresh_snapshot('{}') triggered via control endpoint (server={:?})",
210        body.path, body.server
211    );
212    Ok((StatusCode::OK, "Snapshot refreshed".to_string()))
213}
214
215/// POST /remove_snapshot — remove a path from the cache and snapshot list.
216///
217/// Only available when the proxy is running in `PreGenerate` mode.
218/// Body: `{ "path": "/about" }` or `{ "path": "/about", "server": "frontend" }`
219async fn remove_snapshot_handler(
220    State(state): State<Arc<ControlState>>,
221    headers: HeaderMap,
222    Json(body): Json<PathBody>,
223) -> Result<impl IntoResponse, (StatusCode, String)> {
224    check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
225
226    let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
227    for handle in handles {
228        handle
229            .remove_snapshot(&body.path)
230            .await
231            .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
232    }
233    tracing::info!(
234        "remove_snapshot('{}') triggered via control endpoint (server={:?})",
235        body.path, body.server
236    );
237    Ok((StatusCode::OK, "Snapshot removed".to_string()))
238}
239
240/// POST /refresh_all_snapshots — re-fetch every tracked snapshot from upstream.
241///
242/// Only available when the proxy is running in `PreGenerate` mode.
243/// Optional body: `{ "server": "frontend" }` to target a specific server.
244async fn refresh_all_snapshots_handler(
245    State(state): State<Arc<ControlState>>,
246    headers: HeaderMap,
247    body: Option<Json<serde_json::Value>>,
248) -> Result<impl IntoResponse, (StatusCode, String)> {
249    check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
250
251    let server_filter = body
252        .as_ref()
253        .and_then(|Json(v)| v.get("server"))
254        .and_then(|v| v.as_str())
255        .map(|s| s.to_string());
256
257    let handles = state.resolve_snapshot_handles(server_filter.as_deref())?;
258    for handle in handles {
259        handle
260            .refresh_all_snapshots()
261            .await
262            .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
263    }
264    tracing::info!(
265        "refresh_all_snapshots triggered via control endpoint (server={:?})",
266        server_filter
267    );
268    Ok((StatusCode::OK, "All snapshots refreshed".to_string()))
269}
270
271/// Create the control server router.
272///
273/// `handles` contains one `(server_name, CacheHandle)` pair per named proxy server.
274pub fn create_control_router(
275    handles: Vec<(String, CacheHandle)>,
276    auth_token: Option<String>,
277) -> Router {
278    let state = Arc::new(ControlState::new(handles, auth_token));
279
280    Router::new()
281        .route("/invalidate_all", post(invalidate_all_handler))
282        .route("/invalidate", post(invalidate_handler))
283        .route("/add_snapshot", post(add_snapshot_handler))
284        .route("/refresh_snapshot", post(refresh_snapshot_handler))
285        .route("/remove_snapshot", post(remove_snapshot_handler))
286        .route("/refresh_all_snapshots", post(refresh_all_snapshots_handler))
287        .with_state(state)
288}