use std::path::PathBuf;
use std::sync::Arc;
use futures::StreamExt;
use rig::completion::{GetTokenUsage, Message};
use rig::prelude::*;
use rig::streaming::StreamingChat;
use tokio::sync::mpsc::UnboundedSender;
use crate::action::Action;
use crate::ai::provider::AiProvider;
use crate::approval::ApprovalHook;
use crate::session::db::Database;
use crate::session::memory_tool::SaveMemoryTool;
use crate::tools::{
EditTool, GlobTool, GrepTool, LsTool, ReadTool, ShellTool, WebFetchTool, WebSearchTool,
WriteTool,
};
pub struct StreamChatParams {
pub history: Vec<Message>,
pub prompt: String,
pub system_prompt: String,
pub tx: UnboundedSender<Action>,
pub working_dir: PathBuf,
pub brave_api_key: Option<String>,
pub max_turns: usize,
pub approval_hook: ApprovalHook,
pub db: Option<Arc<Database>>,
pub project_path: String,
}
pub fn spawn_streaming_chat(
provider: &AiProvider,
params: StreamChatParams,
) -> tokio::task::JoinHandle<()> {
let StreamChatParams {
history,
prompt,
system_prompt,
tx,
working_dir,
brave_api_key,
max_turns,
approval_hook,
db,
project_path,
} = params;
let memory_db = db.unwrap_or_else(|| {
Arc::new(Database::open_in_memory().expect("in-memory DB for memory tool"))
});
let save_memory = SaveMemoryTool::new(memory_db, project_path);
match provider {
AiProvider::Bedrock { client, model } => {
let agent = client
.agent(model)
.preamble(&system_prompt)
.max_tokens(4096)
.tool(ShellTool::new(working_dir.clone()))
.tool(ReadTool)
.tool(WriteTool)
.tool(EditTool)
.tool(GrepTool::new(working_dir.clone()))
.tool(GlobTool::new(working_dir))
.tool(LsTool)
.tool(WebFetchTool::new())
.tool(WebSearchTool::new(brave_api_key))
.tool(save_memory)
.build();
spawn_stream_task(agent, history, prompt, tx, max_turns, approval_hook)
}
AiProvider::OpenRouter { client, model } => {
let agent = client
.agent(model)
.preamble(&system_prompt)
.max_tokens(4096)
.tool(ShellTool::new(working_dir.clone()))
.tool(ReadTool)
.tool(WriteTool)
.tool(EditTool)
.tool(GrepTool::new(working_dir.clone()))
.tool(GlobTool::new(working_dir))
.tool(LsTool)
.tool(WebFetchTool::new())
.tool(WebSearchTool::new(brave_api_key))
.tool(save_memory)
.build();
spawn_stream_task(agent, history, prompt, tx, max_turns, approval_hook)
}
}
}
#[allow(clippy::too_many_lines)]
fn spawn_stream_task<M>(
agent: rig::agent::Agent<M>,
history: Vec<Message>,
prompt: String,
tx: UnboundedSender<Action>,
max_turns: usize,
hook: ApprovalHook,
) -> tokio::task::JoinHandle<()>
where
M: rig::completion::CompletionModel + 'static,
M::StreamingResponse: Send + GetTokenUsage,
{
tokio::spawn(async move {
use rig::agent::MultiTurnStreamItem;
use rig::streaming::{StreamedAssistantContent, StreamedUserContent};
let mut stream = agent
.stream_chat(&prompt, history)
.multi_turn(max_turns)
.with_hook(hook)
.await;
let mut input_tokens = 0u64;
let mut output_tokens = 0u64;
let mut last_turn_input = 0u64;
let mut tool_call_names: std::collections::HashMap<String, (String, std::time::Instant)> =
std::collections::HashMap::new();
while let Some(item) = stream.next().await {
match item {
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
text,
))) => {
let _ = tx.send(Action::StreamChunk(text.text));
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::ToolCall {
tool_call,
internal_call_id,
},
)) => {
let name = tool_call.function.name.clone();
let args_json = tool_call.function.arguments.to_string();
tool_call_names
.insert(internal_call_id, (name.clone(), std::time::Instant::now()));
let _ = tx.send(Action::ToolCallStart { name, args_json });
}
Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
tool_result,
internal_call_id,
})) => {
let (name, start_time) = tool_call_names
.remove(&internal_call_id)
.unwrap_or_else(|| ("unknown".to_string(), std::time::Instant::now()));
let duration_ms =
u64::try_from(start_time.elapsed().as_millis()).unwrap_or(u64::MAX);
let result_text: String = tool_result
.content
.iter()
.filter_map(|c| {
if let rig::message::ToolResultContent::Text(t) = c {
Some(
serde_json::from_str::<String>(&t.text)
.unwrap_or_else(|_| t.text.clone()),
)
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n");
let _ = tx.send(Action::ToolResult {
name,
result: result_text,
duration_ms,
});
}
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Final(
ref res,
))) => {
if let Some(usage) = res.token_usage() {
input_tokens += usage.input_tokens;
output_tokens += usage.output_tokens;
last_turn_input = usage.input_tokens;
let _ = tx.send(Action::TokenUpdate {
output_tokens,
context_tokens: last_turn_input,
});
}
}
Ok(MultiTurnStreamItem::FinalResponse(res)) => {
let usage = res.usage();
input_tokens = usage.input_tokens;
output_tokens = usage.output_tokens;
}
Err(e) => {
let err_str = e.to_string();
if err_str.contains("MaxTurnError") || err_str.contains("max turn limit") {
let _ = tx.send(Action::StreamComplete {
input_tokens,
output_tokens,
context_tokens: last_turn_input,
});
let _ = tx.send(Action::ShowSystemMessage(format!(
"Agentic loop reached the {max_turns}-turn limit. \
You can continue by sending another message."
)));
} else {
let error_msg = format_error(&err_str);
let _ = tx.send(Action::StreamError(error_msg));
}
return;
}
_ => {
}
}
}
let _ = tx.send(Action::StreamComplete {
input_tokens,
output_tokens,
context_tokens: last_turn_input,
});
})
}
fn format_error(error: &str) -> String {
let lower = error.to_lowercase();
if lower.contains("401")
|| lower.contains("403")
|| lower.contains("unauthorized")
|| lower.contains("authentication")
|| lower.contains("invalid.*key")
{
format!(
"Authentication failed: check your API key in ~/.seval/config.toml. Original error: {error}"
)
} else {
error.to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn format_error_detects_auth_failures() {
let msg = format_error("HTTP 401 Unauthorized");
assert!(msg.contains("Authentication failed"));
}
#[test]
fn format_error_detects_403() {
let msg = format_error("HTTP 403 Forbidden");
assert!(msg.contains("Authentication failed"));
}
#[test]
fn format_error_passes_through_normal_errors() {
let msg = format_error("Connection refused");
assert_eq!(msg, "Connection refused");
}
}