mockforge_ws/
ai_event_generator.rs

1//! AI-powered WebSocket event generation
2//!
3//! This module integrates LLM-powered replay augmentation into WebSocket
4//! event streaming, allowing realistic event generation from narrative descriptions.
5
6use axum::extract::ws::{Message, WebSocket};
7use mockforge_data::{ReplayAugmentationConfig, ReplayAugmentationEngine};
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use tokio::time::{sleep, Duration};
11use tracing::{debug, error, info};
12
13/// AI event generator for WebSocket connections
14pub struct AiEventGenerator {
15    /// Replay augmentation engine
16    engine: Arc<RwLock<ReplayAugmentationEngine>>,
17}
18
19impl AiEventGenerator {
20    /// Create a new AI event generator
21    pub fn new(config: ReplayAugmentationConfig) -> mockforge_core::Result<Self> {
22        debug!("Creating AI event generator");
23        let engine = ReplayAugmentationEngine::new(config)
24            .map_err(|e| mockforge_core::Error::generic(e.to_string()))?;
25        Ok(Self {
26            engine: Arc::new(RwLock::new(engine)),
27        })
28    }
29
30    /// Stream AI-generated events to a WebSocket connection
31    ///
32    /// This method generates events using the configured AI engine and sends them
33    /// to the client via WebSocket.
34    pub async fn stream_events(&self, mut socket: WebSocket, max_events: Option<usize>) {
35        info!("Starting AI event stream (max_events: {:?})", max_events);
36
37        // Generate all events at once
38        let events = match self.engine.write().await.generate_stream().await {
39            Ok(events) => events,
40            Err(e) => {
41                error!("Failed to generate event stream: {}", e);
42                return;
43            }
44        };
45
46        info!("Generated {} events from AI engine", events.len());
47
48        let max = max_events.unwrap_or(events.len());
49        let events_to_send = events.into_iter().take(max);
50
51        for event in events_to_send {
52            // Convert event to JSON message
53            let message_json = serde_json::json!({
54                "type": event.event_type,
55                "timestamp": event.timestamp.to_rfc3339(),
56                "sequence": event.sequence,
57                "data": event.data
58            });
59
60            let message_str = match serde_json::to_string(&message_json) {
61                Ok(s) => s,
62                Err(e) => {
63                    error!("Failed to serialize event: {}", e);
64                    continue;
65                }
66            };
67
68            debug!("Sending AI-generated event: {}", message_str);
69
70            // Send event to client
71            if socket.send(Message::Text(message_str.into())).await.is_err() {
72                info!("Client disconnected, stopping event stream");
73                break;
74            }
75
76            // Small delay between events (configurable event rate would be better)
77            sleep(Duration::from_millis(100)).await;
78        }
79
80        info!("AI event stream completed");
81    }
82
83    /// Stream events with custom event rate
84    pub async fn stream_events_with_rate(
85        &self,
86        mut socket: WebSocket,
87        max_events: Option<usize>,
88        events_per_second: f64,
89    ) {
90        info!(
91            "Starting AI event stream (max_events: {:?}, rate: {} events/sec)",
92            max_events, events_per_second
93        );
94
95        // Generate all events at once
96        let events = match self.engine.write().await.generate_stream().await {
97            Ok(events) => events,
98            Err(e) => {
99                error!("Failed to generate event stream: {}", e);
100                return;
101            }
102        };
103
104        info!("Generated {} events from AI engine", events.len());
105
106        let delay_ms = (1000.0 / events_per_second) as u64;
107        let max = max_events.unwrap_or(events.len());
108        let events_to_send = events.into_iter().take(max);
109
110        for event in events_to_send {
111            // Convert event to JSON message
112            let message_json = serde_json::json!({
113                "type": event.event_type,
114                "timestamp": event.timestamp.to_rfc3339(),
115                "sequence": event.sequence,
116                "data": event.data
117            });
118
119            let message_str = match serde_json::to_string(&message_json) {
120                Ok(s) => s,
121                Err(e) => {
122                    error!("Failed to serialize event: {}", e);
123                    continue;
124                }
125            };
126
127            debug!("Sending AI-generated event: {}", message_str);
128
129            // Send event to client
130            if socket.send(Message::Text(message_str.into())).await.is_err() {
131                info!("Client disconnected, stopping event stream");
132                break;
133            }
134
135            // Delay based on configured rate
136            sleep(Duration::from_millis(delay_ms)).await;
137        }
138
139        info!("AI event stream completed");
140    }
141}
142
143/// Configuration for WebSocket AI event generation
144#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
145pub struct WebSocketAiConfig {
146    /// Whether AI event generation is enabled
147    pub enabled: bool,
148    /// Replay augmentation configuration
149    pub replay: Option<ReplayAugmentationConfig>,
150    /// Maximum number of events to generate
151    pub max_events: Option<usize>,
152    /// Events per second
153    pub event_rate: Option<f64>,
154}
155
156impl Default for WebSocketAiConfig {
157    fn default() -> Self {
158        Self {
159            enabled: false,
160            replay: None,
161            max_events: Some(100),
162            event_rate: Some(1.0),
163        }
164    }
165}
166
167impl WebSocketAiConfig {
168    /// Check if AI features are enabled
169    pub fn is_enabled(&self) -> bool {
170        self.enabled && self.replay.is_some()
171    }
172
173    /// Create an AI event generator from this configuration
174    pub fn create_generator(&self) -> mockforge_core::Result<Option<AiEventGenerator>> {
175        if let Some(replay_config) = &self.replay {
176            let generator = AiEventGenerator::new(replay_config.clone())?;
177            Ok(Some(generator))
178        } else {
179            Ok(None)
180        }
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use mockforge_data::{EventStrategy, ReplayMode};
188
189    // ==================== WebSocketAiConfig Tests ====================
190
191    #[test]
192    fn test_websocket_ai_config_default() {
193        let config = WebSocketAiConfig::default();
194        assert!(!config.is_enabled());
195        assert_eq!(config.max_events, Some(100));
196        assert_eq!(config.event_rate, Some(1.0));
197    }
198
199    #[test]
200    fn test_websocket_ai_config_default_enabled_false() {
201        let config = WebSocketAiConfig::default();
202        assert!(!config.enabled);
203        assert!(config.replay.is_none());
204    }
205
206    #[test]
207    fn test_websocket_ai_config_is_enabled() {
208        let mut config = WebSocketAiConfig {
209            enabled: true,
210            ..Default::default()
211        };
212
213        // Still not enabled without replay config
214        assert!(!config.is_enabled());
215
216        // Now enabled with replay config
217        config.replay = Some(ReplayAugmentationConfig {
218            mode: ReplayMode::Generated,
219            strategy: EventStrategy::CountBased,
220            ..Default::default()
221        });
222        assert!(config.is_enabled());
223    }
224
225    #[test]
226    fn test_websocket_ai_config_enabled_requires_both() {
227        // Only enabled flag set
228        let config1 = WebSocketAiConfig {
229            enabled: true,
230            replay: None,
231            max_events: None,
232            event_rate: None,
233        };
234        assert!(!config1.is_enabled());
235
236        // Only replay set, but enabled is false
237        let config2 = WebSocketAiConfig {
238            enabled: false,
239            replay: Some(ReplayAugmentationConfig::default()),
240            max_events: None,
241            event_rate: None,
242        };
243        assert!(!config2.is_enabled());
244
245        // Both set
246        let config3 = WebSocketAiConfig {
247            enabled: true,
248            replay: Some(ReplayAugmentationConfig::default()),
249            max_events: None,
250            event_rate: None,
251        };
252        assert!(config3.is_enabled());
253    }
254
255    #[test]
256    fn test_websocket_ai_config_custom_values() {
257        let config = WebSocketAiConfig {
258            enabled: true,
259            replay: Some(ReplayAugmentationConfig {
260                mode: ReplayMode::Generated,
261                strategy: EventStrategy::TimeBased,
262                ..Default::default()
263            }),
264            max_events: Some(50),
265            event_rate: Some(2.5),
266        };
267
268        assert!(config.is_enabled());
269        assert_eq!(config.max_events, Some(50));
270        assert_eq!(config.event_rate, Some(2.5));
271    }
272
273    #[test]
274    fn test_websocket_ai_config_create_generator_none_when_no_replay() {
275        let config = WebSocketAiConfig::default();
276        let result = config.create_generator();
277        assert!(result.is_ok());
278        assert!(result.unwrap().is_none());
279    }
280
281    #[test]
282    fn test_websocket_ai_config_create_generator_with_replay_set() {
283        let config = WebSocketAiConfig {
284            enabled: true,
285            replay: Some(ReplayAugmentationConfig {
286                mode: ReplayMode::Generated,
287                strategy: EventStrategy::CountBased,
288                ..Default::default()
289            }),
290            max_events: Some(10),
291            event_rate: Some(1.0),
292        };
293
294        // The result depends on proper initialization of the ReplayAugmentationEngine
295        // Just verify it returns a Result
296        let _result = config.create_generator();
297    }
298
299    // ==================== ReplayMode Tests ====================
300
301    #[test]
302    fn test_replay_mode_generated() {
303        let config = ReplayAugmentationConfig {
304            mode: ReplayMode::Generated,
305            strategy: EventStrategy::CountBased,
306            ..Default::default()
307        };
308        assert!(matches!(config.mode, ReplayMode::Generated));
309    }
310
311    // ==================== EventStrategy Tests ====================
312
313    #[test]
314    fn test_event_strategy_count_based() {
315        let config = ReplayAugmentationConfig {
316            mode: ReplayMode::Generated,
317            strategy: EventStrategy::CountBased,
318            ..Default::default()
319        };
320        assert!(matches!(config.strategy, EventStrategy::CountBased));
321    }
322
323    #[test]
324    fn test_event_strategy_time_based() {
325        let config = ReplayAugmentationConfig {
326            mode: ReplayMode::Generated,
327            strategy: EventStrategy::TimeBased,
328            ..Default::default()
329        };
330        assert!(matches!(config.strategy, EventStrategy::TimeBased));
331    }
332
333    // ==================== AiEventGenerator Tests ====================
334    // Note: AiEventGenerator::new may fail without proper config, so we just check it doesn't panic
335
336    // ==================== Serialization Tests ====================
337
338    #[test]
339    fn test_websocket_ai_config_serialize() {
340        let config = WebSocketAiConfig {
341            enabled: true,
342            replay: None,
343            max_events: Some(50),
344            event_rate: Some(1.5),
345        };
346
347        let json = serde_json::to_string(&config).unwrap();
348        assert!(json.contains("\"enabled\":true"));
349        assert!(json.contains("\"max_events\":50"));
350        assert!(json.contains("\"event_rate\":1.5"));
351    }
352
353    #[test]
354    fn test_websocket_ai_config_deserialize() {
355        let json = r#"{
356            "enabled": true,
357            "replay": null,
358            "max_events": 100,
359            "event_rate": 2.0
360        }"#;
361
362        let config: WebSocketAiConfig = serde_json::from_str(json).unwrap();
363        assert!(config.enabled);
364        assert!(config.replay.is_none());
365        assert_eq!(config.max_events, Some(100));
366        assert_eq!(config.event_rate, Some(2.0));
367    }
368
369    #[test]
370    fn test_websocket_ai_config_roundtrip() {
371        let original = WebSocketAiConfig {
372            enabled: true,
373            replay: Some(ReplayAugmentationConfig::default()),
374            max_events: Some(25),
375            event_rate: Some(0.5),
376        };
377
378        let json = serde_json::to_string(&original).unwrap();
379        let restored: WebSocketAiConfig = serde_json::from_str(&json).unwrap();
380
381        assert_eq!(original.enabled, restored.enabled);
382        assert_eq!(original.max_events, restored.max_events);
383        assert_eq!(original.event_rate, restored.event_rate);
384        assert!(restored.replay.is_some());
385    }
386
387    // ==================== Clone and Debug Tests ====================
388
389    #[test]
390    fn test_websocket_ai_config_clone() {
391        let config = WebSocketAiConfig {
392            enabled: true,
393            replay: Some(ReplayAugmentationConfig::default()),
394            max_events: Some(50),
395            event_rate: Some(1.0),
396        };
397
398        let cloned = config.clone();
399        assert_eq!(config.enabled, cloned.enabled);
400        assert_eq!(config.max_events, cloned.max_events);
401        assert_eq!(config.event_rate, cloned.event_rate);
402    }
403
404    #[test]
405    fn test_websocket_ai_config_debug() {
406        let config = WebSocketAiConfig::default();
407        let debug_str = format!("{:?}", config);
408        assert!(debug_str.contains("WebSocketAiConfig"));
409        assert!(debug_str.contains("enabled"));
410    }
411}