use std::sync::Arc;
use anyhow::Result;
use futures::StreamExt;
use brainwires_call_policy::BudgetGuard;
use brainwires_core::{
ChatOptions, ContentBlock, Message, MessageContent, Provider, Role, StreamChunk, Tool,
ToolContext, ToolUse, Usage,
};
use brainwires_tool_runtime::{PreHookDecision, ToolExecutor, ToolPreHook};
use crate::summarization::Summarizer;
pub struct ChatAgent {
provider: Arc<dyn Provider>,
executor: Arc<dyn ToolExecutor>,
messages: Vec<Message>,
options: ChatOptions,
max_tool_rounds: usize,
pre_execute_hook: Option<Arc<dyn ToolPreHook>>,
cumulative_usage: Usage,
budget: Option<BudgetGuard>,
tool_concurrency: usize,
summarizer: Option<Arc<dyn Summarizer>>,
summarization_keep_tail: usize,
}
impl ChatAgent {
pub fn new(
provider: Arc<dyn Provider>,
executor: Arc<dyn ToolExecutor>,
options: ChatOptions,
) -> Self {
Self {
provider,
executor,
messages: Vec::new(),
options,
max_tool_rounds: 10,
pre_execute_hook: None,
cumulative_usage: Usage::default(),
budget: None,
tool_concurrency: 4,
summarizer: None,
summarization_keep_tail: 6,
}
}
pub fn with_summarizer(mut self, summarizer: Arc<dyn Summarizer>) -> Self {
self.summarizer = Some(summarizer);
self
}
pub fn with_summarization_keep_tail(mut self, keep: usize) -> Self {
self.summarization_keep_tail = keep.max(1);
self
}
pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
self.tool_concurrency = concurrency.max(1);
self
}
pub fn with_max_tool_rounds(mut self, rounds: usize) -> Self {
self.max_tool_rounds = rounds;
self
}
pub fn with_budget(mut self, guard: BudgetGuard) -> Self {
self.budget = Some(guard);
self
}
pub fn with_pre_execute_hook(mut self, hook: Arc<dyn ToolPreHook>) -> Self {
self.pre_execute_hook = Some(hook);
self
}
pub fn with_system_prompt(mut self, prompt: &str) -> Self {
if let Some(first) = self.messages.first()
&& first.role == Role::System
{
self.messages.remove(0);
}
self.messages.insert(0, Message::system(prompt));
self
}
pub async fn process_message(&mut self, input: &str) -> Result<String> {
self.messages.push(Message::user(input));
self.run_completion(None::<fn(&str)>).await
}
pub async fn process_message_streaming<F>(&mut self, input: &str, on_chunk: F) -> Result<String>
where
F: Fn(&str) + Send + Sync,
{
self.messages.push(Message::user(input));
self.run_completion(Some(on_chunk)).await
}
pub fn messages(&self) -> &[Message] {
&self.messages
}
pub fn restore_messages(&mut self, messages: Vec<Message>) {
self.messages = messages;
}
pub fn clear_history(&mut self) {
self.messages.clear();
}
pub fn trim_history(&mut self, max_messages: usize) {
if self.messages.len() <= max_messages {
return;
}
let has_system = self
.messages
.first()
.map(|m| m.role == Role::System)
.unwrap_or(false);
if has_system && max_messages > 0 {
let system = self.messages.remove(0);
let keep = max_messages.saturating_sub(1);
let start = self.messages.len().saturating_sub(keep);
self.messages = std::iter::once(system)
.chain(self.messages.drain(start..))
.collect();
} else {
let start = self.messages.len().saturating_sub(max_messages);
self.messages = self.messages.drain(start..).collect();
}
}
pub fn message_count(&self) -> usize {
self.messages.len()
}
pub fn cumulative_usage(&self) -> &Usage {
&self.cumulative_usage
}
pub fn reset_usage(&mut self) {
self.cumulative_usage = Usage::default();
}
pub async fn compact_history(&mut self) -> Result<()> {
let Some(summarizer) = self.summarizer.clone() else {
self.trim_history(20);
return Ok(());
};
let keep_tail = self.summarization_keep_tail;
if self.messages.len() <= keep_tail + 1 {
return Ok(());
}
let has_system = self
.messages
.first()
.map(|m| m.role == Role::System)
.unwrap_or(false);
let head_end = if has_system { 1 } else { 0 };
let tail_start = self.messages.len().saturating_sub(keep_tail);
if tail_start <= head_end {
return Ok(());
}
let to_summarize: Vec<Message> = self.messages[head_end..tail_start].to_vec();
let summary = summarizer.summarize(&to_summarize).await?;
let synthetic = Message::assistant(format!("[Prior conversation summary] {summary}"));
let tail: Vec<Message> = self.messages[tail_start..].to_vec();
let mut new_messages = Vec::with_capacity(head_end + 1 + tail.len());
if has_system {
new_messages.push(self.messages[0].clone());
}
new_messages.push(synthetic);
new_messages.extend(tail);
self.messages = new_messages;
Ok(())
}
async fn run_completion<F>(&mut self, on_chunk: Option<F>) -> Result<String>
where
F: Fn(&str) + Send + Sync,
{
let mut final_text = String::new();
for _ in 0..self.max_tool_rounds {
if let Some(ref guard) = self.budget {
guard.check_and_tick().map_err(anyhow::Error::from)?;
}
let tool_defs: Vec<Tool> = self.executor.available_tools();
let tools_opt = if tool_defs.is_empty() {
None
} else {
Some(tool_defs.as_slice())
};
let (text_buf, tool_uses, response_id, compaction) =
self.collect_stream(tools_opt, &on_chunk).await?;
if let Some((summary, tokens_freed)) = compaction {
tracing::info!(
tokens_freed = ?tokens_freed,
"context compaction triggered; replacing history with model summary"
);
let system_msg = self
.messages
.iter()
.find(|m| m.role == Role::System)
.cloned();
self.messages.clear();
if let Some(sys) = system_msg {
self.messages.push(sys);
}
self.messages.push(Message::assistant(&summary));
}
if tool_uses.is_empty() {
self.messages.push(Message::assistant(&text_buf));
final_text = text_buf;
break;
}
let mut blocks = Vec::new();
if !text_buf.is_empty() {
blocks.push(ContentBlock::Text {
text: text_buf.clone(),
});
}
for tu in &tool_uses {
blocks.push(ContentBlock::ToolUse {
id: tu.id.clone(),
name: tu.name.clone(),
input: tu.input.clone(),
});
}
let metadata = response_id.map(|rid| serde_json::json!({"response_id": rid}));
self.messages.push(Message {
role: Role::Assistant,
content: MessageContent::Blocks(blocks),
name: None,
metadata,
});
let serialize_map: std::collections::HashMap<&str, bool> = tool_defs
.iter()
.map(|t| (t.name.as_str(), t.serialize))
.collect();
let (serial_idx, parallel_idx): (Vec<usize>, Vec<usize>) = (0..tool_uses.len())
.partition(|&i| {
serialize_map
.get(tool_uses[i].name.as_str())
.copied()
.unwrap_or(false)
});
let mut slots: Vec<Option<ContentBlock>> = (0..tool_uses.len()).map(|_| None).collect();
for i in serial_idx {
let tu = &tool_uses[i];
let block =
execute_one_tool(tu, self.executor.clone(), self.pre_execute_hook.clone())
.await;
slots[i] = Some(block);
}
if !parallel_idx.is_empty() {
use futures::StreamExt as _;
use futures::future::BoxFuture;
let concurrency = self.tool_concurrency.max(1);
let executor = self.executor.clone();
let hook = self.pre_execute_hook.clone();
let futures: Vec<BoxFuture<'static, (usize, ContentBlock)>> = parallel_idx
.into_iter()
.map(|i| {
let tu = tool_uses[i].clone();
let exec = executor.clone();
let hk = hook.clone();
Box::pin(async move { (i, execute_one_tool(&tu, exec, hk).await) })
as BoxFuture<'static, (usize, ContentBlock)>
})
.collect();
let results: Vec<(usize, ContentBlock)> = futures::stream::iter(futures)
.buffer_unordered(concurrency)
.collect()
.await;
for (i, block) in results {
slots[i] = Some(block);
}
}
let result_blocks: Vec<ContentBlock> = slots
.into_iter()
.map(|b| b.expect("every tool use produced a result"))
.collect();
self.messages.push(Message {
role: Role::User,
content: MessageContent::Blocks(result_blocks),
name: None,
metadata: None,
});
final_text = text_buf;
}
Ok(final_text)
}
async fn collect_stream<F>(
&mut self,
tools_opt: Option<&[Tool]>,
on_chunk: &Option<F>,
) -> Result<(
String,
Vec<ToolUse>,
Option<String>,
Option<(String, Option<u32>)>,
)>
where
F: Fn(&str) + Send + Sync,
{
let mut stream = self
.provider
.stream_chat(&self.messages, tools_opt, &self.options);
let mut text_buf = String::new();
let mut tool_uses: Vec<ToolUse> = Vec::new();
let mut current_tool_id = String::new();
let mut current_tool_name = String::new();
let mut current_tool_input = String::new();
let mut last_response_id: Option<String> = None;
let mut compaction: Option<(String, Option<u32>)> = None;
while let Some(chunk) = stream.next().await {
match chunk? {
StreamChunk::Text(t) => {
if let Some(cb) = on_chunk {
cb(&t);
}
text_buf.push_str(&t);
}
StreamChunk::ToolUse { id, name } => {
if !current_tool_id.is_empty() {
let input: serde_json::Value = serde_json::from_str(¤t_tool_input)
.unwrap_or(serde_json::Value::Null);
tool_uses.push(ToolUse {
id: std::mem::take(&mut current_tool_id),
name: std::mem::take(&mut current_tool_name),
input,
});
current_tool_input.clear();
}
current_tool_id = id;
current_tool_name = name;
}
StreamChunk::ToolInputDelta { partial_json, .. } => {
current_tool_input.push_str(&partial_json);
}
StreamChunk::ToolCall {
call_id,
response_id,
tool_name,
parameters,
..
} => {
last_response_id = Some(response_id);
tool_uses.push(ToolUse {
id: call_id,
name: tool_name,
input: parameters,
});
}
StreamChunk::Usage(u) => {
self.cumulative_usage.prompt_tokens += u.prompt_tokens;
self.cumulative_usage.completion_tokens += u.completion_tokens;
self.cumulative_usage.total_tokens += u.total_tokens;
if let Some(ref guard) = self.budget {
guard.record_usage(&u);
}
}
StreamChunk::Done => {}
StreamChunk::ContextCompacted {
summary,
tokens_freed,
} => {
compaction = Some((summary, tokens_freed));
}
}
}
if !current_tool_id.is_empty() {
let input: serde_json::Value =
serde_json::from_str(¤t_tool_input).unwrap_or(serde_json::Value::Null);
tool_uses.push(ToolUse {
id: current_tool_id,
name: current_tool_name,
input,
});
}
Ok((text_buf, tool_uses, last_response_id, compaction))
}
}
async fn execute_one_tool(
tu: &ToolUse,
executor: Arc<dyn ToolExecutor>,
pre_execute_hook: Option<Arc<dyn ToolPreHook>>,
) -> ContentBlock {
if let Some(ref hook) = pre_execute_hook {
let ctx = ToolContext::default();
match hook.before_execute(tu, &ctx).await {
Ok(PreHookDecision::Allow) => {}
Ok(PreHookDecision::Reject(reason)) => {
return ContentBlock::ToolResult {
tool_use_id: tu.id.clone(),
content: reason,
is_error: Some(true),
};
}
Err(e) => {
tracing::warn!(tool = %tu.name, error = %e, "Pre-execute hook error");
}
}
}
let exec_ctx = ToolContext::default();
let result = match executor.execute(tu, &exec_ctx).await {
Ok(r) => r,
Err(e) => {
brainwires_core::ToolResult::error(tu.id.clone(), format!("tool executor error: {e}"))
}
};
ContentBlock::ToolResult {
tool_use_id: tu.id.clone(),
content: result.content,
is_error: Some(result.is_error),
}
}
#[cfg(test)]
mod tests {
use super::*;
use brainwires_core::{ToolContext, ToolInputSchema};
use brainwires_tool_builtins::BuiltinToolExecutor;
use brainwires_tool_runtime::ToolRegistry;
use futures::stream;
use std::collections::HashMap;
struct MockProvider {
response_text: String,
}
impl MockProvider {
fn new(text: &str) -> Self {
Self {
response_text: text.to_string(),
}
}
}
#[async_trait::async_trait]
impl Provider for MockProvider {
fn name(&self) -> &str {
"mock"
}
async fn chat(
&self,
_messages: &[Message],
_tools: Option<&[Tool]>,
_options: &ChatOptions,
) -> Result<brainwires_core::ChatResponse> {
Ok(brainwires_core::ChatResponse {
message: Message::assistant(&self.response_text),
usage: brainwires_core::Usage::new(10, 20),
finish_reason: Some("stop".to_string()),
})
}
fn stream_chat<'a>(
&'a self,
_messages: &'a [Message],
_tools: Option<&'a [Tool]>,
_options: &'a ChatOptions,
) -> futures::stream::BoxStream<'a, Result<StreamChunk>> {
let text = self.response_text.clone();
Box::pin(stream::iter(vec![
Ok(StreamChunk::Text(text)),
Ok(StreamChunk::Done),
]))
}
}
fn make_executor() -> Arc<dyn ToolExecutor> {
let mut registry = ToolRegistry::new();
registry.register(Tool {
name: "test_tool".to_string(),
description: "A test tool".to_string(),
input_schema: ToolInputSchema::object(HashMap::new(), vec![]),
..Default::default()
});
let context = ToolContext::default();
Arc::new(BuiltinToolExecutor::new(registry, context))
}
fn make_agent() -> ChatAgent {
let provider = Arc::new(MockProvider::new("Hello from mock!"));
let executor = make_executor();
ChatAgent::new(provider, executor, ChatOptions::default())
}
#[test]
fn test_new_creates_successfully() {
let agent = make_agent();
assert_eq!(agent.message_count(), 0);
assert_eq!(agent.max_tool_rounds, 10);
}
#[test]
fn test_with_system_prompt_adds_system_message() {
let agent = make_agent().with_system_prompt("You are helpful.");
assert_eq!(agent.message_count(), 1);
assert_eq!(agent.messages()[0].role, Role::System);
assert_eq!(agent.messages()[0].text(), Some("You are helpful."));
}
#[test]
fn test_with_system_prompt_replaces_existing() {
let agent = make_agent()
.with_system_prompt("First prompt")
.with_system_prompt("Second prompt");
assert_eq!(agent.message_count(), 1);
assert_eq!(agent.messages()[0].text(), Some("Second prompt"));
}
#[test]
fn test_with_max_tool_rounds() {
let agent = make_agent().with_max_tool_rounds(5);
assert_eq!(agent.max_tool_rounds, 5);
}
#[test]
fn test_messages_returns_history() {
let mut agent = make_agent();
assert!(agent.messages().is_empty());
agent.messages.push(Message::user("test"));
assert_eq!(agent.messages().len(), 1);
}
#[test]
fn test_clear_history() {
let mut agent = make_agent().with_system_prompt("sys");
agent.messages.push(Message::user("hello"));
assert_eq!(agent.message_count(), 2);
agent.clear_history();
assert_eq!(agent.message_count(), 0);
}
#[test]
fn test_trim_history_no_system() {
let mut agent = make_agent();
for i in 0..10 {
agent.messages.push(Message::user(format!("msg {}", i)));
}
assert_eq!(agent.message_count(), 10);
agent.trim_history(3);
assert_eq!(agent.message_count(), 3);
assert_eq!(agent.messages()[0].text(), Some("msg 7"));
assert_eq!(agent.messages()[1].text(), Some("msg 8"));
assert_eq!(agent.messages()[2].text(), Some("msg 9"));
}
#[test]
fn test_trim_history_preserves_system() {
let mut agent = make_agent().with_system_prompt("system prompt");
for i in 0..10 {
agent.messages.push(Message::user(format!("msg {}", i)));
}
assert_eq!(agent.message_count(), 11); agent.trim_history(4);
assert_eq!(agent.message_count(), 4);
assert_eq!(agent.messages()[0].role, Role::System);
assert_eq!(agent.messages()[0].text(), Some("system prompt"));
assert_eq!(agent.messages()[1].text(), Some("msg 7"));
assert_eq!(agent.messages()[2].text(), Some("msg 8"));
assert_eq!(agent.messages()[3].text(), Some("msg 9"));
}
#[test]
fn test_trim_history_under_limit_is_noop() {
let mut agent = make_agent();
agent.messages.push(Message::user("only one"));
agent.trim_history(10);
assert_eq!(agent.message_count(), 1);
}
#[test]
fn test_message_count() {
let mut agent = make_agent();
assert_eq!(agent.message_count(), 0);
agent.messages.push(Message::user("a"));
assert_eq!(agent.message_count(), 1);
agent.messages.push(Message::assistant("b"));
assert_eq!(agent.message_count(), 2);
}
#[tokio::test]
async fn test_process_message_returns_text() {
let mut agent = make_agent();
let result = agent.process_message("Hi").await.unwrap();
assert_eq!(result, "Hello from mock!");
assert_eq!(agent.message_count(), 2);
assert_eq!(agent.messages()[0].role, Role::User);
assert_eq!(agent.messages()[1].role, Role::Assistant);
}
#[tokio::test]
async fn test_process_message_streaming() {
let mut agent = make_agent();
let chunks = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let chunks_clone = chunks.clone();
let result = agent
.process_message_streaming("Hi", move |chunk| {
chunks_clone.lock().unwrap().push(chunk.to_string());
})
.await
.unwrap();
assert_eq!(result, "Hello from mock!");
let received = chunks.lock().unwrap();
assert_eq!(received.len(), 1);
assert_eq!(received[0], "Hello from mock!");
}
}