mod helpers;
use super::events::{AgentEvent, AgentResult};
use super::state::AgentMetrics;
use super::state_formatter::format_todo_summary;
use super::{AgentConfig, AgentState};
use crate::hooks::{HookContext, HookEvent, HookInput, HookManager, HookOutput};
use crate::session::types::TodoItem;
use crate::session::{Session, SessionConfig};
use crate::tools::{ExecutionContext, ToolOutput, ToolRegistry, ToolResult};
use crate::types::{StopReason, ToolResultBlock, Usage};
use std::sync::Arc;
use std::sync::atomic::Ordering;
#[test]
fn test_agent_result() {
let metrics = AgentMetrics {
iterations: 3,
tool_calls: 2,
..Default::default()
};
let result = AgentResult {
text: "Hello".to_string(),
usage: Usage {
input_tokens: 100,
output_tokens: 50,
..Default::default()
},
tool_calls: 2,
iterations: 3,
stop_reason: StopReason::EndTurn,
state: AgentState::Completed,
metrics,
session_id: "test-session".to_string(),
structured_output: None,
messages: Vec::new(),
uuid: "test-uuid".to_string(),
};
assert_eq!(result.text(), "Hello");
assert_eq!(result.total_tokens(), 150);
assert!(result.state.is_terminal());
assert_eq!(result.metrics().iterations, 3);
}
#[test]
fn test_agent_result_session_id() {
let result = AgentResult {
text: String::new(),
usage: Usage::default(),
tool_calls: 0,
iterations: 1,
stop_reason: StopReason::EndTurn,
state: AgentState::Completed,
metrics: AgentMetrics::default(),
session_id: "my-session-123".to_string(),
structured_output: None,
messages: Vec::new(),
uuid: "test-uuid".to_string(),
};
assert_eq!(result.session_id(), "my-session-123");
}
#[test]
fn test_agent_result_extract_success() {
#[derive(serde::Deserialize, PartialEq, Debug)]
struct TestOutput {
value: i32,
}
let result = AgentResult {
text: String::new(),
usage: Usage::default(),
tool_calls: 0,
iterations: 1,
stop_reason: StopReason::EndTurn,
state: AgentState::Completed,
metrics: AgentMetrics::default(),
session_id: "test".to_string(),
structured_output: Some(serde_json::json!({"value": 42})),
messages: Vec::new(),
uuid: "test-uuid".to_string(),
};
let extracted: TestOutput = result.extract().unwrap();
assert_eq!(extracted.value, 42);
}
#[test]
fn test_agent_result_extract_no_output() {
let result = AgentResult {
text: String::new(),
usage: Usage::default(),
tool_calls: 0,
iterations: 1,
stop_reason: StopReason::EndTurn,
state: AgentState::Completed,
metrics: AgentMetrics::default(),
session_id: "test".to_string(),
structured_output: None,
messages: Vec::new(),
uuid: "test-uuid".to_string(),
};
let extracted: Result<serde_json::Value, _> = result.extract();
assert!(extracted.is_err());
}
#[test]
fn test_agent_event_variants() {
let text_event = AgentEvent::Text("Hello".to_string());
assert!(matches!(text_event, AgentEvent::Text(_)));
let tool_complete = AgentEvent::ToolComplete {
id: "tool_1".to_string(),
name: "Read".to_string(),
output: "file content".to_string(),
is_error: false,
duration_ms: 50,
};
assert!(matches!(
tool_complete,
AgentEvent::ToolComplete {
is_error: false,
..
}
));
let tool_blocked = AgentEvent::ToolBlocked {
id: "tool_2".to_string(),
name: "Bash".to_string(),
reason: "Permission denied".to_string(),
};
assert!(matches!(tool_blocked, AgentEvent::ToolBlocked { .. }));
}
#[test]
fn test_stop_reason_variants() {
assert_eq!(StopReason::EndTurn, StopReason::EndTurn);
assert_ne!(StopReason::EndTurn, StopReason::MaxTokens);
assert_ne!(StopReason::EndTurn, StopReason::ToolUse);
}
#[test]
fn test_tool_result_block_from_tool_result() {
let result = ToolResult::success("content");
let block = ToolResultBlock::from_tool_result("tool_123", &result);
assert_eq!(block.tool_use_id, "tool_123");
assert!(!block.is_error.unwrap_or(false));
}
#[test]
fn test_session_messages_basic() {
let mut session = Session::new(SessionConfig::default());
assert!(session.messages.is_empty());
session.add_user_message("Hello");
assert_eq!(session.messages.len(), 1);
}
#[test]
fn test_session_usage_update() {
let mut session = Session::new(SessionConfig::default());
session.add_user_message("Test");
session.update_usage(&Usage {
input_tokens: 100,
output_tokens: 50,
cache_read_input_tokens: Some(10),
cache_creation_input_tokens: None,
..Default::default()
});
assert_eq!(session.current_input_tokens, 110);
assert_eq!(session.total_usage.input_tokens, 100);
assert_eq!(session.total_usage.output_tokens, 50);
}
#[test]
fn test_agent_metrics_recording() {
let mut metrics = AgentMetrics {
iterations: 5,
..Default::default()
};
metrics.record_api_call();
metrics.record_api_call();
assert_eq!(metrics.api_calls, 2);
metrics.record_tool("tu_1", "Read", 50, false);
metrics.record_tool("tu_2", "Read", 30, false);
metrics.record_tool("tu_3", "Bash", 100, true);
assert_eq!(metrics.tool_calls, 3);
assert_eq!(metrics.errors, 1);
}
#[test]
fn test_agent_state_transitions() {
assert!(AgentState::Initializing.can_continue());
assert!(AgentState::Running.can_continue());
assert!(!AgentState::Completed.can_continue());
assert!(!AgentState::Failed.can_continue());
assert!(AgentState::WaitingForToolResults.is_waiting());
assert!(AgentState::WaitingForUserInput.is_waiting());
assert!(!AgentState::Running.is_waiting());
assert!(AgentState::Completed.is_terminal());
assert!(AgentState::Failed.is_terminal());
assert!(!AgentState::Running.is_terminal());
}
#[test]
fn test_hook_context_builder() {
let hook_context = HookContext::new("session-1")
.cwd(std::path::PathBuf::from("/test/dir"))
.env([("KEY".to_string(), "VALUE".to_string())].into());
assert_eq!(hook_context.session_id, "session-1");
assert_eq!(
hook_context.cwd,
Some(std::path::PathBuf::from("/test/dir"))
);
assert_eq!(hook_context.env.get("KEY"), Some(&"VALUE".to_string()));
}
#[test]
fn test_hook_output_builder() {
let allow = HookOutput::allow();
assert!(allow.continue_execution);
let block = HookOutput::block("reason");
assert!(!block.continue_execution);
assert_eq!(block.stop_reason, Some("reason".to_string()));
}
#[test]
fn test_hook_event_can_block() {
assert!(HookEvent::PreToolUse.can_block());
assert!(HookEvent::UserPromptSubmit.can_block());
assert!(HookEvent::SessionStart.can_block());
assert!(!HookEvent::PreCompact.can_block());
assert!(HookEvent::SubagentStart.can_block());
assert!(!HookEvent::PostToolUse.can_block());
assert!(!HookEvent::SessionEnd.can_block());
}
#[test]
fn test_tool_result_variants() {
let success = ToolResult::success("content");
assert!(!success.is_error());
assert_eq!(success.text(), "content");
let error = ToolResult::error("failed");
assert!(error.is_error());
assert!(error.error_message().contains("failed"));
let empty = ToolResult::empty();
assert!(!empty.is_error());
assert_eq!(empty.text(), "");
}
#[test]
fn test_agent_config_default_values() {
let config = AgentConfig::default();
assert_eq!(config.execution.max_iterations, 100);
assert!(config.execution.auto_compact);
assert!(config.execution.timeout.is_some());
}
#[test]
fn test_usage_accumulation() {
let mut usage = Usage::default();
assert_eq!(usage.total(), 0);
usage.input_tokens = 100;
usage.output_tokens = 50;
assert_eq!(usage.total(), 150);
}
#[test]
fn test_format_todo_summary_empty() {
let todos: Vec<TodoItem> = vec![];
let summary = format_todo_summary(&todos);
assert!(summary.is_empty());
}
#[test]
fn test_format_todo_summary_with_items() {
use crate::session::SessionId;
let session_id = SessionId::new();
let mut todo1 = TodoItem::new(session_id, "Fix bug", "Fixing bug");
todo1.start();
let todo2 = TodoItem::new(session_id, "Write tests", "Writing tests");
let mut todo3 = TodoItem::new(session_id, "Deploy", "Deploying");
todo3.complete();
let todos = vec![todo1, todo2, todo3];
let summary = format_todo_summary(&todos);
assert!(summary.contains("1."));
assert!(summary.contains("Fix bug"));
assert!(summary.contains("Write tests"));
}
#[tokio::test]
async fn test_hook_manager_integration() {
use helpers::TestTrackingHook;
let mut hooks = HookManager::new();
let hook = TestTrackingHook::new("test-hook", vec![HookEvent::PreToolUse]);
let call_count = hook.call_count.clone();
hooks.register(hook);
let input = HookInput::pre_tool_use("session", "Read", serde_json::json!({}));
let hook_context = HookContext::new("session");
let output = hooks
.execute(HookEvent::PreToolUse, input, &hook_context)
.await
.unwrap();
assert!(output.continue_execution);
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_hook_blocking() {
use helpers::BlockingHook;
let mut hooks = HookManager::new();
hooks.register(BlockingHook {
reason: "blocked".to_string(),
});
let input = HookInput::user_prompt_submit("session", "test");
let hook_context = HookContext::new("session");
let output = hooks
.execute(HookEvent::UserPromptSubmit, input, &hook_context)
.await
.unwrap();
assert!(!output.continue_execution);
assert_eq!(output.stop_reason, Some("blocked".to_string()));
}
#[tokio::test]
async fn test_hook_input_modification() {
use helpers::InputModifyingHook;
let mut hooks = HookManager::new();
hooks.register(InputModifyingHook);
let input = HookInput::pre_tool_use(
"session",
"Read",
serde_json::json!({"file_path": "/original/path"}),
);
let hook_context = HookContext::new("session");
let output = hooks
.execute(HookEvent::PreToolUse, input, &hook_context)
.await
.unwrap();
assert!(output.continue_execution);
assert!(output.updated_input.is_some());
let updated = output.updated_input.unwrap();
assert_eq!(updated["file_path"], "/modified/path");
}
#[test]
fn test_tool_registry_with_dummy() {
use helpers::DummyTool;
let mut registry = ToolRegistry::new();
let tool = Arc::new(DummyTool {
name: "TestTool".to_string(),
output: ToolOutput::Success("success".to_string()),
});
registry.register(tool);
assert!(registry.contains("TestTool"));
assert_eq!(registry.names().len(), 1);
}
#[tokio::test]
async fn test_tool_registry_execute() {
use helpers::DummyTool;
let mut registry = ToolRegistry::from_context(ExecutionContext::permissive());
let tool = Arc::new(DummyTool {
name: "TestTool".to_string(),
output: ToolOutput::Success("test output".to_string()),
});
registry.register(tool);
let result = registry.execute("TestTool", serde_json::json!({})).await;
assert!(!result.is_error());
assert_eq!(result.text(), "test output");
}
#[tokio::test]
async fn test_tool_registry_execute_unknown() {
let registry = ToolRegistry::new();
let result = registry.execute("UnknownTool", serde_json::json!({})).await;
assert!(result.is_error());
assert!(result.error_message().contains("unknown tool"));
}