1use axum::{
7 extract::{Path, Query, State, WebSocketUpgrade},
8 http::StatusCode,
9 response::{IntoResponse, Json},
10 routing::{get, post},
11 Router,
12};
13use mockforge_world_state::{
14 model::{StateLayer, WorldStateSnapshot},
15 WorldStateEngine, WorldStateQuery,
16};
17use serde::{Deserialize, Serialize};
18use serde_json::Value;
19use std::collections::HashSet;
20use std::sync::Arc;
21use tokio::sync::RwLock;
22use tracing::{error, info};
23
24#[derive(Clone)]
26pub struct WorldStateState {
27 pub engine: Arc<RwLock<WorldStateEngine>>,
29}
30
31#[derive(Debug, Deserialize)]
33pub struct WorldStateQueryParams {
34 pub workspace: Option<String>,
36 pub layers: Option<String>,
38 pub node_types: Option<String>,
40}
41
42#[derive(Debug, Deserialize)]
44pub struct WorldStateQueryRequest {
45 pub node_types: Option<Vec<String>>,
47 pub layers: Option<Vec<String>>,
49 pub node_ids: Option<Vec<String>>,
51 pub relationship_types: Option<Vec<String>>,
53 #[serde(default = "default_true")]
55 pub include_edges: bool,
56 pub max_depth: Option<usize>,
58}
59
60fn default_true() -> bool {
61 true
62}
63
64#[derive(Debug, Serialize)]
66pub struct WorldStateSnapshotResponse {
67 pub snapshot: WorldStateSnapshot,
69 pub available_layers: Vec<String>,
71}
72
73#[derive(Debug, Serialize)]
75pub struct WorldStateGraphResponse {
76 pub nodes: Vec<Value>,
78 pub edges: Vec<Value>,
80 pub metadata: Value,
82}
83
84pub async fn get_current_snapshot(
88 State(state): State<WorldStateState>,
89) -> Result<Json<WorldStateSnapshotResponse>, StatusCode> {
90 let engine = state.engine.read().await;
91 let snapshot = engine.get_current_snapshot().await.map_err(|e| {
92 error!("Failed to create world state snapshot: {}", e);
93 StatusCode::INTERNAL_SERVER_ERROR
94 })?;
95
96 let layers: Vec<String> = engine.get_layers().iter().map(|l| l.name().to_string()).collect();
97
98 Ok(Json(WorldStateSnapshotResponse {
99 snapshot,
100 available_layers: layers,
101 }))
102}
103
104pub async fn get_snapshot(
108 State(state): State<WorldStateState>,
109 Path(snapshot_id): Path<String>,
110) -> Result<Json<WorldStateSnapshot>, StatusCode> {
111 let engine = state.engine.read().await;
112 let snapshot = engine.get_snapshot(&snapshot_id).await.ok_or_else(|| {
113 error!("Snapshot not found: {}", snapshot_id);
114 StatusCode::NOT_FOUND
115 })?;
116
117 Ok(Json(snapshot))
118}
119
120pub async fn get_world_state_graph(
124 State(state): State<WorldStateState>,
125 Query(params): Query<WorldStateQueryParams>,
126) -> Result<Json<WorldStateGraphResponse>, StatusCode> {
127 let engine = state.engine.read().await;
128 let snapshot = engine.get_current_snapshot().await.map_err(|e| {
129 error!("Failed to create world state snapshot: {}", e);
130 StatusCode::INTERNAL_SERVER_ERROR
131 })?;
132
133 let nodes: Vec<Value> = snapshot
135 .nodes
136 .iter()
137 .map(|n| serde_json::to_value(n).unwrap_or_default())
138 .collect();
139
140 let edges: Vec<Value> = snapshot
141 .edges
142 .iter()
143 .map(|e| serde_json::to_value(e).unwrap_or_default())
144 .collect();
145
146 let metadata = serde_json::json!({
147 "node_count": nodes.len(),
148 "edge_count": edges.len(),
149 "timestamp": snapshot.timestamp.to_rfc3339(),
150 });
151
152 Ok(Json(WorldStateGraphResponse {
153 nodes,
154 edges,
155 metadata,
156 }))
157}
158
159pub async fn get_layers(State(state): State<WorldStateState>) -> Result<Json<Value>, StatusCode> {
163 let engine = state.engine.read().await;
164 let layers: Vec<Value> = engine
165 .get_layers()
166 .iter()
167 .map(|layer| {
168 serde_json::json!({
169 "id": format!("{:?}", layer),
170 "name": layer.name(),
171 })
172 })
173 .collect();
174
175 Ok(Json(serde_json::json!({
176 "layers": layers,
177 "count": layers.len(),
178 })))
179}
180
181pub async fn query_world_state(
185 State(state): State<WorldStateState>,
186 Json(request): Json<WorldStateQueryRequest>,
187) -> Result<Json<WorldStateSnapshot>, StatusCode> {
188 let engine = state.engine.read().await;
189
190 let mut query = WorldStateQuery::new();
192
193 if let Some(ref node_types) = request.node_types {
194 let types: HashSet<_> = node_types
195 .iter()
196 .filter_map(|s| {
197 Some(s.as_str())
200 })
201 .collect();
202 }
204
205 if let Some(ref layers) = request.layers {
206 let layer_set: HashSet<StateLayer> = layers
207 .iter()
208 .filter_map(|s| {
209 match s.as_str() {
211 "personas" => Some(StateLayer::Personas),
212 "lifecycle" => Some(StateLayer::Lifecycle),
213 "reality" => Some(StateLayer::Reality),
214 "time" => Some(StateLayer::Time),
215 "protocols" => Some(StateLayer::Protocols),
216 "behavior" => Some(StateLayer::Behavior),
217 "schemas" => Some(StateLayer::Schemas),
218 "recorded" => Some(StateLayer::Recorded),
219 "ai_modifiers" => Some(StateLayer::AiModifiers),
220 "system" => Some(StateLayer::System),
221 _ => None,
222 }
223 })
224 .collect();
225 if !layer_set.is_empty() {
226 query = query.with_layers(layer_set);
227 }
228 }
229
230 if let Some(ref node_ids) = request.node_ids {
231 let id_set: HashSet<String> = node_ids.iter().cloned().collect();
232 query = query.with_node_ids(id_set);
233 }
234
235 if let Some(ref rel_types) = request.relationship_types {
236 let rel_set: HashSet<String> = rel_types.iter().cloned().collect();
237 query = query.with_relationship_types(rel_set);
238 }
239
240 query = query.include_edges(request.include_edges);
241
242 if let Some(depth) = request.max_depth {
243 query = query.with_max_depth(depth);
244 }
245
246 let snapshot = engine.query(&query).await.map_err(|e| {
247 error!("Failed to query world state: {}", e);
248 StatusCode::INTERNAL_SERVER_ERROR
249 })?;
250
251 Ok(Json(snapshot))
252}
253
254pub async fn world_state_websocket_handler(
258 ws: WebSocketUpgrade,
259 State(state): State<WorldStateState>,
260) -> impl IntoResponse {
261 ws.on_upgrade(|socket| handle_world_state_stream(socket, state))
262}
263
264async fn handle_world_state_stream(
266 mut socket: axum::extract::ws::WebSocket,
267 state: WorldStateState,
268) {
269 use axum::extract::ws::Message;
270 use futures_util::{SinkExt, StreamExt};
271 use tokio::time::{interval, Duration};
272
273 {
275 let engine = state.engine.read().await;
276 if let Ok(snapshot) = engine.get_current_snapshot().await {
277 if let Ok(json) = serde_json::to_string(&snapshot) {
278 let _ = socket.send(Message::Text(json.into())).await;
279 }
280 }
281 }
282
283 let mut interval = interval(Duration::from_secs(5));
285 let mut closed = false;
286
287 loop {
288 tokio::select! {
289 msg = socket.next() => {
291 match msg {
292 Some(Ok(Message::Text(text))) => {
293 info!("Received WebSocket message: {}", text);
294 }
296 Some(Ok(Message::Close(_))) => {
297 info!("WebSocket connection closed");
298 closed = true;
299 break;
300 }
301 Some(Err(e)) => {
302 error!("WebSocket error: {}", e);
303 closed = true;
304 break;
305 }
306 None => {
307 closed = true;
308 break;
309 }
310 _ => {}
311 }
312 }
313 _ = interval.tick() => {
315 let engine = state.engine.read().await;
316 if let Ok(snapshot) = engine.get_current_snapshot().await {
317 if let Ok(json) = serde_json::to_string(&snapshot) {
318 if socket.send(Message::Text(json.into())).await.is_err() {
319 closed = true;
320 break;
321 }
322 }
323 }
324 }
325 }
326
327 if closed {
328 break;
329 }
330 }
331}
332
333pub fn world_state_router() -> Router<WorldStateState> {
335 Router::new()
336 .route("/snapshot", get(get_current_snapshot))
337 .route("/snapshot/:id", get(get_snapshot))
338 .route("/graph", get(get_world_state_graph))
339 .route("/layers", get(get_layers))
340 .route("/query", post(query_world_state))
341 .route("/stream", get(world_state_websocket_handler))
342}