1use axum::{
7 extract::{Path, Query, State, WebSocketUpgrade},
8 http::StatusCode,
9 response::{IntoResponse, Json},
10 routing::{get, post},
11 Router,
12};
13use futures_util::StreamExt;
14use mockforge_world_state::{
15 model::{StateLayer, WorldStateSnapshot},
16 WorldStateEngine, WorldStateQuery,
17};
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use std::collections::HashSet;
21use std::sync::Arc;
22use tokio::sync::RwLock;
23use tracing::{error, info};
24
25#[derive(Clone)]
27pub struct WorldStateState {
28 pub engine: Arc<RwLock<WorldStateEngine>>,
30}
31
32#[derive(Debug, Deserialize)]
34pub struct WorldStateQueryParams {
35 pub workspace: Option<String>,
37 pub layers: Option<String>,
39 pub node_types: Option<String>,
41}
42
43#[derive(Debug, Deserialize)]
45pub struct WorldStateQueryRequest {
46 pub node_types: Option<Vec<String>>,
48 pub layers: Option<Vec<String>>,
50 pub node_ids: Option<Vec<String>>,
52 pub relationship_types: Option<Vec<String>>,
54 #[serde(default = "default_true")]
56 pub include_edges: bool,
57 pub max_depth: Option<usize>,
59}
60
61fn default_true() -> bool {
62 true
63}
64
65#[derive(Debug, Serialize)]
67pub struct WorldStateSnapshotResponse {
68 pub snapshot: WorldStateSnapshot,
70 pub available_layers: Vec<String>,
72}
73
74#[derive(Debug, Serialize)]
76pub struct WorldStateGraphResponse {
77 pub nodes: Vec<Value>,
79 pub edges: Vec<Value>,
81 pub metadata: Value,
83}
84
85pub async fn get_current_snapshot(
89 State(state): State<WorldStateState>,
90) -> Result<Json<WorldStateSnapshotResponse>, StatusCode> {
91 let engine = state.engine.read().await;
92 let snapshot = engine.get_current_snapshot().await.map_err(|e| {
93 error!("Failed to create world state snapshot: {}", e);
94 StatusCode::INTERNAL_SERVER_ERROR
95 })?;
96
97 let layers: Vec<String> = engine.get_layers().iter().map(|l| l.name().to_string()).collect();
98
99 Ok(Json(WorldStateSnapshotResponse {
100 snapshot,
101 available_layers: layers,
102 }))
103}
104
105pub async fn get_snapshot(
109 State(state): State<WorldStateState>,
110 Path(snapshot_id): Path<String>,
111) -> Result<Json<WorldStateSnapshot>, StatusCode> {
112 let engine = state.engine.read().await;
113 let snapshot = engine.get_snapshot(&snapshot_id).await.ok_or_else(|| {
114 error!("Snapshot not found: {}", snapshot_id);
115 StatusCode::NOT_FOUND
116 })?;
117
118 Ok(Json(snapshot))
119}
120
121pub async fn get_world_state_graph(
125 State(state): State<WorldStateState>,
126 Query(_params): Query<WorldStateQueryParams>,
127) -> Result<Json<WorldStateGraphResponse>, StatusCode> {
128 let engine = state.engine.read().await;
129 let snapshot = engine.get_current_snapshot().await.map_err(|e| {
130 error!("Failed to create world state snapshot: {}", e);
131 StatusCode::INTERNAL_SERVER_ERROR
132 })?;
133
134 let nodes: Vec<Value> = snapshot
136 .nodes
137 .iter()
138 .map(|n| serde_json::to_value(n).unwrap_or_default())
139 .collect();
140
141 let edges: Vec<Value> = snapshot
142 .edges
143 .iter()
144 .map(|e| serde_json::to_value(e).unwrap_or_default())
145 .collect();
146
147 let metadata = serde_json::json!({
148 "node_count": nodes.len(),
149 "edge_count": edges.len(),
150 "timestamp": snapshot.timestamp.to_rfc3339(),
151 });
152
153 Ok(Json(WorldStateGraphResponse {
154 nodes,
155 edges,
156 metadata,
157 }))
158}
159
160pub async fn get_layers(State(state): State<WorldStateState>) -> Result<Json<Value>, StatusCode> {
164 let engine = state.engine.read().await;
165 let layers: Vec<Value> = engine
166 .get_layers()
167 .iter()
168 .map(|layer| {
169 serde_json::json!({
170 "id": format!("{:?}", layer),
171 "name": layer.name(),
172 })
173 })
174 .collect();
175
176 Ok(Json(serde_json::json!({
177 "layers": layers,
178 "count": layers.len(),
179 })))
180}
181
182pub async fn query_world_state(
186 State(state): State<WorldStateState>,
187 Json(request): Json<WorldStateQueryRequest>,
188) -> Result<Json<WorldStateSnapshot>, StatusCode> {
189 let engine = state.engine.read().await;
190
191 let mut query = WorldStateQuery::new();
193
194 if let Some(ref node_types) = request.node_types {
195 use mockforge_world_state::model::NodeType;
196 let types: HashSet<NodeType> = node_types
197 .iter()
198 .filter_map(|s| match s.as_str() {
199 "persona" => Some(NodeType::Persona),
200 "entity" => Some(NodeType::Entity),
201 "session" => Some(NodeType::Session),
202 "protocol" => Some(NodeType::Protocol),
203 "behavior" => Some(NodeType::Behavior),
204 "schema" => Some(NodeType::Schema),
205 "recorded" => Some(NodeType::Recorded),
206 "ai_modifier" => Some(NodeType::AiModifier),
207 "system" => Some(NodeType::System),
208 _ => None,
209 })
210 .collect();
211 if !types.is_empty() {
212 query = query.with_node_types(types);
213 }
214 }
215
216 if let Some(ref layers) = request.layers {
217 let layer_set: HashSet<StateLayer> = layers
218 .iter()
219 .filter_map(|s| {
220 match s.as_str() {
222 "personas" => Some(StateLayer::Personas),
223 "lifecycle" => Some(StateLayer::Lifecycle),
224 "reality" => Some(StateLayer::Reality),
225 "time" => Some(StateLayer::Time),
226 "protocols" => Some(StateLayer::Protocols),
227 "behavior" => Some(StateLayer::Behavior),
228 "schemas" => Some(StateLayer::Schemas),
229 "recorded" => Some(StateLayer::Recorded),
230 "ai_modifiers" => Some(StateLayer::AiModifiers),
231 "system" => Some(StateLayer::System),
232 _ => None,
233 }
234 })
235 .collect();
236 if !layer_set.is_empty() {
237 query = query.with_layers(layer_set);
238 }
239 }
240
241 if let Some(ref node_ids) = request.node_ids {
242 let id_set: HashSet<String> = node_ids.iter().cloned().collect();
243 query = query.with_node_ids(id_set);
244 }
245
246 if let Some(ref rel_types) = request.relationship_types {
247 let rel_set: HashSet<String> = rel_types.iter().cloned().collect();
248 query = query.with_relationship_types(rel_set);
249 }
250
251 query = query.include_edges(request.include_edges);
252
253 if let Some(depth) = request.max_depth {
254 query = query.with_max_depth(depth);
255 }
256
257 let snapshot = engine.query(&query).await.map_err(|e| {
258 error!("Failed to query world state: {}", e);
259 StatusCode::INTERNAL_SERVER_ERROR
260 })?;
261
262 Ok(Json(snapshot))
263}
264
265pub async fn world_state_websocket_handler(
269 ws: WebSocketUpgrade,
270 State(state): State<WorldStateState>,
271) -> impl IntoResponse {
272 ws.on_upgrade(|socket| handle_world_state_stream(socket, state))
273}
274
275async fn handle_world_state_stream(
277 mut socket: axum::extract::ws::WebSocket,
278 state: WorldStateState,
279) {
280 use axum::extract::ws::Message;
281 use tokio::time::{interval, Duration};
282
283 {
285 let engine = state.engine.read().await;
286 if let Ok(snapshot) = engine.get_current_snapshot().await {
287 if let Ok(json) = serde_json::to_string(&snapshot) {
288 let _ = socket.send(Message::Text(json.into())).await;
289 }
290 }
291 }
292
293 let mut interval = interval(Duration::from_secs(5));
295
296 loop {
297 tokio::select! {
298 msg = socket.next() => {
300 match msg {
301 Some(Ok(Message::Text(text))) => {
302 info!("Received WebSocket message: {}", text);
303 }
305 Some(Ok(Message::Close(_))) => {
306 info!("WebSocket connection closed");
307 break;
308 }
309 Some(Err(e)) => {
310 error!("WebSocket error: {}", e);
311 break;
312 }
313 None => {
314 break;
315 }
316 _ => {}
317 }
318 }
319 _ = interval.tick() => {
321 let engine = state.engine.read().await;
322 if let Ok(snapshot) = engine.get_current_snapshot().await {
323 if let Ok(json) = serde_json::to_string(&snapshot) {
324 if socket.send(Message::Text(json.into())).await.is_err() {
325 break;
326 }
327 }
328 }
329 }
330 }
331 }
332}
333
334pub fn world_state_router() -> Router<WorldStateState> {
336 Router::new()
337 .route("/snapshot", get(get_current_snapshot))
338 .route("/snapshot/{id}", get(get_snapshot))
339 .route("/graph", get(get_world_state_graph))
340 .route("/layers", get(get_layers))
341 .route("/query", post(query_world_state))
342 .route("/stream", get(world_state_websocket_handler))
343}