use axum::{
extract::{Path, Query, State, WebSocketUpgrade},
http::StatusCode,
response::{IntoResponse, Json},
routing::{get, post},
Router,
};
use futures_util::StreamExt;
use mockforge_world_state::{
model::{StateLayer, WorldStateSnapshot},
WorldStateEngine, WorldStateQuery,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{error, info};
#[derive(Clone)]
pub struct WorldStateState {
pub engine: Arc<RwLock<WorldStateEngine>>,
}
#[derive(Debug, Deserialize)]
pub struct WorldStateQueryParams {
pub workspace: Option<String>,
pub layers: Option<String>,
pub node_types: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct WorldStateQueryRequest {
pub node_types: Option<Vec<String>>,
pub layers: Option<Vec<String>>,
pub node_ids: Option<Vec<String>>,
pub relationship_types: Option<Vec<String>>,
#[serde(default = "default_true")]
pub include_edges: bool,
pub max_depth: Option<usize>,
}
fn default_true() -> bool {
true
}
#[derive(Debug, Serialize)]
pub struct WorldStateSnapshotResponse {
pub snapshot: WorldStateSnapshot,
pub available_layers: Vec<String>,
}
#[derive(Debug, Serialize)]
pub struct WorldStateGraphResponse {
pub nodes: Vec<Value>,
pub edges: Vec<Value>,
pub metadata: Value,
}
pub async fn get_current_snapshot(
State(state): State<WorldStateState>,
) -> Result<Json<WorldStateSnapshotResponse>, StatusCode> {
let engine = state.engine.read().await;
let snapshot = engine.get_current_snapshot().await.map_err(|e| {
error!("Failed to create world state snapshot: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
let layers: Vec<String> = engine.get_layers().iter().map(|l| l.name().to_string()).collect();
Ok(Json(WorldStateSnapshotResponse {
snapshot,
available_layers: layers,
}))
}
pub async fn get_snapshot(
State(state): State<WorldStateState>,
Path(snapshot_id): Path<String>,
) -> Result<Json<WorldStateSnapshot>, StatusCode> {
let engine = state.engine.read().await;
let snapshot = engine.get_snapshot(&snapshot_id).await.ok_or_else(|| {
error!("Snapshot not found: {}", snapshot_id);
StatusCode::NOT_FOUND
})?;
Ok(Json(snapshot))
}
pub async fn get_world_state_graph(
State(state): State<WorldStateState>,
Query(_params): Query<WorldStateQueryParams>,
) -> Result<Json<WorldStateGraphResponse>, StatusCode> {
let engine = state.engine.read().await;
let snapshot = engine.get_current_snapshot().await.map_err(|e| {
error!("Failed to create world state snapshot: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
let nodes: Vec<Value> = snapshot
.nodes
.iter()
.map(|n| serde_json::to_value(n).unwrap_or_default())
.collect();
let edges: Vec<Value> = snapshot
.edges
.iter()
.map(|e| serde_json::to_value(e).unwrap_or_default())
.collect();
let metadata = serde_json::json!({
"node_count": nodes.len(),
"edge_count": edges.len(),
"timestamp": snapshot.timestamp.to_rfc3339(),
});
Ok(Json(WorldStateGraphResponse {
nodes,
edges,
metadata,
}))
}
pub async fn get_layers(State(state): State<WorldStateState>) -> Result<Json<Value>, StatusCode> {
let engine = state.engine.read().await;
let layers: Vec<Value> = engine
.get_layers()
.iter()
.map(|layer| {
serde_json::json!({
"id": format!("{:?}", layer),
"name": layer.name(),
})
})
.collect();
Ok(Json(serde_json::json!({
"layers": layers,
"count": layers.len(),
})))
}
pub async fn query_world_state(
State(state): State<WorldStateState>,
Json(request): Json<WorldStateQueryRequest>,
) -> Result<Json<WorldStateSnapshot>, StatusCode> {
let engine = state.engine.read().await;
let mut query = WorldStateQuery::new();
if let Some(ref node_types) = request.node_types {
use mockforge_world_state::model::NodeType;
let types: HashSet<NodeType> = node_types
.iter()
.filter_map(|s| match s.as_str() {
"persona" => Some(NodeType::Persona),
"entity" => Some(NodeType::Entity),
"session" => Some(NodeType::Session),
"protocol" => Some(NodeType::Protocol),
"behavior" => Some(NodeType::Behavior),
"schema" => Some(NodeType::Schema),
"recorded" => Some(NodeType::Recorded),
"ai_modifier" => Some(NodeType::AiModifier),
"system" => Some(NodeType::System),
_ => None,
})
.collect();
if !types.is_empty() {
query = query.with_node_types(types);
}
}
if let Some(ref layers) = request.layers {
let layer_set: HashSet<StateLayer> = layers
.iter()
.filter_map(|s| {
match s.as_str() {
"personas" => Some(StateLayer::Personas),
"lifecycle" => Some(StateLayer::Lifecycle),
"reality" => Some(StateLayer::Reality),
"time" => Some(StateLayer::Time),
"protocols" => Some(StateLayer::Protocols),
"behavior" => Some(StateLayer::Behavior),
"schemas" => Some(StateLayer::Schemas),
"recorded" => Some(StateLayer::Recorded),
"ai_modifiers" => Some(StateLayer::AiModifiers),
"system" => Some(StateLayer::System),
_ => None,
}
})
.collect();
if !layer_set.is_empty() {
query = query.with_layers(layer_set);
}
}
if let Some(ref node_ids) = request.node_ids {
let id_set: HashSet<String> = node_ids.iter().cloned().collect();
query = query.with_node_ids(id_set);
}
if let Some(ref rel_types) = request.relationship_types {
let rel_set: HashSet<String> = rel_types.iter().cloned().collect();
query = query.with_relationship_types(rel_set);
}
query = query.include_edges(request.include_edges);
if let Some(depth) = request.max_depth {
query = query.with_max_depth(depth);
}
let snapshot = engine.query(&query).await.map_err(|e| {
error!("Failed to query world state: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(snapshot))
}
pub async fn world_state_websocket_handler(
ws: WebSocketUpgrade,
State(state): State<WorldStateState>,
) -> impl IntoResponse {
ws.on_upgrade(|socket| handle_world_state_stream(socket, state))
}
async fn handle_world_state_stream(
mut socket: axum::extract::ws::WebSocket,
state: WorldStateState,
) {
use axum::extract::ws::Message;
use tokio::time::{interval, Duration};
{
let engine = state.engine.read().await;
if let Ok(snapshot) = engine.get_current_snapshot().await {
if let Ok(json) = serde_json::to_string(&snapshot) {
let _ = socket.send(Message::Text(json.into())).await;
}
}
}
let mut interval = interval(Duration::from_secs(5));
loop {
tokio::select! {
msg = socket.next() => {
match msg {
Some(Ok(Message::Text(text))) => {
info!("Received WebSocket message: {}", text);
}
Some(Ok(Message::Close(_))) => {
info!("WebSocket connection closed");
break;
}
Some(Err(e)) => {
error!("WebSocket error: {}", e);
break;
}
None => {
break;
}
_ => {}
}
}
_ = interval.tick() => {
let engine = state.engine.read().await;
if let Ok(snapshot) = engine.get_current_snapshot().await {
if let Ok(json) = serde_json::to_string(&snapshot) {
if socket.send(Message::Text(json.into())).await.is_err() {
break;
}
}
}
}
}
}
}
pub fn world_state_router() -> Router<WorldStateState> {
Router::new()
.route("/snapshot", get(get_current_snapshot))
.route("/snapshot/{id}", get(get_snapshot))
.route("/graph", get(get_world_state_graph))
.route("/layers", get(get_layers))
.route("/query", post(query_world_state))
.route("/stream", get(world_state_websocket_handler))
}