1use crate::rag::{RagConfig, RagEngine};
7use crate::{Error, Result};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::time::Duration;
11use tokio::time::interval;
12
13#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
15#[serde(rename_all = "snake_case")]
16pub enum ReplayMode {
17 #[default]
19 Static,
20 Augmented,
22 Generated,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
28#[serde(rename_all = "snake_case")]
29pub enum EventStrategy {
30 TimeBased,
32 CountBased,
34 ConditionalBased,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ReplayAugmentationConfig {
41 pub mode: ReplayMode,
43 pub narrative: Option<String>,
45 pub event_type: String,
47 pub event_schema: Option<Value>,
49 pub strategy: EventStrategy,
51 pub duration_secs: Option<u64>,
53 pub event_count: Option<usize>,
55 pub event_rate: Option<f64>,
57 pub conditions: Vec<EventCondition>,
59 pub rag_config: Option<RagConfig>,
61 pub progressive_evolution: bool,
63}
64
65impl Default for ReplayAugmentationConfig {
66 fn default() -> Self {
67 Self {
68 mode: ReplayMode::Static,
69 narrative: None,
70 event_type: "event".to_string(),
71 event_schema: None,
72 strategy: EventStrategy::CountBased,
73 duration_secs: None,
74 event_count: Some(10),
75 event_rate: Some(1.0),
76 conditions: Vec::new(),
77 rag_config: None,
78 progressive_evolution: true,
79 }
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct EventCondition {
86 pub name: String,
88 pub expression: String,
90 pub action: ConditionAction,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
96#[serde(rename_all = "snake_case")]
97pub enum ConditionAction {
98 GenerateEvent,
100 Stop,
102 ChangeRate(u64), TransitionScenario(String),
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct GeneratedEvent {
111 pub event_type: String,
113 pub timestamp: chrono::DateTime<chrono::Utc>,
115 pub data: Value,
117 pub sequence: usize,
119 pub metadata: std::collections::HashMap<String, String>,
121}
122
123impl GeneratedEvent {
124 pub fn new(event_type: String, data: Value, sequence: usize) -> Self {
126 Self {
127 event_type,
128 timestamp: chrono::Utc::now(),
129 data,
130 sequence,
131 metadata: std::collections::HashMap::new(),
132 }
133 }
134
135 pub fn with_metadata(mut self, key: String, value: String) -> Self {
137 self.metadata.insert(key, value);
138 self
139 }
140
141 pub fn to_json(&self) -> Result<String> {
143 serde_json::to_string(self)
144 .map_err(|e| Error::generic(format!("Failed to serialize event: {}", e)))
145 }
146}
147
148pub struct ReplayAugmentationEngine {
150 config: ReplayAugmentationConfig,
152 rag_engine: Option<RagEngine>,
154 sequence: usize,
156 scenario_state: ScenarioState,
158}
159
160#[derive(Debug, Clone)]
162struct ScenarioState {
163 _current_time: std::time::Instant,
165 events_generated: usize,
167 last_event: Option<Value>,
169 context: Vec<String>,
171}
172
173impl Default for ScenarioState {
174 fn default() -> Self {
175 Self {
176 _current_time: std::time::Instant::now(),
177 events_generated: 0,
178 last_event: None,
179 context: Vec::new(),
180 }
181 }
182}
183
184impl ReplayAugmentationEngine {
185 pub fn new(config: ReplayAugmentationConfig) -> Result<Self> {
187 Self::validate_config(&config)?;
188
189 let rag_engine = if config.mode != ReplayMode::Static {
190 let rag_config = config.rag_config.clone().unwrap_or_default();
191 Some(RagEngine::new(rag_config))
192 } else {
193 None
194 };
195
196 Ok(Self {
197 config,
198 rag_engine,
199 sequence: 0,
200 scenario_state: ScenarioState::default(),
201 })
202 }
203
204 fn validate_config(config: &ReplayAugmentationConfig) -> Result<()> {
206 if config.mode != ReplayMode::Static && config.narrative.is_none() {
207 return Err(Error::generic(
208 "Narrative is required for augmented or generated replay modes",
209 ));
210 }
211
212 match config.strategy {
213 EventStrategy::TimeBased => {
214 if config.duration_secs.is_none() {
215 return Err(Error::generic(
216 "Duration must be specified for time-based strategy",
217 ));
218 }
219 }
220 EventStrategy::CountBased => {
221 if config.event_count.is_none() {
222 return Err(Error::generic(
223 "Event count must be specified for count-based strategy",
224 ));
225 }
226 }
227 EventStrategy::ConditionalBased => {
228 if config.conditions.is_empty() {
229 return Err(Error::generic(
230 "Conditions must be specified for conditional-based strategy",
231 ));
232 }
233 }
234 }
235
236 Ok(())
237 }
238
239 pub async fn generate_stream(&mut self) -> Result<Vec<GeneratedEvent>> {
241 match self.config.strategy {
242 EventStrategy::CountBased => self.generate_count_based().await,
243 EventStrategy::TimeBased => self.generate_time_based().await,
244 EventStrategy::ConditionalBased => self.generate_conditional_based().await,
245 }
246 }
247
248 async fn generate_count_based(&mut self) -> Result<Vec<GeneratedEvent>> {
250 let count = self.config.event_count.unwrap_or(10);
251 let mut events = Vec::with_capacity(count);
252
253 for i in 0..count {
254 let event = self.generate_single_event(i).await?;
255 events.push(event);
256
257 if let Some(rate) = self.config.event_rate {
259 if rate > 0.0 {
260 let delay_ms = (1000.0 / rate) as u64;
261 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
262 }
263 }
264 }
265
266 Ok(events)
267 }
268
269 async fn generate_time_based(&mut self) -> Result<Vec<GeneratedEvent>> {
271 let duration = Duration::from_secs(self.config.duration_secs.unwrap_or(60));
272 let rate = self.config.event_rate.unwrap_or(1.0);
273 let interval_ms = (1000.0 / rate) as u64;
274
275 let mut events = Vec::new();
276 let mut ticker = interval(Duration::from_millis(interval_ms));
277 let start = std::time::Instant::now();
278
279 let mut index = 0;
280 while start.elapsed() < duration {
281 ticker.tick().await;
282 let event = self.generate_single_event(index).await?;
283 events.push(event);
284 index += 1;
285 }
286
287 Ok(events)
288 }
289
290 async fn generate_conditional_based(&mut self) -> Result<Vec<GeneratedEvent>> {
292 let mut events = Vec::new();
293 let mut index = 0;
294 let max_events = 1000; while index < max_events {
297 let mut should_continue = true;
299 let conditions = self.config.conditions.clone(); for condition in &conditions {
302 if self.evaluate_condition(condition, &events) {
303 match &condition.action {
304 ConditionAction::GenerateEvent => {
305 let event = self.generate_single_event(index).await?;
306 events.push(event);
307 index += 1;
308 }
309 ConditionAction::Stop => {
310 should_continue = false;
311 break;
312 }
313 ConditionAction::ChangeRate(_rate) => {
314 }
316 ConditionAction::TransitionScenario(_scenario) => {
317 self.scenario_state.context.clear();
319 }
320 }
321 }
322 }
323
324 if !should_continue {
325 break;
326 }
327
328 if events.is_empty() && index > 10 {
330 break;
331 }
332
333 tokio::time::sleep(Duration::from_millis(100)).await;
334 }
335
336 Ok(events)
337 }
338
339 async fn generate_single_event(&mut self, index: usize) -> Result<GeneratedEvent> {
341 let data = match self.config.mode {
342 ReplayMode::Static => self.generate_static_event(),
343 ReplayMode::Augmented => self.generate_augmented_event(index).await?,
344 ReplayMode::Generated => self.generate_llm_event(index).await?,
345 };
346
347 self.sequence += 1;
348 self.scenario_state.events_generated += 1;
349 self.scenario_state.last_event = Some(data.clone());
350
351 Ok(GeneratedEvent::new(self.config.event_type.clone(), data, self.sequence))
352 }
353
354 fn generate_static_event(&self) -> Value {
356 if let Some(schema) = &self.config.event_schema {
357 schema.clone()
358 } else {
359 serde_json::json!({
360 "type": self.config.event_type,
361 "timestamp": chrono::Utc::now().to_rfc3339()
362 })
363 }
364 }
365
366 async fn generate_augmented_event(&mut self, index: usize) -> Result<Value> {
368 let mut base_event = self.generate_static_event();
369
370 if let Some(rag_engine) = &self.rag_engine {
371 let narrative = self.config.narrative.as_ref().unwrap();
372 let prompt = self.build_augmentation_prompt(narrative, index)?;
373
374 let enhancement = rag_engine.generate_text(&prompt).await?;
375 let enhancement_json = self.parse_json_response(&enhancement)?;
376
377 if let (Some(base_obj), Some(enhancement_obj)) =
379 (base_event.as_object_mut(), enhancement_json.as_object())
380 {
381 for (key, value) in enhancement_obj {
382 base_obj.insert(key.clone(), value.clone());
383 }
384 } else {
385 base_event = enhancement_json;
386 }
387 }
388
389 Ok(base_event)
390 }
391
392 async fn generate_llm_event(&mut self, index: usize) -> Result<Value> {
394 let rag_engine = self
395 .rag_engine
396 .as_ref()
397 .ok_or_else(|| Error::generic("RAG engine not initialized for generated mode"))?;
398
399 let narrative = self.config.narrative.as_ref().unwrap();
400 let prompt = self.build_generation_prompt(narrative, index)?;
401
402 let response = rag_engine.generate_text(&prompt).await?;
403 self.parse_json_response(&response)
404 }
405
406 fn build_augmentation_prompt(&self, narrative: &str, index: usize) -> Result<String> {
408 let mut prompt = format!(
409 "Enhance this event data based on the following scenario:\n\n{}\n\n",
410 narrative
411 );
412
413 prompt.push_str(&format!("Event #{} (out of ongoing stream)\n\n", index + 1));
414
415 if let Some(last_event) = &self.scenario_state.last_event {
416 prompt.push_str(&format!(
417 "Previous event:\n{}\n\n",
418 serde_json::to_string_pretty(last_event).unwrap_or_default()
419 ));
420 }
421
422 if self.config.progressive_evolution {
423 prompt.push_str("Progressively evolve the scenario with each event.\n");
424 }
425
426 if let Some(schema) = &self.config.event_schema {
427 prompt.push_str(&format!(
428 "Conform to this schema:\n{}\n\n",
429 serde_json::to_string_pretty(schema).unwrap_or_default()
430 ));
431 }
432
433 prompt.push_str("Return valid JSON only for the enhanced event data.");
434
435 Ok(prompt)
436 }
437
438 fn build_generation_prompt(&self, narrative: &str, index: usize) -> Result<String> {
440 let mut prompt =
441 format!("Generate realistic event data for this scenario:\n\n{}\n\n", narrative);
442
443 prompt.push_str(&format!("Event type: {}\n", self.config.event_type));
444 prompt.push_str(&format!("Event #{}\n\n", index + 1));
445
446 if let Some(last_event) = &self.scenario_state.last_event {
447 prompt.push_str(&format!(
448 "Previous event:\n{}\n\n",
449 serde_json::to_string_pretty(last_event).unwrap_or_default()
450 ));
451
452 if self.config.progressive_evolution {
453 prompt.push_str("Naturally evolve from the previous event.\n");
454 }
455 }
456
457 if let Some(schema) = &self.config.event_schema {
458 prompt.push_str(&format!(
459 "Conform to this schema:\n{}\n\n",
460 serde_json::to_string_pretty(schema).unwrap_or_default()
461 ));
462 }
463
464 prompt.push_str("Return valid JSON only.");
465
466 Ok(prompt)
467 }
468
469 fn parse_json_response(&self, response: &str) -> Result<Value> {
471 let trimmed = response.trim();
472
473 let json_str = if trimmed.starts_with("```json") {
475 trimmed
476 .strip_prefix("```json")
477 .and_then(|s| s.strip_suffix("```"))
478 .unwrap_or(trimmed)
479 .trim()
480 } else if trimmed.starts_with("```") {
481 trimmed
482 .strip_prefix("```")
483 .and_then(|s| s.strip_suffix("```"))
484 .unwrap_or(trimmed)
485 .trim()
486 } else {
487 trimmed
488 };
489
490 serde_json::from_str(json_str)
492 .map_err(|e| Error::generic(format!("Failed to parse LLM response as JSON: {}", e)))
493 }
494
495 fn evaluate_condition(&self, _condition: &EventCondition, events: &[GeneratedEvent]) -> bool {
497 events.len() < 100 }
501
502 pub fn reset(&mut self) {
504 self.sequence = 0;
505 self.scenario_state = ScenarioState::default();
506 }
507
508 pub fn sequence(&self) -> usize {
510 self.sequence
511 }
512
513 pub fn events_generated(&self) -> usize {
515 self.scenario_state.events_generated
516 }
517}
518
519pub mod scenarios {
521 use super::*;
522
523 pub fn stock_market_scenario() -> ReplayAugmentationConfig {
525 ReplayAugmentationConfig {
526 mode: ReplayMode::Generated,
527 narrative: Some(
528 "Simulate 10 minutes of live market data with realistic price movements, \
529 volume changes, and occasional volatility spikes."
530 .to_string(),
531 ),
532 event_type: "market_tick".to_string(),
533 event_schema: Some(serde_json::json!({
534 "symbol": "string",
535 "price": "number",
536 "volume": "number",
537 "timestamp": "string"
538 })),
539 strategy: EventStrategy::TimeBased,
540 duration_secs: Some(600), event_rate: Some(2.0), ..Default::default()
543 }
544 }
545
546 pub fn chat_messages_scenario() -> ReplayAugmentationConfig {
548 ReplayAugmentationConfig {
549 mode: ReplayMode::Generated,
550 narrative: Some(
551 "Simulate a group chat conversation between 3-5 users discussing a project, \
552 with natural message pacing and realistic content."
553 .to_string(),
554 ),
555 event_type: "chat_message".to_string(),
556 event_schema: Some(serde_json::json!({
557 "user_id": "string",
558 "message": "string",
559 "timestamp": "string"
560 })),
561 strategy: EventStrategy::CountBased,
562 event_count: Some(50),
563 event_rate: Some(0.5), ..Default::default()
565 }
566 }
567
568 pub fn iot_sensor_scenario() -> ReplayAugmentationConfig {
570 ReplayAugmentationConfig {
571 mode: ReplayMode::Generated,
572 narrative: Some(
573 "Simulate IoT sensor readings from a smart building with temperature, \
574 humidity, and occupancy data showing daily patterns."
575 .to_string(),
576 ),
577 event_type: "sensor_reading".to_string(),
578 event_schema: Some(serde_json::json!({
579 "sensor_id": "string",
580 "temperature": "number",
581 "humidity": "number",
582 "occupancy": "number",
583 "timestamp": "string"
584 })),
585 strategy: EventStrategy::CountBased,
586 event_count: Some(100),
587 event_rate: Some(1.0),
588 progressive_evolution: true,
589 ..Default::default()
590 }
591 }
592}
593
594#[cfg(test)]
595mod tests {
596 use super::*;
597
598 #[test]
599 fn test_replay_mode_default() {
600 assert_eq!(ReplayMode::default(), ReplayMode::Static);
601 }
602
603 #[test]
604 fn test_event_strategy_variants() {
605 let time_based = EventStrategy::TimeBased;
606 let count_based = EventStrategy::CountBased;
607 let conditional = EventStrategy::ConditionalBased;
608
609 assert!(matches!(time_based, EventStrategy::TimeBased));
610 assert!(matches!(count_based, EventStrategy::CountBased));
611 assert!(matches!(conditional, EventStrategy::ConditionalBased));
612 }
613
614 #[test]
615 fn test_generated_event_creation() {
616 let data = serde_json::json!({"test": "value"});
617 let event = GeneratedEvent::new("test_event".to_string(), data, 1);
618
619 assert_eq!(event.event_type, "test_event");
620 assert_eq!(event.sequence, 1);
621 }
622
623 #[test]
624 fn test_replay_config_validation_missing_narrative() {
625 let config = ReplayAugmentationConfig {
626 mode: ReplayMode::Generated,
627 ..Default::default()
628 };
629
630 assert!(ReplayAugmentationEngine::validate_config(&config).is_err());
631 }
632
633 #[test]
634 fn test_scenario_templates() {
635 let stock_scenario = scenarios::stock_market_scenario();
636 assert_eq!(stock_scenario.mode, ReplayMode::Generated);
637 assert!(stock_scenario.narrative.is_some());
638
639 let chat_scenario = scenarios::chat_messages_scenario();
640 assert_eq!(chat_scenario.event_type, "chat_message");
641
642 let iot_scenario = scenarios::iot_sensor_scenario();
643 assert!(iot_scenario.progressive_evolution);
644 }
645}