1use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
6use axum::extract::State;
7use axum::response::IntoResponse;
8use axum::routing::get;
9use axum::Router;
10use futures::stream::StreamExt;
11use futures::SinkExt;
12use serde::{Deserialize, Serialize};
13use tokio::sync::broadcast;
14use tracing::*;
15
16const DEFAULT_WS_BROADCAST_CAPACITY: usize = 1024;
18
19fn get_ws_broadcast_capacity() -> usize {
21 std::env::var("MOCKFORGE_WS_BROADCAST_CAPACITY")
22 .ok()
23 .and_then(|s| s.parse().ok())
24 .unwrap_or(DEFAULT_WS_BROADCAST_CAPACITY)
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29#[serde(tag = "type", rename_all = "snake_case")]
30pub enum MockEvent {
31 MockCreated {
33 mock: super::management::MockConfig,
35 timestamp: String,
37 },
38 MockUpdated {
40 mock: super::management::MockConfig,
42 timestamp: String,
44 },
45 MockDeleted {
47 id: String,
49 timestamp: String,
51 },
52 StatsUpdated {
54 stats: super::management::ServerStats,
56 timestamp: String,
58 },
59 Connected {
61 message: String,
63 timestamp: String,
65 },
66 StateMachineUpdated {
68 resource_type: String,
70 state_machine: mockforge_core::intelligent_behavior::rules::StateMachine,
72 timestamp: String,
74 },
75 StateMachineDeleted {
77 resource_type: String,
79 timestamp: String,
81 },
82 StateInstanceCreated {
84 resource_id: String,
86 resource_type: String,
88 initial_state: String,
90 timestamp: String,
92 },
93 StateTransitioned {
95 resource_id: String,
97 resource_type: String,
99 from_state: String,
101 to_state: String,
103 state_data: std::collections::HashMap<String, serde_json::Value>,
105 timestamp: String,
107 },
108 StateInstanceDeleted {
110 resource_id: String,
112 resource_type: String,
114 timestamp: String,
116 },
117}
118
119impl MockEvent {
120 pub fn mock_created(mock: super::management::MockConfig) -> Self {
122 Self::MockCreated {
123 mock,
124 timestamp: chrono::Utc::now().to_rfc3339(),
125 }
126 }
127
128 pub fn mock_updated(mock: super::management::MockConfig) -> Self {
130 Self::MockUpdated {
131 mock,
132 timestamp: chrono::Utc::now().to_rfc3339(),
133 }
134 }
135
136 pub fn mock_deleted(id: String) -> Self {
138 Self::MockDeleted {
139 id,
140 timestamp: chrono::Utc::now().to_rfc3339(),
141 }
142 }
143
144 pub fn stats_updated(stats: super::management::ServerStats) -> Self {
146 Self::StatsUpdated {
147 stats,
148 timestamp: chrono::Utc::now().to_rfc3339(),
149 }
150 }
151
152 pub fn connected(message: String) -> Self {
154 Self::Connected {
155 message,
156 timestamp: chrono::Utc::now().to_rfc3339(),
157 }
158 }
159
160 pub fn state_machine_updated(
162 resource_type: String,
163 state_machine: mockforge_core::intelligent_behavior::rules::StateMachine,
164 ) -> Self {
165 Self::StateMachineUpdated {
166 resource_type,
167 state_machine,
168 timestamp: chrono::Utc::now().to_rfc3339(),
169 }
170 }
171
172 pub fn state_machine_deleted(resource_type: String) -> Self {
174 Self::StateMachineDeleted {
175 resource_type,
176 timestamp: chrono::Utc::now().to_rfc3339(),
177 }
178 }
179
180 pub fn state_instance_created(
182 resource_id: String,
183 resource_type: String,
184 initial_state: String,
185 ) -> Self {
186 Self::StateInstanceCreated {
187 resource_id,
188 resource_type,
189 initial_state,
190 timestamp: chrono::Utc::now().to_rfc3339(),
191 }
192 }
193
194 pub fn state_transitioned(
196 resource_id: String,
197 resource_type: String,
198 from_state: String,
199 to_state: String,
200 state_data: std::collections::HashMap<String, serde_json::Value>,
201 ) -> Self {
202 Self::StateTransitioned {
203 resource_id,
204 resource_type,
205 from_state,
206 to_state,
207 state_data,
208 timestamp: chrono::Utc::now().to_rfc3339(),
209 }
210 }
211
212 pub fn state_instance_deleted(resource_id: String, resource_type: String) -> Self {
214 Self::StateInstanceDeleted {
215 resource_id,
216 resource_type,
217 timestamp: chrono::Utc::now().to_rfc3339(),
218 }
219 }
220}
221
222#[derive(Clone)]
224pub struct WsManagementState {
225 pub tx: broadcast::Sender<MockEvent>,
227}
228
229impl WsManagementState {
230 pub fn new() -> Self {
235 let capacity = get_ws_broadcast_capacity();
236 let (tx, _) = broadcast::channel(capacity);
237 Self { tx }
238 }
239
240 pub fn broadcast(
245 &self,
246 event: MockEvent,
247 ) -> Result<usize, Box<broadcast::error::SendError<MockEvent>>> {
248 match self.tx.send(event) {
249 Ok(n) => Ok(n),
250 Err(e) => {
251 warn!(
252 "WebSocket broadcast failed: no active receivers. Event dropped: {:?}",
253 std::mem::discriminant(&e.0)
254 );
255 Err(Box::new(e))
256 }
257 }
258 }
259}
260
261impl Default for WsManagementState {
262 fn default() -> Self {
263 Self::new()
264 }
265}
266
267async fn ws_handler(
269 ws: WebSocketUpgrade,
270 State(state): State<WsManagementState>,
271) -> impl IntoResponse {
272 ws.on_upgrade(move |socket| handle_socket(socket, state))
273}
274
275async fn handle_socket(socket: WebSocket, state: WsManagementState) {
277 let (mut sender, mut receiver) = socket.split();
278
279 let mut rx = state.tx.subscribe();
281
282 let connected_event = MockEvent::connected("Connected to MockForge management API".to_string());
284 if let Ok(json) = serde_json::to_string(&connected_event) {
285 if sender.send(Message::Text(json.into())).await.is_err() {
286 return;
287 }
288 }
289
290 let mut send_task = tokio::spawn(async move {
292 loop {
293 match rx.recv().await {
294 Ok(event) => {
295 if let Ok(json) = serde_json::to_string(&event) {
296 if sender.send(Message::Text(json.into())).await.is_err() {
297 break;
298 }
299 }
300 }
301 Err(broadcast::error::RecvError::Lagged(count)) => {
302 warn!("WebSocket client lagged behind, {} events dropped", count);
303 }
305 Err(broadcast::error::RecvError::Closed) => {
306 break;
307 }
308 }
309 }
310 });
311
312 let mut recv_task = tokio::spawn(async move {
314 while let Some(Ok(msg)) = receiver.next().await {
315 match msg {
316 Message::Text(text) => {
317 debug!("Received WebSocket message: {}", text);
318 }
320 Message::Close(_) => {
321 info!("WebSocket client disconnected");
322 break;
323 }
324 _ => {}
325 }
326 }
327 });
328
329 tokio::select! {
331 _ = &mut send_task => {
332 debug!("Send task completed");
333 recv_task.abort();
334 }
335 _ = &mut recv_task => {
336 debug!("Receive task completed");
337 send_task.abort();
338 }
339 }
340}
341
342pub fn ws_management_router(state: WsManagementState) -> Router {
344 Router::new().route("/", get(ws_handler)).with_state(state)
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[test]
352 fn test_ws_management_state_creation() {
353 let _state = WsManagementState::new();
354 }
356
357 #[test]
358 fn test_mock_event_creation() {
359 use super::super::management::{MockConfig, MockResponse};
360
361 let mock = MockConfig {
362 id: "test-1".to_string(),
363 name: "Test Mock".to_string(),
364 method: "GET".to_string(),
365 path: "/test".to_string(),
366 response: MockResponse {
367 body: serde_json::json!({"message": "test"}),
368 headers: None,
369 },
370 enabled: true,
371 latency_ms: None,
372 status_code: Some(200),
373 request_match: None,
374 priority: None,
375 scenario: None,
376 required_scenario_state: None,
377 new_scenario_state: None,
378 version: 1,
379 };
380
381 let event = MockEvent::mock_created(mock);
382
383 let json = serde_json::to_string(&event).unwrap();
385 assert!(json.contains("mock_created"));
386 }
387
388 #[test]
389 fn test_broadcast_event() {
390 let state = WsManagementState::new();
391
392 let event = MockEvent::connected("Test connection".to_string());
393
394 let result = state.broadcast(event);
396 assert!(result.is_err() || result.is_ok());
398 }
399
400 #[tokio::test]
401 async fn test_ws_management_router_creation() {
402 let state = WsManagementState::new();
403 let _router = ws_management_router(state);
404 }
406}