1use crate::rag::{RagConfig, RagEngine};
7use crate::{Error, Result};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::time::Duration;
11use tokio::time::interval;
12
13#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
15#[serde(rename_all = "snake_case")]
16pub enum ReplayMode {
17 #[default]
19 Static,
20 Augmented,
22 Generated,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
28#[serde(rename_all = "snake_case")]
29pub enum EventStrategy {
30 TimeBased,
32 CountBased,
34 ConditionalBased,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ReplayAugmentationConfig {
41 pub mode: ReplayMode,
43 pub narrative: Option<String>,
45 pub event_type: String,
47 pub event_schema: Option<Value>,
49 pub strategy: EventStrategy,
51 pub duration_secs: Option<u64>,
53 pub event_count: Option<usize>,
55 pub event_rate: Option<f64>,
57 pub conditions: Vec<EventCondition>,
59 pub rag_config: Option<RagConfig>,
61 pub progressive_evolution: bool,
63}
64
65impl Default for ReplayAugmentationConfig {
66 fn default() -> Self {
67 Self {
68 mode: ReplayMode::Static,
69 narrative: None,
70 event_type: "event".to_string(),
71 event_schema: None,
72 strategy: EventStrategy::CountBased,
73 duration_secs: None,
74 event_count: Some(10),
75 event_rate: Some(1.0),
76 conditions: Vec::new(),
77 rag_config: None,
78 progressive_evolution: true,
79 }
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct EventCondition {
86 pub name: String,
88 pub expression: String,
90 pub action: ConditionAction,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
96#[serde(rename_all = "snake_case")]
97pub enum ConditionAction {
98 GenerateEvent,
100 Stop,
102 ChangeRate(u64), TransitionScenario(String),
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct GeneratedEvent {
111 pub event_type: String,
113 pub timestamp: chrono::DateTime<chrono::Utc>,
115 pub data: Value,
117 pub sequence: usize,
119 pub metadata: std::collections::HashMap<String, String>,
121}
122
123impl GeneratedEvent {
124 pub fn new(event_type: String, data: Value, sequence: usize) -> Self {
126 Self {
127 event_type,
128 timestamp: chrono::Utc::now(),
129 data,
130 sequence,
131 metadata: std::collections::HashMap::new(),
132 }
133 }
134
135 pub fn with_metadata(mut self, key: String, value: String) -> Self {
137 self.metadata.insert(key, value);
138 self
139 }
140
141 pub fn to_json(&self) -> Result<String> {
143 serde_json::to_string(self)
144 .map_err(|e| Error::generic(format!("Failed to serialize event: {}", e)))
145 }
146}
147
148pub struct ReplayAugmentationEngine {
150 config: ReplayAugmentationConfig,
152 rag_engine: Option<RagEngine>,
154 sequence: usize,
156 scenario_state: ScenarioState,
158}
159
160#[derive(Debug, Clone)]
162struct ScenarioState {
163 _current_time: std::time::Instant,
165 events_generated: usize,
167 last_event: Option<Value>,
169 context: Vec<String>,
171}
172
173impl Default for ScenarioState {
174 fn default() -> Self {
175 Self {
176 _current_time: std::time::Instant::now(),
177 events_generated: 0,
178 last_event: None,
179 context: Vec::new(),
180 }
181 }
182}
183
184impl ReplayAugmentationEngine {
185 pub fn new(config: ReplayAugmentationConfig) -> Result<Self> {
187 Self::validate_config(&config)?;
188
189 let rag_engine = if config.mode != ReplayMode::Static {
190 let rag_config = config.rag_config.clone().unwrap_or_default();
191 Some(RagEngine::new(rag_config))
192 } else {
193 None
194 };
195
196 Ok(Self {
197 config,
198 rag_engine,
199 sequence: 0,
200 scenario_state: ScenarioState::default(),
201 })
202 }
203
204 fn validate_config(config: &ReplayAugmentationConfig) -> Result<()> {
206 if config.mode != ReplayMode::Static && config.narrative.is_none() {
207 return Err(Error::generic(
208 "Narrative is required for augmented or generated replay modes",
209 ));
210 }
211
212 match config.strategy {
213 EventStrategy::TimeBased => {
214 if config.duration_secs.is_none() {
215 return Err(Error::generic(
216 "Duration must be specified for time-based strategy",
217 ));
218 }
219 }
220 EventStrategy::CountBased => {
221 if config.event_count.is_none() {
222 return Err(Error::generic(
223 "Event count must be specified for count-based strategy",
224 ));
225 }
226 }
227 EventStrategy::ConditionalBased => {
228 if config.conditions.is_empty() {
229 return Err(Error::generic(
230 "Conditions must be specified for conditional-based strategy",
231 ));
232 }
233 }
234 }
235
236 Ok(())
237 }
238
239 pub async fn generate_stream(&mut self) -> Result<Vec<GeneratedEvent>> {
241 match self.config.strategy {
242 EventStrategy::CountBased => self.generate_count_based().await,
243 EventStrategy::TimeBased => self.generate_time_based().await,
244 EventStrategy::ConditionalBased => self.generate_conditional_based().await,
245 }
246 }
247
248 async fn generate_count_based(&mut self) -> Result<Vec<GeneratedEvent>> {
250 let count = self.config.event_count.unwrap_or(10);
251 let mut events = Vec::with_capacity(count);
252
253 for i in 0..count {
254 let event = self.generate_single_event(i).await?;
255 events.push(event);
256
257 if let Some(rate) = self.config.event_rate {
259 if rate > 0.0 {
260 let delay_ms = (1000.0 / rate) as u64;
261 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
262 }
263 }
264 }
265
266 Ok(events)
267 }
268
269 async fn generate_time_based(&mut self) -> Result<Vec<GeneratedEvent>> {
271 let duration = Duration::from_secs(self.config.duration_secs.unwrap_or(60));
272 let rate = self.config.event_rate.unwrap_or(1.0);
273 let interval_ms = (1000.0 / rate) as u64;
274
275 let mut events = Vec::new();
276 let mut ticker = interval(Duration::from_millis(interval_ms));
277 let start = std::time::Instant::now();
278
279 let mut index = 0;
280 while start.elapsed() < duration {
281 ticker.tick().await;
282 let event = self.generate_single_event(index).await?;
283 events.push(event);
284 index += 1;
285 }
286
287 Ok(events)
288 }
289
290 async fn generate_conditional_based(&mut self) -> Result<Vec<GeneratedEvent>> {
292 let mut events = Vec::new();
293 let mut index = 0;
294 let max_events = 1000; while index < max_events {
297 let mut should_continue = true;
299 let conditions = self.config.conditions.clone(); for condition in &conditions {
302 if self.evaluate_condition(condition, &events) {
303 match &condition.action {
304 ConditionAction::GenerateEvent => {
305 let event = self.generate_single_event(index).await?;
306 events.push(event);
307 index += 1;
308 }
309 ConditionAction::Stop => {
310 should_continue = false;
311 break;
312 }
313 ConditionAction::ChangeRate(_rate) => {
314 }
316 ConditionAction::TransitionScenario(_scenario) => {
317 self.scenario_state.context.clear();
319 }
320 }
321 }
322 }
323
324 if !should_continue {
325 break;
326 }
327
328 if events.is_empty() && index > 10 {
330 break;
331 }
332
333 tokio::time::sleep(Duration::from_millis(100)).await;
334 }
335
336 Ok(events)
337 }
338
339 async fn generate_single_event(&mut self, index: usize) -> Result<GeneratedEvent> {
341 let data = match self.config.mode {
342 ReplayMode::Static => self.generate_static_event(),
343 ReplayMode::Augmented => self.generate_augmented_event(index).await?,
344 ReplayMode::Generated => self.generate_llm_event(index).await?,
345 };
346
347 self.sequence += 1;
348 self.scenario_state.events_generated += 1;
349 self.scenario_state.last_event = Some(data.clone());
350
351 Ok(GeneratedEvent::new(self.config.event_type.clone(), data, self.sequence))
352 }
353
354 fn generate_static_event(&self) -> Value {
356 if let Some(schema) = &self.config.event_schema {
357 schema.clone()
358 } else {
359 serde_json::json!({
360 "type": self.config.event_type,
361 "timestamp": chrono::Utc::now().to_rfc3339()
362 })
363 }
364 }
365
366 async fn generate_augmented_event(&mut self, index: usize) -> Result<Value> {
368 let mut base_event = self.generate_static_event();
369
370 if let Some(rag_engine) = &self.rag_engine {
371 let narrative = self.config.narrative.as_ref().unwrap();
372 let prompt = self.build_augmentation_prompt(narrative, index)?;
373
374 let enhancement = rag_engine.generate_text(&prompt).await?;
375 let enhancement_json = self.parse_json_response(&enhancement)?;
376
377 if let (Some(base_obj), Some(enhancement_obj)) =
379 (base_event.as_object_mut(), enhancement_json.as_object())
380 {
381 for (key, value) in enhancement_obj {
382 base_obj.insert(key.clone(), value.clone());
383 }
384 } else {
385 base_event = enhancement_json;
386 }
387 }
388
389 Ok(base_event)
390 }
391
392 async fn generate_llm_event(&mut self, index: usize) -> Result<Value> {
394 let rag_engine = self
395 .rag_engine
396 .as_ref()
397 .ok_or_else(|| Error::generic("RAG engine not initialized for generated mode"))?;
398
399 let narrative = self.config.narrative.as_ref().unwrap();
400 let prompt = self.build_generation_prompt(narrative, index)?;
401
402 let response = rag_engine.generate_text(&prompt).await?;
403 self.parse_json_response(&response)
404 }
405
406 fn build_augmentation_prompt(&self, narrative: &str, index: usize) -> Result<String> {
408 let mut prompt = format!(
409 "Enhance this event data based on the following scenario:\n\n{}\n\n",
410 narrative
411 );
412
413 prompt.push_str(&format!("Event #{} (out of ongoing stream)\n\n", index + 1));
414
415 if let Some(last_event) = &self.scenario_state.last_event {
416 prompt.push_str(&format!(
417 "Previous event:\n{}\n\n",
418 serde_json::to_string_pretty(last_event).unwrap_or_default()
419 ));
420 }
421
422 if self.config.progressive_evolution {
423 prompt.push_str("Progressively evolve the scenario with each event.\n");
424 }
425
426 if let Some(schema) = &self.config.event_schema {
427 prompt.push_str(&format!(
428 "Conform to this schema:\n{}\n\n",
429 serde_json::to_string_pretty(schema).unwrap_or_default()
430 ));
431 }
432
433 prompt.push_str("Return valid JSON only for the enhanced event data.");
434
435 Ok(prompt)
436 }
437
438 fn build_generation_prompt(&self, narrative: &str, index: usize) -> Result<String> {
440 let mut prompt =
441 format!("Generate realistic event data for this scenario:\n\n{}\n\n", narrative);
442
443 prompt.push_str(&format!("Event type: {}\n", self.config.event_type));
444 prompt.push_str(&format!("Event #{}\n\n", index + 1));
445
446 if let Some(last_event) = &self.scenario_state.last_event {
447 prompt.push_str(&format!(
448 "Previous event:\n{}\n\n",
449 serde_json::to_string_pretty(last_event).unwrap_or_default()
450 ));
451
452 if self.config.progressive_evolution {
453 prompt.push_str("Naturally evolve from the previous event.\n");
454 }
455 }
456
457 if let Some(schema) = &self.config.event_schema {
458 prompt.push_str(&format!(
459 "Conform to this schema:\n{}\n\n",
460 serde_json::to_string_pretty(schema).unwrap_or_default()
461 ));
462 }
463
464 prompt.push_str("Return valid JSON only.");
465
466 Ok(prompt)
467 }
468
469 fn parse_json_response(&self, response: &str) -> Result<Value> {
471 let trimmed = response.trim();
472
473 let json_str = if trimmed.starts_with("```json") {
475 trimmed
476 .strip_prefix("```json")
477 .and_then(|s| s.strip_suffix("```"))
478 .unwrap_or(trimmed)
479 .trim()
480 } else if trimmed.starts_with("```") {
481 trimmed
482 .strip_prefix("```")
483 .and_then(|s| s.strip_suffix("```"))
484 .unwrap_or(trimmed)
485 .trim()
486 } else {
487 trimmed
488 };
489
490 serde_json::from_str(json_str)
492 .map_err(|e| Error::generic(format!("Failed to parse LLM response as JSON: {}", e)))
493 }
494
495 fn evaluate_condition(&self, condition: &EventCondition, events: &[GeneratedEvent]) -> bool {
505 let expr = condition.expression.trim();
506
507 if expr.eq_ignore_ascii_case("true") {
509 return true;
510 }
511 if expr.eq_ignore_ascii_case("false") {
512 return false;
513 }
514
515 let parts: Vec<&str> = expr.splitn(3, ' ').collect();
517 if parts.len() != 3 {
518 tracing::warn!(
519 expression = expr,
520 "Unrecognized condition expression, defaulting to true"
521 );
522 return true;
523 }
524
525 let variable = parts[0];
526 let operator = parts[1];
527 let threshold: i64 = match parts[2].parse() {
528 Ok(v) => v,
529 Err(_) => {
530 tracing::warn!(
531 expression = expr,
532 "Could not parse threshold as integer, defaulting to true"
533 );
534 return true;
535 }
536 };
537
538 let actual: i64 = match variable {
539 "count" => events.len() as i64,
540 "sequence" => self.sequence as i64,
541 "events_generated" => self.scenario_state.events_generated as i64,
542 _ => {
543 tracing::warn!(variable, "Unknown condition variable, defaulting to true");
544 return true;
545 }
546 };
547
548 match operator {
549 "<" => actual < threshold,
550 ">" => actual > threshold,
551 "<=" => actual <= threshold,
552 ">=" => actual >= threshold,
553 "==" => actual == threshold,
554 "!=" => actual != threshold,
555 _ => {
556 tracing::warn!(operator, "Unknown comparison operator, defaulting to true");
557 true
558 }
559 }
560 }
561
562 pub fn reset(&mut self) {
564 self.sequence = 0;
565 self.scenario_state = ScenarioState::default();
566 }
567
568 pub fn sequence(&self) -> usize {
570 self.sequence
571 }
572
573 pub fn events_generated(&self) -> usize {
575 self.scenario_state.events_generated
576 }
577}
578
579pub mod scenarios {
581 use super::*;
582
583 pub fn stock_market_scenario() -> ReplayAugmentationConfig {
585 ReplayAugmentationConfig {
586 mode: ReplayMode::Generated,
587 narrative: Some(
588 "Simulate 10 minutes of live market data with realistic price movements, \
589 volume changes, and occasional volatility spikes."
590 .to_string(),
591 ),
592 event_type: "market_tick".to_string(),
593 event_schema: Some(serde_json::json!({
594 "symbol": "string",
595 "price": "number",
596 "volume": "number",
597 "timestamp": "string"
598 })),
599 strategy: EventStrategy::TimeBased,
600 duration_secs: Some(600), event_rate: Some(2.0), ..Default::default()
603 }
604 }
605
606 pub fn chat_messages_scenario() -> ReplayAugmentationConfig {
608 ReplayAugmentationConfig {
609 mode: ReplayMode::Generated,
610 narrative: Some(
611 "Simulate a group chat conversation between 3-5 users discussing a project, \
612 with natural message pacing and realistic content."
613 .to_string(),
614 ),
615 event_type: "chat_message".to_string(),
616 event_schema: Some(serde_json::json!({
617 "user_id": "string",
618 "message": "string",
619 "timestamp": "string"
620 })),
621 strategy: EventStrategy::CountBased,
622 event_count: Some(50),
623 event_rate: Some(0.5), ..Default::default()
625 }
626 }
627
628 pub fn iot_sensor_scenario() -> ReplayAugmentationConfig {
630 ReplayAugmentationConfig {
631 mode: ReplayMode::Generated,
632 narrative: Some(
633 "Simulate IoT sensor readings from a smart building with temperature, \
634 humidity, and occupancy data showing daily patterns."
635 .to_string(),
636 ),
637 event_type: "sensor_reading".to_string(),
638 event_schema: Some(serde_json::json!({
639 "sensor_id": "string",
640 "temperature": "number",
641 "humidity": "number",
642 "occupancy": "number",
643 "timestamp": "string"
644 })),
645 strategy: EventStrategy::CountBased,
646 event_count: Some(100),
647 event_rate: Some(1.0),
648 progressive_evolution: true,
649 ..Default::default()
650 }
651 }
652}
653
654#[cfg(test)]
655mod tests {
656 use super::*;
657
658 #[test]
659 fn test_replay_mode_default() {
660 assert_eq!(ReplayMode::default(), ReplayMode::Static);
661 }
662
663 #[test]
664 fn test_event_strategy_variants() {
665 let time_based = EventStrategy::TimeBased;
666 let count_based = EventStrategy::CountBased;
667 let conditional = EventStrategy::ConditionalBased;
668
669 assert!(matches!(time_based, EventStrategy::TimeBased));
670 assert!(matches!(count_based, EventStrategy::CountBased));
671 assert!(matches!(conditional, EventStrategy::ConditionalBased));
672 }
673
674 #[test]
675 fn test_generated_event_creation() {
676 let data = serde_json::json!({"test": "value"});
677 let event = GeneratedEvent::new("test_event".to_string(), data, 1);
678
679 assert_eq!(event.event_type, "test_event");
680 assert_eq!(event.sequence, 1);
681 }
682
683 #[test]
684 fn test_replay_config_validation_missing_narrative() {
685 let config = ReplayAugmentationConfig {
686 mode: ReplayMode::Generated,
687 ..Default::default()
688 };
689
690 assert!(ReplayAugmentationEngine::validate_config(&config).is_err());
691 }
692
693 #[test]
694 fn test_scenario_templates() {
695 let stock_scenario = scenarios::stock_market_scenario();
696 assert_eq!(stock_scenario.mode, ReplayMode::Generated);
697 assert!(stock_scenario.narrative.is_some());
698
699 let chat_scenario = scenarios::chat_messages_scenario();
700 assert_eq!(chat_scenario.event_type, "chat_message");
701
702 let iot_scenario = scenarios::iot_sensor_scenario();
703 assert!(iot_scenario.progressive_evolution);
704 }
705}