use axum::{
extract::{Path, Query, State},
http::StatusCode,
response::Json,
};
use mockforge_core::consistency::ConsistencyEngine;
use mockforge_core::snapshots::{ProtocolStateExporter, SnapshotComponents, SnapshotManager};
use mockforge_core::workspace_persistence::WorkspacePersistence;
use serde::Deserialize;
use serde_json::Value;
use std::sync::Arc;
use tracing::{debug, error, info, warn};
#[derive(Clone)]
pub struct SnapshotState {
pub manager: Arc<SnapshotManager>,
pub consistency_engine: Option<Arc<ConsistencyEngine>>,
pub workspace_persistence: Option<Arc<WorkspacePersistence>>,
pub vbr_engine: Option<Arc<dyn ProtocolStateExporter>>,
pub recorder: Option<Arc<dyn ProtocolStateExporter>>,
}
#[derive(Debug, Deserialize)]
pub struct SaveSnapshotRequest {
pub name: String,
pub description: Option<String>,
pub components: Option<SnapshotComponents>,
}
#[derive(Debug, Deserialize)]
pub struct LoadSnapshotRequest {
pub components: Option<SnapshotComponents>,
}
#[derive(Debug, Deserialize)]
pub struct SnapshotQuery {
#[serde(default = "default_workspace")]
pub workspace: String,
}
fn default_workspace() -> String {
"default".to_string()
}
async fn extract_vbr_state(vbr_engine: &Option<Arc<dyn ProtocolStateExporter>>) -> Option<Value> {
if let Some(engine) = vbr_engine {
match engine.export_state().await {
Ok(state) => {
let summary = engine.state_summary().await;
info!("Extracted VBR state from {} engine: {}", engine.protocol_name(), summary);
Some(state)
}
Err(e) => {
warn!("Failed to extract VBR state: {}", e);
None
}
}
} else {
debug!("No VBR engine available for state extraction");
None
}
}
async fn extract_recorder_state(
recorder: &Option<Arc<dyn ProtocolStateExporter>>,
) -> Option<Value> {
if let Some(rec) = recorder {
match rec.export_state().await {
Ok(state) => {
let summary = rec.state_summary().await;
info!("Extracted Recorder state from {} engine: {}", rec.protocol_name(), summary);
Some(state)
}
Err(e) => {
warn!("Failed to extract Recorder state: {}", e);
None
}
}
} else {
debug!("No Recorder available for state extraction");
None
}
}
pub async fn save_snapshot(
State(state): State<SnapshotState>,
Query(params): Query<SnapshotQuery>,
Json(request): Json<SaveSnapshotRequest>,
) -> Result<Json<Value>, StatusCode> {
let components = request.components.unwrap_or_else(SnapshotComponents::all);
let consistency_engine = state.consistency_engine.as_deref();
let workspace_persistence = state.workspace_persistence.as_deref();
let vbr_state = if components.vbr_state {
extract_vbr_state(&state.vbr_engine).await
} else {
None
};
let recorder_state = if components.recorder_state {
extract_recorder_state(&state.recorder).await
} else {
None
};
let manifest = state
.manager
.save_snapshot(
request.name.clone(),
request.description,
params.workspace.clone(),
components,
consistency_engine,
workspace_persistence,
vbr_state,
recorder_state,
)
.await
.map_err(|e| {
error!("Failed to save snapshot: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
info!("Saved snapshot '{}' for workspace '{}'", request.name, params.workspace);
Ok(Json(serde_json::json!({
"success": true,
"manifest": manifest,
})))
}
pub async fn load_snapshot(
State(state): State<SnapshotState>,
Path(name): Path<String>,
Query(params): Query<SnapshotQuery>,
Json(request): Json<LoadSnapshotRequest>,
) -> Result<Json<Value>, StatusCode> {
let consistency_engine = state.consistency_engine.as_deref();
let workspace_persistence = state.workspace_persistence.as_deref();
let (manifest, vbr_state, recorder_state) = state
.manager
.load_snapshot(
name.clone(),
params.workspace.clone(),
request.components,
consistency_engine,
workspace_persistence,
)
.await
.map_err(|e| {
error!("Failed to load snapshot: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
info!("Loaded snapshot '{}' for workspace '{}'", name, params.workspace);
Ok(Json(serde_json::json!({
"success": true,
"manifest": manifest,
"vbr_state": vbr_state,
"recorder_state": recorder_state,
})))
}
pub async fn list_snapshots(
State(state): State<SnapshotState>,
Query(params): Query<SnapshotQuery>,
) -> Result<Json<Value>, StatusCode> {
let snapshots = state.manager.list_snapshots(¶ms.workspace).await.map_err(|e| {
error!("Failed to list snapshots: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(serde_json::json!({
"workspace": params.workspace,
"snapshots": snapshots,
"count": snapshots.len(),
})))
}
pub async fn get_snapshot_info(
State(state): State<SnapshotState>,
Path(name): Path<String>,
Query(params): Query<SnapshotQuery>,
) -> Result<Json<Value>, StatusCode> {
let manifest = state
.manager
.get_snapshot_info(name.clone(), params.workspace.clone())
.await
.map_err(|e| {
error!("Failed to get snapshot info: {}", e);
StatusCode::NOT_FOUND
})?;
Ok(Json(serde_json::json!({
"success": true,
"manifest": manifest,
})))
}
pub async fn delete_snapshot(
State(state): State<SnapshotState>,
Path(name): Path<String>,
Query(params): Query<SnapshotQuery>,
) -> Result<Json<Value>, StatusCode> {
state
.manager
.delete_snapshot(name.clone(), params.workspace.clone())
.await
.map_err(|e| {
error!("Failed to delete snapshot: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
info!("Deleted snapshot '{}' for workspace '{}'", name, params.workspace);
Ok(Json(serde_json::json!({
"success": true,
"message": format!("Snapshot '{}' deleted successfully", name),
})))
}
pub async fn validate_snapshot(
State(state): State<SnapshotState>,
Path(name): Path<String>,
Query(params): Query<SnapshotQuery>,
) -> Result<Json<Value>, StatusCode> {
let is_valid = state
.manager
.validate_snapshot(name.clone(), params.workspace.clone())
.await
.map_err(|e| {
error!("Failed to validate snapshot: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(serde_json::json!({
"success": true,
"valid": is_valid,
"snapshot": name,
"workspace": params.workspace,
})))
}
pub fn snapshot_router(state: SnapshotState) -> axum::Router {
use axum::routing::{get, post};
axum::Router::new()
.route("/api/v1/snapshots", get(list_snapshots).post(save_snapshot))
.route("/api/v1/snapshots/{name}", get(get_snapshot_info).delete(delete_snapshot))
.route("/api/v1/snapshots/{name}/load", post(load_snapshot))
.route("/api/v1/snapshots/{name}/validate", get(validate_snapshot))
.with_state(state)
}