use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use serde_json::{json, Value};
use tokio::sync::{mpsc, Mutex};
use crate::agent::Agent;
use crate::channels::{ChannelHub, SessionMap};
use crate::config::{IterationLimitConfig, ModelsConfig, ProviderKind};
use crate::events::EventStore;
use crate::llm_runtime::{router_from_models, SharedLlmRuntime};
use crate::memory::embeddings::EmbeddingService;
use crate::providers::ProviderError;
use crate::state::SqliteStateStore;
use crate::tools::command_risk::{PermissionMode, RiskLevel};
use crate::tools::memory::RememberFactTool;
use crate::tools::{SystemInfoTool, TerminalTool};
use crate::traits::{
Channel, ChannelCapabilities, ChatOptions, ModelProvider, ProviderResponse, TokenUsage, Tool,
ToolCall, ToolRole,
};
use crate::types::{ApprovalResponse, MediaMessage};
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct MockChatCall {
pub model: String,
pub messages: Vec<Value>,
pub tools: Vec<Value>,
pub options: ChatOptions,
}
pub struct MockProvider {
responses: Mutex<Vec<ProviderResponse>>,
response_delays: Mutex<Vec<Duration>>,
pub call_log: Mutex<Vec<MockChatCall>>,
reject_non_default_options: bool,
pub skip_planning_calls: bool,
}
impl MockProvider {
pub fn new() -> Self {
Self {
responses: Mutex::new(Vec::new()),
response_delays: Mutex::new(Vec::new()),
call_log: Mutex::new(Vec::new()),
reject_non_default_options: false,
skip_planning_calls: true,
}
}
pub fn with_responses(responses: Vec<ProviderResponse>) -> Self {
Self {
responses: Mutex::new(responses),
response_delays: Mutex::new(Vec::new()),
call_log: Mutex::new(Vec::new()),
reject_non_default_options: false,
skip_planning_calls: true,
}
}
pub fn with_delayed_responses(
responses: Vec<ProviderResponse>,
response_delays: Vec<Duration>,
) -> Self {
Self {
responses: Mutex::new(responses),
response_delays: Mutex::new(response_delays),
call_log: Mutex::new(Vec::new()),
reject_non_default_options: false,
skip_planning_calls: true,
}
}
pub fn rejecting_non_default_options(mut self) -> Self {
self.reject_non_default_options = true;
self
}
pub fn text_response(text: &str) -> ProviderResponse {
ProviderResponse {
content: Some(text.to_string()),
tool_calls: vec![],
usage: Some(TokenUsage {
input_tokens: 10,
output_tokens: 5,
cached_input_tokens: None,
cache_creation_input_tokens: None,
model: "mock".to_string(),
}),
thinking: None,
response_note: None,
}
}
pub fn tool_call_response(tool_name: &str, args: &str) -> ProviderResponse {
ProviderResponse {
content: None,
tool_calls: vec![ToolCall {
id: format!("call_{}", uuid::Uuid::new_v4()),
name: tool_name.to_string(),
arguments: args.to_string(),
extra_content: None,
}],
usage: Some(TokenUsage {
input_tokens: 10,
output_tokens: 5,
cached_input_tokens: None,
cache_creation_input_tokens: None,
model: "mock".to_string(),
}),
thinking: None,
response_note: None,
}
}
pub async fn call_count(&self) -> usize {
self.call_log.lock().await.len()
}
}
#[async_trait]
impl ModelProvider for MockProvider {
async fn chat(
&self,
model: &str,
messages: &[Value],
tools: &[Value],
) -> anyhow::Result<ProviderResponse> {
self.chat_with_options(model, messages, tools, &ChatOptions::default())
.await
}
async fn chat_with_options(
&self,
model: &str,
messages: &[Value],
tools: &[Value],
options: &ChatOptions,
) -> anyhow::Result<ProviderResponse> {
if self.skip_planning_calls {
let is_planning_call = messages.iter().any(|m| {
m.get("content")
.and_then(|c| c.as_str())
.is_some_and(|s| s.contains("task planner") || s.contains("progress evaluator"))
});
if is_planning_call {
return Ok(ProviderResponse {
content: None,
tool_calls: vec![],
usage: None,
thinking: None,
response_note: None,
});
}
}
self.call_log.lock().await.push(MockChatCall {
model: model.to_string(),
messages: messages.to_vec(),
tools: tools.to_vec(),
options: options.clone(),
});
if self.reject_non_default_options && *options != ChatOptions::default() {
return Err(ProviderError::from_status(400, "unsupported chat options").into());
}
let delay = {
let mut response_delays = self.response_delays.lock().await;
if response_delays.is_empty() {
None
} else {
Some(response_delays.remove(0))
}
};
if let Some(delay) = delay {
tokio::time::sleep(delay).await;
}
let mut responses = self.responses.lock().await;
if responses.is_empty() {
Ok(MockProvider::text_response("Mock response"))
} else {
Ok(responses.remove(0))
}
}
async fn list_models(&self) -> anyhow::Result<Vec<String>> {
Ok(vec!["mock-model".to_string()])
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct SentMessage {
pub session_id: String,
pub text: String,
}
pub struct TestChannel {
pub messages: Mutex<Vec<SentMessage>>,
pub default_approval: Mutex<ApprovalResponse>,
}
impl TestChannel {
pub fn new() -> Self {
Self {
messages: Mutex::new(Vec::new()),
default_approval: Mutex::new(ApprovalResponse::AllowOnce),
}
}
#[allow(dead_code)]
pub async fn messages_for(&self, session_id: &str) -> Vec<String> {
self.messages
.lock()
.await
.iter()
.filter(|m| m.session_id == session_id)
.map(|m| m.text.clone())
.collect()
}
#[allow(dead_code)]
pub async fn message_count(&self) -> usize {
self.messages.lock().await.len()
}
}
#[async_trait]
impl Channel for TestChannel {
fn name(&self) -> String {
"test".to_string()
}
fn capabilities(&self) -> ChannelCapabilities {
ChannelCapabilities {
markdown: true,
inline_buttons: false,
media: false,
max_message_len: 4096,
}
}
async fn send_text(&self, session_id: &str, text: &str) -> anyhow::Result<()> {
self.messages.lock().await.push(SentMessage {
session_id: session_id.to_string(),
text: text.to_string(),
});
Ok(())
}
async fn send_media(&self, _session_id: &str, _media: &MediaMessage) -> anyhow::Result<()> {
Ok(())
}
async fn request_approval(
&self,
_session_id: &str,
_command: &str,
_risk_level: RiskLevel,
_warnings: &[String],
_permission_mode: PermissionMode,
) -> anyhow::Result<ApprovalResponse> {
Ok(self.default_approval.lock().await.clone())
}
}
#[allow(dead_code)]
pub struct TestHarness {
pub agent: Agent,
pub state: Arc<SqliteStateStore>,
pub provider: Arc<MockProvider>,
pub channel: Arc<TestChannel>,
_db_file: tempfile::NamedTempFile,
_skills_dir: tempfile::TempDir,
}
pub async fn setup_test_agent(provider: MockProvider) -> anyhow::Result<TestHarness> {
setup_test_agent_internal(provider, vec![], None, true, true).await
}
#[allow(dead_code)]
pub async fn setup_test_agent_root(provider: MockProvider) -> anyhow::Result<TestHarness> {
setup_test_agent_internal(provider, vec![], None, false, true).await
}
#[allow(dead_code)]
pub async fn setup_test_agent_root_with_extra_tools_and_llm_timeout(
provider: MockProvider,
extra_tools: Vec<Arc<dyn Tool>>,
llm_call_timeout_secs: Option<u64>,
) -> anyhow::Result<TestHarness> {
setup_test_agent_internal(provider, extra_tools, llm_call_timeout_secs, false, true).await
}
#[allow(dead_code)]
pub async fn setup_test_agent_root_with_only_tools_and_llm_timeout(
provider: MockProvider,
tools: Vec<Arc<dyn Tool>>,
llm_call_timeout_secs: Option<u64>,
) -> anyhow::Result<TestHarness> {
setup_test_agent_internal(provider, tools, llm_call_timeout_secs, false, false).await
}
pub async fn setup_test_agent_with_extra_tools_and_llm_timeout(
provider: MockProvider,
extra_tools: Vec<Arc<dyn Tool>>,
llm_call_timeout_secs: Option<u64>,
) -> anyhow::Result<TestHarness> {
setup_test_agent_internal(provider, extra_tools, llm_call_timeout_secs, true, true).await
}
async fn setup_test_agent_internal(
provider: MockProvider,
extra_tools: Vec<Arc<dyn Tool>>,
llm_call_timeout_secs: Option<u64>,
use_test_executor_mode: bool,
include_default_tools: bool,
) -> anyhow::Result<TestHarness> {
let db_file = tempfile::NamedTempFile::new()?;
let db_path = db_file.path().to_str().unwrap().to_string();
let skills_dir = tempfile::TempDir::new()?;
let embedding_service = Arc::new(EmbeddingService::new()?);
let state = Arc::new(SqliteStateStore::new(&db_path, 100, None, embedding_service).await?);
let event_store = Arc::new(EventStore::new(state.pool()).await?);
let provider = Arc::new(provider);
let mut tools: Vec<Arc<dyn Tool>> = if include_default_tools {
vec![
Arc::new(SystemInfoTool),
Arc::new(RememberFactTool::new(
state.clone() as Arc<dyn crate::traits::StateStore>
)),
]
} else {
Vec::new()
};
tools.extend(extra_tools);
let models_config = ModelsConfig {
default_model: "mock-model".to_string(),
fallback_models: Vec::new(),
primary: "mock-model".to_string(),
fast: "mock-model".to_string(),
smart: "mock-model".to_string(),
};
let llm_runtime = SharedLlmRuntime::new(
provider.clone() as Arc<dyn ModelProvider>,
router_from_models(models_config.clone()),
ProviderKind::OpenaiCompatible,
models_config.primary.clone(),
);
let goal_token_registry = crate::goal_tokens::GoalTokenRegistry::new();
let mut agent = Agent::new(
llm_runtime,
state.clone() as Arc<dyn crate::traits::StateStore>,
event_store,
tools,
"mock-model".to_string(), "You are a helpful test assistant.".to_string(), PathBuf::from("config.toml"), skills_dir.path().to_path_buf(), 3, 50, 100, 8000, 30, 20, None, IterationLimitConfig::Unlimited,
None, None, llm_call_timeout_secs, None, Some(goal_token_registry), None, true, crate::config::ContextWindowConfig {
progressive_facts: false,
..Default::default()
},
crate::config::PolicyConfig::default(),
crate::config::PathAliasConfig::default(),
None,
Arc::new(crate::agent::specialists::SpecialistRegistry::load(None)),
None, crate::config::VisionConfig::from_files(&crate::config::FilesConfig::default()),
crate::config::AudioConfig::from_files(&crate::config::FilesConfig::default()),
crate::config::SttConfig::from_files(&crate::config::FilesConfig::default()),
(&crate::config::DiagnosticsHarnessEvalConfig::default()).into(),
);
if use_test_executor_mode {
agent.set_test_executor_mode();
}
let channel = Arc::new(TestChannel::new());
Ok(TestHarness {
agent,
state,
provider,
channel,
_db_file: db_file,
_skills_dir: skills_dir,
})
}
#[allow(dead_code)]
pub async fn setup_test_agent_with_models(
provider: MockProvider,
primary_model: &str,
smart_model: &str,
) -> anyhow::Result<TestHarness> {
let db_file = tempfile::NamedTempFile::new()?;
let db_path = db_file.path().to_str().unwrap().to_string();
let skills_dir = tempfile::TempDir::new()?;
let embedding_service = Arc::new(EmbeddingService::new()?);
let state = Arc::new(SqliteStateStore::new(&db_path, 100, None, embedding_service).await?);
let event_store = Arc::new(EventStore::new(state.pool()).await?);
let provider = Arc::new(provider);
let tools: Vec<Arc<dyn Tool>> = vec![
Arc::new(SystemInfoTool),
Arc::new(RememberFactTool::new(
state.clone() as Arc<dyn crate::traits::StateStore>
)),
];
let models_config = ModelsConfig {
default_model: primary_model.to_string(),
fallback_models: vec![smart_model.to_string()],
primary: primary_model.to_string(),
fast: smart_model.to_string(),
smart: smart_model.to_string(),
};
let llm_runtime = SharedLlmRuntime::new(
provider.clone() as Arc<dyn ModelProvider>,
router_from_models(models_config.clone()),
ProviderKind::OpenaiCompatible,
models_config.primary.clone(),
);
let goal_token_registry = crate::goal_tokens::GoalTokenRegistry::new();
let agent = Agent::new(
llm_runtime,
state.clone() as Arc<dyn crate::traits::StateStore>,
event_store,
tools,
primary_model.to_string(),
"You are a helpful test assistant.".to_string(),
PathBuf::from("config.toml"),
skills_dir.path().to_path_buf(),
3, 50, 100, 8000, 30, 20, None, IterationLimitConfig::Unlimited,
None, None, None, None, Some(goal_token_registry), None, true, crate::config::ContextWindowConfig {
progressive_facts: false,
..Default::default()
},
crate::config::PolicyConfig::default(),
crate::config::PathAliasConfig::default(),
None,
Arc::new(crate::agent::specialists::SpecialistRegistry::load(None)),
None, crate::config::VisionConfig::from_files(&crate::config::FilesConfig::default()),
crate::config::AudioConfig::from_files(&crate::config::FilesConfig::default()),
crate::config::SttConfig::from_files(&crate::config::FilesConfig::default()),
(&crate::config::DiagnosticsHarnessEvalConfig::default()).into(),
);
let channel = Arc::new(TestChannel::new());
Ok(TestHarness {
agent,
state,
provider,
channel,
_db_file: db_file,
_skills_dir: skills_dir,
})
}
#[allow(dead_code)]
pub async fn setup_test_agent_orchestrator(provider: MockProvider) -> anyhow::Result<TestHarness> {
let db_file = tempfile::NamedTempFile::new()?;
let db_path = db_file.path().to_str().unwrap().to_string();
let skills_dir = tempfile::TempDir::new()?;
let embedding_service = Arc::new(EmbeddingService::new()?);
let state = Arc::new(SqliteStateStore::new(&db_path, 100, None, embedding_service).await?);
let event_store = Arc::new(EventStore::new(state.pool()).await?);
let provider = Arc::new(provider);
let tools: Vec<Arc<dyn Tool>> = vec![
Arc::new(SystemInfoTool),
Arc::new(RememberFactTool::new(
state.clone() as Arc<dyn crate::traits::StateStore>
)),
];
let models_config = ModelsConfig {
default_model: "primary-model".to_string(),
fallback_models: vec!["fast-model".to_string(), "smart-model".to_string()],
primary: "primary-model".to_string(),
fast: "fast-model".to_string(),
smart: "smart-model".to_string(),
};
let llm_runtime = SharedLlmRuntime::new(
provider.clone() as Arc<dyn ModelProvider>,
router_from_models(models_config.clone()),
ProviderKind::OpenaiCompatible,
models_config.primary.clone(),
);
let goal_token_registry = crate::goal_tokens::GoalTokenRegistry::new();
let agent = Agent::new(
llm_runtime,
state.clone() as Arc<dyn crate::traits::StateStore>,
event_store,
tools,
"primary-model".to_string(),
"You are a helpful test assistant.".to_string(),
PathBuf::from("config.toml"),
skills_dir.path().to_path_buf(),
3, 50, 100, 8000, 30, 20, None, IterationLimitConfig::Unlimited,
None, None, None, None, Some(goal_token_registry), None, true, crate::config::ContextWindowConfig {
progressive_facts: false,
..Default::default()
},
crate::config::PolicyConfig::default(),
crate::config::PathAliasConfig::default(),
None,
Arc::new(crate::agent::specialists::SpecialistRegistry::load(None)),
None, crate::config::VisionConfig::from_files(&crate::config::FilesConfig::default()),
crate::config::AudioConfig::from_files(&crate::config::FilesConfig::default()),
crate::config::SttConfig::from_files(&crate::config::FilesConfig::default()),
(&crate::config::DiagnosticsHarnessEvalConfig::default()).into(),
);
let channel = Arc::new(TestChannel::new());
Ok(TestHarness {
agent,
state,
provider,
channel,
_db_file: db_file,
_skills_dir: skills_dir,
})
}
#[allow(dead_code)]
pub async fn setup_test_agent_orchestrator_task_leads(
provider: MockProvider,
) -> anyhow::Result<TestHarness> {
setup_test_agent_orchestrator(provider).await
}
#[allow(dead_code)]
pub struct MockTool {
tool_name: String,
tool_description: String,
return_value: String,
role: ToolRole,
available: bool,
}
#[allow(dead_code)]
impl MockTool {
pub fn new(name: &str, description: &str, return_value: &str) -> Self {
Self {
tool_name: name.to_string(),
tool_description: description.to_string(),
return_value: return_value.to_string(),
role: ToolRole::Action,
available: true,
}
}
pub fn with_role(mut self, role: ToolRole) -> Self {
self.role = role;
self
}
pub fn with_availability(mut self, available: bool) -> Self {
self.available = available;
self
}
}
#[async_trait]
impl Tool for MockTool {
fn name(&self) -> &str {
&self.tool_name
}
fn description(&self) -> &str {
&self.tool_description
}
fn schema(&self) -> Value {
json!({
"name": self.tool_name,
"description": self.tool_description,
"parameters": {
"type": "object",
"properties": {},
"additionalProperties": false
}
})
}
async fn call(&self, _args: &str) -> anyhow::Result<String> {
Ok(self.return_value.clone())
}
fn tool_role(&self) -> ToolRole {
self.role
}
fn is_available(&self) -> bool {
self.available
}
}
#[allow(dead_code)]
pub struct FullStackTestHarness {
pub agent: Agent,
pub state: Arc<SqliteStateStore>,
pub provider: Arc<MockProvider>,
pub channel: Arc<TestChannel>,
pub hub: Arc<ChannelHub>,
pub session_map: SessionMap,
_db_file: tempfile::NamedTempFile,
_skills_dir: tempfile::TempDir,
_approval_task: tokio::task::JoinHandle<()>,
}
#[allow(dead_code)]
pub async fn setup_full_stack_test_agent(
provider: MockProvider,
) -> anyhow::Result<FullStackTestHarness> {
setup_full_stack_test_agent_with_extra_tools(provider, vec![]).await
}
#[allow(dead_code)]
pub async fn setup_full_stack_test_agent_with_extra_tools(
provider: MockProvider,
extra_tools: Vec<Arc<dyn Tool>>,
) -> anyhow::Result<FullStackTestHarness> {
let db_file = tempfile::NamedTempFile::new()?;
let db_path = db_file.path().to_str().unwrap().to_string();
let skills_dir = tempfile::TempDir::new()?;
let embedding_service = Arc::new(EmbeddingService::new()?);
let state = Arc::new(SqliteStateStore::new(&db_path, 100, None, embedding_service).await?);
let pool = state.pool();
let event_store = Arc::new(EventStore::new(pool.clone()).await?);
let (approval_tx, approval_rx) = mpsc::channel(16);
let approval_tx = crate::tools::ApprovalBroker::new(approval_tx);
let terminal_tool = Arc::new(
TerminalTool::new(
vec!["*".to_string()],
approval_tx,
30,
8000,
PermissionMode::Yolo,
pool.clone(),
)
.await,
);
let mut tools: Vec<Arc<dyn Tool>> = vec![
Arc::new(SystemInfoTool),
Arc::new(RememberFactTool::new(
state.clone() as Arc<dyn crate::traits::StateStore>
)),
terminal_tool,
];
tools.extend(extra_tools);
let provider = Arc::new(provider);
let models_config = ModelsConfig {
default_model: "mock-model".to_string(),
fallback_models: Vec::new(),
primary: "mock-model".to_string(),
fast: "mock-model".to_string(),
smart: "mock-model".to_string(),
};
let llm_runtime = SharedLlmRuntime::new(
provider.clone() as Arc<dyn ModelProvider>,
router_from_models(models_config.clone()),
ProviderKind::OpenaiCompatible,
models_config.primary.clone(),
);
let goal_token_registry = crate::goal_tokens::GoalTokenRegistry::new();
let mut agent = Agent::new(
llm_runtime,
state.clone() as Arc<dyn crate::traits::StateStore>,
event_store,
tools,
"mock-model".to_string(),
"You are a helpful test assistant.".to_string(),
PathBuf::from("config.toml"),
skills_dir.path().to_path_buf(),
3, 50, 100, 8000, 30, 20, None, IterationLimitConfig::Unlimited,
None, None, None, None, Some(goal_token_registry), None, true, crate::config::ContextWindowConfig {
progressive_facts: false,
..Default::default()
},
crate::config::PolicyConfig::default(),
crate::config::PathAliasConfig::default(),
None,
Arc::new(crate::agent::specialists::SpecialistRegistry::load(None)),
None, crate::config::VisionConfig::from_files(&crate::config::FilesConfig::default()),
crate::config::AudioConfig::from_files(&crate::config::FilesConfig::default()),
crate::config::SttConfig::from_files(&crate::config::FilesConfig::default()),
(&crate::config::DiagnosticsHarnessEvalConfig::default()).into(),
);
agent.set_test_executor_mode();
let channel = Arc::new(TestChannel::new());
let mut map = HashMap::new();
map.insert("telegram_test".to_string(), "test".to_string());
let session_map: SessionMap = Arc::new(tokio::sync::RwLock::new(map));
let hub = Arc::new(ChannelHub::new(
vec![channel.clone() as Arc<dyn Channel>],
session_map.clone(),
));
let hub_for_approvals = hub.clone();
let approval_task = tokio::spawn(async move {
hub_for_approvals.approval_listener(approval_rx).await;
});
Ok(FullStackTestHarness {
agent,
state,
provider,
channel,
hub,
session_map,
_db_file: db_file,
_skills_dir: skills_dir,
_approval_task: approval_task,
})
}