Skip to main content

datasynth_server/rest/
websocket.rs

1//! WebSocket handlers for real-time data streaming.
2
3use std::time::Duration;
4
5use axum::extract::ws::{Message, WebSocket};
6use futures::{SinkExt, StreamExt};
7use serde::Serialize;
8use tracing::{error, info, warn};
9
10use super::routes::AppState;
11use datasynth_runtime::{EnhancedOrchestrator, PhaseConfig};
12
13/// Metrics update sent via WebSocket.
14#[derive(Debug, Serialize)]
15pub struct MetricsUpdate {
16    pub timestamp: String,
17    pub total_entries: u64,
18    pub total_anomalies: u64,
19    pub entries_per_second: f64,
20    pub active_streams: u32,
21    pub uptime_seconds: u64,
22}
23
24/// Event sent via WebSocket.
25#[derive(Debug, Serialize)]
26pub struct EventUpdate {
27    pub sequence: u64,
28    pub timestamp: String,
29    pub event_type: String,
30    pub document_id: String,
31    pub company_code: String,
32    pub amount: String,
33    pub is_anomaly: bool,
34}
35
36/// Marker type for metrics stream.
37pub struct MetricsStream;
38
39/// Handle WebSocket connection for metrics streaming.
40pub async fn handle_metrics_socket(socket: WebSocket, state: AppState) {
41    let (mut sender, mut receiver) = socket.split();
42
43    info!("Metrics WebSocket connected");
44
45    // Spawn a task to send metrics updates
46    let state_clone = state.clone();
47    let mut interval = tokio::time::interval(Duration::from_secs(1));
48
49    loop {
50        tokio::select! {
51            // Send metrics every second
52            _ = interval.tick() => {
53                let uptime = state_clone.server_state.uptime_seconds();
54                let total_entries = state_clone.server_state.total_entries.load(std::sync::atomic::Ordering::Relaxed);
55
56                let entries_per_second = if uptime > 0 {
57                    total_entries as f64 / uptime as f64
58                } else {
59                    0.0
60                };
61
62                let update = MetricsUpdate {
63                    timestamp: chrono::Utc::now().to_rfc3339(),
64                    total_entries,
65                    total_anomalies: state_clone.server_state.total_anomalies.load(std::sync::atomic::Ordering::Relaxed),
66                    entries_per_second,
67                    active_streams: state_clone.server_state.active_streams.load(std::sync::atomic::Ordering::Relaxed) as u32,
68                    uptime_seconds: uptime,
69                };
70
71                match serde_json::to_string(&update) {
72                    Ok(json) => {
73                        if sender.send(Message::Text(json.into())).await.is_err() {
74                            info!("Metrics WebSocket client disconnected");
75                            break;
76                        }
77                    }
78                    Err(e) => {
79                        error!("Failed to serialize metrics: {}", e);
80                    }
81                }
82            }
83            // Handle incoming messages (for ping/pong or close)
84            msg = receiver.next() => {
85                match msg {
86                    Some(Ok(Message::Close(_))) | None => {
87                        info!("Metrics WebSocket closed by client");
88                        break;
89                    }
90                    Some(Ok(Message::Ping(data))) => {
91                        if sender.send(Message::Pong(data)).await.is_err() {
92                            break;
93                        }
94                    }
95                    Some(Err(e)) => {
96                        warn!("Metrics WebSocket error: {}", e);
97                        break;
98                    }
99                    _ => {}
100                }
101            }
102        }
103    }
104}
105
106/// Handle WebSocket connection for event streaming.
107pub async fn handle_events_socket(socket: WebSocket, state: AppState) {
108    let (mut sender, mut receiver) = socket.split();
109
110    info!("Events WebSocket connected");
111
112    // Increment active streams
113    state
114        .server_state
115        .active_streams
116        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
117
118    let config = state.server_state.config.read().await.clone();
119
120    let phase_config = PhaseConfig {
121        generate_master_data: false,
122        generate_document_flows: false,
123        generate_journal_entries: true,
124        inject_anomalies: false,
125        show_progress: false,
126        ..Default::default()
127    };
128
129    let mut sequence = 0u64;
130    let delay = Duration::from_millis(100); // 10 events per second
131
132    loop {
133        // Check if we should stop
134        if state
135            .server_state
136            .stream_stopped
137            .load(std::sync::atomic::Ordering::Relaxed)
138        {
139            info!("Events stream stopped by control command");
140            break;
141        }
142
143        // Check if we should pause
144        while state
145            .server_state
146            .stream_paused
147            .load(std::sync::atomic::Ordering::Relaxed)
148        {
149            tokio::time::sleep(Duration::from_millis(100)).await;
150            if state
151                .server_state
152                .stream_stopped
153                .load(std::sync::atomic::Ordering::Relaxed)
154            {
155                break;
156            }
157        }
158
159        // Check for incoming messages
160        tokio::select! {
161            msg = receiver.next() => {
162                match msg {
163                    Some(Ok(Message::Close(_))) | None => {
164                        info!("Events WebSocket closed by client");
165                        break;
166                    }
167                    Some(Ok(Message::Ping(data))) => {
168                        if sender.send(Message::Pong(data)).await.is_err() {
169                            break;
170                        }
171                    }
172                    Some(Err(e)) => {
173                        warn!("Events WebSocket error: {}", e);
174                        break;
175                    }
176                    _ => {}
177                }
178            }
179            _ = tokio::time::sleep(delay) => {
180                // Generate and send an event
181                let mut orchestrator = match EnhancedOrchestrator::new(config.clone(), phase_config.clone()) {
182                    Ok(o) => o,
183                    Err(e) => {
184                        error!("Failed to create orchestrator: {}", e);
185                        break;
186                    }
187                };
188
189                let result = match orchestrator.generate() {
190                    Ok(r) => r,
191                    Err(e) => {
192                        error!("Generation failed: {}", e);
193                        break;
194                    }
195                };
196
197                // Send each entry
198                for entry in result.journal_entries.iter().take(1) {
199                    sequence += 1;
200                    state.server_state.total_stream_events.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
201                    state.server_state.total_entries.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
202
203                    let total_amount: rust_decimal::Decimal = entry.lines.iter()
204                        .map(|l| l.debit_amount)
205                        .sum();
206
207                    let event = EventUpdate {
208                        sequence,
209                        timestamp: chrono::Utc::now().to_rfc3339(),
210                        event_type: "JournalEntry".to_string(),
211                        document_id: entry.header.document_id.to_string(),
212                        company_code: entry.header.company_code.clone(),
213                        amount: total_amount.to_string(),
214                        is_anomaly: entry.header.is_fraud,
215                    };
216
217                    match serde_json::to_string(&event) {
218                        Ok(json) => {
219                            if sender.send(Message::Text(json.into())).await.is_err() {
220                                info!("Events WebSocket client disconnected");
221                                break;
222                            }
223                        }
224                        Err(e) => {
225                            error!("Failed to serialize event: {}", e);
226                        }
227                    }
228                }
229            }
230        }
231    }
232
233    // Decrement active streams
234    state
235        .server_state
236        .active_streams
237        .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
238}
239
240#[cfg(test)]
241#[allow(clippy::unwrap_used)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_metrics_update_serialization() {
247        let update = MetricsUpdate {
248            timestamp: "2024-01-01T00:00:00Z".to_string(),
249            total_entries: 1000,
250            total_anomalies: 10,
251            entries_per_second: 16.67,
252            active_streams: 1,
253            uptime_seconds: 60,
254        };
255        let json = serde_json::to_string(&update).unwrap();
256        assert!(json.contains("total_entries"));
257        assert!(json.contains("1000"));
258    }
259
260    #[test]
261    fn test_event_update_serialization() {
262        let event = EventUpdate {
263            sequence: 1,
264            timestamp: "2024-01-01T00:00:00Z".to_string(),
265            event_type: "JournalEntry".to_string(),
266            document_id: "12345".to_string(),
267            company_code: "1000".to_string(),
268            amount: "1000.00".to_string(),
269            is_anomaly: false,
270        };
271        let json = serde_json::to_string(&event).unwrap();
272        assert!(json.contains("JournalEntry"));
273        assert!(json.contains("12345"));
274    }
275}