use crate::ast::AgentParams;
use crate::core::mcp_config::{load_merged_config, McpServer};
use crate::error::NikaError;
use crate::event::EventLog;
use crate::mcp::types::McpConfig as McpClientConfig;
use crate::mcp::McpClient;
use crate::provider::rig::{RigProvider, StreamChunk};
use crate::runtime::RigAgentLoop;
use crate::tui::command::ModelProvider;
use rustc_hash::FxHashMap;
use serde_json::Value;
use std::sync::Arc;
use tokio::sync::mpsc;
#[derive(Debug, Default, Clone)]
pub struct StreamingState {
pub is_streaming: bool,
pub partial_response: String,
pub tokens_received: usize,
}
impl StreamingState {
pub fn new() -> Self {
Self::default()
}
pub fn start(&mut self) {
self.is_streaming = true;
self.partial_response.clear();
self.tokens_received = 0;
}
pub fn append(&mut self, chunk: &str) {
self.partial_response.push_str(chunk);
self.tokens_received += 1; }
pub fn finish(&mut self) -> String {
self.is_streaming = false;
std::mem::take(&mut self.partial_response)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ChatRole {
User,
Assistant,
System,
Tool,
}
impl ChatRole {
pub fn display_name(&self) -> &'static str {
match self {
ChatRole::User => "You",
ChatRole::Assistant => "Nika",
ChatRole::System => "System",
ChatRole::Tool => "Tool",
}
}
}
#[derive(Debug, Clone)]
pub struct ChatMessage {
pub role: ChatRole,
pub content: String,
pub timestamp: std::time::Instant,
}
impl ChatMessage {
pub fn user(content: impl Into<String>) -> Self {
Self {
role: ChatRole::User,
content: content.into(),
timestamp: std::time::Instant::now(),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: ChatRole::Assistant,
content: content.into(),
timestamp: std::time::Instant::now(),
}
}
pub fn system(content: impl Into<String>) -> Self {
Self {
role: ChatRole::System,
content: content.into(),
timestamp: std::time::Instant::now(),
}
}
pub fn tool(content: impl Into<String>) -> Self {
Self {
role: ChatRole::Tool,
content: content.into(),
timestamp: std::time::Instant::now(),
}
}
}
pub struct ChatAgent {
provider: RigProvider,
model_override: Option<String>,
history: Vec<ChatMessage>,
streaming_tx: Option<mpsc::Sender<String>>,
stream_chunk_tx: Option<mpsc::Sender<StreamChunk>>,
streaming_state: StreamingState,
http_client: reqwest::Client,
pub total_input_tokens: u64,
pub total_output_tokens: u64,
}
impl ChatAgent {
pub fn new() -> Result<Self, NikaError> {
let provider = RigProvider::auto().ok_or_else(|| NikaError::MissingApiKey {
provider: "any (ANTHROPIC_API_KEY, OPENAI_API_KEY, MISTRAL_API_KEY, GROQ_API_KEY, DEEPSEEK_API_KEY, or GEMINI_API_KEY)".to_string(),
})?;
Ok(Self {
provider,
model_override: None,
history: Vec::new(),
streaming_tx: None,
stream_chunk_tx: None,
streaming_state: StreamingState::new(),
http_client: reqwest::Client::new(),
total_input_tokens: 0,
total_output_tokens: 0,
})
}
pub fn with_overrides(provider: Option<&str>, model: Option<&str>) -> Result<Self, NikaError> {
let mut agent = Self::new()?;
if let Some(p) = provider {
agent.provider = RigProvider::from_name(p)?;
}
if let Some(m) = model {
agent.model_override = Some(m.to_string());
}
Ok(agent)
}
pub fn with_streaming(mut self, tx: mpsc::Sender<String>) -> Self {
self.streaming_tx = Some(tx);
self
}
pub fn with_stream_chunks(mut self, tx: mpsc::Sender<StreamChunk>) -> Self {
self.stream_chunk_tx = Some(tx);
self
}
pub fn set_stream_chunk_tx(&mut self, tx: mpsc::Sender<StreamChunk>) {
self.stream_chunk_tx = Some(tx);
}
pub fn set_provider(&mut self, provider: ModelProvider) -> Result<(), NikaError> {
if matches!(provider, ModelProvider::List) {
return Ok(());
}
self.provider = RigProvider::from_name(provider.command_name())?;
Ok(())
}
pub fn provider_name(&self) -> &'static str {
self.provider.name()
}
pub fn model_name(&self) -> &str {
self.model_override
.as_deref()
.unwrap_or_else(|| self.provider.default_model())
}
pub fn total_tokens(&self) -> u64 {
self.total_input_tokens + self.total_output_tokens
}
pub async fn infer(&mut self, prompt: &str) -> Result<String, NikaError> {
self.history.push(ChatMessage::user(prompt));
self.streaming_state.start();
if let Some(tx) = &self.streaming_tx {
let _ = tx
.send(format!("Sending to {}...", self.provider.name()))
.await;
}
let response = if let Some(tx) = self.stream_chunk_tx.clone() {
let metrics_tx = tx.clone();
let result = self
.provider
.infer_stream(prompt, tx, self.model_override.as_deref())
.await
.map_err(|e| NikaError::ProviderApiError {
message: e.to_string(),
})?;
self.total_input_tokens += result.input_tokens;
self.total_output_tokens += result.output_tokens;
let _ = metrics_tx
.send(StreamChunk::Metrics {
input_tokens: result.input_tokens,
output_tokens: result.output_tokens,
})
.await;
result.text
} else {
self.provider
.infer(prompt, None)
.await
.map_err(|e| NikaError::ProviderApiError {
message: e.to_string(),
})?
};
self.streaming_state.finish();
self.history.push(ChatMessage::assistant(&response));
if let Some(tx) = &self.streaming_tx {
let _ = tx.send(response.clone()).await;
}
Ok(response)
}
pub async fn exec_command(&self, command: &str) -> Result<String, NikaError> {
use tokio::process::Command as TokioCommand;
let output = TokioCommand::new("sh")
.arg("-c")
.arg(command)
.output()
.await
.map_err(|e| NikaError::Execution(format!("Failed to execute command: {}", e)))?;
let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr);
if output.status.success() {
Ok(stdout.trim().to_string())
} else {
let exit_code = output.status.code().unwrap_or(-1);
Ok(format!(
"Exit code: {}\n{}\n{}",
exit_code,
stdout.trim(),
stderr.trim()
))
}
}
pub async fn fetch(&self, url: &str, method: &str) -> Result<String, NikaError> {
let request = match method.to_uppercase().as_str() {
"POST" => self.http_client.post(url),
"PUT" => self.http_client.put(url),
"DELETE" => self.http_client.delete(url),
"PATCH" => self.http_client.patch(url),
"HEAD" => self.http_client.head(url),
_ => self.http_client.get(url), };
let response = request
.send()
.await
.map_err(|e| NikaError::Execution(format!("HTTP request failed: {}", e)))?;
let status = response.status();
let text = response
.text()
.await
.map_err(|e| NikaError::Execution(format!("Failed to read response: {}", e)))?;
if !status.is_success() {
Ok(format!(
"HTTP {} {}\n{}",
status.as_u16(),
status.as_str(),
text
))
} else {
Ok(text)
}
}
pub fn history(&self) -> &[ChatMessage] {
&self.history
}
pub fn clear_history(&mut self) {
self.history.clear();
}
pub fn with_history(messages: Vec<ChatMessage>) -> Result<Self, NikaError> {
let mut agent = Self::new()?;
agent.history = messages;
Ok(agent)
}
pub fn take_history(&mut self) -> Vec<ChatMessage> {
std::mem::take(&mut self.history)
}
pub fn set_history(&mut self, messages: Vec<ChatMessage>) {
self.history = messages;
}
pub fn streaming_state(&self) -> &StreamingState {
&self.streaming_state
}
pub fn is_streaming(&self) -> bool {
self.streaming_state.is_streaming
}
pub async fn invoke(
&self,
tool_name: &str,
server_name: Option<&str>,
params: Value,
) -> Result<String, NikaError> {
let config = load_merged_config().map_err(|e| NikaError::InvalidConfig {
message: format!("Failed to load MCP config: {}", e),
})?;
let (resolved_server_name, server): (String, &McpServer) = if let Some(name) = server_name {
let server = config
.servers
.get(name)
.ok_or_else(|| NikaError::InvalidConfig {
message: format!(
"MCP server '{}' not found. Available: {:?}",
name,
config.servers.keys().collect::<Vec<_>>()
),
})?;
(name.to_string(), server)
} else {
let (name, server) =
config
.servers
.iter()
.find(|(_, s)| s.enabled)
.ok_or_else(|| NikaError::InvalidConfig {
message: "No MCP servers configured. Use 'nika mcp add' to add one."
.to_string(),
})?;
(name.clone(), server)
};
if !server.enabled {
return Err(NikaError::InvalidConfig {
message: format!("MCP server '{}' is disabled", resolved_server_name),
});
}
let client_config = McpClientConfig {
name: resolved_server_name.clone(),
command: server.command.clone(),
args: server.args.clone(),
env: server
.env
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
cwd: None,
};
let client = McpClient::new(client_config)?;
client.connect().await?;
let result = client.call_tool(tool_name, params).await?;
let format_block = |c: crate::mcp::types::ContentBlock| -> String {
use crate::mcp::types::ContentBlock;
match c {
ContentBlock::Text { text } => text,
ContentBlock::Image { data, mime_type } => {
format!("[Image: {} bytes, {}]", data.len(), mime_type)
}
ContentBlock::Audio { data, mime_type } => {
format!("[Audio: {} bytes, {}]", data.len(), mime_type)
}
ContentBlock::Resource(res) => res
.text
.unwrap_or_else(|| format!("[Resource: {}]", res.uri)),
ContentBlock::ResourceLink { uri, name, .. } => {
if let Some(n) = name {
format!("[ResourceLink: {} ({})]", uri, n)
} else {
format!("[ResourceLink: {}]", uri)
}
}
}
};
if result.is_error {
let error_text = result
.content
.into_iter()
.map(format_block)
.collect::<Vec<_>>()
.join("\n");
Err(NikaError::McpToolError {
tool: tool_name.to_string(),
reason: format!(
"MCP server '{}' returned error: {}",
resolved_server_name, error_text
),
error_code: None,
})
} else {
let text = result
.content
.into_iter()
.map(format_block)
.collect::<Vec<_>>()
.join("\n");
Ok(text)
}
}
pub async fn run_agent(
&self,
goal: String,
max_turns: Option<u32>,
extended_thinking: bool,
servers: Vec<String>,
) -> Result<String, NikaError> {
let config = load_merged_config().map_err(|e| NikaError::InvalidConfig {
message: format!("Failed to load MCP config: {}", e),
})?;
let mut mcp_clients: FxHashMap<String, Arc<McpClient>> = FxHashMap::default();
let servers_to_use = if servers.is_empty() {
config
.servers
.iter()
.filter(|(_, s)| s.enabled)
.map(|(name, _)| name.clone())
.collect::<Vec<_>>()
} else {
servers
};
if servers_to_use.is_empty() {
return Err(NikaError::InvalidConfig {
message: "No MCP servers configured. Use 'nika mcp add' to add one.".to_string(),
});
}
for server_name in &servers_to_use {
let server =
config
.servers
.get(server_name)
.ok_or_else(|| NikaError::InvalidConfig {
message: format!(
"MCP server '{}' not found. Available: {:?}",
server_name,
config.servers.keys().collect::<Vec<_>>()
),
})?;
if !server.enabled {
tracing::warn!("Skipping disabled MCP server: {}", server_name);
continue;
}
let client_config = McpClientConfig {
name: server_name.clone(),
command: server.command.clone(),
args: server.args.clone(),
env: server
.env
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
cwd: None,
};
let client = McpClient::new(client_config)?;
client.connect().await?;
mcp_clients.insert(server_name.clone(), Arc::new(client));
}
let params = AgentParams {
prompt: goal,
system: None,
provider: None, model: None, mcp: servers_to_use.clone(),
tools: vec![], max_turns: Some(max_turns.unwrap_or(10)),
token_budget: None,
stop_sequences: vec![],
scope: None,
extended_thinking: if extended_thinking { Some(true) } else { None },
thinking_budget: None,
depth_limit: Some(3), ..Default::default()
};
let event_log = EventLog::new();
let task_id = format!("chat-agent-{}", uuid::Uuid::new_v4());
let mut agent_loop = RigAgentLoop::new(task_id, params, event_log, mcp_clients)?;
let result = agent_loop.run_auto().await?;
let final_response = if let Some(response) = result.final_output.get("response") {
response.as_str().unwrap_or_default().to_string()
} else if let Some(output) = result.final_output.get("output") {
output.as_str().unwrap_or_default().to_string()
} else {
serde_json::to_string_pretty(&result.final_output).unwrap_or_else(|_| {
format!(
"[Agent completed in {} turns, {} tokens used]",
result.turns, result.total_tokens
)
})
};
Ok(final_response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
#[test]
fn test_streaming_state_default() {
let state = StreamingState::default();
assert!(!state.is_streaming);
assert!(state.partial_response.is_empty());
assert_eq!(state.tokens_received, 0);
}
#[test]
fn test_streaming_state_start() {
let mut state = StreamingState::new();
state.partial_response = "leftover".to_string();
state.tokens_received = 10;
state.start();
assert!(state.is_streaming);
assert!(state.partial_response.is_empty());
assert_eq!(state.tokens_received, 0);
}
#[test]
fn test_streaming_state_append() {
let mut state = StreamingState::new();
state.start();
state.append("Hello");
state.append(", ");
state.append("world!");
assert_eq!(state.partial_response, "Hello, world!");
assert_eq!(state.tokens_received, 3);
}
#[test]
fn test_streaming_state_finish() {
let mut state = StreamingState::new();
state.start();
state.append("Complete response");
let result = state.finish();
assert_eq!(result, "Complete response");
assert!(!state.is_streaming);
assert!(state.partial_response.is_empty());
}
#[test]
fn test_chat_role_display_names() {
assert_eq!(ChatRole::User.display_name(), "You");
assert_eq!(ChatRole::Assistant.display_name(), "Nika");
assert_eq!(ChatRole::System.display_name(), "System");
assert_eq!(ChatRole::Tool.display_name(), "Tool");
}
#[test]
fn test_chat_role_equality() {
assert_eq!(ChatRole::User, ChatRole::User);
assert_ne!(ChatRole::User, ChatRole::Assistant);
}
#[test]
fn test_chat_message_user() {
let msg = ChatMessage::user("Hello");
assert_eq!(msg.role, ChatRole::User);
assert_eq!(msg.content, "Hello");
}
#[test]
fn test_chat_message_assistant() {
let msg = ChatMessage::assistant("Hi there!");
assert_eq!(msg.role, ChatRole::Assistant);
assert_eq!(msg.content, "Hi there!");
}
#[test]
fn test_chat_message_system() {
let msg = ChatMessage::system("You are a helpful assistant.");
assert_eq!(msg.role, ChatRole::System);
assert_eq!(msg.content, "You are a helpful assistant.");
}
#[test]
fn test_chat_message_tool() {
let msg = ChatMessage::tool("{\"result\": \"success\"}");
assert_eq!(msg.role, ChatRole::Tool);
assert_eq!(msg.content, "{\"result\": \"success\"}");
}
#[tokio::test]
async fn test_chat_agent_creation() {
let agent = ChatAgent::new();
match agent {
Ok(a) => {
let valid_providers = ["claude", "openai", "mistral", "groq", "deepseek", "gemini"];
assert!(
valid_providers.contains(&a.provider_name()),
"Expected valid provider, got: {}",
a.provider_name()
);
}
Err(e) => {
assert!(
e.to_string().contains("API key"),
"Expected API key error, got: {}",
e
);
}
}
}
#[test]
#[serial]
fn test_chat_agent_initial_state() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let agent = ChatAgent::new().expect("Should create agent");
assert!(agent.history().is_empty());
assert!(!agent.is_streaming());
let valid_providers = ["claude", "openai", "mistral", "groq", "deepseek", "gemini"];
assert!(
valid_providers.contains(&agent.provider_name()),
"Expected valid provider, got: {}",
agent.provider_name()
);
}
#[test]
#[serial]
fn test_chat_agent_with_claude_fallback() {
std::env::set_var("ANTHROPIC_API_KEY", "test-key-for-unit-test");
let agent = ChatAgent::new().expect("Should create agent");
assert!(agent.provider_name() == "openai" || agent.provider_name() == "claude");
}
#[test]
#[serial]
fn test_set_provider_openai() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let mut agent = ChatAgent::new().expect("Should create agent");
let result = agent.set_provider(ModelProvider::OpenAI);
assert!(result.is_ok());
assert_eq!(agent.provider_name(), "openai");
}
#[test]
#[serial]
fn test_set_provider_claude() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
std::env::set_var("ANTHROPIC_API_KEY", "test-key-for-unit-test");
let mut agent = ChatAgent::new().expect("Should create agent");
if std::env::var("ANTHROPIC_API_KEY").is_ok() {
let result = agent.set_provider(ModelProvider::Claude);
assert!(result.is_ok());
assert_eq!(agent.provider_name(), "claude");
}
}
#[test]
#[serial]
fn test_set_provider_missing_key() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let mut agent = ChatAgent::new().expect("Should create agent");
if std::env::var("ANTHROPIC_API_KEY").is_err() {
let result = agent.set_provider(ModelProvider::Claude);
assert!(result.is_err());
if let Err(NikaError::MissingApiKey { provider }) = result {
assert_eq!(provider, "Claude");
} else {
panic!("Expected MissingApiKey error");
}
} else {
let result = agent.set_provider(ModelProvider::Claude);
assert!(result.is_ok());
}
}
#[test]
#[serial]
fn test_set_provider_list_does_not_change() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let mut agent = ChatAgent::new().expect("Should create agent");
let original = agent.provider_name();
let result = agent.set_provider(ModelProvider::List);
assert!(result.is_ok());
assert_eq!(agent.provider_name(), original);
}
#[test]
#[serial]
fn test_set_provider_mistral() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
std::env::set_var("MISTRAL_API_KEY", "test-key-for-unit-test");
let mut agent = ChatAgent::new().expect("Should create agent");
let result = agent.set_provider(ModelProvider::Mistral);
if std::env::var("MISTRAL_API_KEY").is_ok_and(|v| !v.is_empty()) {
assert!(result.is_ok());
assert_eq!(agent.provider_name(), "mistral");
}
}
#[test]
#[serial]
fn test_set_provider_groq() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
std::env::set_var("GROQ_API_KEY", "test-key-for-unit-test");
let mut agent = ChatAgent::new().expect("Should create agent");
let result = agent.set_provider(ModelProvider::Groq);
if std::env::var("GROQ_API_KEY").is_ok_and(|v| !v.is_empty()) {
assert!(result.is_ok());
assert_eq!(agent.provider_name(), "groq");
}
}
#[test]
#[serial]
fn test_set_provider_deepseek() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
std::env::set_var("DEEPSEEK_API_KEY", "test-key-for-unit-test");
let mut agent = ChatAgent::new().expect("Should create agent");
let result = agent.set_provider(ModelProvider::DeepSeek);
if std::env::var("DEEPSEEK_API_KEY").is_ok_and(|v| !v.is_empty()) {
assert!(result.is_ok());
assert_eq!(agent.provider_name(), "deepseek");
}
}
#[test]
#[serial]
fn test_with_overrides_mistral() {
std::env::set_var("MISTRAL_API_KEY", "test-key-for-unit-test");
let agent = ChatAgent::with_overrides(Some("mistral"), None);
if std::env::var("MISTRAL_API_KEY").is_ok_and(|v| !v.is_empty()) {
assert!(agent.is_ok());
assert_eq!(agent.unwrap().provider_name(), "mistral");
}
}
#[test]
#[serial]
fn test_with_overrides_invalid_provider() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let agent = ChatAgent::with_overrides(Some("invalid_provider"), None);
assert!(agent.is_err());
if let Err(NikaError::InvalidConfig { message }) = agent {
assert!(message.contains("Unknown provider"));
}
}
#[test]
#[serial]
fn test_history_starts_empty() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let agent = ChatAgent::new().expect("Should create agent");
assert!(agent.history().is_empty());
}
#[test]
#[serial]
fn test_clear_history() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let mut agent = ChatAgent::new().expect("Should create agent");
agent.history.push(ChatMessage::user("Hello"));
agent.history.push(ChatMessage::assistant("Hi!"));
assert_eq!(agent.history().len(), 2);
agent.clear_history();
assert!(agent.history().is_empty());
}
#[test]
#[serial]
fn test_with_history() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let history = vec![
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there!"),
ChatMessage::user("How are you?"),
];
let agent = ChatAgent::with_history(history).expect("Should create agent with history");
assert_eq!(agent.history().len(), 3);
assert_eq!(agent.history()[0].role, ChatRole::User);
assert_eq!(agent.history()[0].content, "Hello");
assert_eq!(agent.history()[1].role, ChatRole::Assistant);
assert_eq!(agent.history()[2].content, "How are you?");
}
#[test]
#[serial]
fn test_take_history() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let mut agent = ChatAgent::new().expect("Should create agent");
agent.history.push(ChatMessage::user("Hello"));
agent.history.push(ChatMessage::assistant("Hi!"));
let taken = agent.take_history();
assert_eq!(taken.len(), 2);
assert!(agent.history().is_empty()); assert_eq!(taken[0].content, "Hello");
assert_eq!(taken[1].content, "Hi!");
}
#[test]
#[serial]
fn test_set_history() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let mut agent = ChatAgent::new().expect("Should create agent");
agent.history.push(ChatMessage::user("Old message"));
let new_history = vec![
ChatMessage::user("New conversation"),
ChatMessage::assistant("Fresh start!"),
];
agent.set_history(new_history);
assert_eq!(agent.history().len(), 2);
assert_eq!(agent.history()[0].content, "New conversation");
assert_eq!(agent.history()[1].content, "Fresh start!");
}
#[tokio::test]
#[serial]
async fn test_exec_command_echo() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let agent = ChatAgent::new().expect("Should create agent");
let result = agent.exec_command("echo hello").await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "hello");
}
#[tokio::test]
#[serial]
async fn test_exec_command_with_args() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let agent = ChatAgent::new().expect("Should create agent");
let result = agent.exec_command("echo -n 'test output'").await;
assert!(result.is_ok());
assert!(result.unwrap().contains("test output"));
}
#[tokio::test]
#[serial]
async fn test_exec_command_failure() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let agent = ChatAgent::new().expect("Should create agent");
let result = agent.exec_command("exit 1").await;
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.contains("Exit code: 1"));
}
#[tokio::test]
#[serial]
async fn test_exec_command_pipe() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let agent = ChatAgent::new().expect("Should create agent");
let result = agent
.exec_command("echo 'hello world' | tr 'a-z' 'A-Z'")
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "HELLO WORLD");
}
#[test]
#[serial]
fn test_streaming_state_access() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let agent = ChatAgent::new().expect("Should create agent");
assert!(!agent.is_streaming());
assert!(!agent.streaming_state().is_streaming);
}
#[tokio::test]
#[serial]
async fn test_with_streaming_channel() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let (tx, _rx) = mpsc::channel::<String>(10);
let agent = ChatAgent::new()
.expect("Should create agent")
.with_streaming(tx);
assert!(agent.streaming_tx.is_some());
}
#[tokio::test]
#[serial]
async fn test_invoke_unknown_server() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let agent = ChatAgent::new().expect("Should create agent");
let result = agent
.invoke(
"some_tool",
Some("nonexistent_server"),
serde_json::json!({}),
)
.await;
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("not found") || err_msg.contains("No MCP servers"),
"Expected 'not found' or 'No MCP servers' in error, got: {}",
err_msg
);
}
#[tokio::test]
#[serial]
async fn test_invoke_no_servers_configured() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let agent = ChatAgent::new().expect("Should create agent");
let result = agent
.invoke(
"test_tool",
Some("definitely_not_configured"),
serde_json::json!({}),
)
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, NikaError::InvalidConfig { .. }),
"Expected InvalidConfig error, got: {:?}",
err
);
}
#[test]
fn test_invoke_params_serialization() {
let params = serde_json::json!({
"entity": "qr-code",
"locale": "fr-FR",
"count": 5,
"nested": {
"key": "value"
},
"array": [1, 2, 3]
});
assert_eq!(params["entity"], "qr-code");
assert_eq!(params["locale"], "fr-FR");
assert_eq!(params["count"], 5);
assert_eq!(params["nested"]["key"], "value");
assert_eq!(params["array"][0], 1);
}
#[tokio::test]
#[serial]
async fn test_run_agent_no_servers_configured() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let agent = ChatAgent::new().expect("Should create agent");
let result = agent
.run_agent(
"Test goal".to_string(),
Some(3),
false,
vec!["nonexistent_server".to_string()],
)
.await;
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("not found") || err_msg.contains("No MCP servers"),
"Expected 'not found' or 'No MCP servers' in error, got: {}",
err_msg
);
}
#[tokio::test]
#[serial]
async fn test_run_agent_empty_goal_validation() {
std::env::set_var("OPENAI_API_KEY", "test-key-for-unit-test");
let agent = ChatAgent::new().expect("Should create agent");
let result = agent
.run_agent(
"".to_string(), Some(5),
false,
vec!["fake_server".to_string()],
)
.await;
assert!(result.is_err());
}
#[test]
fn test_agent_params_construction() {
use crate::ast::AgentParams;
let params = AgentParams {
prompt: "Test goal".to_string(),
system: None,
provider: None,
model: None,
mcp: vec!["novanet".to_string()],
tools: vec![],
max_turns: Some(10),
token_budget: None,
stop_sequences: vec![],
scope: None,
extended_thinking: Some(true),
thinking_budget: None,
depth_limit: Some(3),
..Default::default()
};
assert_eq!(params.prompt, "Test goal");
assert_eq!(params.max_turns, Some(10));
assert_eq!(params.extended_thinking, Some(true));
assert_eq!(params.depth_limit, Some(3));
assert_eq!(params.mcp, vec!["novanet"]);
}
#[test]
fn test_run_agent_default_max_turns() {
let actual: u32 = 10;
assert_eq!(actual, 10);
let actual: u32 = 5;
assert_eq!(actual, 5);
}
}