Skip to main content

mockforge_data/
replay_augmentation.rs

1//! LLM-powered replay augmentation for WebSocket and GraphQL subscriptions
2//!
3//! This module enables AI-driven event stream generation for real-time protocols,
4//! allowing users to define high-level scenarios that generate realistic event sequences.
5
6use crate::rag::{RagConfig, RagEngine};
7use crate::{Error, Result};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::time::Duration;
11use tokio::time::interval;
12
13/// Replay augmentation mode
14#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
15#[serde(rename_all = "snake_case")]
16pub enum ReplayMode {
17    /// Static replay from pre-recorded events
18    #[default]
19    Static,
20    /// LLM-augmented replay with scenario-based generation
21    Augmented,
22    /// Fully generated event stream from narrative description
23    Generated,
24}
25
26/// Event generation strategy
27#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
28#[serde(rename_all = "snake_case")]
29pub enum EventStrategy {
30    /// Time-based event generation
31    TimeBased,
32    /// Count-based event generation
33    CountBased,
34    /// Condition-based event generation
35    ConditionalBased,
36}
37
38/// Replay augmentation configuration
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ReplayAugmentationConfig {
41    /// Replay mode
42    pub mode: ReplayMode,
43    /// Narrative description of the scenario
44    pub narrative: Option<String>,
45    /// Event type/name
46    pub event_type: String,
47    /// Event schema (optional JSON schema)
48    pub event_schema: Option<Value>,
49    /// Event generation strategy
50    pub strategy: EventStrategy,
51    /// Duration to replay (for time-based)
52    pub duration_secs: Option<u64>,
53    /// Number of events to generate (for count-based)
54    pub event_count: Option<usize>,
55    /// Event rate (events per second)
56    pub event_rate: Option<f64>,
57    /// Conditions for event generation
58    pub conditions: Vec<EventCondition>,
59    /// RAG configuration for LLM
60    pub rag_config: Option<RagConfig>,
61    /// Enable progressive scenario evolution
62    pub progressive_evolution: bool,
63}
64
65impl Default for ReplayAugmentationConfig {
66    fn default() -> Self {
67        Self {
68            mode: ReplayMode::Static,
69            narrative: None,
70            event_type: "event".to_string(),
71            event_schema: None,
72            strategy: EventStrategy::CountBased,
73            duration_secs: None,
74            event_count: Some(10),
75            event_rate: Some(1.0),
76            conditions: Vec::new(),
77            rag_config: None,
78            progressive_evolution: true,
79        }
80    }
81}
82
83/// Event generation condition
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct EventCondition {
86    /// Condition name/description
87    pub name: String,
88    /// Condition expression (simplified)
89    pub expression: String,
90    /// Action to take when condition is met
91    pub action: ConditionAction,
92}
93
94/// Condition action
95#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
96#[serde(rename_all = "snake_case")]
97pub enum ConditionAction {
98    /// Generate a new event
99    GenerateEvent,
100    /// Stop event generation
101    Stop,
102    /// Change event rate
103    ChangeRate(u64), // events per second
104    /// Transition to new scenario
105    TransitionScenario(String),
106}
107
108/// Generated event
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct GeneratedEvent {
111    /// Event type
112    pub event_type: String,
113    /// Event timestamp
114    pub timestamp: chrono::DateTime<chrono::Utc>,
115    /// Event data
116    pub data: Value,
117    /// Sequence number
118    pub sequence: usize,
119    /// Event metadata
120    pub metadata: std::collections::HashMap<String, String>,
121}
122
123impl GeneratedEvent {
124    /// Create a new generated event
125    pub fn new(event_type: String, data: Value, sequence: usize) -> Self {
126        Self {
127            event_type,
128            timestamp: chrono::Utc::now(),
129            data,
130            sequence,
131            metadata: std::collections::HashMap::new(),
132        }
133    }
134
135    /// Add metadata
136    pub fn with_metadata(mut self, key: String, value: String) -> Self {
137        self.metadata.insert(key, value);
138        self
139    }
140
141    /// Convert to JSON
142    pub fn to_json(&self) -> Result<String> {
143        serde_json::to_string(self)
144            .map_err(|e| Error::generic(format!("Failed to serialize event: {}", e)))
145    }
146}
147
148/// Replay augmentation engine
149pub struct ReplayAugmentationEngine {
150    /// Configuration
151    config: ReplayAugmentationConfig,
152    /// RAG engine for LLM-based generation
153    rag_engine: Option<RagEngine>,
154    /// Event sequence counter
155    sequence: usize,
156    /// Current scenario state
157    scenario_state: ScenarioState,
158}
159
160/// Scenario state tracking
161#[derive(Debug, Clone)]
162struct ScenarioState {
163    /// Current timestamp in scenario
164    _current_time: std::time::Instant,
165    /// Events generated so far
166    events_generated: usize,
167    /// Last event data (for progressive evolution)
168    last_event: Option<Value>,
169    /// Scenario context
170    context: Vec<String>,
171}
172
173impl Default for ScenarioState {
174    fn default() -> Self {
175        Self {
176            _current_time: std::time::Instant::now(),
177            events_generated: 0,
178            last_event: None,
179            context: Vec::new(),
180        }
181    }
182}
183
184impl ReplayAugmentationEngine {
185    /// Create a new replay augmentation engine
186    pub fn new(config: ReplayAugmentationConfig) -> Result<Self> {
187        Self::validate_config(&config)?;
188
189        let rag_engine = if config.mode != ReplayMode::Static {
190            let rag_config = config.rag_config.clone().unwrap_or_default();
191            Some(RagEngine::new(rag_config))
192        } else {
193            None
194        };
195
196        Ok(Self {
197            config,
198            rag_engine,
199            sequence: 0,
200            scenario_state: ScenarioState::default(),
201        })
202    }
203
204    /// Validate configuration
205    fn validate_config(config: &ReplayAugmentationConfig) -> Result<()> {
206        if config.mode != ReplayMode::Static && config.narrative.is_none() {
207            return Err(Error::generic(
208                "Narrative is required for augmented or generated replay modes",
209            ));
210        }
211
212        match config.strategy {
213            EventStrategy::TimeBased => {
214                if config.duration_secs.is_none() {
215                    return Err(Error::generic(
216                        "Duration must be specified for time-based strategy",
217                    ));
218                }
219            }
220            EventStrategy::CountBased => {
221                if config.event_count.is_none() {
222                    return Err(Error::generic(
223                        "Event count must be specified for count-based strategy",
224                    ));
225                }
226            }
227            EventStrategy::ConditionalBased => {
228                if config.conditions.is_empty() {
229                    return Err(Error::generic(
230                        "Conditions must be specified for conditional-based strategy",
231                    ));
232                }
233            }
234        }
235
236        Ok(())
237    }
238
239    /// Generate event stream based on configuration
240    pub async fn generate_stream(&mut self) -> Result<Vec<GeneratedEvent>> {
241        match self.config.strategy {
242            EventStrategy::CountBased => self.generate_count_based().await,
243            EventStrategy::TimeBased => self.generate_time_based().await,
244            EventStrategy::ConditionalBased => self.generate_conditional_based().await,
245        }
246    }
247
248    /// Generate events based on count
249    async fn generate_count_based(&mut self) -> Result<Vec<GeneratedEvent>> {
250        let count = self.config.event_count.unwrap_or(10);
251        let mut events = Vec::with_capacity(count);
252
253        for i in 0..count {
254            let event = self.generate_single_event(i).await?;
255            events.push(event);
256
257            // Add delay between events if rate is specified
258            if let Some(rate) = self.config.event_rate {
259                if rate > 0.0 {
260                    let delay_ms = (1000.0 / rate) as u64;
261                    tokio::time::sleep(Duration::from_millis(delay_ms)).await;
262                }
263            }
264        }
265
266        Ok(events)
267    }
268
269    /// Generate events based on time duration
270    async fn generate_time_based(&mut self) -> Result<Vec<GeneratedEvent>> {
271        let duration = Duration::from_secs(self.config.duration_secs.unwrap_or(60));
272        let rate = self.config.event_rate.unwrap_or(1.0);
273        let interval_ms = (1000.0 / rate) as u64;
274
275        let mut events = Vec::new();
276        let mut ticker = interval(Duration::from_millis(interval_ms));
277        let start = std::time::Instant::now();
278
279        let mut index = 0;
280        while start.elapsed() < duration {
281            ticker.tick().await;
282            let event = self.generate_single_event(index).await?;
283            events.push(event);
284            index += 1;
285        }
286
287        Ok(events)
288    }
289
290    /// Generate events based on conditions
291    async fn generate_conditional_based(&mut self) -> Result<Vec<GeneratedEvent>> {
292        let mut events = Vec::new();
293        let mut index = 0;
294        let max_events = 1000; // Safety limit
295
296        while index < max_events {
297            // Check conditions
298            let mut should_continue = true;
299            let conditions = self.config.conditions.clone(); // Clone to avoid borrow issues
300
301            for condition in &conditions {
302                if self.evaluate_condition(condition, &events) {
303                    match &condition.action {
304                        ConditionAction::GenerateEvent => {
305                            let event = self.generate_single_event(index).await?;
306                            events.push(event);
307                            index += 1;
308                        }
309                        ConditionAction::Stop => {
310                            should_continue = false;
311                            break;
312                        }
313                        ConditionAction::ChangeRate(_rate) => {
314                            // Update rate (would require mutable config)
315                        }
316                        ConditionAction::TransitionScenario(_scenario) => {
317                            // Transition to new scenario
318                            self.scenario_state.context.clear();
319                        }
320                    }
321                }
322            }
323
324            if !should_continue {
325                break;
326            }
327
328            // Prevent infinite loop
329            if events.is_empty() && index > 10 {
330                break;
331            }
332
333            tokio::time::sleep(Duration::from_millis(100)).await;
334        }
335
336        Ok(events)
337    }
338
339    /// Generate a single event
340    async fn generate_single_event(&mut self, index: usize) -> Result<GeneratedEvent> {
341        let data = match self.config.mode {
342            ReplayMode::Static => self.generate_static_event(),
343            ReplayMode::Augmented => self.generate_augmented_event(index).await?,
344            ReplayMode::Generated => self.generate_llm_event(index).await?,
345        };
346
347        self.sequence += 1;
348        self.scenario_state.events_generated += 1;
349        self.scenario_state.last_event = Some(data.clone());
350
351        Ok(GeneratedEvent::new(self.config.event_type.clone(), data, self.sequence))
352    }
353
354    /// Generate static event (fallback)
355    fn generate_static_event(&self) -> Value {
356        if let Some(schema) = &self.config.event_schema {
357            schema.clone()
358        } else {
359            serde_json::json!({
360                "type": self.config.event_type,
361                "timestamp": chrono::Utc::now().to_rfc3339()
362            })
363        }
364    }
365
366    /// Generate augmented event (base + LLM enhancement)
367    async fn generate_augmented_event(&mut self, index: usize) -> Result<Value> {
368        let mut base_event = self.generate_static_event();
369
370        if let Some(rag_engine) = &self.rag_engine {
371            let narrative = self.config.narrative.as_ref().unwrap();
372            let prompt = self.build_augmentation_prompt(narrative, index)?;
373
374            let enhancement = rag_engine.generate_text(&prompt).await?;
375            let enhancement_json = self.parse_json_response(&enhancement)?;
376
377            // Merge enhancement with base event
378            if let (Some(base_obj), Some(enhancement_obj)) =
379                (base_event.as_object_mut(), enhancement_json.as_object())
380            {
381                for (key, value) in enhancement_obj {
382                    base_obj.insert(key.clone(), value.clone());
383                }
384            } else {
385                base_event = enhancement_json;
386            }
387        }
388
389        Ok(base_event)
390    }
391
392    /// Generate fully LLM-generated event
393    async fn generate_llm_event(&mut self, index: usize) -> Result<Value> {
394        let rag_engine = self
395            .rag_engine
396            .as_ref()
397            .ok_or_else(|| Error::generic("RAG engine not initialized for generated mode"))?;
398
399        let narrative = self.config.narrative.as_ref().unwrap();
400        let prompt = self.build_generation_prompt(narrative, index)?;
401
402        let response = rag_engine.generate_text(&prompt).await?;
403        self.parse_json_response(&response)
404    }
405
406    /// Build augmentation prompt
407    fn build_augmentation_prompt(&self, narrative: &str, index: usize) -> Result<String> {
408        let mut prompt = format!(
409            "Enhance this event data based on the following scenario:\n\n{}\n\n",
410            narrative
411        );
412
413        prompt.push_str(&format!("Event #{} (out of ongoing stream)\n\n", index + 1));
414
415        if let Some(last_event) = &self.scenario_state.last_event {
416            prompt.push_str(&format!(
417                "Previous event:\n{}\n\n",
418                serde_json::to_string_pretty(last_event).unwrap_or_default()
419            ));
420        }
421
422        if self.config.progressive_evolution {
423            prompt.push_str("Progressively evolve the scenario with each event.\n");
424        }
425
426        if let Some(schema) = &self.config.event_schema {
427            prompt.push_str(&format!(
428                "Conform to this schema:\n{}\n\n",
429                serde_json::to_string_pretty(schema).unwrap_or_default()
430            ));
431        }
432
433        prompt.push_str("Return valid JSON only for the enhanced event data.");
434
435        Ok(prompt)
436    }
437
438    /// Build generation prompt
439    fn build_generation_prompt(&self, narrative: &str, index: usize) -> Result<String> {
440        let mut prompt =
441            format!("Generate realistic event data for this scenario:\n\n{}\n\n", narrative);
442
443        prompt.push_str(&format!("Event type: {}\n", self.config.event_type));
444        prompt.push_str(&format!("Event #{}\n\n", index + 1));
445
446        if let Some(last_event) = &self.scenario_state.last_event {
447            prompt.push_str(&format!(
448                "Previous event:\n{}\n\n",
449                serde_json::to_string_pretty(last_event).unwrap_or_default()
450            ));
451
452            if self.config.progressive_evolution {
453                prompt.push_str("Naturally evolve from the previous event.\n");
454            }
455        }
456
457        if let Some(schema) = &self.config.event_schema {
458            prompt.push_str(&format!(
459                "Conform to this schema:\n{}\n\n",
460                serde_json::to_string_pretty(schema).unwrap_or_default()
461            ));
462        }
463
464        prompt.push_str("Return valid JSON only.");
465
466        Ok(prompt)
467    }
468
469    /// Parse JSON response from LLM
470    fn parse_json_response(&self, response: &str) -> Result<Value> {
471        let trimmed = response.trim();
472
473        // Try to extract from markdown code blocks
474        let json_str = if trimmed.starts_with("```json") {
475            trimmed
476                .strip_prefix("```json")
477                .and_then(|s| s.strip_suffix("```"))
478                .unwrap_or(trimmed)
479                .trim()
480        } else if trimmed.starts_with("```") {
481            trimmed
482                .strip_prefix("```")
483                .and_then(|s| s.strip_suffix("```"))
484                .unwrap_or(trimmed)
485                .trim()
486        } else {
487            trimmed
488        };
489
490        // Parse JSON
491        serde_json::from_str(json_str)
492            .map_err(|e| Error::generic(format!("Failed to parse LLM response as JSON: {}", e)))
493    }
494
495    /// Evaluate a condition expression against current event state.
496    ///
497    /// Supported expressions:
498    /// - `count <op> <n>` — compare event count (e.g., "count < 100", "count >= 50")
499    /// - `sequence <op> <n>` — compare current sequence number
500    /// - `events_generated <op> <n>` — compare total events generated in scenario
501    /// - `true` / `false` — literal boolean
502    ///
503    /// where `<op>` is one of: `<`, `>`, `<=`, `>=`, `==`, `!=`
504    fn evaluate_condition(&self, condition: &EventCondition, events: &[GeneratedEvent]) -> bool {
505        let expr = condition.expression.trim();
506
507        // Literal booleans
508        if expr.eq_ignore_ascii_case("true") {
509            return true;
510        }
511        if expr.eq_ignore_ascii_case("false") {
512            return false;
513        }
514
515        // Parse "<variable> <op> <value>" expressions
516        let parts: Vec<&str> = expr.splitn(3, ' ').collect();
517        if parts.len() != 3 {
518            tracing::warn!(
519                expression = expr,
520                "Unrecognized condition expression, defaulting to true"
521            );
522            return true;
523        }
524
525        let variable = parts[0];
526        let operator = parts[1];
527        let threshold: i64 = match parts[2].parse() {
528            Ok(v) => v,
529            Err(_) => {
530                tracing::warn!(
531                    expression = expr,
532                    "Could not parse threshold as integer, defaulting to true"
533                );
534                return true;
535            }
536        };
537
538        let actual: i64 = match variable {
539            "count" => events.len() as i64,
540            "sequence" => self.sequence as i64,
541            "events_generated" => self.scenario_state.events_generated as i64,
542            _ => {
543                tracing::warn!(variable, "Unknown condition variable, defaulting to true");
544                return true;
545            }
546        };
547
548        match operator {
549            "<" => actual < threshold,
550            ">" => actual > threshold,
551            "<=" => actual <= threshold,
552            ">=" => actual >= threshold,
553            "==" => actual == threshold,
554            "!=" => actual != threshold,
555            _ => {
556                tracing::warn!(operator, "Unknown comparison operator, defaulting to true");
557                true
558            }
559        }
560    }
561
562    /// Reset the engine state
563    pub fn reset(&mut self) {
564        self.sequence = 0;
565        self.scenario_state = ScenarioState::default();
566    }
567
568    /// Get current sequence number
569    pub fn sequence(&self) -> usize {
570        self.sequence
571    }
572
573    /// Get events generated count
574    pub fn events_generated(&self) -> usize {
575        self.scenario_state.events_generated
576    }
577}
578
579/// Pre-defined scenario templates
580pub mod scenarios {
581    use super::*;
582
583    /// Stock market simulation scenario
584    pub fn stock_market_scenario() -> ReplayAugmentationConfig {
585        ReplayAugmentationConfig {
586            mode: ReplayMode::Generated,
587            narrative: Some(
588                "Simulate 10 minutes of live market data with realistic price movements, \
589                 volume changes, and occasional volatility spikes."
590                    .to_string(),
591            ),
592            event_type: "market_tick".to_string(),
593            event_schema: Some(serde_json::json!({
594                "symbol": "string",
595                "price": "number",
596                "volume": "number",
597                "timestamp": "string"
598            })),
599            strategy: EventStrategy::TimeBased,
600            duration_secs: Some(600), // 10 minutes
601            event_rate: Some(2.0),    // 2 events per second
602            ..Default::default()
603        }
604    }
605
606    /// Chat application scenario
607    pub fn chat_messages_scenario() -> ReplayAugmentationConfig {
608        ReplayAugmentationConfig {
609            mode: ReplayMode::Generated,
610            narrative: Some(
611                "Simulate a group chat conversation between 3-5 users discussing a project, \
612                 with natural message pacing and realistic content."
613                    .to_string(),
614            ),
615            event_type: "chat_message".to_string(),
616            event_schema: Some(serde_json::json!({
617                "user_id": "string",
618                "message": "string",
619                "timestamp": "string"
620            })),
621            strategy: EventStrategy::CountBased,
622            event_count: Some(50),
623            event_rate: Some(0.5), // One message every 2 seconds
624            ..Default::default()
625        }
626    }
627
628    /// IoT sensor data scenario
629    pub fn iot_sensor_scenario() -> ReplayAugmentationConfig {
630        ReplayAugmentationConfig {
631            mode: ReplayMode::Generated,
632            narrative: Some(
633                "Simulate IoT sensor readings from a smart building with temperature, \
634                 humidity, and occupancy data showing daily patterns."
635                    .to_string(),
636            ),
637            event_type: "sensor_reading".to_string(),
638            event_schema: Some(serde_json::json!({
639                "sensor_id": "string",
640                "temperature": "number",
641                "humidity": "number",
642                "occupancy": "number",
643                "timestamp": "string"
644            })),
645            strategy: EventStrategy::CountBased,
646            event_count: Some(100),
647            event_rate: Some(1.0),
648            progressive_evolution: true,
649            ..Default::default()
650        }
651    }
652}
653
654#[cfg(test)]
655mod tests {
656    use super::*;
657
658    #[test]
659    fn test_replay_mode_default() {
660        assert_eq!(ReplayMode::default(), ReplayMode::Static);
661    }
662
663    #[test]
664    fn test_event_strategy_variants() {
665        let time_based = EventStrategy::TimeBased;
666        let count_based = EventStrategy::CountBased;
667        let conditional = EventStrategy::ConditionalBased;
668
669        assert!(matches!(time_based, EventStrategy::TimeBased));
670        assert!(matches!(count_based, EventStrategy::CountBased));
671        assert!(matches!(conditional, EventStrategy::ConditionalBased));
672    }
673
674    #[test]
675    fn test_generated_event_creation() {
676        let data = serde_json::json!({"test": "value"});
677        let event = GeneratedEvent::new("test_event".to_string(), data, 1);
678
679        assert_eq!(event.event_type, "test_event");
680        assert_eq!(event.sequence, 1);
681    }
682
683    #[test]
684    fn test_replay_config_validation_missing_narrative() {
685        let config = ReplayAugmentationConfig {
686            mode: ReplayMode::Generated,
687            ..Default::default()
688        };
689
690        assert!(ReplayAugmentationEngine::validate_config(&config).is_err());
691    }
692
693    #[test]
694    fn test_scenario_templates() {
695        let stock_scenario = scenarios::stock_market_scenario();
696        assert_eq!(stock_scenario.mode, ReplayMode::Generated);
697        assert!(stock_scenario.narrative.is_some());
698
699        let chat_scenario = scenarios::chat_messages_scenario();
700        assert_eq!(chat_scenario.event_type, "chat_message");
701
702        let iot_scenario = scenarios::iot_sensor_scenario();
703        assert!(iot_scenario.progressive_evolution);
704    }
705}