use axum::{routing::get, Router};
use axum_test::TestServer;
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
use uuid::Uuid;
fn unique_email(prefix: &str) -> String {
format!("{}+{}@test.example.com", prefix, Uuid::new_v4())
}
use ares::{
auth::jwt::AuthService,
db::PostgresClient,
llm::LLMClient,
types::{ToolCall, ToolDefinition},
utils::toml_config::{
AgentConfig, AresConfig, AuthConfig as TomlAuthConfig,
DatabaseConfig as TomlDatabaseConfig, DynamicConfigPaths, ModelConfig, ProviderConfig,
RagConfig, ServerConfig as TomlServerConfig,
},
AgentRegistry, AppState, AresConfigManager, ConfigBasedLLMFactory, DynamicConfigManager,
ProviderRegistry, ToolRegistry,
};
use futures::StreamExt;
mod common;
use common::mocks::MockLLMClient;
#[allow(unused_imports)]
use common::mocks::MockLLMFactory;
async fn create_test_app() -> Router {
let db = common::test_db::create_test_db().await;
let auth_service = AuthService::new(
"test_jwt_secret_key_for_testing_only".to_string(),
900, 604800, );
let mut providers = HashMap::new();
providers.insert(
"ollama-local".to_string(),
ProviderConfig::Ollama {
base_url: "http://localhost:11434".to_string(),
default_model: "ministral-3:3b".to_string(),
},
);
let mut models = HashMap::new();
models.insert(
"default".to_string(),
ModelConfig {
provider: "ollama-local".to_string(),
model: "ministral-3:3b".to_string(),
temperature: 0.7,
max_tokens: 512,
top_p: None,
frequency_penalty: None,
presence_penalty: None,
},
);
let mut agents = HashMap::new();
agents.insert(
"router".to_string(),
AgentConfig {
model: "default".to_string(),
system_prompt: Some("You are a routing agent.".to_string()),
tools: vec![],
max_tool_iterations: 10,
parallel_tools: false,
extra: HashMap::new(),
},
);
agents.insert(
"product".to_string(),
AgentConfig {
model: "default".to_string(),
system_prompt: Some("You are a product support agent.".to_string()),
tools: vec![],
max_tool_iterations: 10,
parallel_tools: false,
extra: HashMap::new(),
},
);
let ares_config = AresConfig {
server: TomlServerConfig {
host: "127.0.0.1".to_string(),
port: 3000,
log_level: "debug".to_string(),
cors_origins: vec!["*".to_string()],
rate_limit_per_second: 0, rate_limit_burst: 0,
},
auth: TomlAuthConfig {
jwt_secret_env: "TEST_JWT_SECRET".to_string(),
jwt_access_expiry: 900,
jwt_refresh_expiry: 604800,
api_key_env: "TEST_API_KEY".to_string(),
},
database: TomlDatabaseConfig {
url: "postgres://postgres:postgres@localhost:5432/ares_test".to_string(),
qdrant: None,
},
config: DynamicConfigPaths::default(),
providers,
models,
tools: HashMap::new(),
agents,
workflows: HashMap::new(),
rag: RagConfig::default(),
#[cfg(feature = "skills")]
skills: None,
};
let config_manager = Arc::new(AresConfigManager::from_config(ares_config));
let provider_registry = Arc::new(ProviderRegistry::from_config(&config_manager.config()));
let llm_factory = Arc::new(ConfigBasedLLMFactory::new(
provider_registry.clone(),
"default",
));
let tool_registry = Arc::new(ToolRegistry::with_config(&config_manager.config()));
let agent_registry = Arc::new(AgentRegistry::from_config(
&config_manager.config(),
provider_registry.clone(),
tool_registry.clone(),
));
let temp_dir = tempfile::TempDir::new().expect("Failed to create temp dir");
let base = temp_dir.path();
std::fs::create_dir_all(base.join("agents")).unwrap();
std::fs::create_dir_all(base.join("models")).unwrap();
std::fs::create_dir_all(base.join("tools")).unwrap();
std::fs::create_dir_all(base.join("workflows")).unwrap();
std::fs::create_dir_all(base.join("mcps")).unwrap();
let dynamic_config = Arc::new(
DynamicConfigManager::new(
base.join("agents"),
base.join("models"),
base.join("tools"),
base.join("workflows"),
base.join("mcps"),
false, )
.expect("Failed to create DynamicConfigManager"),
);
let db = Arc::new(db);
let state = AppState {
config_manager,
db: db.clone(),
tenant_db: Arc::new(ares::db::TenantDb::new(db)),
llm_factory,
provider_registry,
agent_registry,
tool_registry,
auth_service: Arc::new(auth_service),
dynamic_config,
deploy_registry: ares::api::handlers::deploy::DeployRegistry::default(),
emergency_stop: Arc::new(std::sync::atomic::AtomicBool::new(false)),
context_provider: Arc::new(ares::agents::context_provider::NoOpContextProvider),
#[cfg(feature = "mcp")]
mcp_registry: None,
};
Router::new()
.route("/health", get(|| async { "OK" }))
.nest(
"/api",
ares::api::routes::create_router(state.auth_service.clone(), state.tenant_db.clone()),
)
.with_state(state)
}
async fn create_test_server() -> TestServer {
let app = create_test_app().await;
TestServer::new(app).expect("Failed to create test server")
}
#[tokio::test]
async fn test_health_check() {
let server = create_test_server().await;
let response = server.get("/health").await;
response.assert_status_ok();
response.assert_text("OK");
}
#[tokio::test]
async fn test_health_check_multiple_times() {
let server = create_test_server().await;
for _ in 0..5 {
let response = server.get("/health").await;
response.assert_status_ok();
}
}
#[tokio::test]
async fn test_register_user() {
let server = create_test_server().await;
let email = unique_email("register");
let response = server
.post("/api/auth/register")
.json(&json!({
"email": email,
"password": "password123",
"name": "Test User"
}))
.await;
response.assert_status_ok();
let body: serde_json::Value = response.json();
assert!(body["access_token"].is_string());
assert!(body["refresh_token"].is_string());
assert!(body["expires_in"].is_number());
}
#[tokio::test]
async fn test_register_and_login() {
let server = create_test_server().await;
let email = unique_email("login");
let response = server
.post("/api/auth/register")
.json(&json!({
"email": email,
"password": "password123",
"name": "Test User"
}))
.await;
response.assert_status_ok();
let body: serde_json::Value = response.json();
assert!(body["access_token"].is_string());
let response = server
.post("/api/auth/login")
.json(&json!({
"email": email,
"password": "password123"
}))
.await;
response.assert_status_ok();
let body: serde_json::Value = response.json();
assert!(body["access_token"].is_string());
assert!(body["refresh_token"].is_string());
}
#[tokio::test]
async fn test_register_duplicate_user() {
let server = create_test_server().await;
let email = unique_email("dup");
let response = server
.post("/api/auth/register")
.json(&json!({
"email": email,
"password": "password123",
"name": "Test User"
}))
.await;
response.assert_status_ok();
let response = server
.post("/api/auth/register")
.json(&json!({
"email": email,
"password": "password456",
"name": "Another User"
}))
.await;
response.assert_status_bad_request();
}
#[tokio::test]
async fn test_login_invalid_credentials() {
let server = create_test_server().await;
let response = server
.post("/api/auth/login")
.json(&json!({
"email": unique_email("nonexistent"),
"password": "password123"
}))
.await;
response.assert_status_unauthorized();
}
#[tokio::test]
async fn test_login_wrong_password() {
let server = create_test_server().await;
let email = unique_email("wrongpass");
let response = server
.post("/api/auth/register")
.json(&json!({
"email": email,
"password": "correct_password",
"name": "Test User"
}))
.await;
response.assert_status_ok();
let response = server
.post("/api/auth/login")
.json(&json!({
"email": email,
"password": "wrong_password"
}))
.await;
response.assert_status_unauthorized();
}
#[tokio::test]
async fn test_register_short_password() {
let server = create_test_server().await;
let response = server
.post("/api/auth/register")
.json(&json!({
"email": unique_email("shortpass"),
"password": "short",
"name": "Test User"
}))
.await;
response.assert_status_bad_request();
}
#[tokio::test]
async fn test_register_invalid_email() {
let server = create_test_server().await;
let response = server
.post("/api/auth/register")
.json(&json!({
"email": unique_email("invalidemail"),
"password": "password123",
"name": "Test User"
}))
.await;
response.assert_status_ok();
}
#[tokio::test]
async fn test_register_empty_name() {
let server = create_test_server().await;
let response = server
.post("/api/auth/register")
.json(&json!({
"email": unique_email("emptyname"),
"password": "password123",
"name": ""
}))
.await;
response.assert_status_ok();
}
#[tokio::test]
async fn test_refresh_token() {
let server = create_test_server().await;
let response = server
.post("/api/auth/register")
.json(&json!({
"email": unique_email("refresh"),
"password": "password123",
"name": "Test User"
}))
.await;
response.assert_status_ok();
let body: serde_json::Value = response.json();
let refresh_token = body["refresh_token"].as_str().unwrap();
let response = server
.post("/api/auth/refresh")
.json(&json!({
"refresh_token": refresh_token
}))
.await;
response.assert_status_ok();
let body: serde_json::Value = response.json();
assert!(body["access_token"].is_string());
assert!(body["refresh_token"].is_string());
}
#[tokio::test]
async fn test_refresh_token_invalid() {
let server = create_test_server().await;
let response = server
.post("/api/auth/refresh")
.json(&json!({
"refresh_token": "invalid_token_here"
}))
.await;
response.assert_status_unauthorized();
}
#[tokio::test]
async fn test_multiple_logins() {
let server = create_test_server().await;
let email = unique_email("multilogin");
let response = server
.post("/api/auth/register")
.json(&json!({
"email": email,
"password": "password123",
"name": "Test User"
}))
.await;
response.assert_status_ok();
for i in 0..3 {
let response = server
.post("/api/auth/login")
.json(&json!({
"email": email,
"password": "password123"
}))
.await;
response.assert_status_ok();
let body: serde_json::Value = response.json();
assert!(body["access_token"].is_string(), "Login {} failed", i + 1);
}
}
#[tokio::test]
async fn test_agents_list() {
let server = create_test_server().await;
let response = server.get("/api/agents").await;
response.assert_status_ok();
let body: Vec<serde_json::Value> = response.json();
assert!(!body.is_empty());
let agent_names: Vec<&str> = body.iter().filter_map(|a| a["name"].as_str()).collect();
assert!(agent_names.contains(&"Product Agent"));
assert!(agent_names.contains(&"Invoice Agent"));
assert!(agent_names.contains(&"Sales Agent"));
assert!(agent_names.contains(&"Finance Agent"));
assert!(agent_names.contains(&"HR Agent"));
}
#[tokio::test]
async fn test_agents_list_structure() {
let server = create_test_server().await;
let response = server.get("/api/agents").await;
response.assert_status_ok();
let body: Vec<serde_json::Value> = response.json();
for agent in body {
assert!(agent["name"].is_string());
assert!(agent["description"].is_string());
}
}
#[tokio::test]
#[ignore = "requires running Ollama server"]
async fn test_chat_endpoint_with_live_ollama() {
let server = create_test_server().await;
let register = server
.post("/api/auth/register")
.json(&json!({
"email": unique_email("chatuser"),
"password": "password123",
"name": "Chat User"
}))
.await;
register.assert_status_ok();
let body: serde_json::Value = register.json();
let token = body["access_token"].as_str().unwrap();
let response = server
.post("/api/chat")
.add_header("Authorization", format!("Bearer {}", token))
.json(&json!({
"message": "Hello agent!",
"agent_type": "product"
}))
.await;
response.assert_status_ok();
let body: serde_json::Value = response.json();
let agent = body["agent"].as_str().unwrap();
assert!(
agent.contains("Product"),
"Agent should contain 'Product', got: {}",
agent
);
assert!(body["response"].is_string(), "Response should be a string");
assert!(
!body["response"].as_str().unwrap().is_empty(),
"Response should not be empty"
);
assert!(
body["context_id"].is_string(),
"context_id should be a string"
);
}
#[tokio::test]
async fn test_mock_llm_generate() {
let client = MockLLMClient::new("Hello, world!");
let result = client.generate("test prompt").await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "Hello, world!");
}
#[tokio::test]
async fn test_mock_llm_with_system() {
let client = MockLLMClient::new("System response");
let result = client
.generate_with_system("You are helpful", "Hello")
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "System response");
}
#[tokio::test]
async fn test_mock_llm_with_history() {
let client = MockLLMClient::new("History response");
let messages = vec![
("user".to_string(), "Hello".to_string()),
("assistant".to_string(), "Hi!".to_string()),
];
let result = client.generate_with_history(&messages).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().content, "History response");
}
#[tokio::test]
async fn test_mock_llm_with_tools_no_calls() {
let client = MockLLMClient::new("Tool response");
let tools = vec![ToolDefinition {
name: "calculator".to_string(),
description: "Math operations".to_string(),
parameters: json!({}),
}];
let result = client.generate_with_tools("Calculate 2+2", &tools).await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.content, "Tool response");
assert_eq!(response.finish_reason, "stop");
assert!(response.tool_calls.is_empty());
}
#[tokio::test]
async fn test_mock_llm_with_tools_with_calls() {
let tool_calls = vec![ToolCall {
id: "call-1".to_string(),
name: "calculator".to_string(),
arguments: json!({"operation": "add", "a": 2, "b": 2}),
}];
let client = MockLLMClient::with_tool_calls("I'll calculate that", tool_calls);
let tools = vec![ToolDefinition {
name: "calculator".to_string(),
description: "Math operations".to_string(),
parameters: json!({}),
}];
let result = client.generate_with_tools("Calculate 2+2", &tools).await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.finish_reason, "tool_calls");
assert_eq!(response.tool_calls.len(), 1);
assert_eq!(response.tool_calls[0].name, "calculator");
}
#[tokio::test]
async fn test_mock_llm_streaming() {
let client = MockLLMClient::new("Hello streaming world!");
let result = client.stream("test").await;
assert!(result.is_ok());
let mut stream = result.unwrap();
let mut collected = String::new();
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => collected.push_str(&chunk),
Err(_) => break,
}
}
assert_eq!(collected, "Hello streaming world!");
}
#[tokio::test]
async fn test_mock_llm_failure() {
let client = MockLLMClient::failing();
let result = client.generate("test").await;
assert!(result.is_err());
let result = client.generate_with_system("sys", "test").await;
assert!(result.is_err());
let result = client.generate_with_history(&[]).await;
assert!(result.is_err());
let result = client.generate_with_tools("test", &[]).await;
assert!(result.is_err());
let result = client.stream("test").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_mock_llm_model_name() {
let client = MockLLMClient::new("test");
assert_eq!(client.model_name(), "mock-model");
}
#[tokio::test]
async fn test_multiple_tool_calls() {
let tool_calls = vec![
ToolCall {
id: "call-1".to_string(),
name: "get_weather".to_string(),
arguments: json!({"city": "London"}),
},
ToolCall {
id: "call-2".to_string(),
name: "get_time".to_string(),
arguments: json!({"timezone": "UTC"}),
},
ToolCall {
id: "call-3".to_string(),
name: "search".to_string(),
arguments: json!({"query": "news"}),
},
];
let client = MockLLMClient::with_tool_calls("Processing multiple tools", tool_calls);
let tools: Vec<ToolDefinition> = vec![];
let result = client
.generate_with_tools("What's the weather, time, and news?", &tools)
.await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.tool_calls.len(), 3);
assert_eq!(response.tool_calls[0].name, "get_weather");
assert_eq!(response.tool_calls[1].name, "get_time");
assert_eq!(response.tool_calls[2].name, "search");
}
#[tokio::test]
async fn test_tool_definition_structure() {
let tool = ToolDefinition {
name: "complex_tool".to_string(),
description: "A complex tool with nested parameters".to_string(),
parameters: json!({
"type": "object",
"properties": {
"name": {"type": "string", "description": "The name"},
"count": {"type": "integer", "minimum": 0},
"options": {
"type": "object",
"properties": {
"verbose": {"type": "boolean"},
"format": {"type": "string", "enum": ["json", "text"]}
}
}
},
"required": ["name"]
}),
};
assert_eq!(tool.name, "complex_tool");
assert!(tool.parameters["properties"]["options"].is_object());
}
#[tokio::test]
async fn test_tool_call_complex_arguments() {
let tool_call = ToolCall {
id: "call-complex".to_string(),
name: "complex_tool".to_string(),
arguments: json!({
"string_arg": "hello",
"number_arg": 42,
"float_arg": 2.75,
"bool_arg": true,
"null_arg": null,
"array_arg": [1, 2, 3],
"object_arg": {"nested": "value", "deep": {"deeper": true}}
}),
};
assert_eq!(tool_call.arguments["string_arg"], "hello");
assert_eq!(tool_call.arguments["number_arg"], 42);
assert!((tool_call.arguments["float_arg"].as_f64().unwrap() - 2.75).abs() < 0.001);
assert!(tool_call.arguments["bool_arg"].as_bool().unwrap());
assert!(tool_call.arguments["null_arg"].is_null());
assert_eq!(
tool_call.arguments["array_arg"].as_array().unwrap().len(),
3
);
assert_eq!(tool_call.arguments["object_arg"]["deep"]["deeper"], true);
}
#[tokio::test]
async fn test_empty_prompt() {
let client = MockLLMClient::new("Response to empty");
let result = client.generate("").await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_very_long_prompt() {
let client = MockLLMClient::new("Response to long prompt");
let long_prompt = "test ".repeat(10000);
let result = client.generate(&long_prompt).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_unicode_content() {
let client = MockLLMClient::new("Response with unicode: 你好世界 🌍 مرحبا");
let result = client
.generate("Hello in multiple languages: 你好 مرحبا")
.await;
assert!(result.is_ok());
let response = result.unwrap();
assert!(response.contains("你好世界"));
assert!(response.contains("🌍"));
}
#[tokio::test]
async fn test_special_characters() {
let client = MockLLMClient::new("Response with special chars: <>&\"'\\");
let prompt = r#"Test with "quotes", 'apostrophes', \backslash, <angle>, &ersand"#;
let result = client.generate(prompt).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_newlines_in_content() {
let client = MockLLMClient::new("Line 1\nLine 2\nLine 3");
let result = client.generate("Give me multiple lines").await;
assert!(result.is_ok());
let response = result.unwrap();
assert!(response.contains('\n'));
}
#[tokio::test]
async fn test_empty_history() {
let client = MockLLMClient::new("Response to empty history");
let history: Vec<(String, String)> = vec![];
let result = client.generate_with_history(&history).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_large_history() {
let client = MockLLMClient::new("Response after long history");
let history: Vec<(String, String)> = (0..100)
.map(|i| {
if i % 2 == 0 {
("user".to_string(), format!("Message {}", i))
} else {
("assistant".to_string(), format!("Response {}", i))
}
})
.collect();
let result = client.generate_with_history(&history).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_auth_response_structure() {
let server = create_test_server().await;
let response = server
.post("/api/auth/register")
.json(&json!({
"email": unique_email("structure"),
"password": "password123",
"name": "Test User"
}))
.await;
response.assert_status_ok();
let body: serde_json::Value = response.json();
assert!(body["access_token"].is_string());
assert!(body["refresh_token"].is_string());
assert!(body["expires_in"].is_number());
assert!(!body["access_token"].as_str().unwrap().is_empty());
assert!(!body["refresh_token"].as_str().unwrap().is_empty());
assert!(body["expires_in"].as_i64().unwrap() > 0);
}
#[tokio::test]
async fn test_missing_required_fields() {
let server = create_test_server().await;
let response = server
.post("/api/auth/register")
.json(&json!({
"email": unique_email("missing"),
"name": "Test User"
}))
.await;
response.assert_status_unprocessable_entity();
let response = server
.post("/api/auth/register")
.json(&json!({
"password": "password123",
"name": "Test User"
}))
.await;
response.assert_status_unprocessable_entity();
}
#[tokio::test]
async fn test_extra_fields_ignored() {
let server = create_test_server().await;
let response = server
.post("/api/auth/register")
.json(&json!({
"email": unique_email("extrafields"),
"password": "password123",
"name": "Test User",
"extra_field": "should be ignored",
"another_extra": 12345
}))
.await;
response.assert_status_ok();
}