1use crate::rag::{RagConfig, RagEngine};
7use crate::{Error, Result};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::time::Duration;
11use tokio::time::interval;
12
13#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
15#[serde(rename_all = "snake_case")]
16pub enum ReplayMode {
17 #[default]
19 Static,
20 Augmented,
22 Generated,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
28#[serde(rename_all = "snake_case")]
29pub enum EventStrategy {
30 TimeBased,
32 CountBased,
34 ConditionalBased,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ReplayAugmentationConfig {
41 pub mode: ReplayMode,
43 pub narrative: Option<String>,
45 pub event_type: String,
47 pub event_schema: Option<Value>,
49 pub strategy: EventStrategy,
51 pub duration_secs: Option<u64>,
53 pub event_count: Option<usize>,
55 pub event_rate: Option<f64>,
57 pub conditions: Vec<EventCondition>,
59 pub rag_config: Option<RagConfig>,
61 pub progressive_evolution: bool,
63}
64
65impl Default for ReplayAugmentationConfig {
66 fn default() -> Self {
67 Self {
68 mode: ReplayMode::Static,
69 narrative: None,
70 event_type: "event".to_string(),
71 event_schema: None,
72 strategy: EventStrategy::CountBased,
73 duration_secs: None,
74 event_count: Some(10),
75 event_rate: Some(1.0),
76 conditions: Vec::new(),
77 rag_config: None,
78 progressive_evolution: true,
79 }
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct EventCondition {
86 pub name: String,
88 pub expression: String,
90 pub action: ConditionAction,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
96#[serde(rename_all = "snake_case")]
97pub enum ConditionAction {
98 GenerateEvent,
100 Stop,
102 ChangeRate(u64), TransitionScenario(String),
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct GeneratedEvent {
111 pub event_type: String,
113 pub timestamp: chrono::DateTime<chrono::Utc>,
115 pub data: Value,
117 pub sequence: usize,
119 pub metadata: std::collections::HashMap<String, String>,
121}
122
123impl GeneratedEvent {
124 pub fn new(event_type: String, data: Value, sequence: usize) -> Self {
126 Self {
127 event_type,
128 timestamp: chrono::Utc::now(),
129 data,
130 sequence,
131 metadata: std::collections::HashMap::new(),
132 }
133 }
134
135 pub fn with_metadata(mut self, key: String, value: String) -> Self {
137 self.metadata.insert(key, value);
138 self
139 }
140
141 pub fn to_json(&self) -> Result<String> {
143 serde_json::to_string(self)
144 .map_err(|e| Error::generic(format!("Failed to serialize event: {}", e)))
145 }
146}
147
148pub struct ReplayAugmentationEngine {
150 config: ReplayAugmentationConfig,
152 rag_engine: Option<RagEngine>,
154 sequence: usize,
156 scenario_state: ScenarioState,
158}
159
160#[derive(Debug, Clone)]
162struct ScenarioState {
163 _current_time: std::time::Instant,
165 events_generated: usize,
167 last_event: Option<Value>,
169 context: Vec<String>,
171}
172
173impl Default for ScenarioState {
174 fn default() -> Self {
175 Self {
176 _current_time: std::time::Instant::now(),
177 events_generated: 0,
178 last_event: None,
179 context: Vec::new(),
180 }
181 }
182}
183
184impl ReplayAugmentationEngine {
185 pub fn new(config: ReplayAugmentationConfig) -> Result<Self> {
187 Self::validate_config(&config)?;
188
189 let rag_engine = if config.mode != ReplayMode::Static {
190 let rag_config = config.rag_config.clone().unwrap_or_default();
191 Some(RagEngine::new(rag_config))
192 } else {
193 None
194 };
195
196 Ok(Self {
197 config,
198 rag_engine,
199 sequence: 0,
200 scenario_state: ScenarioState::default(),
201 })
202 }
203
204 fn validate_config(config: &ReplayAugmentationConfig) -> Result<()> {
206 if config.mode != ReplayMode::Static && config.narrative.is_none() {
207 return Err(Error::generic(
208 "Narrative is required for augmented or generated replay modes",
209 ));
210 }
211
212 match config.strategy {
213 EventStrategy::TimeBased => {
214 if config.duration_secs.is_none() {
215 return Err(Error::generic(
216 "Duration must be specified for time-based strategy",
217 ));
218 }
219 }
220 EventStrategy::CountBased => {
221 if config.event_count.is_none() {
222 return Err(Error::generic(
223 "Event count must be specified for count-based strategy",
224 ));
225 }
226 }
227 EventStrategy::ConditionalBased => {
228 if config.conditions.is_empty() {
229 return Err(Error::generic(
230 "Conditions must be specified for conditional-based strategy",
231 ));
232 }
233 }
234 }
235
236 Ok(())
237 }
238
239 pub async fn generate_stream(&mut self) -> Result<Vec<GeneratedEvent>> {
241 match self.config.strategy {
242 EventStrategy::CountBased => self.generate_count_based().await,
243 EventStrategy::TimeBased => self.generate_time_based().await,
244 EventStrategy::ConditionalBased => self.generate_conditional_based().await,
245 }
246 }
247
248 async fn generate_count_based(&mut self) -> Result<Vec<GeneratedEvent>> {
250 let count = self.config.event_count.unwrap_or(10);
251 let mut events = Vec::with_capacity(count);
252
253 for i in 0..count {
254 let event = self.generate_single_event(i).await?;
255 events.push(event);
256
257 if let Some(rate) = self.config.event_rate {
259 if rate > 0.0 {
260 let delay_ms = (1000.0 / rate) as u64;
261 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
262 }
263 }
264 }
265
266 Ok(events)
267 }
268
269 async fn generate_time_based(&mut self) -> Result<Vec<GeneratedEvent>> {
271 let duration = Duration::from_secs(self.config.duration_secs.unwrap_or(60));
272 let rate = self.config.event_rate.unwrap_or(1.0);
273 let interval_ms = (1000.0 / rate) as u64;
274
275 let mut events = Vec::new();
276 let mut ticker = interval(Duration::from_millis(interval_ms));
277 let start = std::time::Instant::now();
278
279 let mut index = 0;
280 while start.elapsed() < duration {
281 ticker.tick().await;
282 let event = self.generate_single_event(index).await?;
283 events.push(event);
284 index += 1;
285 }
286
287 Ok(events)
288 }
289
290 async fn generate_conditional_based(&mut self) -> Result<Vec<GeneratedEvent>> {
292 let mut events = Vec::new();
293 let mut index = 0;
294 let max_events = 1000; while index < max_events {
297 let mut should_continue = true;
299 let conditions = self.config.conditions.clone(); for condition in &conditions {
302 if self.evaluate_condition(condition, &events) {
303 match &condition.action {
304 ConditionAction::GenerateEvent => {
305 let event = self.generate_single_event(index).await?;
306 events.push(event);
307 index += 1;
308 }
309 ConditionAction::Stop => {
310 should_continue = false;
311 break;
312 }
313 ConditionAction::ChangeRate(_rate) => {
314 }
316 ConditionAction::TransitionScenario(_scenario) => {
317 self.scenario_state.context.clear();
319 }
320 }
321 }
322 }
323
324 if !should_continue {
325 break;
326 }
327
328 if events.is_empty() && index > 10 {
330 break;
331 }
332
333 tokio::time::sleep(Duration::from_millis(100)).await;
334 }
335
336 Ok(events)
337 }
338
339 async fn generate_single_event(&mut self, index: usize) -> Result<GeneratedEvent> {
341 let data = match self.config.mode {
342 ReplayMode::Static => self.generate_static_event(),
343 ReplayMode::Augmented => self.generate_augmented_event(index).await?,
344 ReplayMode::Generated => self.generate_llm_event(index).await?,
345 };
346
347 self.sequence += 1;
348 self.scenario_state.events_generated += 1;
349 self.scenario_state.last_event = Some(data.clone());
350
351 Ok(GeneratedEvent::new(self.config.event_type.clone(), data, self.sequence))
352 }
353
354 fn generate_static_event(&self) -> Value {
356 if let Some(schema) = &self.config.event_schema {
357 schema.clone()
358 } else {
359 serde_json::json!({
360 "type": self.config.event_type,
361 "timestamp": chrono::Utc::now().to_rfc3339()
362 })
363 }
364 }
365
366 async fn generate_augmented_event(&mut self, index: usize) -> Result<Value> {
368 let mut base_event = self.generate_static_event();
369
370 if let Some(rag_engine) = &self.rag_engine {
371 let narrative =
372 self.config.narrative.as_ref().ok_or_else(|| {
373 Error::config("narrative is required for augmented replay mode")
374 })?;
375 let prompt = self.build_augmentation_prompt(narrative, index)?;
376
377 let enhancement = rag_engine.generate_text(&prompt).await?;
378 let enhancement_json = self.parse_json_response(&enhancement)?;
379
380 if let (Some(base_obj), Some(enhancement_obj)) =
382 (base_event.as_object_mut(), enhancement_json.as_object())
383 {
384 for (key, value) in enhancement_obj {
385 base_obj.insert(key.clone(), value.clone());
386 }
387 } else {
388 base_event = enhancement_json;
389 }
390 }
391
392 Ok(base_event)
393 }
394
395 async fn generate_llm_event(&mut self, index: usize) -> Result<Value> {
397 let rag_engine = self
398 .rag_engine
399 .as_ref()
400 .ok_or_else(|| Error::generic("RAG engine not initialized for generated mode"))?;
401
402 let narrative = self
403 .config
404 .narrative
405 .as_ref()
406 .ok_or_else(|| Error::config("narrative is required for generated replay mode"))?;
407 let prompt = self.build_generation_prompt(narrative, index)?;
408
409 let response = rag_engine.generate_text(&prompt).await?;
410 self.parse_json_response(&response)
411 }
412
413 fn build_augmentation_prompt(&self, narrative: &str, index: usize) -> Result<String> {
415 let mut prompt = format!(
416 "Enhance this event data based on the following scenario:\n\n{}\n\n",
417 narrative
418 );
419
420 prompt.push_str(&format!("Event #{} (out of ongoing stream)\n\n", index + 1));
421
422 if let Some(last_event) = &self.scenario_state.last_event {
423 prompt.push_str(&format!(
424 "Previous event:\n{}\n\n",
425 serde_json::to_string_pretty(last_event).unwrap_or_default()
426 ));
427 }
428
429 if self.config.progressive_evolution {
430 prompt.push_str("Progressively evolve the scenario with each event.\n");
431 }
432
433 if let Some(schema) = &self.config.event_schema {
434 prompt.push_str(&format!(
435 "Conform to this schema:\n{}\n\n",
436 serde_json::to_string_pretty(schema).unwrap_or_default()
437 ));
438 }
439
440 prompt.push_str("Return valid JSON only for the enhanced event data.");
441
442 Ok(prompt)
443 }
444
445 fn build_generation_prompt(&self, narrative: &str, index: usize) -> Result<String> {
447 let mut prompt =
448 format!("Generate realistic event data for this scenario:\n\n{}\n\n", narrative);
449
450 prompt.push_str(&format!("Event type: {}\n", self.config.event_type));
451 prompt.push_str(&format!("Event #{}\n\n", index + 1));
452
453 if let Some(last_event) = &self.scenario_state.last_event {
454 prompt.push_str(&format!(
455 "Previous event:\n{}\n\n",
456 serde_json::to_string_pretty(last_event).unwrap_or_default()
457 ));
458
459 if self.config.progressive_evolution {
460 prompt.push_str("Naturally evolve from the previous event.\n");
461 }
462 }
463
464 if let Some(schema) = &self.config.event_schema {
465 prompt.push_str(&format!(
466 "Conform to this schema:\n{}\n\n",
467 serde_json::to_string_pretty(schema).unwrap_or_default()
468 ));
469 }
470
471 prompt.push_str("Return valid JSON only.");
472
473 Ok(prompt)
474 }
475
476 fn parse_json_response(&self, response: &str) -> Result<Value> {
478 let trimmed = response.trim();
479
480 let json_str = if trimmed.starts_with("```json") {
482 trimmed
483 .strip_prefix("```json")
484 .and_then(|s| s.strip_suffix("```"))
485 .unwrap_or(trimmed)
486 .trim()
487 } else if trimmed.starts_with("```") {
488 trimmed
489 .strip_prefix("```")
490 .and_then(|s| s.strip_suffix("```"))
491 .unwrap_or(trimmed)
492 .trim()
493 } else {
494 trimmed
495 };
496
497 serde_json::from_str(json_str)
499 .map_err(|e| Error::generic(format!("Failed to parse LLM response as JSON: {}", e)))
500 }
501
502 fn evaluate_condition(&self, condition: &EventCondition, events: &[GeneratedEvent]) -> bool {
512 let expr = condition.expression.trim();
513
514 if expr.eq_ignore_ascii_case("true") {
516 return true;
517 }
518 if expr.eq_ignore_ascii_case("false") {
519 return false;
520 }
521
522 let parts: Vec<&str> = expr.splitn(3, ' ').collect();
524 if parts.len() != 3 {
525 tracing::warn!(
526 expression = expr,
527 "Unrecognized condition expression, defaulting to true"
528 );
529 return true;
530 }
531
532 let variable = parts[0];
533 let operator = parts[1];
534 let threshold: i64 = match parts[2].parse() {
535 Ok(v) => v,
536 Err(_) => {
537 tracing::warn!(
538 expression = expr,
539 "Could not parse threshold as integer, defaulting to true"
540 );
541 return true;
542 }
543 };
544
545 let actual: i64 = match variable {
546 "count" => events.len() as i64,
547 "sequence" => self.sequence as i64,
548 "events_generated" => self.scenario_state.events_generated as i64,
549 _ => {
550 tracing::warn!(variable, "Unknown condition variable, defaulting to true");
551 return true;
552 }
553 };
554
555 match operator {
556 "<" => actual < threshold,
557 ">" => actual > threshold,
558 "<=" => actual <= threshold,
559 ">=" => actual >= threshold,
560 "==" => actual == threshold,
561 "!=" => actual != threshold,
562 _ => {
563 tracing::warn!(operator, "Unknown comparison operator, defaulting to true");
564 true
565 }
566 }
567 }
568
569 pub fn reset(&mut self) {
571 self.sequence = 0;
572 self.scenario_state = ScenarioState::default();
573 }
574
575 pub fn sequence(&self) -> usize {
577 self.sequence
578 }
579
580 pub fn events_generated(&self) -> usize {
582 self.scenario_state.events_generated
583 }
584}
585
586pub mod scenarios {
588 use super::*;
589
590 pub fn stock_market_scenario() -> ReplayAugmentationConfig {
592 ReplayAugmentationConfig {
593 mode: ReplayMode::Generated,
594 narrative: Some(
595 "Simulate 10 minutes of live market data with realistic price movements, \
596 volume changes, and occasional volatility spikes."
597 .to_string(),
598 ),
599 event_type: "market_tick".to_string(),
600 event_schema: Some(serde_json::json!({
601 "symbol": "string",
602 "price": "number",
603 "volume": "number",
604 "timestamp": "string"
605 })),
606 strategy: EventStrategy::TimeBased,
607 duration_secs: Some(600), event_rate: Some(2.0), ..Default::default()
610 }
611 }
612
613 pub fn chat_messages_scenario() -> ReplayAugmentationConfig {
615 ReplayAugmentationConfig {
616 mode: ReplayMode::Generated,
617 narrative: Some(
618 "Simulate a group chat conversation between 3-5 users discussing a project, \
619 with natural message pacing and realistic content."
620 .to_string(),
621 ),
622 event_type: "chat_message".to_string(),
623 event_schema: Some(serde_json::json!({
624 "user_id": "string",
625 "message": "string",
626 "timestamp": "string"
627 })),
628 strategy: EventStrategy::CountBased,
629 event_count: Some(50),
630 event_rate: Some(0.5), ..Default::default()
632 }
633 }
634
635 pub fn iot_sensor_scenario() -> ReplayAugmentationConfig {
637 ReplayAugmentationConfig {
638 mode: ReplayMode::Generated,
639 narrative: Some(
640 "Simulate IoT sensor readings from a smart building with temperature, \
641 humidity, and occupancy data showing daily patterns."
642 .to_string(),
643 ),
644 event_type: "sensor_reading".to_string(),
645 event_schema: Some(serde_json::json!({
646 "sensor_id": "string",
647 "temperature": "number",
648 "humidity": "number",
649 "occupancy": "number",
650 "timestamp": "string"
651 })),
652 strategy: EventStrategy::CountBased,
653 event_count: Some(100),
654 event_rate: Some(1.0),
655 progressive_evolution: true,
656 ..Default::default()
657 }
658 }
659}
660
661#[cfg(test)]
662mod tests {
663 use super::*;
664
665 #[test]
666 fn test_replay_mode_default() {
667 assert_eq!(ReplayMode::default(), ReplayMode::Static);
668 }
669
670 #[test]
671 fn test_event_strategy_variants() {
672 let time_based = EventStrategy::TimeBased;
673 let count_based = EventStrategy::CountBased;
674 let conditional = EventStrategy::ConditionalBased;
675
676 assert!(matches!(time_based, EventStrategy::TimeBased));
677 assert!(matches!(count_based, EventStrategy::CountBased));
678 assert!(matches!(conditional, EventStrategy::ConditionalBased));
679 }
680
681 #[test]
682 fn test_generated_event_creation() {
683 let data = serde_json::json!({"test": "value"});
684 let event = GeneratedEvent::new("test_event".to_string(), data, 1);
685
686 assert_eq!(event.event_type, "test_event");
687 assert_eq!(event.sequence, 1);
688 }
689
690 #[test]
691 fn test_replay_config_validation_missing_narrative() {
692 let config = ReplayAugmentationConfig {
693 mode: ReplayMode::Generated,
694 ..Default::default()
695 };
696
697 assert!(ReplayAugmentationEngine::validate_config(&config).is_err());
698 }
699
700 #[test]
701 fn test_scenario_templates() {
702 let stock_scenario = scenarios::stock_market_scenario();
703 assert_eq!(stock_scenario.mode, ReplayMode::Generated);
704 assert!(stock_scenario.narrative.is_some());
705
706 let chat_scenario = scenarios::chat_messages_scenario();
707 assert_eq!(chat_scenario.event_type, "chat_message");
708
709 let iot_scenario = scenarios::iot_sensor_scenario();
710 assert!(iot_scenario.progressive_evolution);
711 }
712}