use crate::cache::CacheHandle;
use axum::{
extract::State,
http::{header, HeaderMap, StatusCode},
response::IntoResponse,
routing::post,
Json, Router,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::task::JoinHandle;
#[derive(Clone)]
pub struct ControlState {
handles: Vec<(String, CacheHandle)>,
auth_token: Option<String>,
}
impl ControlState {
pub fn new(handles: Vec<(String, CacheHandle)>, auth_token: Option<String>) -> Self {
Self {
handles,
auth_token,
}
}
fn resolve_handles(
&self,
server: Option<&str>,
) -> Result<Vec<&CacheHandle>, (StatusCode, String)> {
match server {
None => Ok(self.handles.iter().map(|(_, h)| h).collect()),
Some(name) => {
let matched: Vec<&CacheHandle> = self
.handles
.iter()
.filter(|(n, _)| n == name)
.map(|(_, h)| h)
.collect();
if matched.is_empty() {
Err((
StatusCode::NOT_FOUND,
format!("No server named '{}' found", name),
))
} else {
Ok(matched)
}
}
}
}
fn resolve_snapshot_handles(
&self,
server: Option<&str>,
) -> Result<Vec<&CacheHandle>, (StatusCode, String)> {
match server {
None => {
let handles: Vec<&CacheHandle> = self
.handles
.iter()
.filter(|(_, h)| h.is_snapshot_capable())
.map(|(_, h)| h)
.collect();
if handles.is_empty() {
return Err((
StatusCode::BAD_REQUEST,
"No servers running in PreGenerate mode — snapshot operations are not available".to_string(),
));
}
Ok(handles)
}
Some(name) => {
let matched: Vec<&CacheHandle> = self
.handles
.iter()
.filter(|(n, _)| n == name)
.map(|(_, h)| h)
.collect();
if matched.is_empty() {
Err((
StatusCode::NOT_FOUND,
format!("No server named '{}' found", name),
))
} else {
Ok(matched)
}
}
}
}
}
#[derive(Deserialize)]
struct PatternBody {
pattern: String,
server: Option<String>,
}
#[derive(Deserialize)]
struct PathBody {
path: String,
server: Option<String>,
}
#[derive(Deserialize)]
struct BulkPatternBody {
patterns: Vec<String>,
server: Option<String>,
}
#[derive(Deserialize)]
struct BulkPathBody {
paths: Vec<String>,
server: Option<String>,
}
#[derive(Serialize)]
struct BulkOperationItemResult {
item: String,
success: bool,
error: Option<String>,
}
#[derive(Serialize)]
struct BulkOperationResponse {
operation: &'static str,
server: Option<String>,
requested: usize,
succeeded: usize,
failed: usize,
results: Vec<BulkOperationItemResult>,
}
#[derive(Clone, Copy)]
enum BulkSnapshotAction {
Add,
Refresh,
Remove,
}
fn check_auth(state: &ControlState, headers: &HeaderMap) -> Result<(), StatusCode> {
if let Some(required_token) = &state.auth_token {
let auth_header = headers
.get(header::AUTHORIZATION)
.and_then(|h| h.to_str().ok());
let expected = format!("Bearer {}", required_token);
if auth_header != Some(expected.as_str()) {
tracing::warn!("Unauthorized control endpoint attempt");
return Err(StatusCode::UNAUTHORIZED);
}
}
Ok(())
}
fn validate_bulk_items<T>(items: &[T], field_name: &str) -> Result<(), (StatusCode, String)> {
if items.is_empty() {
return Err((
StatusCode::BAD_REQUEST,
format!("'{}' must contain at least one item", field_name),
));
}
Ok(())
}
fn bulk_response(
operation: &'static str,
server: Option<String>,
results: Vec<BulkOperationItemResult>,
) -> (StatusCode, Json<BulkOperationResponse>) {
let requested = results.len();
let succeeded = results.iter().filter(|result| result.success).count();
let failed = requested - succeeded;
(
StatusCode::OK,
Json(BulkOperationResponse {
operation,
server,
requested,
succeeded,
failed,
results,
}),
)
}
async fn run_bulk_snapshot_operation(
handles: Vec<&CacheHandle>,
paths: &[String],
action: BulkSnapshotAction,
) -> Vec<BulkOperationItemResult> {
let handles: Arc<Vec<CacheHandle>> = Arc::new(handles.into_iter().cloned().collect());
let tasks: Vec<JoinHandle<BulkOperationItemResult>> = paths
.iter()
.cloned()
.map(|path| {
let handles = Arc::clone(&handles);
tokio::spawn(async move {
let error = run_snapshot_operation_for_path(handles.as_ref(), &path, action).await;
BulkOperationItemResult {
item: path,
success: error.is_none(),
error,
}
})
})
.collect();
let mut results = Vec::with_capacity(tasks.len());
for task in tasks {
match task.await {
Ok(result) => results.push(result),
Err(err) => {
tracing::error!("bulk snapshot task failed: {}", err);
results.push(BulkOperationItemResult {
item: "<unknown>".to_string(),
success: false,
error: Some("bulk snapshot task failed".to_string()),
});
}
}
}
results
}
async fn run_snapshot_operation_for_path(
handles: &[CacheHandle],
path: &str,
action: BulkSnapshotAction,
) -> Option<String> {
for handle in handles {
let outcome = match action {
BulkSnapshotAction::Add => handle.add_snapshot(path).await,
BulkSnapshotAction::Refresh => handle.refresh_snapshot(path).await,
BulkSnapshotAction::Remove => handle.remove_snapshot(path).await,
};
if let Err(err) = outcome {
return Some(err.to_string());
}
}
None
}
async fn invalidate_all_handler(
State(state): State<Arc<ControlState>>,
headers: HeaderMap,
) -> Result<impl IntoResponse, StatusCode> {
check_auth(&state, &headers)?;
for (_, handle) in &state.handles {
handle.invalidate_all();
}
tracing::info!(
"invalidate_all triggered via control endpoint ({} server(s))",
state.handles.len()
);
Ok((StatusCode::OK, "Cache invalidated"))
}
async fn invalidate_handler(
State(state): State<Arc<ControlState>>,
headers: HeaderMap,
Json(body): Json<PatternBody>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
let handles = state.resolve_handles(body.server.as_deref())?;
for handle in handles {
handle.invalidate(&body.pattern);
}
tracing::info!(
"invalidate('{}') triggered via control endpoint (server={:?})",
body.pattern,
body.server
);
Ok((StatusCode::OK, "Pattern invalidation triggered".to_string()))
}
async fn bulk_invalidate_handler(
State(state): State<Arc<ControlState>>,
headers: HeaderMap,
Json(body): Json<BulkPatternBody>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
validate_bulk_items(&body.patterns, "patterns")?;
let handles = state.resolve_handles(body.server.as_deref())?;
let mut results = Vec::with_capacity(body.patterns.len());
for pattern in &body.patterns {
for handle in &handles {
handle.invalidate(pattern);
}
results.push(BulkOperationItemResult {
item: pattern.clone(),
success: true,
error: None,
});
}
tracing::info!(
"bulk_invalidate(count={}) triggered via control endpoint (server={:?})",
body.patterns.len(),
body.server
);
Ok(bulk_response("bulk_invalidate", body.server, results))
}
async fn add_snapshot_handler(
State(state): State<Arc<ControlState>>,
headers: HeaderMap,
Json(body): Json<PathBody>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
for handle in handles {
handle
.add_snapshot(&body.path)
.await
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
}
tracing::info!(
"add_snapshot('{}') triggered via control endpoint (server={:?})",
body.path,
body.server
);
Ok((StatusCode::OK, "Snapshot added".to_string()))
}
async fn bulk_add_snapshot_handler(
State(state): State<Arc<ControlState>>,
headers: HeaderMap,
Json(body): Json<BulkPathBody>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
validate_bulk_items(&body.paths, "paths")?;
let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
let results = run_bulk_snapshot_operation(handles, &body.paths, BulkSnapshotAction::Add).await;
tracing::info!(
"bulk_add_snapshot(count={}) triggered via control endpoint (server={:?})",
body.paths.len(),
body.server
);
Ok(bulk_response("bulk_add_snapshot", body.server, results))
}
async fn refresh_snapshot_handler(
State(state): State<Arc<ControlState>>,
headers: HeaderMap,
Json(body): Json<PathBody>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
for handle in handles {
handle
.refresh_snapshot(&body.path)
.await
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
}
tracing::info!(
"refresh_snapshot('{}') triggered via control endpoint (server={:?})",
body.path,
body.server
);
Ok((StatusCode::OK, "Snapshot refreshed".to_string()))
}
async fn bulk_refresh_snapshot_handler(
State(state): State<Arc<ControlState>>,
headers: HeaderMap,
Json(body): Json<BulkPathBody>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
validate_bulk_items(&body.paths, "paths")?;
let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
let results =
run_bulk_snapshot_operation(handles, &body.paths, BulkSnapshotAction::Refresh).await;
tracing::info!(
"bulk_refresh_snapshot(count={}) triggered via control endpoint (server={:?})",
body.paths.len(),
body.server
);
Ok(bulk_response("bulk_refresh_snapshot", body.server, results))
}
async fn remove_snapshot_handler(
State(state): State<Arc<ControlState>>,
headers: HeaderMap,
Json(body): Json<PathBody>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
for handle in handles {
handle
.remove_snapshot(&body.path)
.await
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
}
tracing::info!(
"remove_snapshot('{}') triggered via control endpoint (server={:?})",
body.path,
body.server
);
Ok((StatusCode::OK, "Snapshot removed".to_string()))
}
async fn bulk_remove_snapshot_handler(
State(state): State<Arc<ControlState>>,
headers: HeaderMap,
Json(body): Json<BulkPathBody>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
validate_bulk_items(&body.paths, "paths")?;
let handles = state.resolve_snapshot_handles(body.server.as_deref())?;
let results =
run_bulk_snapshot_operation(handles, &body.paths, BulkSnapshotAction::Remove).await;
tracing::info!(
"bulk_remove_snapshot(count={}) triggered via control endpoint (server={:?})",
body.paths.len(),
body.server
);
Ok(bulk_response("bulk_remove_snapshot", body.server, results))
}
async fn refresh_all_snapshots_handler(
State(state): State<Arc<ControlState>>,
headers: HeaderMap,
body: Option<Json<serde_json::Value>>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
check_auth(&state, &headers).map_err(|s| (s, String::new()))?;
let server_filter = body
.as_ref()
.and_then(|Json(v)| v.get("server"))
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let handles = state.resolve_snapshot_handles(server_filter.as_deref())?;
for handle in handles {
handle
.refresh_all_snapshots()
.await
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
}
tracing::info!(
"refresh_all_snapshots triggered via control endpoint (server={:?})",
server_filter
);
Ok((StatusCode::OK, "All snapshots refreshed".to_string()))
}
pub fn create_control_router(
handles: Vec<(String, CacheHandle)>,
auth_token: Option<String>,
) -> Router {
let state = Arc::new(ControlState::new(handles, auth_token));
Router::new()
.route("/invalidate_all", post(invalidate_all_handler))
.route("/invalidate", post(invalidate_handler))
.route("/bulk_invalidate", post(bulk_invalidate_handler))
.route("/add_snapshot", post(add_snapshot_handler))
.route("/bulk_add_snapshot", post(bulk_add_snapshot_handler))
.route("/refresh_snapshot", post(refresh_snapshot_handler))
.route(
"/bulk_refresh_snapshot",
post(bulk_refresh_snapshot_handler),
)
.route("/remove_snapshot", post(remove_snapshot_handler))
.route("/bulk_remove_snapshot", post(bulk_remove_snapshot_handler))
.route(
"/refresh_all_snapshots",
post(refresh_all_snapshots_handler),
)
.with_state(state)
}