Skip to main content

phantom_frame/
control.rs

1use crate::cache::CacheHandle;
2use axum::{
3    extract::State,
4    http::{header, HeaderMap, StatusCode},
5    response::IntoResponse,
6    routing::post,
7    Json, Router,
8};
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tokio::task::JoinHandle;
12
13#[derive(Clone)]
14pub struct ControlState {
15    /// Named server handles — (server_name, handle) pairs.
16    handles: Vec<(String, CacheHandle)>,
17    auth_token: Option<String>,
18}
19
20impl ControlState {
21    pub fn new(handles: Vec<(String, CacheHandle)>, auth_token: Option<String>) -> Self {
22        Self {
23            handles,
24            auth_token,
25        }
26    }
27
28    /// Return handles matching `server` (if provided) or all handles.
29    /// Returns `Err` when a name was given but no server matched.
30    fn resolve_handles(
31        &self,
32        server: Option<&str>,
33    ) -> Result<Vec<&CacheHandle>, (StatusCode, String)> {
34        match server {
35            None => Ok(self.handles.iter().map(|(_, h)| h).collect()),
36            Some(name) => {
37                let matched: Vec<&CacheHandle> = self
38                    .handles
39                    .iter()
40                    .filter(|(n, _)| n == name)
41                    .map(|(_, h)| h)
42                    .collect();
43                if matched.is_empty() {
44                    Err((
45                        StatusCode::NOT_FOUND,
46                        format!("No server named '{}' found", name),
47                    ))
48                } else {
49                    Ok(matched)
50                }
51            }
52        }
53    }
54
55    /// Like `resolve_handles`, but for snapshot operations:
56    /// - When a specific server is named, return it even if it's in Dynamic mode
57    ///   (the operation will then fail with BAD_REQUEST from the handle itself).
58    /// - When broadcasting (no server specified), silently skip Dynamic-mode
59    ///   servers that don't support snapshots.
60    fn resolve_snapshot_handles(
61        &self,
62        server: Option<&str>,
63    ) -> Result<Vec<&CacheHandle>, (StatusCode, String)> {
64        match server {
65            None => {
66                let handles: Vec<&CacheHandle> = self
67                    .handles
68                    .iter()
69                    .filter(|(_, h)| h.is_snapshot_capable())
70                    .map(|(_, h)| h)
71                    .collect();
72                if handles.is_empty() {
73                    return Err((
74                        StatusCode::BAD_REQUEST,
75                        "No servers running in PreGenerate mode — snapshot operations are not available".to_string(),
76                    ));
77                }
78                Ok(handles)
79            }
80            Some(name) => {
81                let matched: Vec<&CacheHandle> = self
82                    .handles
83                    .iter()
84                    .filter(|(n, _)| n == name)
85                    .map(|(_, h)| h)
86                    .collect();
87                if matched.is_empty() {
88                    Err((
89                        StatusCode::NOT_FOUND,
90                        format!("No server named '{}' found", name),
91                    ))
92                } else {
93                    Ok(matched)
94                }
95            }
96        }
97    }
98}
99
100#[derive(Deserialize)]
101struct PatternBody {
102    pattern: String,
103    /// Optional: only invalidate this named server's cache.
104    server: Option<String>,
105}
106
107#[derive(Deserialize)]
108struct PathBody {
109    path: String,
110    /// Optional: only operate on this named server.
111    /// When omitted, the operation is broadcast to all servers.
112    server: Option<String>,
113}
114
115#[derive(Deserialize)]
116struct BulkPatternBody {
117    patterns: Vec<String>,
118    /// Optional: only invalidate this named server's cache.
119    server: Option<String>,
120}
121
122#[derive(Deserialize)]
123struct BulkPathBody {
124    paths: Vec<String>,
125    /// Optional: only operate on this named server.
126    /// When omitted, the operation is broadcast to all servers.
127    server: Option<String>,
128}
129
130#[derive(Serialize)]
131struct BulkOperationItemResult {
132    item: String,
133    success: bool,
134    error: Option<String>,
135}
136
137#[derive(Serialize)]
138struct BulkOperationResponse {
139    operation: &'static str,
140    server: Option<String>,
141    requested: usize,
142    succeeded: usize,
143    failed: usize,
144    results: Vec<BulkOperationItemResult>,
145}
146
147#[derive(Clone, Copy)]
148enum BulkSnapshotAction {
149    Add,
150    Refresh,
151    Remove,
152}
153
154/// Returns `Err(UNAUTHORIZED)` when the request lacks a valid Bearer token.
155fn check_auth(state: &ControlState, headers: &HeaderMap) -> Result<(), StatusCode> {
156    if let Some(required_token) = &state.auth_token {
157        let auth_header = headers
158            .get(header::AUTHORIZATION)
159            .and_then(|h| h.to_str().ok());
160        let expected = format!("Bearer {}", required_token);
161        if auth_header != Some(expected.as_str()) {
162            tracing::warn!("Unauthorized control endpoint attempt");
163            return Err(StatusCode::UNAUTHORIZED);
164        }
165    }
166    Ok(())
167}
168
169fn validate_bulk_items<T>(items: &[T], field_name: &str) -> Result<(), (StatusCode, String)> {
170    if items.is_empty() {
171        return Err((
172            StatusCode::BAD_REQUEST,
173            format!("'{}' must contain at least one item", field_name),
174        ));
175    }
176    Ok(())
177}
178
179fn bulk_response(
180    operation: &'static str,
181    server: Option<String>,
182    results: Vec<BulkOperationItemResult>,
183) -> (StatusCode, Json<BulkOperationResponse>) {
184    let requested = results.len();
185    let succeeded = results.iter().filter(|result| result.success).count();
186    let failed = requested - succeeded;
187
188    (
189        StatusCode::OK,
190        Json(BulkOperationResponse {
191            operation,
192            server,
193            requested,
194            succeeded,
195            failed,
196            results,
197        }),
198    )
199}
200
201async fn run_bulk_snapshot_operation(
202    handles: Vec<&CacheHandle>,
203    paths: &[String],
204    action: BulkSnapshotAction,
205) -> Vec<BulkOperationItemResult> {
206    let handles: Arc<Vec<CacheHandle>> = Arc::new(handles.into_iter().cloned().collect());
207    let tasks: Vec<JoinHandle<BulkOperationItemResult>> = paths
208        .iter()
209        .cloned()
210        .map(|path| {
211            let handles = Arc::clone(&handles);
212            tokio::spawn(async move {
213                let error = run_snapshot_operation_for_path(handles.as_ref(), &path, action).await;
214
215                BulkOperationItemResult {
216                    item: path,
217                    success: error.is_none(),
218                    error,
219                }
220            })
221        })
222        .collect();
223
224    let mut results = Vec::with_capacity(tasks.len());
225
226    for task in tasks {
227        match task.await {
228            Ok(result) => results.push(result),
229            Err(err) => {
230                tracing::error!("bulk snapshot task failed: {}", err);
231                results.push(BulkOperationItemResult {
232                    item: "<unknown>".to_string(),
233                    success: false,
234                    error: Some("bulk snapshot task failed".to_string()),
235                });
236            }
237        }
238    }
239
240    results
241}
242
243async fn run_snapshot_operation_for_path(
244    handles: &[CacheHandle],
245    path: &str,
246    action: BulkSnapshotAction,
247) -> Option<String> {
248    for handle in handles {
249        let outcome = match action {
250            BulkSnapshotAction::Add => handle.add_snapshot(path).await,
251            BulkSnapshotAction::Refresh => handle.refresh_snapshot(path).await,
252            BulkSnapshotAction::Remove => handle.remove_snapshot(path).await,
253        };
254
255        if let Err(err) = outcome {
256            return Some(err.to_string());
257        }
258    }
259
260    None
261}
262
263/// POST /invalidate_all — invalidate every cached entry across all servers.
264async fn invalidate_all_handler(
265    State(state): State<Arc<ControlState>>,
266    headers: HeaderMap,
267) -> Result<impl IntoResponse, StatusCode> {
268    check_auth(&state, &headers)?;
269
270    for (_, handle) in &state.handles {
271        handle.invalidate_all();
272    }
273    tracing::info!(
274        "invalidate_all triggered via control endpoint ({} server(s))",
275        state.handles.len()
276    );
277    Ok((StatusCode::OK, "Cache invalidated"))
278}
279
280/// POST /invalidate — invalidate entries matching a wildcard pattern.
281///
282/// Body: `{ "pattern": "/api/*" }` or `{ "pattern": "/api/*", "server": "frontend" }`
283async fn invalidate_handler(
284    State(state): State<Arc<ControlState>>,
285    headers: HeaderMap,
286    Json(body): Json<PatternBody>,
287) -> Result<impl IntoResponse, (StatusCode, String)> {
288    check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
289
290    let handles = state.resolve_handles(body.server.as_deref())?;
291    for handle in handles {
292        handle.invalidate(&body.pattern);
293    }
294    tracing::info!(
295        "invalidate('{}') triggered via control endpoint (server={:?})",
296        body.pattern,
297        body.server
298    );
299    Ok((StatusCode::OK, "Pattern invalidation triggered".to_string()))
300}
301
302/// POST /bulk_invalidate — invalidate entries matching multiple wildcard patterns.
303///
304/// Body: `{ "patterns": ["/api/*", "/blog/*"], "server": "frontend" }`
305/// or `{ "patterns": ["/api/*", "/blog/*"] }`
306async fn bulk_invalidate_handler(
307    State(state): State<Arc<ControlState>>,
308    headers: HeaderMap,
309    Json(body): Json<BulkPatternBody>,
310) -> Result<impl IntoResponse, (StatusCode, String)> {
311    check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
312    validate_bulk_items(&body.patterns, "patterns")?;
313
314    let handles = state.resolve_handles(body.server.as_deref())?;
315    let mut results = Vec::with_capacity(body.patterns.len());
316
317    for pattern in &body.patterns {
318        for handle in &handles {
319            handle.invalidate(pattern);
320        }
321
322        results.push(BulkOperationItemResult {
323            item: pattern.clone(),
324            success: true,
325            error: None,
326        });
327    }
328
329    tracing::info!(
330        "bulk_invalidate(count={}) triggered via control endpoint (server={:?})",
331        body.patterns.len(),
332        body.server
333    );
334
335    Ok(bulk_response("bulk_invalidate", body.server, results))
336}
337
338/// POST /add_snapshot — fetch a path from upstream, cache it, and track it.
339///
340/// Only available when the proxy is running in `PreGenerate` mode.
341/// Body: `{ "path": "/about" }` or `{ "path": "/about", "server": "frontend" }`
342async fn add_snapshot_handler(
343    State(state): State<Arc<ControlState>>,
344    headers: HeaderMap,
345    Json(body): Json<PathBody>,
346) -> Result<impl IntoResponse, (StatusCode, String)> {
347    check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
348
349    let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
350    for handle in handles {
351        handle
352            .add_snapshot(&body.path)
353            .await
354            .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
355    }
356    tracing::info!(
357        "add_snapshot('{}') triggered via control endpoint (server={:?})",
358        body.path,
359        body.server
360    );
361    Ok((StatusCode::OK, "Snapshot added".to_string()))
362}
363
364/// POST /bulk_add_snapshot — fetch multiple paths from upstream, cache them, and track them.
365///
366/// Only available when the proxy is running in `PreGenerate` mode.
367/// Body: `{ "paths": ["/about", "/pricing"], "server": "frontend" }`
368/// or `{ "paths": ["/about", "/pricing"] }`
369async fn bulk_add_snapshot_handler(
370    State(state): State<Arc<ControlState>>,
371    headers: HeaderMap,
372    Json(body): Json<BulkPathBody>,
373) -> Result<impl IntoResponse, (StatusCode, String)> {
374    check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
375    validate_bulk_items(&body.paths, "paths")?;
376
377    let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
378    let results = run_bulk_snapshot_operation(handles, &body.paths, BulkSnapshotAction::Add).await;
379
380    tracing::info!(
381        "bulk_add_snapshot(count={}) triggered via control endpoint (server={:?})",
382        body.paths.len(),
383        body.server
384    );
385
386    Ok(bulk_response("bulk_add_snapshot", body.server, results))
387}
388
389/// POST /refresh_snapshot — re-fetch a cached snapshot path from upstream.
390///
391/// Only available when the proxy is running in `PreGenerate` mode.
392/// Body: `{ "path": "/about" }` or `{ "path": "/about", "server": "frontend" }`
393async fn refresh_snapshot_handler(
394    State(state): State<Arc<ControlState>>,
395    headers: HeaderMap,
396    Json(body): Json<PathBody>,
397) -> Result<impl IntoResponse, (StatusCode, String)> {
398    check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
399
400    let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
401    for handle in handles {
402        handle
403            .refresh_snapshot(&body.path)
404            .await
405            .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
406    }
407    tracing::info!(
408        "refresh_snapshot('{}') triggered via control endpoint (server={:?})",
409        body.path,
410        body.server
411    );
412    Ok((StatusCode::OK, "Snapshot refreshed".to_string()))
413}
414
415/// POST /bulk_refresh_snapshot — re-fetch multiple cached snapshot paths from upstream.
416///
417/// Only available when the proxy is running in `PreGenerate` mode.
418/// Body: `{ "paths": ["/about", "/pricing"], "server": "frontend" }`
419/// or `{ "paths": ["/about", "/pricing"] }`
420async fn bulk_refresh_snapshot_handler(
421    State(state): State<Arc<ControlState>>,
422    headers: HeaderMap,
423    Json(body): Json<BulkPathBody>,
424) -> Result<impl IntoResponse, (StatusCode, String)> {
425    check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
426    validate_bulk_items(&body.paths, "paths")?;
427
428    let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
429    let results =
430        run_bulk_snapshot_operation(handles, &body.paths, BulkSnapshotAction::Refresh).await;
431
432    tracing::info!(
433        "bulk_refresh_snapshot(count={}) triggered via control endpoint (server={:?})",
434        body.paths.len(),
435        body.server
436    );
437
438    Ok(bulk_response("bulk_refresh_snapshot", body.server, results))
439}
440
441/// POST /remove_snapshot — remove a path from the cache and snapshot list.
442///
443/// Only available when the proxy is running in `PreGenerate` mode.
444/// Body: `{ "path": "/about" }` or `{ "path": "/about", "server": "frontend" }`
445async fn remove_snapshot_handler(
446    State(state): State<Arc<ControlState>>,
447    headers: HeaderMap,
448    Json(body): Json<PathBody>,
449) -> Result<impl IntoResponse, (StatusCode, String)> {
450    check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
451
452    let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
453    for handle in handles {
454        handle
455            .remove_snapshot(&body.path)
456            .await
457            .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
458    }
459    tracing::info!(
460        "remove_snapshot('{}') triggered via control endpoint (server={:?})",
461        body.path,
462        body.server
463    );
464    Ok((StatusCode::OK, "Snapshot removed".to_string()))
465}
466
467/// POST /bulk_remove_snapshot — remove multiple paths from the cache and snapshot list.
468///
469/// Only available when the proxy is running in `PreGenerate` mode.
470/// Body: `{ "paths": ["/about", "/pricing"], "server": "frontend" }`
471/// or `{ "paths": ["/about", "/pricing"] }`
472async fn bulk_remove_snapshot_handler(
473    State(state): State<Arc<ControlState>>,
474    headers: HeaderMap,
475    Json(body): Json<BulkPathBody>,
476) -> Result<impl IntoResponse, (StatusCode, String)> {
477    check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
478    validate_bulk_items(&body.paths, "paths")?;
479
480    let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
481    let results =
482        run_bulk_snapshot_operation(handles, &body.paths, BulkSnapshotAction::Remove).await;
483
484    tracing::info!(
485        "bulk_remove_snapshot(count={}) triggered via control endpoint (server={:?})",
486        body.paths.len(),
487        body.server
488    );
489
490    Ok(bulk_response("bulk_remove_snapshot", body.server, results))
491}
492
493/// POST /refresh_all_snapshots — re-fetch every tracked snapshot from upstream.
494///
495/// Only available when the proxy is running in `PreGenerate` mode.
496/// Optional body: `{ "server": "frontend" }` to target a specific server.
497async fn refresh_all_snapshots_handler(
498    State(state): State<Arc<ControlState>>,
499    headers: HeaderMap,
500    body: Option<Json<serde_json::Value>>,
501) -> Result<impl IntoResponse, (StatusCode, String)> {
502    check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
503
504    let server_filter = body
505        .as_ref()
506        .and_then(|Json(v)| v.get("server"))
507        .and_then(|v| v.as_str())
508        .map(|s| s.to_string());
509
510    let handles = state.resolve_snapshot_handles(server_filter.as_deref())?;
511    for handle in handles {
512        handle
513            .refresh_all_snapshots()
514            .await
515            .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
516    }
517    tracing::info!(
518        "refresh_all_snapshots triggered via control endpoint (server={:?})",
519        server_filter
520    );
521    Ok((StatusCode::OK, "All snapshots refreshed".to_string()))
522}
523
524/// Create the control server router.
525///
526/// `handles` contains one `(server_name, CacheHandle)` pair per named proxy server.
527pub fn create_control_router(
528    handles: Vec<(String, CacheHandle)>,
529    auth_token: Option<String>,
530) -> Router {
531    let state = Arc::new(ControlState::new(handles, auth_token));
532
533    Router::new()
534        .route("/invalidate_all", post(invalidate_all_handler))
535        .route("/invalidate", post(invalidate_handler))
536        .route("/bulk_invalidate", post(bulk_invalidate_handler))
537        .route("/add_snapshot", post(add_snapshot_handler))
538        .route("/bulk_add_snapshot", post(bulk_add_snapshot_handler))
539        .route("/refresh_snapshot", post(refresh_snapshot_handler))
540        .route(
541            "/bulk_refresh_snapshot",
542            post(bulk_refresh_snapshot_handler),
543        )
544        .route("/remove_snapshot", post(remove_snapshot_handler))
545        .route("/bulk_remove_snapshot", post(bulk_remove_snapshot_handler))
546        .route(
547            "/refresh_all_snapshots",
548            post(refresh_all_snapshots_handler),
549        )
550        .with_state(state)
551}