Skip to main content

mockforge_data/
replay_augmentation.rs

1//! LLM-powered replay augmentation for WebSocket and GraphQL subscriptions
2//!
3//! This module enables AI-driven event stream generation for real-time protocols,
4//! allowing users to define high-level scenarios that generate realistic event sequences.
5
6use crate::rag::{RagConfig, RagEngine};
7use crate::{Error, Result};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::time::Duration;
11use tokio::time::interval;
12
13/// Replay augmentation mode
14#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
15#[serde(rename_all = "snake_case")]
16pub enum ReplayMode {
17    /// Static replay from pre-recorded events
18    #[default]
19    Static,
20    /// LLM-augmented replay with scenario-based generation
21    Augmented,
22    /// Fully generated event stream from narrative description
23    Generated,
24}
25
26/// Event generation strategy
27#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
28#[serde(rename_all = "snake_case")]
29pub enum EventStrategy {
30    /// Time-based event generation
31    TimeBased,
32    /// Count-based event generation
33    CountBased,
34    /// Condition-based event generation
35    ConditionalBased,
36}
37
38/// Replay augmentation configuration
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ReplayAugmentationConfig {
41    /// Replay mode
42    pub mode: ReplayMode,
43    /// Narrative description of the scenario
44    pub narrative: Option<String>,
45    /// Event type/name
46    pub event_type: String,
47    /// Event schema (optional JSON schema)
48    pub event_schema: Option<Value>,
49    /// Event generation strategy
50    pub strategy: EventStrategy,
51    /// Duration to replay (for time-based)
52    pub duration_secs: Option<u64>,
53    /// Number of events to generate (for count-based)
54    pub event_count: Option<usize>,
55    /// Event rate (events per second)
56    pub event_rate: Option<f64>,
57    /// Conditions for event generation
58    pub conditions: Vec<EventCondition>,
59    /// RAG configuration for LLM
60    pub rag_config: Option<RagConfig>,
61    /// Enable progressive scenario evolution
62    pub progressive_evolution: bool,
63}
64
65impl Default for ReplayAugmentationConfig {
66    fn default() -> Self {
67        Self {
68            mode: ReplayMode::Static,
69            narrative: None,
70            event_type: "event".to_string(),
71            event_schema: None,
72            strategy: EventStrategy::CountBased,
73            duration_secs: None,
74            event_count: Some(10),
75            event_rate: Some(1.0),
76            conditions: Vec::new(),
77            rag_config: None,
78            progressive_evolution: true,
79        }
80    }
81}
82
83/// Event generation condition
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct EventCondition {
86    /// Condition name/description
87    pub name: String,
88    /// Condition expression (simplified)
89    pub expression: String,
90    /// Action to take when condition is met
91    pub action: ConditionAction,
92}
93
94/// Condition action
95#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
96#[serde(rename_all = "snake_case")]
97pub enum ConditionAction {
98    /// Generate a new event
99    GenerateEvent,
100    /// Stop event generation
101    Stop,
102    /// Change event rate
103    ChangeRate(u64), // events per second
104    /// Transition to new scenario
105    TransitionScenario(String),
106}
107
108/// Generated event
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct GeneratedEvent {
111    /// Event type
112    pub event_type: String,
113    /// Event timestamp
114    pub timestamp: chrono::DateTime<chrono::Utc>,
115    /// Event data
116    pub data: Value,
117    /// Sequence number
118    pub sequence: usize,
119    /// Event metadata
120    pub metadata: std::collections::HashMap<String, String>,
121}
122
123impl GeneratedEvent {
124    /// Create a new generated event
125    pub fn new(event_type: String, data: Value, sequence: usize) -> Self {
126        Self {
127            event_type,
128            timestamp: chrono::Utc::now(),
129            data,
130            sequence,
131            metadata: std::collections::HashMap::new(),
132        }
133    }
134
135    /// Add metadata
136    pub fn with_metadata(mut self, key: String, value: String) -> Self {
137        self.metadata.insert(key, value);
138        self
139    }
140
141    /// Convert to JSON
142    pub fn to_json(&self) -> Result<String> {
143        serde_json::to_string(self)
144            .map_err(|e| Error::generic(format!("Failed to serialize event: {}", e)))
145    }
146}
147
148/// Replay augmentation engine
149pub struct ReplayAugmentationEngine {
150    /// Configuration
151    config: ReplayAugmentationConfig,
152    /// RAG engine for LLM-based generation
153    rag_engine: Option<RagEngine>,
154    /// Event sequence counter
155    sequence: usize,
156    /// Current scenario state
157    scenario_state: ScenarioState,
158}
159
160/// Scenario state tracking
161#[derive(Debug, Clone)]
162struct ScenarioState {
163    /// Current timestamp in scenario
164    _current_time: std::time::Instant,
165    /// Events generated so far
166    events_generated: usize,
167    /// Last event data (for progressive evolution)
168    last_event: Option<Value>,
169    /// Scenario context
170    context: Vec<String>,
171}
172
173impl Default for ScenarioState {
174    fn default() -> Self {
175        Self {
176            _current_time: std::time::Instant::now(),
177            events_generated: 0,
178            last_event: None,
179            context: Vec::new(),
180        }
181    }
182}
183
184impl ReplayAugmentationEngine {
185    /// Create a new replay augmentation engine
186    pub fn new(config: ReplayAugmentationConfig) -> Result<Self> {
187        Self::validate_config(&config)?;
188
189        let rag_engine = if config.mode != ReplayMode::Static {
190            let rag_config = config.rag_config.clone().unwrap_or_default();
191            Some(RagEngine::new(rag_config))
192        } else {
193            None
194        };
195
196        Ok(Self {
197            config,
198            rag_engine,
199            sequence: 0,
200            scenario_state: ScenarioState::default(),
201        })
202    }
203
204    /// Validate configuration
205    fn validate_config(config: &ReplayAugmentationConfig) -> Result<()> {
206        if config.mode != ReplayMode::Static && config.narrative.is_none() {
207            return Err(Error::generic(
208                "Narrative is required for augmented or generated replay modes",
209            ));
210        }
211
212        match config.strategy {
213            EventStrategy::TimeBased => {
214                if config.duration_secs.is_none() {
215                    return Err(Error::generic(
216                        "Duration must be specified for time-based strategy",
217                    ));
218                }
219            }
220            EventStrategy::CountBased => {
221                if config.event_count.is_none() {
222                    return Err(Error::generic(
223                        "Event count must be specified for count-based strategy",
224                    ));
225                }
226            }
227            EventStrategy::ConditionalBased => {
228                if config.conditions.is_empty() {
229                    return Err(Error::generic(
230                        "Conditions must be specified for conditional-based strategy",
231                    ));
232                }
233            }
234        }
235
236        Ok(())
237    }
238
239    /// Generate event stream based on configuration
240    pub async fn generate_stream(&mut self) -> Result<Vec<GeneratedEvent>> {
241        match self.config.strategy {
242            EventStrategy::CountBased => self.generate_count_based().await,
243            EventStrategy::TimeBased => self.generate_time_based().await,
244            EventStrategy::ConditionalBased => self.generate_conditional_based().await,
245        }
246    }
247
248    /// Generate events based on count
249    async fn generate_count_based(&mut self) -> Result<Vec<GeneratedEvent>> {
250        let count = self.config.event_count.unwrap_or(10);
251        let mut events = Vec::with_capacity(count);
252
253        for i in 0..count {
254            let event = self.generate_single_event(i).await?;
255            events.push(event);
256
257            // Add delay between events if rate is specified
258            if let Some(rate) = self.config.event_rate {
259                if rate > 0.0 {
260                    let delay_ms = (1000.0 / rate) as u64;
261                    tokio::time::sleep(Duration::from_millis(delay_ms)).await;
262                }
263            }
264        }
265
266        Ok(events)
267    }
268
269    /// Generate events based on time duration
270    async fn generate_time_based(&mut self) -> Result<Vec<GeneratedEvent>> {
271        let duration = Duration::from_secs(self.config.duration_secs.unwrap_or(60));
272        let rate = self.config.event_rate.unwrap_or(1.0);
273        let interval_ms = (1000.0 / rate) as u64;
274
275        let mut events = Vec::new();
276        let mut ticker = interval(Duration::from_millis(interval_ms));
277        let start = std::time::Instant::now();
278
279        let mut index = 0;
280        while start.elapsed() < duration {
281            ticker.tick().await;
282            let event = self.generate_single_event(index).await?;
283            events.push(event);
284            index += 1;
285        }
286
287        Ok(events)
288    }
289
290    /// Generate events based on conditions
291    async fn generate_conditional_based(&mut self) -> Result<Vec<GeneratedEvent>> {
292        let mut events = Vec::new();
293        let mut index = 0;
294        let max_events = 1000; // Safety limit
295
296        while index < max_events {
297            // Check conditions
298            let mut should_continue = true;
299            let conditions = self.config.conditions.clone(); // Clone to avoid borrow issues
300
301            for condition in &conditions {
302                if self.evaluate_condition(condition, &events) {
303                    match &condition.action {
304                        ConditionAction::GenerateEvent => {
305                            let event = self.generate_single_event(index).await?;
306                            events.push(event);
307                            index += 1;
308                        }
309                        ConditionAction::Stop => {
310                            should_continue = false;
311                            break;
312                        }
313                        ConditionAction::ChangeRate(_rate) => {
314                            // Update rate (would require mutable config)
315                        }
316                        ConditionAction::TransitionScenario(_scenario) => {
317                            // Transition to new scenario
318                            self.scenario_state.context.clear();
319                        }
320                    }
321                }
322            }
323
324            if !should_continue {
325                break;
326            }
327
328            // Prevent infinite loop
329            if events.is_empty() && index > 10 {
330                break;
331            }
332
333            tokio::time::sleep(Duration::from_millis(100)).await;
334        }
335
336        Ok(events)
337    }
338
339    /// Generate a single event
340    async fn generate_single_event(&mut self, index: usize) -> Result<GeneratedEvent> {
341        let data = match self.config.mode {
342            ReplayMode::Static => self.generate_static_event(),
343            ReplayMode::Augmented => self.generate_augmented_event(index).await?,
344            ReplayMode::Generated => self.generate_llm_event(index).await?,
345        };
346
347        self.sequence += 1;
348        self.scenario_state.events_generated += 1;
349        self.scenario_state.last_event = Some(data.clone());
350
351        Ok(GeneratedEvent::new(self.config.event_type.clone(), data, self.sequence))
352    }
353
354    /// Generate static event (fallback)
355    fn generate_static_event(&self) -> Value {
356        if let Some(schema) = &self.config.event_schema {
357            schema.clone()
358        } else {
359            serde_json::json!({
360                "type": self.config.event_type,
361                "timestamp": chrono::Utc::now().to_rfc3339()
362            })
363        }
364    }
365
366    /// Generate augmented event (base + LLM enhancement)
367    async fn generate_augmented_event(&mut self, index: usize) -> Result<Value> {
368        let mut base_event = self.generate_static_event();
369
370        if let Some(rag_engine) = &self.rag_engine {
371            let narrative =
372                self.config.narrative.as_ref().ok_or_else(|| {
373                    Error::config("narrative is required for augmented replay mode")
374                })?;
375            let prompt = self.build_augmentation_prompt(narrative, index)?;
376
377            let enhancement = rag_engine.generate_text(&prompt).await?;
378            let enhancement_json = self.parse_json_response(&enhancement)?;
379
380            // Merge enhancement with base event
381            if let (Some(base_obj), Some(enhancement_obj)) =
382                (base_event.as_object_mut(), enhancement_json.as_object())
383            {
384                for (key, value) in enhancement_obj {
385                    base_obj.insert(key.clone(), value.clone());
386                }
387            } else {
388                base_event = enhancement_json;
389            }
390        }
391
392        Ok(base_event)
393    }
394
395    /// Generate fully LLM-generated event
396    async fn generate_llm_event(&mut self, index: usize) -> Result<Value> {
397        let rag_engine = self
398            .rag_engine
399            .as_ref()
400            .ok_or_else(|| Error::generic("RAG engine not initialized for generated mode"))?;
401
402        let narrative = self
403            .config
404            .narrative
405            .as_ref()
406            .ok_or_else(|| Error::config("narrative is required for generated replay mode"))?;
407        let prompt = self.build_generation_prompt(narrative, index)?;
408
409        let response = rag_engine.generate_text(&prompt).await?;
410        self.parse_json_response(&response)
411    }
412
413    /// Build augmentation prompt
414    fn build_augmentation_prompt(&self, narrative: &str, index: usize) -> Result<String> {
415        let mut prompt = format!(
416            "Enhance this event data based on the following scenario:\n\n{}\n\n",
417            narrative
418        );
419
420        prompt.push_str(&format!("Event #{} (out of ongoing stream)\n\n", index + 1));
421
422        if let Some(last_event) = &self.scenario_state.last_event {
423            prompt.push_str(&format!(
424                "Previous event:\n{}\n\n",
425                serde_json::to_string_pretty(last_event).unwrap_or_default()
426            ));
427        }
428
429        if self.config.progressive_evolution {
430            prompt.push_str("Progressively evolve the scenario with each event.\n");
431        }
432
433        if let Some(schema) = &self.config.event_schema {
434            prompt.push_str(&format!(
435                "Conform to this schema:\n{}\n\n",
436                serde_json::to_string_pretty(schema).unwrap_or_default()
437            ));
438        }
439
440        prompt.push_str("Return valid JSON only for the enhanced event data.");
441
442        Ok(prompt)
443    }
444
445    /// Build generation prompt
446    fn build_generation_prompt(&self, narrative: &str, index: usize) -> Result<String> {
447        let mut prompt =
448            format!("Generate realistic event data for this scenario:\n\n{}\n\n", narrative);
449
450        prompt.push_str(&format!("Event type: {}\n", self.config.event_type));
451        prompt.push_str(&format!("Event #{}\n\n", index + 1));
452
453        if let Some(last_event) = &self.scenario_state.last_event {
454            prompt.push_str(&format!(
455                "Previous event:\n{}\n\n",
456                serde_json::to_string_pretty(last_event).unwrap_or_default()
457            ));
458
459            if self.config.progressive_evolution {
460                prompt.push_str("Naturally evolve from the previous event.\n");
461            }
462        }
463
464        if let Some(schema) = &self.config.event_schema {
465            prompt.push_str(&format!(
466                "Conform to this schema:\n{}\n\n",
467                serde_json::to_string_pretty(schema).unwrap_or_default()
468            ));
469        }
470
471        prompt.push_str("Return valid JSON only.");
472
473        Ok(prompt)
474    }
475
476    /// Parse JSON response from LLM
477    fn parse_json_response(&self, response: &str) -> Result<Value> {
478        let trimmed = response.trim();
479
480        // Try to extract from markdown code blocks
481        let json_str = if trimmed.starts_with("```json") {
482            trimmed
483                .strip_prefix("```json")
484                .and_then(|s| s.strip_suffix("```"))
485                .unwrap_or(trimmed)
486                .trim()
487        } else if trimmed.starts_with("```") {
488            trimmed
489                .strip_prefix("```")
490                .and_then(|s| s.strip_suffix("```"))
491                .unwrap_or(trimmed)
492                .trim()
493        } else {
494            trimmed
495        };
496
497        // Parse JSON
498        serde_json::from_str(json_str)
499            .map_err(|e| Error::generic(format!("Failed to parse LLM response as JSON: {}", e)))
500    }
501
502    /// Evaluate a condition expression against current event state.
503    ///
504    /// Supported expressions:
505    /// - `count <op> <n>` — compare event count (e.g., "count < 100", "count >= 50")
506    /// - `sequence <op> <n>` — compare current sequence number
507    /// - `events_generated <op> <n>` — compare total events generated in scenario
508    /// - `true` / `false` — literal boolean
509    ///
510    /// where `<op>` is one of: `<`, `>`, `<=`, `>=`, `==`, `!=`
511    fn evaluate_condition(&self, condition: &EventCondition, events: &[GeneratedEvent]) -> bool {
512        let expr = condition.expression.trim();
513
514        // Literal booleans
515        if expr.eq_ignore_ascii_case("true") {
516            return true;
517        }
518        if expr.eq_ignore_ascii_case("false") {
519            return false;
520        }
521
522        // Parse "<variable> <op> <value>" expressions
523        let parts: Vec<&str> = expr.splitn(3, ' ').collect();
524        if parts.len() != 3 {
525            tracing::warn!(
526                expression = expr,
527                "Unrecognized condition expression, defaulting to true"
528            );
529            return true;
530        }
531
532        let variable = parts[0];
533        let operator = parts[1];
534        let threshold: i64 = match parts[2].parse() {
535            Ok(v) => v,
536            Err(_) => {
537                tracing::warn!(
538                    expression = expr,
539                    "Could not parse threshold as integer, defaulting to true"
540                );
541                return true;
542            }
543        };
544
545        let actual: i64 = match variable {
546            "count" => events.len() as i64,
547            "sequence" => self.sequence as i64,
548            "events_generated" => self.scenario_state.events_generated as i64,
549            _ => {
550                tracing::warn!(variable, "Unknown condition variable, defaulting to true");
551                return true;
552            }
553        };
554
555        match operator {
556            "<" => actual < threshold,
557            ">" => actual > threshold,
558            "<=" => actual <= threshold,
559            ">=" => actual >= threshold,
560            "==" => actual == threshold,
561            "!=" => actual != threshold,
562            _ => {
563                tracing::warn!(operator, "Unknown comparison operator, defaulting to true");
564                true
565            }
566        }
567    }
568
569    /// Reset the engine state
570    pub fn reset(&mut self) {
571        self.sequence = 0;
572        self.scenario_state = ScenarioState::default();
573    }
574
575    /// Get current sequence number
576    pub fn sequence(&self) -> usize {
577        self.sequence
578    }
579
580    /// Get events generated count
581    pub fn events_generated(&self) -> usize {
582        self.scenario_state.events_generated
583    }
584}
585
586/// Pre-defined scenario templates
587pub mod scenarios {
588    use super::*;
589
590    /// Stock market simulation scenario
591    pub fn stock_market_scenario() -> ReplayAugmentationConfig {
592        ReplayAugmentationConfig {
593            mode: ReplayMode::Generated,
594            narrative: Some(
595                "Simulate 10 minutes of live market data with realistic price movements, \
596                 volume changes, and occasional volatility spikes."
597                    .to_string(),
598            ),
599            event_type: "market_tick".to_string(),
600            event_schema: Some(serde_json::json!({
601                "symbol": "string",
602                "price": "number",
603                "volume": "number",
604                "timestamp": "string"
605            })),
606            strategy: EventStrategy::TimeBased,
607            duration_secs: Some(600), // 10 minutes
608            event_rate: Some(2.0),    // 2 events per second
609            ..Default::default()
610        }
611    }
612
613    /// Chat application scenario
614    pub fn chat_messages_scenario() -> ReplayAugmentationConfig {
615        ReplayAugmentationConfig {
616            mode: ReplayMode::Generated,
617            narrative: Some(
618                "Simulate a group chat conversation between 3-5 users discussing a project, \
619                 with natural message pacing and realistic content."
620                    .to_string(),
621            ),
622            event_type: "chat_message".to_string(),
623            event_schema: Some(serde_json::json!({
624                "user_id": "string",
625                "message": "string",
626                "timestamp": "string"
627            })),
628            strategy: EventStrategy::CountBased,
629            event_count: Some(50),
630            event_rate: Some(0.5), // One message every 2 seconds
631            ..Default::default()
632        }
633    }
634
635    /// IoT sensor data scenario
636    pub fn iot_sensor_scenario() -> ReplayAugmentationConfig {
637        ReplayAugmentationConfig {
638            mode: ReplayMode::Generated,
639            narrative: Some(
640                "Simulate IoT sensor readings from a smart building with temperature, \
641                 humidity, and occupancy data showing daily patterns."
642                    .to_string(),
643            ),
644            event_type: "sensor_reading".to_string(),
645            event_schema: Some(serde_json::json!({
646                "sensor_id": "string",
647                "temperature": "number",
648                "humidity": "number",
649                "occupancy": "number",
650                "timestamp": "string"
651            })),
652            strategy: EventStrategy::CountBased,
653            event_count: Some(100),
654            event_rate: Some(1.0),
655            progressive_evolution: true,
656            ..Default::default()
657        }
658    }
659}
660
661#[cfg(test)]
662mod tests {
663    use super::*;
664
665    #[test]
666    fn test_replay_mode_default() {
667        assert_eq!(ReplayMode::default(), ReplayMode::Static);
668    }
669
670    #[test]
671    fn test_event_strategy_variants() {
672        let time_based = EventStrategy::TimeBased;
673        let count_based = EventStrategy::CountBased;
674        let conditional = EventStrategy::ConditionalBased;
675
676        assert!(matches!(time_based, EventStrategy::TimeBased));
677        assert!(matches!(count_based, EventStrategy::CountBased));
678        assert!(matches!(conditional, EventStrategy::ConditionalBased));
679    }
680
681    #[test]
682    fn test_generated_event_creation() {
683        let data = serde_json::json!({"test": "value"});
684        let event = GeneratedEvent::new("test_event".to_string(), data, 1);
685
686        assert_eq!(event.event_type, "test_event");
687        assert_eq!(event.sequence, 1);
688    }
689
690    #[test]
691    fn test_replay_config_validation_missing_narrative() {
692        let config = ReplayAugmentationConfig {
693            mode: ReplayMode::Generated,
694            ..Default::default()
695        };
696
697        assert!(ReplayAugmentationEngine::validate_config(&config).is_err());
698    }
699
700    #[test]
701    fn test_scenario_templates() {
702        let stock_scenario = scenarios::stock_market_scenario();
703        assert_eq!(stock_scenario.mode, ReplayMode::Generated);
704        assert!(stock_scenario.narrative.is_some());
705
706        let chat_scenario = scenarios::chat_messages_scenario();
707        assert_eq!(chat_scenario.event_type, "chat_message");
708
709        let iot_scenario = scenarios::iot_sensor_scenario();
710        assert!(iot_scenario.progressive_evolution);
711    }
712}