use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use chrono::{DateTime, Utc};
use crate::ids::*;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SessionConfig {
#[serde(default)]
pub cwd: Option<PathBuf>,
#[serde(default)]
pub model: Option<String>,
#[serde(default)]
pub instructions: Option<String>,
#[serde(default)]
pub mcp_servers: Vec<McpServerConfig>,
#[serde(default)]
pub approval_mode: ApprovalMode,
#[serde(default)]
pub sandbox: SandboxConfig,
#[serde(default = "default_max_agents")]
pub max_parallel_agents: usize,
}
fn default_max_agents() -> usize {
8
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SessionSettings {
#[serde(default)]
pub show_rate_limit: bool,
#[serde(default)]
pub subagent_concurrency: Option<usize>,
#[serde(default)]
pub plan_granularity: PlanGranularity,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ApprovalMode {
Always,
Never,
#[default]
RiskBased,
Custom,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SandboxConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default)]
pub network: NetworkPolicy,
#[serde(default)]
pub writable_paths: Vec<PathBuf>,
#[serde(default)]
pub timeout_secs: Option<u64>,
}
fn default_true() -> bool {
true
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum NetworkPolicy {
#[default]
None,
Localhost,
Allowlist(Vec<String>),
Full,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpServerConfig {
pub id: String,
pub name: String,
pub transport: McpTransport,
#[serde(default)]
pub env: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum McpTransport {
Stdio {
command: String,
#[serde(default)]
args: Vec<String>,
},
Socket {
path: PathBuf,
},
Http {
url: String,
},
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AgentRole {
Orchestrator,
DomainLead { domain: String },
Worker,
Specialist { specialty: String },
Scout,
Reviewer,
Custom { name: String },
}
impl Default for AgentRole {
fn default() -> Self {
AgentRole::Worker
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AgentConfig {
#[serde(default)]
pub role: AgentRole,
#[serde(default)]
pub model: Option<String>,
#[serde(default)]
pub cwd: Option<PathBuf>,
#[serde(default)]
pub worktree: Option<String>,
#[serde(default)]
pub tools: Vec<String>,
#[serde(default)]
pub can_spawn: bool,
#[serde(default)]
pub max_children: Option<usize>,
#[serde(default)]
pub token_budget: Option<u64>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AgentStatus {
Spawning,
Initializing,
Running,
Waiting { reason: String },
Completed,
Failed,
Terminated,
}
impl Default for AgentStatus {
fn default() -> Self {
AgentStatus::Spawning
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentResult {
pub success: bool,
pub summary: String,
#[serde(default)]
pub files_changed: Vec<PathBuf>,
#[serde(default)]
pub output: serde_json::Value,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TaskContext {
#[serde(default)]
pub cwd: Option<PathBuf>,
#[serde(default)]
pub files: Vec<PathBuf>,
#[serde(default)]
pub memory_context: Vec<String>,
#[serde(default)]
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskAssignment {
pub task_id: TaskId,
pub description: String,
#[serde(default)]
pub deliverables: Vec<String>,
#[serde(default)]
pub dependencies: Vec<TaskId>,
#[serde(default)]
pub context: TaskContext,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskResult {
pub task_id: TaskId,
pub success: bool,
pub summary: String,
#[serde(default)]
pub files_changed: Vec<PathBuf>,
#[serde(default)]
pub token_usage: TokenUsage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolOutput {
pub success: bool,
pub content: String,
#[serde(default)]
pub data: Option<serde_json::Value>,
#[serde(default)]
pub exit_code: Option<i32>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RiskLevel {
None,
Low,
Medium,
High,
Critical,
}
impl Default for RiskLevel {
fn default() -> Self {
RiskLevel::Medium
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PlanGranularity {
Coarse,
Detailed,
#[default]
Auto,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskPlan {
pub original_request: String,
pub steps: Vec<PlanStep>,
pub agent_assignments: HashMap<String, AgentRole>,
pub dependencies: Vec<(String, String)>,
#[serde(default)]
pub estimated_tokens: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlanStep {
pub id: String,
pub description: String,
pub expected_outcome: String,
#[serde(default)]
pub complexity: StepComplexity,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StepComplexity {
Simple,
#[default]
Moderate,
Complex,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentTree {
pub agent_id: AgentId,
pub role: AgentRole,
pub status: AgentStatus,
pub task_summary: Option<String>,
#[serde(default)]
pub children: Vec<AgentTree>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckpointMeta {
pub id: CheckpointId,
pub name: Option<String>,
pub timestamp: DateTime<Utc>,
pub size_bytes: u64,
pub task_id: Option<TaskId>,
pub summary: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TokenUsage {
pub input_tokens: u64,
pub output_tokens: u64,
pub total_tokens: u64,
#[serde(default)]
pub estimated_cost_usd: Option<f64>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MessageType {
Text,
Thinking,
Code,
Error,
Status,
Progress,
}
impl Default for MessageType {
fn default() -> Self {
MessageType::Text
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageAttachment {
pub data: String,
pub mime_type: String,
#[serde(default)]
pub filename: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_config_defaults() {
let config: SessionConfig = serde_json::from_str("{}").unwrap();
assert_eq!(config.max_parallel_agents, 8);
assert_eq!(config.approval_mode, ApprovalMode::RiskBased);
}
#[test]
fn test_session_config_custom_values() {
let json = r#"{
"max_parallel_agents": 16,
"approval_mode": "always",
"model": "claude-3-opus"
}"#;
let config: SessionConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.max_parallel_agents, 16);
assert_eq!(config.approval_mode, ApprovalMode::Always);
assert_eq!(config.model, Some("claude-3-opus".into()));
}
#[test]
fn test_session_config_with_mcp_servers() {
let config = SessionConfig {
mcp_servers: vec![
McpServerConfig {
id: "my-server".into(),
name: "My Server".into(),
transport: McpTransport::Stdio {
command: "node".into(),
args: vec!["server.js".into()],
},
env: Default::default(),
},
],
..Default::default()
};
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("my-server"));
assert!(json.contains("My Server"));
}
#[test]
fn test_approval_mode_serialization() {
let modes = vec![
(ApprovalMode::Always, "always"),
(ApprovalMode::Never, "never"),
(ApprovalMode::RiskBased, "risk_based"),
(ApprovalMode::Custom, "custom"),
];
for (mode, expected) in modes {
let json = serde_json::to_string(&mode).unwrap();
assert!(json.contains(expected));
let parsed: ApprovalMode = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, mode);
}
}
#[test]
fn test_sandbox_config_defaults() {
let config: SandboxConfig = serde_json::from_str("{}").unwrap();
assert!(config.enabled);
assert_eq!(config.network, NetworkPolicy::None);
assert!(config.writable_paths.is_empty());
}
#[test]
fn test_sandbox_config_custom() {
let config = SandboxConfig {
enabled: true,
network: NetworkPolicy::Localhost,
writable_paths: vec![PathBuf::from("/tmp")],
timeout_secs: Some(120),
};
let json = serde_json::to_string(&config).unwrap();
let parsed: SandboxConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.network, NetworkPolicy::Localhost);
assert_eq!(parsed.timeout_secs, Some(120));
}
#[test]
fn test_network_policy_variants() {
let policies = vec![
NetworkPolicy::None,
NetworkPolicy::Localhost,
NetworkPolicy::Full,
NetworkPolicy::Allowlist(vec!["api.example.com".into()]),
];
for policy in policies {
let json = serde_json::to_string(&policy).unwrap();
let parsed: NetworkPolicy = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, policy);
}
}
#[test]
fn test_mcp_server_stdio_transport() {
let config = McpServerConfig {
id: "test".into(),
name: "Test Server".into(),
transport: McpTransport::Stdio {
command: "npx".into(),
args: vec!["my-server".into()],
},
env: [("API_KEY".into(), "secret".into())].into(),
};
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("stdio"));
assert!(json.contains("npx"));
}
#[test]
fn test_mcp_server_socket_transport() {
let config = McpServerConfig {
id: "test".into(),
name: "Test Server".into(),
transport: McpTransport::Socket {
path: PathBuf::from("/var/run/mcp.sock"),
},
env: Default::default(),
};
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("socket"));
assert!(json.contains("mcp.sock"));
}
#[test]
fn test_mcp_server_http_transport() {
let config = McpServerConfig {
id: "test".into(),
name: "Test Server".into(),
transport: McpTransport::Http {
url: "http://localhost:3000".into(),
},
env: Default::default(),
};
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("http"));
assert!(json.contains("localhost:3000"));
}
#[test]
fn test_agent_role_serialization() {
let role = AgentRole::DomainLead { domain: "frontend".to_string() };
let json = serde_json::to_string(&role).unwrap();
assert!(json.contains("domain_lead"));
assert!(json.contains("frontend"));
}
#[test]
fn test_agent_role_variants() {
let roles = vec![
AgentRole::Orchestrator,
AgentRole::Worker,
AgentRole::Scout,
AgentRole::Reviewer,
AgentRole::DomainLead { domain: "backend".into() },
AgentRole::Specialist { specialty: "security".into() },
AgentRole::Custom { name: "my-role".into() },
];
for role in roles {
let json = serde_json::to_string(&role).unwrap();
let parsed: AgentRole = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, role);
}
}
#[test]
fn test_agent_role_default() {
let role: AgentRole = Default::default();
assert_eq!(role, AgentRole::Worker);
}
#[test]
fn test_agent_config_defaults() {
let config: AgentConfig = Default::default();
assert_eq!(config.role, AgentRole::Worker);
assert!(!config.can_spawn);
assert!(config.model.is_none());
}
#[test]
fn test_agent_config_custom() {
let config = AgentConfig {
role: AgentRole::Orchestrator,
model: Some("claude-3-opus".into()),
cwd: Some(PathBuf::from("/home/user/project")),
can_spawn: true,
max_children: Some(4),
token_budget: Some(100_000),
tools: vec!["read_file".into(), "write_file".into()],
worktree: None,
};
let json = serde_json::to_string(&config).unwrap();
let parsed: AgentConfig = serde_json::from_str(&json).unwrap();
assert!(parsed.can_spawn);
assert_eq!(parsed.max_children, Some(4));
assert_eq!(parsed.token_budget, Some(100_000));
}
#[test]
fn test_agent_status_variants() {
let statuses = vec![
AgentStatus::Spawning,
AgentStatus::Initializing,
AgentStatus::Running,
AgentStatus::Waiting { reason: "Waiting for approval".into() },
AgentStatus::Completed,
AgentStatus::Failed,
AgentStatus::Terminated,
];
for status in statuses {
let json = serde_json::to_string(&status).unwrap();
let parsed: AgentStatus = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, status);
}
}
#[test]
fn test_agent_status_default() {
let status: AgentStatus = Default::default();
assert_eq!(status, AgentStatus::Spawning);
}
#[test]
fn test_agent_result() {
let result = AgentResult {
success: true,
summary: "Task completed successfully".into(),
files_changed: vec![PathBuf::from("src/main.rs")],
output: serde_json::json!({"lines_added": 50}),
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("Task completed successfully"));
assert!(json.contains("src/main.rs"));
}
#[test]
fn test_task_context_defaults() {
let ctx: TaskContext = Default::default();
assert!(ctx.cwd.is_none());
assert!(ctx.files.is_empty());
assert!(ctx.memory_context.is_empty());
}
#[test]
fn test_task_context_custom() {
let ctx = TaskContext {
cwd: Some(PathBuf::from("/project")),
files: vec![PathBuf::from("src/lib.rs")],
memory_context: vec!["Previous context".into()],
metadata: [("key".into(), serde_json::json!("value"))].into(),
};
let json = serde_json::to_string(&ctx).unwrap();
let parsed: TaskContext = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.cwd, Some(PathBuf::from("/project")));
assert!(!parsed.files.is_empty());
}
#[test]
fn test_task_assignment() {
let dep_id = TaskId::new();
let task = TaskAssignment {
task_id: TaskId::new(),
description: "Implement feature X".into(),
deliverables: vec!["Code".into(), "Tests".into()],
dependencies: vec![dep_id],
context: TaskContext::default(),
};
let json = serde_json::to_string(&task).unwrap();
assert!(json.contains("Implement feature X"));
assert!(json.contains("deliverables"));
}
#[test]
fn test_task_result() {
let result = TaskResult {
task_id: TaskId::new(),
success: true,
summary: "Done".into(),
files_changed: vec![PathBuf::from("a.rs"), PathBuf::from("b.rs")],
token_usage: TokenUsage {
input_tokens: 5000,
output_tokens: 2000,
total_tokens: 7000,
estimated_cost_usd: Some(0.07),
},
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("7000"));
assert!(json.contains("0.07"));
}
#[test]
fn test_tool_output() {
let output = ToolOutput {
success: true,
content: "File read successfully".into(),
data: Some(serde_json::json!({"lines": 100})),
exit_code: Some(0),
};
let json = serde_json::to_string(&output).unwrap();
assert!(json.contains("exit_code"));
}
#[test]
fn test_risk_level_variants() {
let levels = vec![
RiskLevel::None,
RiskLevel::Low,
RiskLevel::Medium,
RiskLevel::High,
RiskLevel::Critical,
];
for level in levels {
let json = serde_json::to_string(&level).unwrap();
let parsed: RiskLevel = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, level);
}
}
#[test]
fn test_risk_level_default() {
let level: RiskLevel = Default::default();
assert_eq!(level, RiskLevel::Medium);
}
#[test]
fn test_plan_granularity_variants() {
let granularities = vec![
PlanGranularity::Coarse,
PlanGranularity::Detailed,
PlanGranularity::Auto,
];
for g in granularities {
let json = serde_json::to_string(&g).unwrap();
let parsed: PlanGranularity = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, g);
}
}
#[test]
fn test_plan_granularity_default() {
let g: PlanGranularity = Default::default();
assert_eq!(g, PlanGranularity::Auto);
}
#[test]
fn test_task_plan() {
let plan = TaskPlan {
original_request: "Add authentication".into(),
steps: vec![
PlanStep {
id: "1".into(),
description: "Create auth module".into(),
expected_outcome: "Working auth".into(),
complexity: StepComplexity::Complex,
},
PlanStep {
id: "2".into(),
description: "Add tests".into(),
expected_outcome: "Tests passing".into(),
complexity: StepComplexity::Moderate,
},
],
agent_assignments: [("1".into(), AgentRole::Worker)].into(),
dependencies: vec![("1".into(), "2".into())],
estimated_tokens: 30_000,
};
let json = serde_json::to_string(&plan).unwrap();
assert!(json.contains("Add authentication"));
assert!(json.contains("30000"));
}
#[test]
fn test_step_complexity_variants() {
let complexities = vec![
StepComplexity::Simple,
StepComplexity::Moderate,
StepComplexity::Complex,
];
for c in complexities {
let json = serde_json::to_string(&c).unwrap();
let parsed: StepComplexity = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, c);
}
}
#[test]
fn test_step_complexity_default() {
let c: StepComplexity = Default::default();
assert_eq!(c, StepComplexity::Moderate);
}
#[test]
fn test_agent_tree_nested() {
let tree = AgentTree {
agent_id: AgentId::new(),
role: AgentRole::Orchestrator,
status: AgentStatus::Running,
task_summary: Some("Managing".into()),
children: vec![
AgentTree {
agent_id: AgentId::new(),
role: AgentRole::Worker,
status: AgentStatus::Running,
task_summary: Some("Coding".into()),
children: vec![],
},
AgentTree {
agent_id: AgentId::new(),
role: AgentRole::Worker,
status: AgentStatus::Waiting { reason: "Blocked".into() },
task_summary: Some("Testing".into()),
children: vec![],
},
],
};
let json = serde_json::to_string(&tree).unwrap();
assert!(json.contains("orchestrator"));
assert!(json.contains("children"));
}
#[test]
fn test_checkpoint_meta() {
use chrono::Utc;
let meta = CheckpointMeta {
id: CheckpointId::new(),
name: Some("Before refactor".into()),
timestamp: Utc::now(),
size_bytes: 1024 * 1024,
task_id: Some(TaskId::new()),
summary: "Checkpoint before major changes".into(),
};
let json = serde_json::to_string(&meta).unwrap();
assert!(json.contains("Before refactor"));
assert!(json.contains("1048576"));
}
#[test]
fn test_token_usage_default() {
let usage: TokenUsage = Default::default();
assert_eq!(usage.input_tokens, 0);
assert_eq!(usage.output_tokens, 0);
assert_eq!(usage.total_tokens, 0);
assert!(usage.estimated_cost_usd.is_none());
}
#[test]
fn test_token_usage_with_cost() {
let usage = TokenUsage {
input_tokens: 10_000,
output_tokens: 5_000,
total_tokens: 15_000,
estimated_cost_usd: Some(0.15),
};
let json = serde_json::to_string(&usage).unwrap();
assert!(json.contains("15000"));
assert!(json.contains("0.15"));
}
#[test]
fn test_message_type_variants() {
let types = vec![
MessageType::Text,
MessageType::Thinking,
MessageType::Code,
MessageType::Error,
MessageType::Status,
MessageType::Progress,
];
for t in types {
let json = serde_json::to_string(&t).unwrap();
let parsed: MessageType = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, t);
}
}
#[test]
fn test_message_type_default() {
let t: MessageType = Default::default();
assert_eq!(t, MessageType::Text);
}
#[test]
fn test_image_attachment() {
let attachment = ImageAttachment {
data: "iVBORw0KGgo=".into(),
mime_type: "image/png".into(),
filename: Some("screenshot.png".into()),
};
let json = serde_json::to_string(&attachment).unwrap();
assert!(json.contains("image/png"));
assert!(json.contains("screenshot.png"));
}
#[test]
fn test_session_settings_defaults() {
let settings: SessionSettings = Default::default();
assert!(!settings.show_rate_limit);
assert!(settings.subagent_concurrency.is_none());
assert_eq!(settings.plan_granularity, PlanGranularity::Auto);
}
#[test]
fn test_session_settings_custom() {
let settings = SessionSettings {
show_rate_limit: true,
subagent_concurrency: Some(8),
plan_granularity: PlanGranularity::Detailed,
};
let json = serde_json::to_string(&settings).unwrap();
let parsed: SessionSettings = serde_json::from_str(&json).unwrap();
assert!(parsed.show_rate_limit);
assert_eq!(parsed.subagent_concurrency, Some(8));
}
}