1use crate::rag::{RagConfig, RagEngine};
7use mockforge_core::{Error, Result};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::time::Duration;
11use tokio::time::interval;
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
15#[serde(rename_all = "snake_case")]
16pub enum ReplayMode {
17 Static,
19 Augmented,
21 Generated,
23}
24
25impl Default for ReplayMode {
26 fn default() -> Self {
27 Self::Static
28 }
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
33#[serde(rename_all = "snake_case")]
34pub enum EventStrategy {
35 TimeBased,
37 CountBased,
39 ConditionalBased,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ReplayAugmentationConfig {
46 pub mode: ReplayMode,
48 pub narrative: Option<String>,
50 pub event_type: String,
52 pub event_schema: Option<Value>,
54 pub strategy: EventStrategy,
56 pub duration_secs: Option<u64>,
58 pub event_count: Option<usize>,
60 pub event_rate: Option<f64>,
62 pub conditions: Vec<EventCondition>,
64 pub rag_config: Option<RagConfig>,
66 pub progressive_evolution: bool,
68}
69
70impl Default for ReplayAugmentationConfig {
71 fn default() -> Self {
72 Self {
73 mode: ReplayMode::Static,
74 narrative: None,
75 event_type: "event".to_string(),
76 event_schema: None,
77 strategy: EventStrategy::CountBased,
78 duration_secs: None,
79 event_count: Some(10),
80 event_rate: Some(1.0),
81 conditions: Vec::new(),
82 rag_config: None,
83 progressive_evolution: true,
84 }
85 }
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct EventCondition {
91 pub name: String,
93 pub expression: String,
95 pub action: ConditionAction,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
101#[serde(rename_all = "snake_case")]
102pub enum ConditionAction {
103 GenerateEvent,
105 Stop,
107 ChangeRate(u64), TransitionScenario(String),
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct GeneratedEvent {
116 pub event_type: String,
118 pub timestamp: chrono::DateTime<chrono::Utc>,
120 pub data: Value,
122 pub sequence: usize,
124 pub metadata: std::collections::HashMap<String, String>,
126}
127
128impl GeneratedEvent {
129 pub fn new(event_type: String, data: Value, sequence: usize) -> Self {
131 Self {
132 event_type,
133 timestamp: chrono::Utc::now(),
134 data,
135 sequence,
136 metadata: std::collections::HashMap::new(),
137 }
138 }
139
140 pub fn with_metadata(mut self, key: String, value: String) -> Self {
142 self.metadata.insert(key, value);
143 self
144 }
145
146 pub fn to_json(&self) -> Result<String> {
148 serde_json::to_string(self)
149 .map_err(|e| Error::generic(format!("Failed to serialize event: {}", e)))
150 }
151}
152
153pub struct ReplayAugmentationEngine {
155 config: ReplayAugmentationConfig,
157 rag_engine: Option<RagEngine>,
159 sequence: usize,
161 scenario_state: ScenarioState,
163}
164
165#[derive(Debug, Clone)]
167struct ScenarioState {
168 _current_time: std::time::Instant,
170 events_generated: usize,
172 last_event: Option<Value>,
174 context: Vec<String>,
176}
177
178impl Default for ScenarioState {
179 fn default() -> Self {
180 Self {
181 _current_time: std::time::Instant::now(),
182 events_generated: 0,
183 last_event: None,
184 context: Vec::new(),
185 }
186 }
187}
188
189impl ReplayAugmentationEngine {
190 pub fn new(config: ReplayAugmentationConfig) -> Result<Self> {
192 Self::validate_config(&config)?;
193
194 let rag_engine = if config.mode != ReplayMode::Static {
195 let rag_config = config.rag_config.clone().unwrap_or_default();
196 Some(RagEngine::new(rag_config))
197 } else {
198 None
199 };
200
201 Ok(Self {
202 config,
203 rag_engine,
204 sequence: 0,
205 scenario_state: ScenarioState::default(),
206 })
207 }
208
209 fn validate_config(config: &ReplayAugmentationConfig) -> Result<()> {
211 if config.mode != ReplayMode::Static && config.narrative.is_none() {
212 return Err(Error::generic(
213 "Narrative is required for augmented or generated replay modes",
214 ));
215 }
216
217 match config.strategy {
218 EventStrategy::TimeBased => {
219 if config.duration_secs.is_none() {
220 return Err(Error::generic(
221 "Duration must be specified for time-based strategy",
222 ));
223 }
224 }
225 EventStrategy::CountBased => {
226 if config.event_count.is_none() {
227 return Err(Error::generic(
228 "Event count must be specified for count-based strategy",
229 ));
230 }
231 }
232 EventStrategy::ConditionalBased => {
233 if config.conditions.is_empty() {
234 return Err(Error::generic(
235 "Conditions must be specified for conditional-based strategy",
236 ));
237 }
238 }
239 }
240
241 Ok(())
242 }
243
244 pub async fn generate_stream(&mut self) -> Result<Vec<GeneratedEvent>> {
246 match self.config.strategy {
247 EventStrategy::CountBased => self.generate_count_based().await,
248 EventStrategy::TimeBased => self.generate_time_based().await,
249 EventStrategy::ConditionalBased => self.generate_conditional_based().await,
250 }
251 }
252
253 async fn generate_count_based(&mut self) -> Result<Vec<GeneratedEvent>> {
255 let count = self.config.event_count.unwrap_or(10);
256 let mut events = Vec::with_capacity(count);
257
258 for i in 0..count {
259 let event = self.generate_single_event(i).await?;
260 events.push(event);
261
262 if let Some(rate) = self.config.event_rate {
264 if rate > 0.0 {
265 let delay_ms = (1000.0 / rate) as u64;
266 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
267 }
268 }
269 }
270
271 Ok(events)
272 }
273
274 async fn generate_time_based(&mut self) -> Result<Vec<GeneratedEvent>> {
276 let duration = Duration::from_secs(self.config.duration_secs.unwrap_or(60));
277 let rate = self.config.event_rate.unwrap_or(1.0);
278 let interval_ms = (1000.0 / rate) as u64;
279
280 let mut events = Vec::new();
281 let mut ticker = interval(Duration::from_millis(interval_ms));
282 let start = std::time::Instant::now();
283
284 let mut index = 0;
285 while start.elapsed() < duration {
286 ticker.tick().await;
287 let event = self.generate_single_event(index).await?;
288 events.push(event);
289 index += 1;
290 }
291
292 Ok(events)
293 }
294
295 async fn generate_conditional_based(&mut self) -> Result<Vec<GeneratedEvent>> {
297 let mut events = Vec::new();
298 let mut index = 0;
299 let max_events = 1000; while index < max_events {
302 let mut should_continue = true;
304 let conditions = self.config.conditions.clone(); for condition in &conditions {
307 if self.evaluate_condition(condition, &events) {
308 match &condition.action {
309 ConditionAction::GenerateEvent => {
310 let event = self.generate_single_event(index).await?;
311 events.push(event);
312 index += 1;
313 }
314 ConditionAction::Stop => {
315 should_continue = false;
316 break;
317 }
318 ConditionAction::ChangeRate(_rate) => {
319 }
321 ConditionAction::TransitionScenario(_scenario) => {
322 self.scenario_state.context.clear();
324 }
325 }
326 }
327 }
328
329 if !should_continue {
330 break;
331 }
332
333 if events.is_empty() && index > 10 {
335 break;
336 }
337
338 tokio::time::sleep(Duration::from_millis(100)).await;
339 }
340
341 Ok(events)
342 }
343
344 async fn generate_single_event(&mut self, index: usize) -> Result<GeneratedEvent> {
346 let data = match self.config.mode {
347 ReplayMode::Static => self.generate_static_event(),
348 ReplayMode::Augmented => self.generate_augmented_event(index).await?,
349 ReplayMode::Generated => self.generate_llm_event(index).await?,
350 };
351
352 self.sequence += 1;
353 self.scenario_state.events_generated += 1;
354 self.scenario_state.last_event = Some(data.clone());
355
356 Ok(GeneratedEvent::new(self.config.event_type.clone(), data, self.sequence))
357 }
358
359 fn generate_static_event(&self) -> Value {
361 if let Some(schema) = &self.config.event_schema {
362 schema.clone()
363 } else {
364 serde_json::json!({
365 "type": self.config.event_type,
366 "timestamp": chrono::Utc::now().to_rfc3339()
367 })
368 }
369 }
370
371 async fn generate_augmented_event(&mut self, index: usize) -> Result<Value> {
373 let mut base_event = self.generate_static_event();
374
375 if let Some(rag_engine) = &self.rag_engine {
376 let narrative = self.config.narrative.as_ref().unwrap();
377 let prompt = self.build_augmentation_prompt(narrative, index)?;
378
379 let enhancement = rag_engine.generate_text(&prompt).await?;
380 let enhancement_json = self.parse_json_response(&enhancement)?;
381
382 if let (Some(base_obj), Some(enhancement_obj)) =
384 (base_event.as_object_mut(), enhancement_json.as_object())
385 {
386 for (key, value) in enhancement_obj {
387 base_obj.insert(key.clone(), value.clone());
388 }
389 } else {
390 base_event = enhancement_json;
391 }
392 }
393
394 Ok(base_event)
395 }
396
397 async fn generate_llm_event(&mut self, index: usize) -> Result<Value> {
399 let rag_engine = self
400 .rag_engine
401 .as_ref()
402 .ok_or_else(|| Error::generic("RAG engine not initialized for generated mode"))?;
403
404 let narrative = self.config.narrative.as_ref().unwrap();
405 let prompt = self.build_generation_prompt(narrative, index)?;
406
407 let response = rag_engine.generate_text(&prompt).await?;
408 self.parse_json_response(&response)
409 }
410
411 fn build_augmentation_prompt(&self, narrative: &str, index: usize) -> Result<String> {
413 let mut prompt = format!(
414 "Enhance this event data based on the following scenario:\n\n{}\n\n",
415 narrative
416 );
417
418 prompt.push_str(&format!("Event #{} (out of ongoing stream)\n\n", index + 1));
419
420 if let Some(last_event) = &self.scenario_state.last_event {
421 prompt.push_str(&format!(
422 "Previous event:\n{}\n\n",
423 serde_json::to_string_pretty(last_event).unwrap_or_default()
424 ));
425 }
426
427 if self.config.progressive_evolution {
428 prompt.push_str("Progressively evolve the scenario with each event.\n");
429 }
430
431 if let Some(schema) = &self.config.event_schema {
432 prompt.push_str(&format!(
433 "Conform to this schema:\n{}\n\n",
434 serde_json::to_string_pretty(schema).unwrap_or_default()
435 ));
436 }
437
438 prompt.push_str("Return valid JSON only for the enhanced event data.");
439
440 Ok(prompt)
441 }
442
443 fn build_generation_prompt(&self, narrative: &str, index: usize) -> Result<String> {
445 let mut prompt =
446 format!("Generate realistic event data for this scenario:\n\n{}\n\n", narrative);
447
448 prompt.push_str(&format!("Event type: {}\n", self.config.event_type));
449 prompt.push_str(&format!("Event #{}\n\n", index + 1));
450
451 if let Some(last_event) = &self.scenario_state.last_event {
452 prompt.push_str(&format!(
453 "Previous event:\n{}\n\n",
454 serde_json::to_string_pretty(last_event).unwrap_or_default()
455 ));
456
457 if self.config.progressive_evolution {
458 prompt.push_str("Naturally evolve from the previous event.\n");
459 }
460 }
461
462 if let Some(schema) = &self.config.event_schema {
463 prompt.push_str(&format!(
464 "Conform to this schema:\n{}\n\n",
465 serde_json::to_string_pretty(schema).unwrap_or_default()
466 ));
467 }
468
469 prompt.push_str("Return valid JSON only.");
470
471 Ok(prompt)
472 }
473
474 fn parse_json_response(&self, response: &str) -> Result<Value> {
476 let trimmed = response.trim();
477
478 let json_str = if trimmed.starts_with("```json") {
480 trimmed
481 .strip_prefix("```json")
482 .and_then(|s| s.strip_suffix("```"))
483 .unwrap_or(trimmed)
484 .trim()
485 } else if trimmed.starts_with("```") {
486 trimmed
487 .strip_prefix("```")
488 .and_then(|s| s.strip_suffix("```"))
489 .unwrap_or(trimmed)
490 .trim()
491 } else {
492 trimmed
493 };
494
495 serde_json::from_str(json_str)
497 .map_err(|e| Error::generic(format!("Failed to parse LLM response as JSON: {}", e)))
498 }
499
500 fn evaluate_condition(&self, _condition: &EventCondition, events: &[GeneratedEvent]) -> bool {
502 events.len() < 100 }
506
507 pub fn reset(&mut self) {
509 self.sequence = 0;
510 self.scenario_state = ScenarioState::default();
511 }
512
513 pub fn sequence(&self) -> usize {
515 self.sequence
516 }
517
518 pub fn events_generated(&self) -> usize {
520 self.scenario_state.events_generated
521 }
522}
523
524pub mod scenarios {
526 use super::*;
527
528 pub fn stock_market_scenario() -> ReplayAugmentationConfig {
530 ReplayAugmentationConfig {
531 mode: ReplayMode::Generated,
532 narrative: Some(
533 "Simulate 10 minutes of live market data with realistic price movements, \
534 volume changes, and occasional volatility spikes."
535 .to_string(),
536 ),
537 event_type: "market_tick".to_string(),
538 event_schema: Some(serde_json::json!({
539 "symbol": "string",
540 "price": "number",
541 "volume": "number",
542 "timestamp": "string"
543 })),
544 strategy: EventStrategy::TimeBased,
545 duration_secs: Some(600), event_rate: Some(2.0), ..Default::default()
548 }
549 }
550
551 pub fn chat_messages_scenario() -> ReplayAugmentationConfig {
553 ReplayAugmentationConfig {
554 mode: ReplayMode::Generated,
555 narrative: Some(
556 "Simulate a group chat conversation between 3-5 users discussing a project, \
557 with natural message pacing and realistic content."
558 .to_string(),
559 ),
560 event_type: "chat_message".to_string(),
561 event_schema: Some(serde_json::json!({
562 "user_id": "string",
563 "message": "string",
564 "timestamp": "string"
565 })),
566 strategy: EventStrategy::CountBased,
567 event_count: Some(50),
568 event_rate: Some(0.5), ..Default::default()
570 }
571 }
572
573 pub fn iot_sensor_scenario() -> ReplayAugmentationConfig {
575 ReplayAugmentationConfig {
576 mode: ReplayMode::Generated,
577 narrative: Some(
578 "Simulate IoT sensor readings from a smart building with temperature, \
579 humidity, and occupancy data showing daily patterns."
580 .to_string(),
581 ),
582 event_type: "sensor_reading".to_string(),
583 event_schema: Some(serde_json::json!({
584 "sensor_id": "string",
585 "temperature": "number",
586 "humidity": "number",
587 "occupancy": "number",
588 "timestamp": "string"
589 })),
590 strategy: EventStrategy::CountBased,
591 event_count: Some(100),
592 event_rate: Some(1.0),
593 progressive_evolution: true,
594 ..Default::default()
595 }
596 }
597}
598
599#[cfg(test)]
600mod tests {
601 use super::*;
602
603 #[test]
604 fn test_replay_mode_default() {
605 assert_eq!(ReplayMode::default(), ReplayMode::Static);
606 }
607
608 #[test]
609 fn test_event_strategy_variants() {
610 let time_based = EventStrategy::TimeBased;
611 let count_based = EventStrategy::CountBased;
612 let conditional = EventStrategy::ConditionalBased;
613
614 assert!(matches!(time_based, EventStrategy::TimeBased));
615 assert!(matches!(count_based, EventStrategy::CountBased));
616 assert!(matches!(conditional, EventStrategy::ConditionalBased));
617 }
618
619 #[test]
620 fn test_generated_event_creation() {
621 let data = serde_json::json!({"test": "value"});
622 let event = GeneratedEvent::new("test_event".to_string(), data, 1);
623
624 assert_eq!(event.event_type, "test_event");
625 assert_eq!(event.sequence, 1);
626 }
627
628 #[test]
629 fn test_replay_config_validation_missing_narrative() {
630 let config = ReplayAugmentationConfig {
631 mode: ReplayMode::Generated,
632 ..Default::default()
633 };
634
635 assert!(ReplayAugmentationEngine::validate_config(&config).is_err());
636 }
637
638 #[test]
639 fn test_scenario_templates() {
640 let stock_scenario = scenarios::stock_market_scenario();
641 assert_eq!(stock_scenario.mode, ReplayMode::Generated);
642 assert!(stock_scenario.narrative.is_some());
643
644 let chat_scenario = scenarios::chat_messages_scenario();
645 assert_eq!(chat_scenario.event_type, "chat_message");
646
647 let iot_scenario = scenarios::iot_sensor_scenario();
648 assert!(iot_scenario.progressive_evolution);
649 }
650}