1use axum::extract::ws::{Message, WebSocket};
7use mockforge_data::{ReplayAugmentationConfig, ReplayAugmentationEngine};
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use tokio::time::{sleep, Duration};
11use tracing::{debug, error, info};
12
13pub struct AiEventGenerator {
15 engine: Arc<RwLock<ReplayAugmentationEngine>>,
17}
18
19impl AiEventGenerator {
20 pub fn new(config: ReplayAugmentationConfig) -> mockforge_core::Result<Self> {
22 debug!("Creating AI event generator");
23 let engine = ReplayAugmentationEngine::new(config)
24 .map_err(|e| mockforge_core::Error::generic(e.to_string()))?;
25 Ok(Self {
26 engine: Arc::new(RwLock::new(engine)),
27 })
28 }
29
30 pub async fn stream_events(&self, mut socket: WebSocket, max_events: Option<usize>) {
35 info!("Starting AI event stream (max_events: {:?})", max_events);
36
37 let events = match self.engine.write().await.generate_stream().await {
39 Ok(events) => events,
40 Err(e) => {
41 error!("Failed to generate event stream: {}", e);
42 return;
43 }
44 };
45
46 info!("Generated {} events from AI engine", events.len());
47
48 let max = max_events.unwrap_or(events.len());
49 let events_to_send = events.into_iter().take(max);
50
51 for event in events_to_send {
52 let message_json = serde_json::json!({
54 "type": event.event_type,
55 "timestamp": event.timestamp.to_rfc3339(),
56 "sequence": event.sequence,
57 "data": event.data
58 });
59
60 let message_str = match serde_json::to_string(&message_json) {
61 Ok(s) => s,
62 Err(e) => {
63 error!("Failed to serialize event: {}", e);
64 continue;
65 }
66 };
67
68 debug!("Sending AI-generated event: {}", message_str);
69
70 if socket.send(Message::Text(message_str.into())).await.is_err() {
72 info!("Client disconnected, stopping event stream");
73 break;
74 }
75
76 sleep(Duration::from_millis(100)).await;
78 }
79
80 info!("AI event stream completed");
81 }
82
83 pub async fn stream_events_with_rate(
85 &self,
86 mut socket: WebSocket,
87 max_events: Option<usize>,
88 events_per_second: f64,
89 ) {
90 info!(
91 "Starting AI event stream (max_events: {:?}, rate: {} events/sec)",
92 max_events, events_per_second
93 );
94
95 let events = match self.engine.write().await.generate_stream().await {
97 Ok(events) => events,
98 Err(e) => {
99 error!("Failed to generate event stream: {}", e);
100 return;
101 }
102 };
103
104 info!("Generated {} events from AI engine", events.len());
105
106 let delay_ms = (1000.0 / events_per_second) as u64;
107 let max = max_events.unwrap_or(events.len());
108 let events_to_send = events.into_iter().take(max);
109
110 for event in events_to_send {
111 let message_json = serde_json::json!({
113 "type": event.event_type,
114 "timestamp": event.timestamp.to_rfc3339(),
115 "sequence": event.sequence,
116 "data": event.data
117 });
118
119 let message_str = match serde_json::to_string(&message_json) {
120 Ok(s) => s,
121 Err(e) => {
122 error!("Failed to serialize event: {}", e);
123 continue;
124 }
125 };
126
127 debug!("Sending AI-generated event: {}", message_str);
128
129 if socket.send(Message::Text(message_str.into())).await.is_err() {
131 info!("Client disconnected, stopping event stream");
132 break;
133 }
134
135 sleep(Duration::from_millis(delay_ms)).await;
137 }
138
139 info!("AI event stream completed");
140 }
141}
142
143#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
145pub struct WebSocketAiConfig {
146 pub enabled: bool,
148 pub replay: Option<ReplayAugmentationConfig>,
150 pub max_events: Option<usize>,
152 pub event_rate: Option<f64>,
154}
155
156impl Default for WebSocketAiConfig {
157 fn default() -> Self {
158 Self {
159 enabled: false,
160 replay: None,
161 max_events: Some(100),
162 event_rate: Some(1.0),
163 }
164 }
165}
166
167impl WebSocketAiConfig {
168 pub fn is_enabled(&self) -> bool {
170 self.enabled && self.replay.is_some()
171 }
172
173 pub fn create_generator(&self) -> mockforge_core::Result<Option<AiEventGenerator>> {
175 if let Some(replay_config) = &self.replay {
176 let generator = AiEventGenerator::new(replay_config.clone())?;
177 Ok(Some(generator))
178 } else {
179 Ok(None)
180 }
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187 use mockforge_data::{EventStrategy, ReplayMode};
188
189 #[test]
192 fn test_websocket_ai_config_default() {
193 let config = WebSocketAiConfig::default();
194 assert!(!config.is_enabled());
195 assert_eq!(config.max_events, Some(100));
196 assert_eq!(config.event_rate, Some(1.0));
197 }
198
199 #[test]
200 fn test_websocket_ai_config_default_enabled_false() {
201 let config = WebSocketAiConfig::default();
202 assert!(!config.enabled);
203 assert!(config.replay.is_none());
204 }
205
206 #[test]
207 fn test_websocket_ai_config_is_enabled() {
208 let mut config = WebSocketAiConfig {
209 enabled: true,
210 ..Default::default()
211 };
212
213 assert!(!config.is_enabled());
215
216 config.replay = Some(ReplayAugmentationConfig {
218 mode: ReplayMode::Generated,
219 strategy: EventStrategy::CountBased,
220 ..Default::default()
221 });
222 assert!(config.is_enabled());
223 }
224
225 #[test]
226 fn test_websocket_ai_config_enabled_requires_both() {
227 let config1 = WebSocketAiConfig {
229 enabled: true,
230 replay: None,
231 max_events: None,
232 event_rate: None,
233 };
234 assert!(!config1.is_enabled());
235
236 let config2 = WebSocketAiConfig {
238 enabled: false,
239 replay: Some(ReplayAugmentationConfig::default()),
240 max_events: None,
241 event_rate: None,
242 };
243 assert!(!config2.is_enabled());
244
245 let config3 = WebSocketAiConfig {
247 enabled: true,
248 replay: Some(ReplayAugmentationConfig::default()),
249 max_events: None,
250 event_rate: None,
251 };
252 assert!(config3.is_enabled());
253 }
254
255 #[test]
256 fn test_websocket_ai_config_custom_values() {
257 let config = WebSocketAiConfig {
258 enabled: true,
259 replay: Some(ReplayAugmentationConfig {
260 mode: ReplayMode::Generated,
261 strategy: EventStrategy::TimeBased,
262 ..Default::default()
263 }),
264 max_events: Some(50),
265 event_rate: Some(2.5),
266 };
267
268 assert!(config.is_enabled());
269 assert_eq!(config.max_events, Some(50));
270 assert_eq!(config.event_rate, Some(2.5));
271 }
272
273 #[test]
274 fn test_websocket_ai_config_create_generator_none_when_no_replay() {
275 let config = WebSocketAiConfig::default();
276 let result = config.create_generator();
277 assert!(result.is_ok());
278 assert!(result.unwrap().is_none());
279 }
280
281 #[test]
282 fn test_websocket_ai_config_create_generator_with_replay_set() {
283 let config = WebSocketAiConfig {
284 enabled: true,
285 replay: Some(ReplayAugmentationConfig {
286 mode: ReplayMode::Generated,
287 strategy: EventStrategy::CountBased,
288 ..Default::default()
289 }),
290 max_events: Some(10),
291 event_rate: Some(1.0),
292 };
293
294 let _result = config.create_generator();
297 }
298
299 #[test]
302 fn test_replay_mode_generated() {
303 let config = ReplayAugmentationConfig {
304 mode: ReplayMode::Generated,
305 strategy: EventStrategy::CountBased,
306 ..Default::default()
307 };
308 assert!(matches!(config.mode, ReplayMode::Generated));
309 }
310
311 #[test]
314 fn test_event_strategy_count_based() {
315 let config = ReplayAugmentationConfig {
316 mode: ReplayMode::Generated,
317 strategy: EventStrategy::CountBased,
318 ..Default::default()
319 };
320 assert!(matches!(config.strategy, EventStrategy::CountBased));
321 }
322
323 #[test]
324 fn test_event_strategy_time_based() {
325 let config = ReplayAugmentationConfig {
326 mode: ReplayMode::Generated,
327 strategy: EventStrategy::TimeBased,
328 ..Default::default()
329 };
330 assert!(matches!(config.strategy, EventStrategy::TimeBased));
331 }
332
333 #[test]
339 fn test_websocket_ai_config_serialize() {
340 let config = WebSocketAiConfig {
341 enabled: true,
342 replay: None,
343 max_events: Some(50),
344 event_rate: Some(1.5),
345 };
346
347 let json = serde_json::to_string(&config).unwrap();
348 assert!(json.contains("\"enabled\":true"));
349 assert!(json.contains("\"max_events\":50"));
350 assert!(json.contains("\"event_rate\":1.5"));
351 }
352
353 #[test]
354 fn test_websocket_ai_config_deserialize() {
355 let json = r#"{
356 "enabled": true,
357 "replay": null,
358 "max_events": 100,
359 "event_rate": 2.0
360 }"#;
361
362 let config: WebSocketAiConfig = serde_json::from_str(json).unwrap();
363 assert!(config.enabled);
364 assert!(config.replay.is_none());
365 assert_eq!(config.max_events, Some(100));
366 assert_eq!(config.event_rate, Some(2.0));
367 }
368
369 #[test]
370 fn test_websocket_ai_config_roundtrip() {
371 let original = WebSocketAiConfig {
372 enabled: true,
373 replay: Some(ReplayAugmentationConfig::default()),
374 max_events: Some(25),
375 event_rate: Some(0.5),
376 };
377
378 let json = serde_json::to_string(&original).unwrap();
379 let restored: WebSocketAiConfig = serde_json::from_str(&json).unwrap();
380
381 assert_eq!(original.enabled, restored.enabled);
382 assert_eq!(original.max_events, restored.max_events);
383 assert_eq!(original.event_rate, restored.event_rate);
384 assert!(restored.replay.is_some());
385 }
386
387 #[test]
390 fn test_websocket_ai_config_clone() {
391 let config = WebSocketAiConfig {
392 enabled: true,
393 replay: Some(ReplayAugmentationConfig::default()),
394 max_events: Some(50),
395 event_rate: Some(1.0),
396 };
397
398 let cloned = config.clone();
399 assert_eq!(config.enabled, cloned.enabled);
400 assert_eq!(config.max_events, cloned.max_events);
401 assert_eq!(config.event_rate, cloned.event_rate);
402 }
403
404 #[test]
405 fn test_websocket_ai_config_debug() {
406 let config = WebSocketAiConfig::default();
407 let debug_str = format!("{:?}", config);
408 assert!(debug_str.contains("WebSocketAiConfig"));
409 assert!(debug_str.contains("enabled"));
410 }
411}