1use axum::{
7 extract::{Path, Query, State, WebSocketUpgrade},
8 http::StatusCode,
9 response::{IntoResponse, Json},
10 routing::{get, post},
11 Router,
12};
13use futures_util::StreamExt;
14use mockforge_world_state::{
15 model::{StateLayer, WorldStateSnapshot},
16 WorldStateEngine, WorldStateQuery,
17};
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use std::collections::HashSet;
21use std::sync::Arc;
22use tokio::sync::RwLock;
23use tracing::{error, info};
24
25#[derive(Clone)]
27pub struct WorldStateState {
28 pub engine: Arc<RwLock<WorldStateEngine>>,
30}
31
32#[derive(Debug, Deserialize)]
34pub struct WorldStateQueryParams {
35 pub workspace: Option<String>,
37 pub layers: Option<String>,
39 pub node_types: Option<String>,
41}
42
43#[derive(Debug, Deserialize)]
45pub struct WorldStateQueryRequest {
46 pub node_types: Option<Vec<String>>,
48 pub layers: Option<Vec<String>>,
50 pub node_ids: Option<Vec<String>>,
52 pub relationship_types: Option<Vec<String>>,
54 #[serde(default = "default_true")]
56 pub include_edges: bool,
57 pub max_depth: Option<usize>,
59}
60
61fn default_true() -> bool {
62 true
63}
64
65#[derive(Debug, Serialize)]
67pub struct WorldStateSnapshotResponse {
68 pub snapshot: WorldStateSnapshot,
70 pub available_layers: Vec<String>,
72}
73
74#[derive(Debug, Serialize)]
76pub struct WorldStateGraphResponse {
77 pub nodes: Vec<Value>,
79 pub edges: Vec<Value>,
81 pub metadata: Value,
83}
84
85pub async fn get_current_snapshot(
89 State(state): State<WorldStateState>,
90) -> Result<Json<WorldStateSnapshotResponse>, StatusCode> {
91 let engine = state.engine.read().await;
92 let snapshot = engine.get_current_snapshot().await.map_err(|e| {
93 error!("Failed to create world state snapshot: {}", e);
94 StatusCode::INTERNAL_SERVER_ERROR
95 })?;
96
97 let layers: Vec<String> = engine.get_layers().iter().map(|l| l.name().to_string()).collect();
98
99 Ok(Json(WorldStateSnapshotResponse {
100 snapshot,
101 available_layers: layers,
102 }))
103}
104
105pub async fn get_snapshot(
109 State(state): State<WorldStateState>,
110 Path(snapshot_id): Path<String>,
111) -> Result<Json<WorldStateSnapshot>, StatusCode> {
112 let engine = state.engine.read().await;
113 let snapshot = engine.get_snapshot(&snapshot_id).await.ok_or_else(|| {
114 error!("Snapshot not found: {}", snapshot_id);
115 StatusCode::NOT_FOUND
116 })?;
117
118 Ok(Json(snapshot))
119}
120
121pub async fn get_world_state_graph(
125 State(state): State<WorldStateState>,
126 Query(params): Query<WorldStateQueryParams>,
127) -> Result<Json<WorldStateGraphResponse>, StatusCode> {
128 let engine = state.engine.read().await;
129 let snapshot = engine.get_current_snapshot().await.map_err(|e| {
130 error!("Failed to create world state snapshot: {}", e);
131 StatusCode::INTERNAL_SERVER_ERROR
132 })?;
133
134 let nodes: Vec<Value> = snapshot
136 .nodes
137 .iter()
138 .map(|n| serde_json::to_value(n).unwrap_or_default())
139 .collect();
140
141 let edges: Vec<Value> = snapshot
142 .edges
143 .iter()
144 .map(|e| serde_json::to_value(e).unwrap_or_default())
145 .collect();
146
147 let metadata = serde_json::json!({
148 "node_count": nodes.len(),
149 "edge_count": edges.len(),
150 "timestamp": snapshot.timestamp.to_rfc3339(),
151 });
152
153 Ok(Json(WorldStateGraphResponse {
154 nodes,
155 edges,
156 metadata,
157 }))
158}
159
160pub async fn get_layers(State(state): State<WorldStateState>) -> Result<Json<Value>, StatusCode> {
164 let engine = state.engine.read().await;
165 let layers: Vec<Value> = engine
166 .get_layers()
167 .iter()
168 .map(|layer| {
169 serde_json::json!({
170 "id": format!("{:?}", layer),
171 "name": layer.name(),
172 })
173 })
174 .collect();
175
176 Ok(Json(serde_json::json!({
177 "layers": layers,
178 "count": layers.len(),
179 })))
180}
181
182pub async fn query_world_state(
186 State(state): State<WorldStateState>,
187 Json(request): Json<WorldStateQueryRequest>,
188) -> Result<Json<WorldStateSnapshot>, StatusCode> {
189 let engine = state.engine.read().await;
190
191 let mut query = WorldStateQuery::new();
193
194 if let Some(ref node_types) = request.node_types {
195 let types: HashSet<_> = node_types
196 .iter()
197 .filter_map(|s| {
198 Some(s.as_str())
201 })
202 .collect();
203 }
205
206 if let Some(ref layers) = request.layers {
207 let layer_set: HashSet<StateLayer> = layers
208 .iter()
209 .filter_map(|s| {
210 match s.as_str() {
212 "personas" => Some(StateLayer::Personas),
213 "lifecycle" => Some(StateLayer::Lifecycle),
214 "reality" => Some(StateLayer::Reality),
215 "time" => Some(StateLayer::Time),
216 "protocols" => Some(StateLayer::Protocols),
217 "behavior" => Some(StateLayer::Behavior),
218 "schemas" => Some(StateLayer::Schemas),
219 "recorded" => Some(StateLayer::Recorded),
220 "ai_modifiers" => Some(StateLayer::AiModifiers),
221 "system" => Some(StateLayer::System),
222 _ => None,
223 }
224 })
225 .collect();
226 if !layer_set.is_empty() {
227 query = query.with_layers(layer_set);
228 }
229 }
230
231 if let Some(ref node_ids) = request.node_ids {
232 let id_set: HashSet<String> = node_ids.iter().cloned().collect();
233 query = query.with_node_ids(id_set);
234 }
235
236 if let Some(ref rel_types) = request.relationship_types {
237 let rel_set: HashSet<String> = rel_types.iter().cloned().collect();
238 query = query.with_relationship_types(rel_set);
239 }
240
241 query = query.include_edges(request.include_edges);
242
243 if let Some(depth) = request.max_depth {
244 query = query.with_max_depth(depth);
245 }
246
247 let snapshot = engine.query(&query).await.map_err(|e| {
248 error!("Failed to query world state: {}", e);
249 StatusCode::INTERNAL_SERVER_ERROR
250 })?;
251
252 Ok(Json(snapshot))
253}
254
255pub async fn world_state_websocket_handler(
259 ws: WebSocketUpgrade,
260 State(state): State<WorldStateState>,
261) -> impl IntoResponse {
262 ws.on_upgrade(|socket| handle_world_state_stream(socket, state))
263}
264
265async fn handle_world_state_stream(
267 mut socket: axum::extract::ws::WebSocket,
268 state: WorldStateState,
269) {
270 use axum::extract::ws::Message;
271 use futures_util::SinkExt;
272 use tokio::time::{interval, Duration};
273
274 {
276 let engine = state.engine.read().await;
277 if let Ok(snapshot) = engine.get_current_snapshot().await {
278 if let Ok(json) = serde_json::to_string(&snapshot) {
279 let _ = socket.send(Message::Text(json.into())).await;
280 }
281 }
282 }
283
284 let mut interval = interval(Duration::from_secs(5));
286 let mut closed = false;
287
288 loop {
289 tokio::select! {
290 msg = socket.next() => {
292 match msg {
293 Some(Ok(Message::Text(text))) => {
294 info!("Received WebSocket message: {}", text);
295 }
297 Some(Ok(Message::Close(_))) => {
298 info!("WebSocket connection closed");
299 closed = true;
300 break;
301 }
302 Some(Err(e)) => {
303 error!("WebSocket error: {}", e);
304 closed = true;
305 break;
306 }
307 None => {
308 closed = true;
309 break;
310 }
311 _ => {}
312 }
313 }
314 _ = interval.tick() => {
316 let engine = state.engine.read().await;
317 if let Ok(snapshot) = engine.get_current_snapshot().await {
318 if let Ok(json) = serde_json::to_string(&snapshot) {
319 if socket.send(Message::Text(json.into())).await.is_err() {
320 closed = true;
321 break;
322 }
323 }
324 }
325 }
326 }
327
328 if closed {
329 break;
330 }
331 }
332}
333
334pub fn world_state_router() -> Router<WorldStateState> {
336 Router::new()
337 .route("/snapshot", get(get_current_snapshot))
338 .route("/snapshot/{id}", get(get_snapshot))
339 .route("/graph", get(get_world_state_graph))
340 .route("/layers", get(get_layers))
341 .route("/query", post(query_world_state))
342 .route("/stream", get(world_state_websocket_handler))
343}