use crate::rag::{RagConfig, RagEngine};
use crate::{Error, Result};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::time::Duration;
use tokio::time::interval;
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ReplayMode {
#[default]
Static,
Augmented,
Generated,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum EventStrategy {
TimeBased,
CountBased,
ConditionalBased,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReplayAugmentationConfig {
pub mode: ReplayMode,
pub narrative: Option<String>,
pub event_type: String,
pub event_schema: Option<Value>,
pub strategy: EventStrategy,
pub duration_secs: Option<u64>,
pub event_count: Option<usize>,
pub event_rate: Option<f64>,
pub conditions: Vec<EventCondition>,
pub rag_config: Option<RagConfig>,
pub progressive_evolution: bool,
}
impl Default for ReplayAugmentationConfig {
fn default() -> Self {
Self {
mode: ReplayMode::Static,
narrative: None,
event_type: "event".to_string(),
event_schema: None,
strategy: EventStrategy::CountBased,
duration_secs: None,
event_count: Some(10),
event_rate: Some(1.0),
conditions: Vec::new(),
rag_config: None,
progressive_evolution: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EventCondition {
pub name: String,
pub expression: String,
pub action: ConditionAction,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ConditionAction {
GenerateEvent,
Stop,
ChangeRate(u64), TransitionScenario(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneratedEvent {
pub event_type: String,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub data: Value,
pub sequence: usize,
pub metadata: std::collections::HashMap<String, String>,
}
impl GeneratedEvent {
pub fn new(event_type: String, data: Value, sequence: usize) -> Self {
Self {
event_type,
timestamp: chrono::Utc::now(),
data,
sequence,
metadata: std::collections::HashMap::new(),
}
}
pub fn with_metadata(mut self, key: String, value: String) -> Self {
self.metadata.insert(key, value);
self
}
pub fn to_json(&self) -> Result<String> {
serde_json::to_string(self)
.map_err(|e| Error::generic(format!("Failed to serialize event: {}", e)))
}
}
pub struct ReplayAugmentationEngine {
config: ReplayAugmentationConfig,
rag_engine: Option<RagEngine>,
sequence: usize,
scenario_state: ScenarioState,
}
#[derive(Debug, Clone)]
struct ScenarioState {
_current_time: std::time::Instant,
events_generated: usize,
last_event: Option<Value>,
context: Vec<String>,
}
impl Default for ScenarioState {
fn default() -> Self {
Self {
_current_time: std::time::Instant::now(),
events_generated: 0,
last_event: None,
context: Vec::new(),
}
}
}
impl ReplayAugmentationEngine {
pub fn new(config: ReplayAugmentationConfig) -> Result<Self> {
Self::validate_config(&config)?;
let rag_engine = if config.mode != ReplayMode::Static {
let rag_config = config.rag_config.clone().unwrap_or_default();
Some(RagEngine::new(rag_config))
} else {
None
};
Ok(Self {
config,
rag_engine,
sequence: 0,
scenario_state: ScenarioState::default(),
})
}
fn validate_config(config: &ReplayAugmentationConfig) -> Result<()> {
if config.mode != ReplayMode::Static && config.narrative.is_none() {
return Err(Error::generic(
"Narrative is required for augmented or generated replay modes",
));
}
match config.strategy {
EventStrategy::TimeBased => {
if config.duration_secs.is_none() {
return Err(Error::generic(
"Duration must be specified for time-based strategy",
));
}
}
EventStrategy::CountBased => {
if config.event_count.is_none() {
return Err(Error::generic(
"Event count must be specified for count-based strategy",
));
}
}
EventStrategy::ConditionalBased => {
if config.conditions.is_empty() {
return Err(Error::generic(
"Conditions must be specified for conditional-based strategy",
));
}
}
}
Ok(())
}
pub async fn generate_stream(&mut self) -> Result<Vec<GeneratedEvent>> {
match self.config.strategy {
EventStrategy::CountBased => self.generate_count_based().await,
EventStrategy::TimeBased => self.generate_time_based().await,
EventStrategy::ConditionalBased => self.generate_conditional_based().await,
}
}
async fn generate_count_based(&mut self) -> Result<Vec<GeneratedEvent>> {
let count = self.config.event_count.unwrap_or(10);
let mut events = Vec::with_capacity(count);
for i in 0..count {
let event = self.generate_single_event(i).await?;
events.push(event);
if let Some(rate) = self.config.event_rate {
if rate > 0.0 {
let delay_ms = (1000.0 / rate) as u64;
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
}
}
Ok(events)
}
async fn generate_time_based(&mut self) -> Result<Vec<GeneratedEvent>> {
let duration = Duration::from_secs(self.config.duration_secs.unwrap_or(60));
let rate = self.config.event_rate.unwrap_or(1.0);
let interval_ms = (1000.0 / rate) as u64;
let mut events = Vec::new();
let mut ticker = interval(Duration::from_millis(interval_ms));
let start = std::time::Instant::now();
let mut index = 0;
while start.elapsed() < duration {
ticker.tick().await;
let event = self.generate_single_event(index).await?;
events.push(event);
index += 1;
}
Ok(events)
}
async fn generate_conditional_based(&mut self) -> Result<Vec<GeneratedEvent>> {
let mut events = Vec::new();
let mut index = 0;
let max_events = 1000;
while index < max_events {
let mut should_continue = true;
let conditions = self.config.conditions.clone();
for condition in &conditions {
if self.evaluate_condition(condition, &events) {
match &condition.action {
ConditionAction::GenerateEvent => {
let event = self.generate_single_event(index).await?;
events.push(event);
index += 1;
}
ConditionAction::Stop => {
should_continue = false;
break;
}
ConditionAction::ChangeRate(_rate) => {
}
ConditionAction::TransitionScenario(_scenario) => {
self.scenario_state.context.clear();
}
}
}
}
if !should_continue {
break;
}
if events.is_empty() && index > 10 {
break;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
Ok(events)
}
async fn generate_single_event(&mut self, index: usize) -> Result<GeneratedEvent> {
let data = match self.config.mode {
ReplayMode::Static => self.generate_static_event(),
ReplayMode::Augmented => self.generate_augmented_event(index).await?,
ReplayMode::Generated => self.generate_llm_event(index).await?,
};
self.sequence += 1;
self.scenario_state.events_generated += 1;
self.scenario_state.last_event = Some(data.clone());
Ok(GeneratedEvent::new(self.config.event_type.clone(), data, self.sequence))
}
fn generate_static_event(&self) -> Value {
if let Some(schema) = &self.config.event_schema {
schema.clone()
} else {
serde_json::json!({
"type": self.config.event_type,
"timestamp": chrono::Utc::now().to_rfc3339()
})
}
}
async fn generate_augmented_event(&mut self, index: usize) -> Result<Value> {
let mut base_event = self.generate_static_event();
if let Some(rag_engine) = &self.rag_engine {
let narrative = self.config.narrative.as_ref().unwrap();
let prompt = self.build_augmentation_prompt(narrative, index)?;
let enhancement = rag_engine.generate_text(&prompt).await?;
let enhancement_json = self.parse_json_response(&enhancement)?;
if let (Some(base_obj), Some(enhancement_obj)) =
(base_event.as_object_mut(), enhancement_json.as_object())
{
for (key, value) in enhancement_obj {
base_obj.insert(key.clone(), value.clone());
}
} else {
base_event = enhancement_json;
}
}
Ok(base_event)
}
async fn generate_llm_event(&mut self, index: usize) -> Result<Value> {
let rag_engine = self
.rag_engine
.as_ref()
.ok_or_else(|| Error::generic("RAG engine not initialized for generated mode"))?;
let narrative = self.config.narrative.as_ref().unwrap();
let prompt = self.build_generation_prompt(narrative, index)?;
let response = rag_engine.generate_text(&prompt).await?;
self.parse_json_response(&response)
}
fn build_augmentation_prompt(&self, narrative: &str, index: usize) -> Result<String> {
let mut prompt = format!(
"Enhance this event data based on the following scenario:\n\n{}\n\n",
narrative
);
prompt.push_str(&format!("Event #{} (out of ongoing stream)\n\n", index + 1));
if let Some(last_event) = &self.scenario_state.last_event {
prompt.push_str(&format!(
"Previous event:\n{}\n\n",
serde_json::to_string_pretty(last_event).unwrap_or_default()
));
}
if self.config.progressive_evolution {
prompt.push_str("Progressively evolve the scenario with each event.\n");
}
if let Some(schema) = &self.config.event_schema {
prompt.push_str(&format!(
"Conform to this schema:\n{}\n\n",
serde_json::to_string_pretty(schema).unwrap_or_default()
));
}
prompt.push_str("Return valid JSON only for the enhanced event data.");
Ok(prompt)
}
fn build_generation_prompt(&self, narrative: &str, index: usize) -> Result<String> {
let mut prompt =
format!("Generate realistic event data for this scenario:\n\n{}\n\n", narrative);
prompt.push_str(&format!("Event type: {}\n", self.config.event_type));
prompt.push_str(&format!("Event #{}\n\n", index + 1));
if let Some(last_event) = &self.scenario_state.last_event {
prompt.push_str(&format!(
"Previous event:\n{}\n\n",
serde_json::to_string_pretty(last_event).unwrap_or_default()
));
if self.config.progressive_evolution {
prompt.push_str("Naturally evolve from the previous event.\n");
}
}
if let Some(schema) = &self.config.event_schema {
prompt.push_str(&format!(
"Conform to this schema:\n{}\n\n",
serde_json::to_string_pretty(schema).unwrap_or_default()
));
}
prompt.push_str("Return valid JSON only.");
Ok(prompt)
}
fn parse_json_response(&self, response: &str) -> Result<Value> {
let trimmed = response.trim();
let json_str = if trimmed.starts_with("```json") {
trimmed
.strip_prefix("```json")
.and_then(|s| s.strip_suffix("```"))
.unwrap_or(trimmed)
.trim()
} else if trimmed.starts_with("```") {
trimmed
.strip_prefix("```")
.and_then(|s| s.strip_suffix("```"))
.unwrap_or(trimmed)
.trim()
} else {
trimmed
};
serde_json::from_str(json_str)
.map_err(|e| Error::generic(format!("Failed to parse LLM response as JSON: {}", e)))
}
fn evaluate_condition(&self, condition: &EventCondition, events: &[GeneratedEvent]) -> bool {
let expr = condition.expression.trim();
if expr.eq_ignore_ascii_case("true") {
return true;
}
if expr.eq_ignore_ascii_case("false") {
return false;
}
let parts: Vec<&str> = expr.splitn(3, ' ').collect();
if parts.len() != 3 {
tracing::warn!(
expression = expr,
"Unrecognized condition expression, defaulting to true"
);
return true;
}
let variable = parts[0];
let operator = parts[1];
let threshold: i64 = match parts[2].parse() {
Ok(v) => v,
Err(_) => {
tracing::warn!(
expression = expr,
"Could not parse threshold as integer, defaulting to true"
);
return true;
}
};
let actual: i64 = match variable {
"count" => events.len() as i64,
"sequence" => self.sequence as i64,
"events_generated" => self.scenario_state.events_generated as i64,
_ => {
tracing::warn!(variable, "Unknown condition variable, defaulting to true");
return true;
}
};
match operator {
"<" => actual < threshold,
">" => actual > threshold,
"<=" => actual <= threshold,
">=" => actual >= threshold,
"==" => actual == threshold,
"!=" => actual != threshold,
_ => {
tracing::warn!(operator, "Unknown comparison operator, defaulting to true");
true
}
}
}
pub fn reset(&mut self) {
self.sequence = 0;
self.scenario_state = ScenarioState::default();
}
pub fn sequence(&self) -> usize {
self.sequence
}
pub fn events_generated(&self) -> usize {
self.scenario_state.events_generated
}
}
pub mod scenarios {
use super::*;
pub fn stock_market_scenario() -> ReplayAugmentationConfig {
ReplayAugmentationConfig {
mode: ReplayMode::Generated,
narrative: Some(
"Simulate 10 minutes of live market data with realistic price movements, \
volume changes, and occasional volatility spikes."
.to_string(),
),
event_type: "market_tick".to_string(),
event_schema: Some(serde_json::json!({
"symbol": "string",
"price": "number",
"volume": "number",
"timestamp": "string"
})),
strategy: EventStrategy::TimeBased,
duration_secs: Some(600), event_rate: Some(2.0), ..Default::default()
}
}
pub fn chat_messages_scenario() -> ReplayAugmentationConfig {
ReplayAugmentationConfig {
mode: ReplayMode::Generated,
narrative: Some(
"Simulate a group chat conversation between 3-5 users discussing a project, \
with natural message pacing and realistic content."
.to_string(),
),
event_type: "chat_message".to_string(),
event_schema: Some(serde_json::json!({
"user_id": "string",
"message": "string",
"timestamp": "string"
})),
strategy: EventStrategy::CountBased,
event_count: Some(50),
event_rate: Some(0.5), ..Default::default()
}
}
pub fn iot_sensor_scenario() -> ReplayAugmentationConfig {
ReplayAugmentationConfig {
mode: ReplayMode::Generated,
narrative: Some(
"Simulate IoT sensor readings from a smart building with temperature, \
humidity, and occupancy data showing daily patterns."
.to_string(),
),
event_type: "sensor_reading".to_string(),
event_schema: Some(serde_json::json!({
"sensor_id": "string",
"temperature": "number",
"humidity": "number",
"occupancy": "number",
"timestamp": "string"
})),
strategy: EventStrategy::CountBased,
event_count: Some(100),
event_rate: Some(1.0),
progressive_evolution: true,
..Default::default()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_replay_mode_default() {
assert_eq!(ReplayMode::default(), ReplayMode::Static);
}
#[test]
fn test_event_strategy_variants() {
let time_based = EventStrategy::TimeBased;
let count_based = EventStrategy::CountBased;
let conditional = EventStrategy::ConditionalBased;
assert!(matches!(time_based, EventStrategy::TimeBased));
assert!(matches!(count_based, EventStrategy::CountBased));
assert!(matches!(conditional, EventStrategy::ConditionalBased));
}
#[test]
fn test_generated_event_creation() {
let data = serde_json::json!({"test": "value"});
let event = GeneratedEvent::new("test_event".to_string(), data, 1);
assert_eq!(event.event_type, "test_event");
assert_eq!(event.sequence, 1);
}
#[test]
fn test_replay_config_validation_missing_narrative() {
let config = ReplayAugmentationConfig {
mode: ReplayMode::Generated,
..Default::default()
};
assert!(ReplayAugmentationEngine::validate_config(&config).is_err());
}
#[test]
fn test_scenario_templates() {
let stock_scenario = scenarios::stock_market_scenario();
assert_eq!(stock_scenario.mode, ReplayMode::Generated);
assert!(stock_scenario.narrative.is_some());
let chat_scenario = scenarios::chat_messages_scenario();
assert_eq!(chat_scenario.event_type, "chat_message");
let iot_scenario = scenarios::iot_sensor_scenario();
assert!(iot_scenario.progressive_evolution);
}
}