Skip to main content

datasynth_server/rest/
websocket.rs

1//! WebSocket handlers for real-time data streaming.
2
3use std::time::Duration;
4
5use axum::extract::ws::{Message, WebSocket};
6use futures::{SinkExt, StreamExt};
7use serde::Serialize;
8use tracing::{error, info, warn};
9
10use super::routes::AppState;
11use datasynth_runtime::{EnhancedOrchestrator, PhaseConfig};
12
13/// Metrics update sent via WebSocket.
14#[derive(Debug, Serialize)]
15pub struct MetricsUpdate {
16    pub timestamp: String,
17    pub total_entries: u64,
18    pub total_anomalies: u64,
19    pub entries_per_second: f64,
20    pub active_streams: u32,
21    pub uptime_seconds: u64,
22}
23
24/// Event sent via WebSocket.
25#[derive(Debug, Serialize)]
26pub struct EventUpdate {
27    pub sequence: u64,
28    pub timestamp: String,
29    pub event_type: String,
30    pub document_id: String,
31    pub company_code: String,
32    pub amount: String,
33    pub is_anomaly: bool,
34}
35
36/// Marker type for metrics stream.
37pub struct MetricsStream;
38
39/// Handle WebSocket connection for metrics streaming.
40pub async fn handle_metrics_socket(socket: WebSocket, state: AppState) {
41    let (mut sender, mut receiver) = socket.split();
42
43    info!("Metrics WebSocket connected");
44
45    // Spawn a task to send metrics updates
46    let state_clone = state.clone();
47    let mut interval = tokio::time::interval(Duration::from_secs(1));
48
49    loop {
50        tokio::select! {
51            // Send metrics every second
52            _ = interval.tick() => {
53                let uptime = state_clone.server_state.uptime_seconds();
54                let total_entries = state_clone.server_state.total_entries.load(std::sync::atomic::Ordering::Relaxed);
55
56                let entries_per_second = if uptime > 0 {
57                    total_entries as f64 / uptime as f64
58                } else {
59                    0.0
60                };
61
62                let update = MetricsUpdate {
63                    timestamp: chrono::Utc::now().to_rfc3339(),
64                    total_entries,
65                    total_anomalies: state_clone.server_state.total_anomalies.load(std::sync::atomic::Ordering::Relaxed),
66                    entries_per_second,
67                    active_streams: state_clone.server_state.active_streams.load(std::sync::atomic::Ordering::Relaxed) as u32,
68                    uptime_seconds: uptime,
69                };
70
71                match serde_json::to_string(&update) {
72                    Ok(json) => {
73                        if sender.send(Message::Text(json.into())).await.is_err() {
74                            info!("Metrics WebSocket client disconnected");
75                            break;
76                        }
77                    }
78                    Err(e) => {
79                        error!("Failed to serialize metrics: {}", e);
80                    }
81                }
82            }
83            // Handle incoming messages (for ping/pong or close)
84            msg = receiver.next() => {
85                match msg {
86                    Some(Ok(Message::Close(_))) | None => {
87                        info!("Metrics WebSocket closed by client");
88                        break;
89                    }
90                    Some(Ok(Message::Ping(data))) => {
91                        if sender.send(Message::Pong(data)).await.is_err() {
92                            break;
93                        }
94                    }
95                    Some(Err(e)) => {
96                        warn!("Metrics WebSocket error: {}", e);
97                        break;
98                    }
99                    _ => {}
100                }
101            }
102        }
103    }
104}
105
106/// Handle WebSocket connection for event streaming.
107pub async fn handle_events_socket(socket: WebSocket, state: AppState) {
108    let (mut sender, mut receiver) = socket.split();
109
110    info!("Events WebSocket connected");
111
112    // Increment active streams
113    state
114        .server_state
115        .active_streams
116        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
117
118    let config = state.server_state.config.read().await.clone();
119
120    let phase_config = PhaseConfig {
121        generate_master_data: false,
122        generate_document_flows: false,
123        generate_journal_entries: true,
124        inject_anomalies: false,
125        show_progress: false,
126        ..Default::default()
127    };
128
129    let mut sequence = 0u64;
130    let delay = Duration::from_millis(100); // 10 events per second
131
132    // Create orchestrator once outside the loop to avoid per-iteration overhead
133    let mut orchestrator = match EnhancedOrchestrator::new(config.clone(), phase_config.clone()) {
134        Ok(o) => o,
135        Err(e) => {
136            error!("Failed to create orchestrator: {}", e);
137            state
138                .server_state
139                .active_streams
140                .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
141            return;
142        }
143    };
144
145    loop {
146        // Check if we should stop
147        if state
148            .server_state
149            .stream_stopped
150            .load(std::sync::atomic::Ordering::Relaxed)
151        {
152            info!("Events stream stopped by control command");
153            break;
154        }
155
156        // Check if we should pause
157        while state
158            .server_state
159            .stream_paused
160            .load(std::sync::atomic::Ordering::Relaxed)
161        {
162            tokio::time::sleep(Duration::from_millis(100)).await;
163            if state
164                .server_state
165                .stream_stopped
166                .load(std::sync::atomic::Ordering::Relaxed)
167            {
168                break;
169            }
170        }
171
172        // Check for incoming messages
173        tokio::select! {
174            msg = receiver.next() => {
175                match msg {
176                    Some(Ok(Message::Close(_))) | None => {
177                        info!("Events WebSocket closed by client");
178                        break;
179                    }
180                    Some(Ok(Message::Ping(data))) => {
181                        if sender.send(Message::Pong(data)).await.is_err() {
182                            break;
183                        }
184                    }
185                    Some(Err(e)) => {
186                        warn!("Events WebSocket error: {}", e);
187                        break;
188                    }
189                    _ => {}
190                }
191            }
192            _ = tokio::time::sleep(delay) => {
193                let result = match orchestrator.generate() {
194                    Ok(r) => r,
195                    Err(e) => {
196                        error!("Generation failed: {}", e);
197                        break;
198                    }
199                };
200
201                // Stream all journal entries, not just the first one
202                for entry in result.journal_entries.iter() {
203                    sequence += 1;
204                    state.server_state.total_stream_events.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
205                    state.server_state.total_entries.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
206
207                    let total_amount: rust_decimal::Decimal = entry.lines.iter()
208                        .map(|l| l.debit_amount)
209                        .sum();
210
211                    let event = EventUpdate {
212                        sequence,
213                        timestamp: chrono::Utc::now().to_rfc3339(),
214                        event_type: "JournalEntry".to_string(),
215                        document_id: entry.header.document_id.to_string(),
216                        company_code: entry.header.company_code.clone(),
217                        amount: total_amount.to_string(),
218                        is_anomaly: entry.header.is_fraud,
219                    };
220
221                    match serde_json::to_string(&event) {
222                        Ok(json) => {
223                            if sender.send(Message::Text(json.into())).await.is_err() {
224                                info!("Events WebSocket client disconnected");
225                                break;
226                            }
227                        }
228                        Err(e) => {
229                            error!("Failed to serialize event: {}", e);
230                        }
231                    }
232                }
233            }
234        }
235    }
236
237    // Decrement active streams
238    state
239        .server_state
240        .active_streams
241        .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
242}
243
244#[cfg(test)]
245#[allow(clippy::unwrap_used)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_metrics_update_serialization() {
251        let update = MetricsUpdate {
252            timestamp: "2024-01-01T00:00:00Z".to_string(),
253            total_entries: 1000,
254            total_anomalies: 10,
255            entries_per_second: 16.67,
256            active_streams: 1,
257            uptime_seconds: 60,
258        };
259        let json = serde_json::to_string(&update).unwrap();
260        assert!(json.contains("total_entries"));
261        assert!(json.contains("1000"));
262    }
263
264    #[test]
265    fn test_event_update_serialization() {
266        let event = EventUpdate {
267            sequence: 1,
268            timestamp: "2024-01-01T00:00:00Z".to_string(),
269            event_type: "JournalEntry".to_string(),
270            document_id: "12345".to_string(),
271            company_code: "1000".to_string(),
272            amount: "1000.00".to_string(),
273            is_anomaly: false,
274        };
275        let json = serde_json::to_string(&event).unwrap();
276        assert!(json.contains("JournalEntry"));
277        assert!(json.contains("12345"));
278    }
279}