use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
pub use super::actions::ScenarioActions;
pub use super::conditions::EvalConditions;
pub use super::dependency::DependencyGraphConfig;
pub use super::llm::{LlmConfig, LlmConfigOverride};
pub use super::manager::{
BatchProcessorConfig, ManagerActivationConfig, ManagerConfig, ManagerTemplate,
};
pub use super::milestone::Milestone;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct TaskConfig {
pub goal: String,
#[serde(default)]
pub expected: Option<String>,
#[serde(default)]
pub context: TaskContext,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct TaskContext {
#[serde(default)]
pub target_path: Option<String>,
#[serde(default)]
pub working_dir: Option<String>,
#[serde(default)]
pub max_depth: Option<usize>,
#[serde(default, flatten)]
pub extra: HashMap<String, toml::Value>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ScenarioId(pub String);
impl ScenarioId {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
pub fn learning_key(&self) -> String {
let parts: Vec<&str> = self.0.split(':').collect();
if parts.len() >= 2 {
parts[1].to_string()
} else {
self.0
.chars()
.map(|c| {
if c.is_alphanumeric() || c == '-' || c == '_' {
c
} else {
'_'
}
})
.collect()
}
}
}
impl std::fmt::Display for ScenarioId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<&str> for ScenarioId {
fn from(s: &str) -> Self {
Self::new(s)
}
}
impl From<String> for ScenarioId {
fn from(s: String) -> Self {
Self::new(s)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScenarioVariant {
pub name: String,
#[serde(default)]
pub description: String,
#[serde(default)]
pub llm: Option<LlmConfigOverride>,
#[serde(default)]
pub environment_params: serde_json::Value,
#[serde(default)]
pub dependency_graph: Option<DependencyGraphConfig>,
#[serde(default)]
pub app_config: Option<AppConfigOverride>,
#[serde(default)]
pub max_ticks: Option<u64>,
#[serde(default)]
pub workers_count: Option<usize>,
#[serde(default)]
pub managers_count: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EvalScenario {
pub meta: ScenarioMeta,
#[serde(default)]
pub task: TaskConfig,
#[serde(default)]
pub llm: LlmConfig,
#[serde(default)]
pub manager: ManagerConfig,
#[serde(default)]
pub batch_processor: BatchProcessorConfig,
#[serde(default)]
pub dependency_graph: Option<DependencyGraphConfig>,
#[serde(default)]
pub actions: ScenarioActions,
pub app_config: AppConfigTemplate,
pub environment: EnvironmentConfig,
pub agents: AgentsConfig,
pub conditions: EvalConditions,
#[serde(default)]
pub milestones: Vec<Milestone>,
#[serde(default)]
pub variants: Vec<ScenarioVariant>,
}
impl EvalScenario {
pub fn with_variant(&self, variant_name: &str) -> Option<EvalScenario> {
let variant = self.variants.iter().find(|v| v.name == variant_name)?;
let mut scenario = self.clone();
if let Some(ref llm_override) = variant.llm {
llm_override.apply_to(&mut scenario.llm);
}
if !variant.environment_params.is_null() {
if let serde_json::Value::Object(override_map) = &variant.environment_params {
if let serde_json::Value::Object(ref mut base_map) = scenario.environment.params {
for (key, value) in override_map {
base_map.insert(key.clone(), value.clone());
}
}
}
}
if variant.dependency_graph.is_some() {
scenario.dependency_graph = variant.dependency_graph.clone();
}
if let Some(ref app_override) = variant.app_config {
if let Some(ref strategy) = app_override.management_strategy {
scenario.app_config.management_strategy = strategy.clone();
}
if let Some(tick_ms) = app_override.tick_duration_ms {
scenario.app_config.tick_duration_ms = tick_ms;
}
if let Some(enable_exp) = app_override.enable_exploration {
scenario.app_config.enable_exploration = enable_exp;
}
}
if let Some(max_ticks) = variant.max_ticks {
scenario.app_config.max_ticks = max_ticks;
}
if let Some(workers_count) = variant.workers_count {
if let Some(first_worker) = scenario.agents.workers.first_mut() {
first_worker.count = workers_count;
}
}
if let Some(managers_count) = variant.managers_count {
if let Some(first_manager) = scenario.agents.managers.first_mut() {
first_manager.count = managers_count;
if first_manager.id_pattern.is_none() {
if let Some(ref id) = first_manager.id {
first_manager.id_pattern = Some(format!("{}_{{i}}", id));
first_manager.id = None;
}
}
}
}
scenario.meta.name = format!("{} ({})", self.meta.name, variant_name);
Some(scenario)
}
pub fn variant_names(&self) -> Vec<&str> {
self.variants.iter().map(|v| v.name.as_str()).collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScenarioMeta {
pub name: String,
#[serde(default = "default_version")]
pub version: String,
pub id: ScenarioId,
#[serde(default)]
pub description: String,
#[serde(default)]
pub tags: Vec<String>,
}
fn default_version() -> String {
"1.0.0".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppConfigTemplate {
#[serde(default = "default_tick_duration_ms")]
pub tick_duration_ms: u64,
#[serde(default = "default_max_ticks")]
pub max_ticks: u64,
#[serde(default)]
pub management_strategy: ManagementStrategyConfig,
#[serde(default)]
pub enable_exploration: bool,
}
fn default_tick_duration_ms() -> u64 {
10
}
fn default_max_ticks() -> u64 {
1000
}
impl AppConfigTemplate {
pub fn tick_duration(&self) -> Duration {
Duration::from_millis(self.tick_duration_ms)
}
}
impl Default for AppConfigTemplate {
fn default() -> Self {
Self {
tick_duration_ms: default_tick_duration_ms(),
max_ticks: default_max_ticks(),
management_strategy: ManagementStrategyConfig::default(),
enable_exploration: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AppConfigOverride {
#[serde(default)]
pub management_strategy: Option<ManagementStrategyConfig>,
#[serde(default)]
pub tick_duration_ms: Option<u64>,
#[serde(default)]
pub enable_exploration: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ManagementStrategyConfig {
EveryTick {},
IntervalBased {
#[serde(default = "default_max_interval")]
max_interval: u64,
},
EventDriven {
#[serde(default)]
triggers: Vec<String>,
},
Hybrid {
#[serde(default = "default_max_interval")]
max_interval: u64,
#[serde(default)]
triggers: Vec<String>,
},
#[serde(alias = "disabled")]
Disabled {},
}
fn default_max_interval() -> u64 {
20
}
impl Default for ManagementStrategyConfig {
fn default() -> Self {
Self::IntervalBased {
max_interval: default_max_interval(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnvironmentConfig {
pub env_type: String,
#[serde(default)]
pub params: serde_json::Value,
#[serde(default)]
pub initial_state: Option<InitialStateConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum InitialStateConfig {
#[serde(alias = "seeded_random")]
SeededRandom {},
Fixed {
state: serde_json::Value,
},
Custom {
generator: String,
params: serde_json::Value,
},
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AgentsConfig {
#[serde(default)]
pub workers: Vec<WorkerTemplate>,
#[serde(default)]
pub managers: Vec<ManagerTemplate>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerTemplate {
pub id_pattern: String,
#[serde(default = "default_worker_count")]
pub count: usize,
#[serde(default)]
pub role: String,
#[serde(default)]
pub config: serde_json::Value,
}
fn default_worker_count() -> usize {
1
}
impl WorkerTemplate {
pub fn generate_ids(&self) -> Vec<String> {
(0..self.count)
.map(|i| self.id_pattern.replace("{i}", &i.to_string()))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::scenario::llm::LlmProvider;
#[test]
fn test_scenario_id() {
let id = ScenarioId::new("test:scenario:v1");
assert_eq!(id.as_str(), "test:scenario:v1");
}
#[test]
fn test_scenario_id_learning_key() {
let id = ScenarioId::new("user:troubleshooting:v2");
assert_eq!(id.learning_key(), "troubleshooting");
let id = ScenarioId::new("builtin:resource_gathering:v1");
assert_eq!(id.learning_key(), "resource_gathering");
let id = ScenarioId::new("simple_scenario");
assert_eq!(id.learning_key(), "simple_scenario");
let id = ScenarioId::new("Service Troubleshooting");
assert_eq!(id.learning_key(), "Service_Troubleshooting");
}
#[test]
fn test_worker_template_generate_ids() {
let template = WorkerTemplate {
id_pattern: "worker_{i}".to_string(),
count: 3,
role: "gatherer".to_string(),
config: serde_json::Value::Null,
};
let ids = template.generate_ids();
assert_eq!(ids, vec!["worker_0", "worker_1", "worker_2"]);
}
#[test]
fn test_app_config_template_default() {
let config = AppConfigTemplate::default();
assert_eq!(config.tick_duration_ms, 10);
assert_eq!(config.max_ticks, 1000);
}
#[test]
fn test_management_strategy_deserialize() {
let json = r#"{"type": "hybrid", "max_interval": 30, "triggers": ["event_a"]}"#;
let strategy: ManagementStrategyConfig = serde_json::from_str(json).unwrap();
match strategy {
ManagementStrategyConfig::Hybrid {
max_interval,
triggers,
} => {
assert_eq!(max_interval, 30);
assert_eq!(triggers, vec!["event_a"]);
}
_ => panic!("Expected Hybrid variant"),
}
}
#[test]
fn test_task_config_default() {
let task = TaskConfig::default();
assert!(task.goal.is_empty());
assert!(task.expected.is_none());
}
#[test]
fn test_task_config_deserialize_toml() {
let toml_str = r#"
goal = "Find the function that handles authentication"
expected = "src/auth/handler.rs:42"
[context]
target_path = "/path/to/codebase"
working_dir = "/path/to/codebase"
max_depth = 5
"#;
let task: TaskConfig = toml::from_str(toml_str).unwrap();
assert_eq!(task.goal, "Find the function that handles authentication");
assert_eq!(task.expected, Some("src/auth/handler.rs:42".to_string()));
assert_eq!(
task.context.target_path,
Some("/path/to/codebase".to_string())
);
assert_eq!(task.context.max_depth, Some(5));
}
#[test]
fn test_scenario_variant_with_llm_override() {
let toml_str = r#"
[meta]
name = "Test Scenario"
id = "test:scenario:v1"
[task]
goal = "Test goal"
[llm]
provider = "ollama"
model = "llama3:8b"
[app_config]
max_ticks = 100
[environment]
env_type = "test"
[agents]
[conditions]
on_timeout = "fail"
[[variants]]
name = "mistral"
description = "Use mistral.rs local inference"
[variants.llm]
provider = "mistral"
model = "LiquidAI/LFM2.5-1.2B-Instruct-GGUF"
gguf_files = ["LFM2.5-1.2B-Instruct-Q4_K_M.gguf"]
"#;
let scenario: EvalScenario = toml::from_str(toml_str).unwrap();
assert_eq!(scenario.llm.provider, LlmProvider::Ollama);
assert_eq!(scenario.variants.len(), 1);
let mistral_scenario = scenario.with_variant("mistral").unwrap();
assert_eq!(mistral_scenario.llm.provider, LlmProvider::Mistral);
assert_eq!(
mistral_scenario.llm.model,
"LiquidAI/LFM2.5-1.2B-Instruct-GGUF"
);
assert!(mistral_scenario.llm.is_gguf());
assert_eq!(mistral_scenario.meta.name, "Test Scenario (mistral)");
}
#[test]
fn test_scenario_variant_partial_llm_override() {
let toml_str = r#"
[meta]
name = "Test"
id = "test:v1"
[task]
goal = "Test"
[llm]
provider = "ollama"
model = "llama3:8b"
temperature = 0.1
num_ctx = 4096
[app_config]
max_ticks = 100
[environment]
env_type = "test"
[agents]
[conditions]
on_timeout = "fail"
[[variants]]
name = "high_temp"
[variants.llm]
temperature = 0.9
"#;
let scenario: EvalScenario = toml::from_str(toml_str).unwrap();
let variant = scenario.with_variant("high_temp").unwrap();
assert!((variant.llm.temperature - 0.9).abs() < f32::EPSILON);
assert_eq!(variant.llm.provider, LlmProvider::Ollama);
assert_eq!(variant.llm.model, "llama3:8b");
assert_eq!(variant.llm.num_ctx, Some(4096));
}
}