Skip to main content

mockforge_http/handlers/
world_state.rs

1//! World State API handlers
2//!
3//! This module provides HTTP handlers for querying and visualizing the unified
4//! world state of MockForge, including REST API endpoints and WebSocket streaming.
5
6use axum::{
7    extract::{Path, Query, State, WebSocketUpgrade},
8    http::StatusCode,
9    response::{IntoResponse, Json},
10    routing::{get, post},
11    Router,
12};
13use futures_util::StreamExt;
14use mockforge_world_state::{
15    model::{StateLayer, WorldStateSnapshot},
16    WorldStateEngine, WorldStateQuery,
17};
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use std::collections::HashSet;
21use std::sync::Arc;
22use tokio::sync::RwLock;
23use tracing::{error, info};
24
25/// State for world state handlers
26#[derive(Clone)]
27pub struct WorldStateState {
28    /// World state engine
29    pub engine: Arc<RwLock<WorldStateEngine>>,
30}
31
32/// Query parameters for world state operations
33#[derive(Debug, Deserialize)]
34pub struct WorldStateQueryParams {
35    /// Workspace ID (optional)
36    pub workspace: Option<String>,
37    /// Layer filter (comma-separated)
38    pub layers: Option<String>,
39    /// Node type filter (comma-separated)
40    pub node_types: Option<String>,
41}
42
43/// Request body for querying world state
44#[derive(Debug, Deserialize)]
45pub struct WorldStateQueryRequest {
46    /// Filter by node types
47    pub node_types: Option<Vec<String>>,
48    /// Filter by layers
49    pub layers: Option<Vec<String>>,
50    /// Filter by node IDs
51    pub node_ids: Option<Vec<String>>,
52    /// Filter by relationship types
53    pub relationship_types: Option<Vec<String>>,
54    /// Include edges in results
55    #[serde(default = "default_true")]
56    pub include_edges: bool,
57    /// Maximum depth for traversal
58    pub max_depth: Option<usize>,
59}
60
61fn default_true() -> bool {
62    true
63}
64
65/// Response for world state snapshot
66#[derive(Debug, Serialize)]
67pub struct WorldStateSnapshotResponse {
68    /// The snapshot
69    pub snapshot: WorldStateSnapshot,
70    /// Available layers
71    pub available_layers: Vec<String>,
72}
73
74/// Response for world state graph
75#[derive(Debug, Serialize)]
76pub struct WorldStateGraphResponse {
77    /// Graph nodes
78    pub nodes: Vec<Value>,
79    /// Graph edges
80    pub edges: Vec<Value>,
81    /// Metadata
82    pub metadata: Value,
83}
84
85/// Get current world state snapshot
86///
87/// GET /api/world-state/snapshot
88pub async fn get_current_snapshot(
89    State(state): State<WorldStateState>,
90) -> Result<Json<WorldStateSnapshotResponse>, StatusCode> {
91    let engine = state.engine.read().await;
92    let snapshot = engine.get_current_snapshot().await.map_err(|e| {
93        error!("Failed to create world state snapshot: {}", e);
94        StatusCode::INTERNAL_SERVER_ERROR
95    })?;
96
97    let layers: Vec<String> = engine.get_layers().iter().map(|l| l.name().to_string()).collect();
98
99    Ok(Json(WorldStateSnapshotResponse {
100        snapshot,
101        available_layers: layers,
102    }))
103}
104
105/// Get a specific snapshot by ID
106///
107/// GET /api/world-state/snapshot/{id}
108pub async fn get_snapshot(
109    State(state): State<WorldStateState>,
110    Path(snapshot_id): Path<String>,
111) -> Result<Json<WorldStateSnapshot>, StatusCode> {
112    let engine = state.engine.read().await;
113    let snapshot = engine.get_snapshot(&snapshot_id).await.ok_or_else(|| {
114        error!("Snapshot not found: {}", snapshot_id);
115        StatusCode::NOT_FOUND
116    })?;
117
118    Ok(Json(snapshot))
119}
120
121/// Get world state as a graph
122///
123/// GET /api/world-state/graph
124pub async fn get_world_state_graph(
125    State(state): State<WorldStateState>,
126    Query(_params): Query<WorldStateQueryParams>,
127) -> Result<Json<WorldStateGraphResponse>, StatusCode> {
128    let engine = state.engine.read().await;
129    let snapshot = engine.get_current_snapshot().await.map_err(|e| {
130        error!("Failed to create world state snapshot: {}", e);
131        StatusCode::INTERNAL_SERVER_ERROR
132    })?;
133
134    // Convert nodes and edges to JSON values
135    let nodes: Vec<Value> = snapshot
136        .nodes
137        .iter()
138        .map(|n| serde_json::to_value(n).unwrap_or_default())
139        .collect();
140
141    let edges: Vec<Value> = snapshot
142        .edges
143        .iter()
144        .map(|e| serde_json::to_value(e).unwrap_or_default())
145        .collect();
146
147    let metadata = serde_json::json!({
148        "node_count": nodes.len(),
149        "edge_count": edges.len(),
150        "timestamp": snapshot.timestamp.to_rfc3339(),
151    });
152
153    Ok(Json(WorldStateGraphResponse {
154        nodes,
155        edges,
156        metadata,
157    }))
158}
159
160/// Get available layers
161///
162/// GET /api/world-state/layers
163pub async fn get_layers(State(state): State<WorldStateState>) -> Result<Json<Value>, StatusCode> {
164    let engine = state.engine.read().await;
165    let layers: Vec<Value> = engine
166        .get_layers()
167        .iter()
168        .map(|layer| {
169            serde_json::json!({
170                "id": format!("{:?}", layer),
171                "name": layer.name(),
172            })
173        })
174        .collect();
175
176    Ok(Json(serde_json::json!({
177        "layers": layers,
178        "count": layers.len(),
179    })))
180}
181
182/// Query world state with filters
183///
184/// POST /api/world-state/query
185pub async fn query_world_state(
186    State(state): State<WorldStateState>,
187    Json(request): Json<WorldStateQueryRequest>,
188) -> Result<Json<WorldStateSnapshot>, StatusCode> {
189    let engine = state.engine.read().await;
190
191    // Build query from request
192    let mut query = WorldStateQuery::new();
193
194    if let Some(ref node_types) = request.node_types {
195        use mockforge_world_state::model::NodeType;
196        let types: HashSet<NodeType> = node_types
197            .iter()
198            .filter_map(|s| match s.as_str() {
199                "persona" => Some(NodeType::Persona),
200                "entity" => Some(NodeType::Entity),
201                "session" => Some(NodeType::Session),
202                "protocol" => Some(NodeType::Protocol),
203                "behavior" => Some(NodeType::Behavior),
204                "schema" => Some(NodeType::Schema),
205                "recorded" => Some(NodeType::Recorded),
206                "ai_modifier" => Some(NodeType::AiModifier),
207                "system" => Some(NodeType::System),
208                _ => None,
209            })
210            .collect();
211        if !types.is_empty() {
212            query = query.with_node_types(types);
213        }
214    }
215
216    if let Some(ref layers) = request.layers {
217        let layer_set: HashSet<StateLayer> = layers
218            .iter()
219            .filter_map(|s| {
220                // Parse layer string to StateLayer
221                match s.as_str() {
222                    "personas" => Some(StateLayer::Personas),
223                    "lifecycle" => Some(StateLayer::Lifecycle),
224                    "reality" => Some(StateLayer::Reality),
225                    "time" => Some(StateLayer::Time),
226                    "protocols" => Some(StateLayer::Protocols),
227                    "behavior" => Some(StateLayer::Behavior),
228                    "schemas" => Some(StateLayer::Schemas),
229                    "recorded" => Some(StateLayer::Recorded),
230                    "ai_modifiers" => Some(StateLayer::AiModifiers),
231                    "system" => Some(StateLayer::System),
232                    _ => None,
233                }
234            })
235            .collect();
236        if !layer_set.is_empty() {
237            query = query.with_layers(layer_set);
238        }
239    }
240
241    if let Some(ref node_ids) = request.node_ids {
242        let id_set: HashSet<String> = node_ids.iter().cloned().collect();
243        query = query.with_node_ids(id_set);
244    }
245
246    if let Some(ref rel_types) = request.relationship_types {
247        let rel_set: HashSet<String> = rel_types.iter().cloned().collect();
248        query = query.with_relationship_types(rel_set);
249    }
250
251    query = query.include_edges(request.include_edges);
252
253    if let Some(depth) = request.max_depth {
254        query = query.with_max_depth(depth);
255    }
256
257    let snapshot = engine.query(&query).await.map_err(|e| {
258        error!("Failed to query world state: {}", e);
259        StatusCode::INTERNAL_SERVER_ERROR
260    })?;
261
262    Ok(Json(snapshot))
263}
264
265/// WebSocket handler for real-time world state updates
266///
267/// WS /api/world-state/stream
268pub async fn world_state_websocket_handler(
269    ws: WebSocketUpgrade,
270    State(state): State<WorldStateState>,
271) -> impl IntoResponse {
272    ws.on_upgrade(|socket| handle_world_state_stream(socket, state))
273}
274
275/// Handle WebSocket stream for world state updates
276async fn handle_world_state_stream(
277    mut socket: axum::extract::ws::WebSocket,
278    state: WorldStateState,
279) {
280    use axum::extract::ws::Message;
281    use tokio::time::{interval, Duration};
282
283    // Send initial snapshot
284    {
285        let engine = state.engine.read().await;
286        if let Ok(snapshot) = engine.get_current_snapshot().await {
287            if let Ok(json) = serde_json::to_string(&snapshot) {
288                let _ = socket.send(Message::Text(json.into())).await;
289            }
290        }
291    }
292
293    // Send periodic updates (every 5 seconds)
294    let mut interval = interval(Duration::from_secs(5));
295
296    loop {
297        tokio::select! {
298            // Handle incoming messages (for now, just acknowledge)
299            msg = socket.next() => {
300                match msg {
301                    Some(Ok(Message::Text(text))) => {
302                        info!("Received WebSocket message: {}", text);
303                        // Could handle commands like "subscribe to layer X"
304                    }
305                    Some(Ok(Message::Close(_))) => {
306                        info!("WebSocket connection closed");
307                        break;
308                    }
309                    Some(Err(e)) => {
310                        error!("WebSocket error: {}", e);
311                        break;
312                    }
313                    None => {
314                        break;
315                    }
316                    _ => {}
317                }
318            }
319            // Send periodic updates
320            _ = interval.tick() => {
321                let engine = state.engine.read().await;
322                if let Ok(snapshot) = engine.get_current_snapshot().await {
323                    if let Ok(json) = serde_json::to_string(&snapshot) {
324                        if socket.send(Message::Text(json.into())).await.is_err() {
325                            break;
326                        }
327                    }
328                }
329            }
330        }
331    }
332}
333
334/// Create the world state router
335pub fn world_state_router() -> Router<WorldStateState> {
336    Router::new()
337        .route("/snapshot", get(get_current_snapshot))
338        .route("/snapshot/{id}", get(get_snapshot))
339        .route("/graph", get(get_world_state_graph))
340        .route("/layers", get(get_layers))
341        .route("/query", post(query_world_state))
342        .route("/stream", get(world_state_websocket_handler))
343}