1use axum::{
7 extract::{Path, Query, State, WebSocketUpgrade},
8 http::StatusCode,
9 response::{IntoResponse, Json},
10 routing::{get, post},
11 Router,
12};
13use futures_util::StreamExt;
14use mockforge_world_state::{
15 model::{StateLayer, WorldStateSnapshot},
16 WorldStateEngine, WorldStateQuery,
17};
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use std::collections::HashSet;
21use std::sync::Arc;
22use tokio::sync::RwLock;
23use tracing::{error, info};
24
25#[derive(Clone)]
27pub struct WorldStateState {
28 pub engine: Arc<RwLock<WorldStateEngine>>,
30}
31
32#[derive(Debug, Deserialize)]
34pub struct WorldStateQueryParams {
35 pub workspace: Option<String>,
37 pub layers: Option<String>,
39 pub node_types: Option<String>,
41}
42
43#[derive(Debug, Deserialize)]
45pub struct WorldStateQueryRequest {
46 pub node_types: Option<Vec<String>>,
48 pub layers: Option<Vec<String>>,
50 pub node_ids: Option<Vec<String>>,
52 pub relationship_types: Option<Vec<String>>,
54 #[serde(default = "default_true")]
56 pub include_edges: bool,
57 pub max_depth: Option<usize>,
59}
60
61fn default_true() -> bool {
62 true
63}
64
65#[derive(Debug, Serialize)]
67pub struct WorldStateSnapshotResponse {
68 pub snapshot: WorldStateSnapshot,
70 pub available_layers: Vec<String>,
72}
73
74#[derive(Debug, Serialize)]
76pub struct WorldStateGraphResponse {
77 pub nodes: Vec<Value>,
79 pub edges: Vec<Value>,
81 pub metadata: Value,
83}
84
85pub async fn get_current_snapshot(
89 State(state): State<WorldStateState>,
90) -> Result<Json<WorldStateSnapshotResponse>, StatusCode> {
91 let engine = state.engine.read().await;
92 let snapshot = engine.get_current_snapshot().await.map_err(|e| {
93 error!("Failed to create world state snapshot: {}", e);
94 StatusCode::INTERNAL_SERVER_ERROR
95 })?;
96
97 let layers: Vec<String> = engine.get_layers().iter().map(|l| l.name().to_string()).collect();
98
99 Ok(Json(WorldStateSnapshotResponse {
100 snapshot,
101 available_layers: layers,
102 }))
103}
104
105pub async fn get_snapshot(
109 State(state): State<WorldStateState>,
110 Path(snapshot_id): Path<String>,
111) -> Result<Json<WorldStateSnapshot>, StatusCode> {
112 let engine = state.engine.read().await;
113 let snapshot = engine.get_snapshot(&snapshot_id).await.ok_or_else(|| {
114 error!("Snapshot not found: {}", snapshot_id);
115 StatusCode::NOT_FOUND
116 })?;
117
118 Ok(Json(snapshot))
119}
120
121pub async fn get_world_state_graph(
125 State(state): State<WorldStateState>,
126 Query(_params): Query<WorldStateQueryParams>,
127) -> Result<Json<WorldStateGraphResponse>, StatusCode> {
128 let engine = state.engine.read().await;
129 let snapshot = engine.get_current_snapshot().await.map_err(|e| {
130 error!("Failed to create world state snapshot: {}", e);
131 StatusCode::INTERNAL_SERVER_ERROR
132 })?;
133
134 let nodes: Vec<Value> = snapshot
136 .nodes
137 .iter()
138 .map(|n| serde_json::to_value(n).unwrap_or_default())
139 .collect();
140
141 let edges: Vec<Value> = snapshot
142 .edges
143 .iter()
144 .map(|e| serde_json::to_value(e).unwrap_or_default())
145 .collect();
146
147 let metadata = serde_json::json!({
148 "node_count": nodes.len(),
149 "edge_count": edges.len(),
150 "timestamp": snapshot.timestamp.to_rfc3339(),
151 });
152
153 Ok(Json(WorldStateGraphResponse {
154 nodes,
155 edges,
156 metadata,
157 }))
158}
159
160pub async fn get_layers(State(state): State<WorldStateState>) -> Result<Json<Value>, StatusCode> {
164 let engine = state.engine.read().await;
165 let layers: Vec<Value> = engine
166 .get_layers()
167 .iter()
168 .map(|layer| {
169 serde_json::json!({
170 "id": format!("{:?}", layer),
171 "name": layer.name(),
172 })
173 })
174 .collect();
175
176 Ok(Json(serde_json::json!({
177 "layers": layers,
178 "count": layers.len(),
179 })))
180}
181
182pub async fn query_world_state(
186 State(state): State<WorldStateState>,
187 Json(request): Json<WorldStateQueryRequest>,
188) -> Result<Json<WorldStateSnapshot>, StatusCode> {
189 let engine = state.engine.read().await;
190
191 let mut query = WorldStateQuery::new();
193
194 if let Some(ref node_types) = request.node_types {
195 let _types: HashSet<_> = node_types
196 .iter()
197 .filter_map(|s| {
198 Some(s.as_str())
201 })
202 .collect();
203 }
205
206 if let Some(ref layers) = request.layers {
207 let layer_set: HashSet<StateLayer> = layers
208 .iter()
209 .filter_map(|s| {
210 match s.as_str() {
212 "personas" => Some(StateLayer::Personas),
213 "lifecycle" => Some(StateLayer::Lifecycle),
214 "reality" => Some(StateLayer::Reality),
215 "time" => Some(StateLayer::Time),
216 "protocols" => Some(StateLayer::Protocols),
217 "behavior" => Some(StateLayer::Behavior),
218 "schemas" => Some(StateLayer::Schemas),
219 "recorded" => Some(StateLayer::Recorded),
220 "ai_modifiers" => Some(StateLayer::AiModifiers),
221 "system" => Some(StateLayer::System),
222 _ => None,
223 }
224 })
225 .collect();
226 if !layer_set.is_empty() {
227 query = query.with_layers(layer_set);
228 }
229 }
230
231 if let Some(ref node_ids) = request.node_ids {
232 let id_set: HashSet<String> = node_ids.iter().cloned().collect();
233 query = query.with_node_ids(id_set);
234 }
235
236 if let Some(ref rel_types) = request.relationship_types {
237 let rel_set: HashSet<String> = rel_types.iter().cloned().collect();
238 query = query.with_relationship_types(rel_set);
239 }
240
241 query = query.include_edges(request.include_edges);
242
243 if let Some(depth) = request.max_depth {
244 query = query.with_max_depth(depth);
245 }
246
247 let snapshot = engine.query(&query).await.map_err(|e| {
248 error!("Failed to query world state: {}", e);
249 StatusCode::INTERNAL_SERVER_ERROR
250 })?;
251
252 Ok(Json(snapshot))
253}
254
255pub async fn world_state_websocket_handler(
259 ws: WebSocketUpgrade,
260 State(state): State<WorldStateState>,
261) -> impl IntoResponse {
262 ws.on_upgrade(|socket| handle_world_state_stream(socket, state))
263}
264
265async fn handle_world_state_stream(
267 mut socket: axum::extract::ws::WebSocket,
268 state: WorldStateState,
269) {
270 use axum::extract::ws::Message;
271 use tokio::time::{interval, Duration};
272
273 {
275 let engine = state.engine.read().await;
276 if let Ok(snapshot) = engine.get_current_snapshot().await {
277 if let Ok(json) = serde_json::to_string(&snapshot) {
278 let _ = socket.send(Message::Text(json.into())).await;
279 }
280 }
281 }
282
283 let mut interval = interval(Duration::from_secs(5));
285 let mut closed = false;
286
287 loop {
288 tokio::select! {
289 msg = socket.next() => {
291 match msg {
292 Some(Ok(Message::Text(text))) => {
293 info!("Received WebSocket message: {}", text);
294 }
296 Some(Ok(Message::Close(_))) => {
297 info!("WebSocket connection closed");
298 closed = true;
299 break;
300 }
301 Some(Err(e)) => {
302 error!("WebSocket error: {}", e);
303 closed = true;
304 break;
305 }
306 None => {
307 closed = true;
308 break;
309 }
310 _ => {}
311 }
312 }
313 _ = interval.tick() => {
315 let engine = state.engine.read().await;
316 if let Ok(snapshot) = engine.get_current_snapshot().await {
317 if let Ok(json) = serde_json::to_string(&snapshot) {
318 if socket.send(Message::Text(json.into())).await.is_err() {
319 closed = true;
320 break;
321 }
322 }
323 }
324 }
325 }
326
327 if closed {
328 break;
329 }
330 }
331}
332
333pub fn world_state_router() -> Router<WorldStateState> {
335 Router::new()
336 .route("/snapshot", get(get_current_snapshot))
337 .route("/snapshot/{id}", get(get_snapshot))
338 .route("/graph", get(get_world_state_graph))
339 .route("/layers", get(get_layers))
340 .route("/query", post(query_world_state))
341 .route("/stream", get(world_state_websocket_handler))
342}