1use std::sync::Arc;
7
8use anyhow::Result;
9use futures::StreamExt;
10
11use brainwires_core::{
12 ChatOptions, ContentBlock, Message, MessageContent, Provider, Role, StreamChunk, Tool,
13 ToolContext, ToolUse, Usage,
14};
15use brainwires_tools::{BuiltinToolExecutor, PreHookDecision, ToolPreHook};
16
17pub struct ChatAgent {
45 provider: Arc<dyn Provider>,
46 executor: Arc<BuiltinToolExecutor>,
47 messages: Vec<Message>,
48 options: ChatOptions,
49 max_tool_rounds: usize,
50 pre_execute_hook: Option<Arc<dyn ToolPreHook>>,
51 cumulative_usage: Usage,
53}
54
55impl ChatAgent {
56 pub fn new(
60 provider: Arc<dyn Provider>,
61 executor: Arc<BuiltinToolExecutor>,
62 options: ChatOptions,
63 ) -> Self {
64 Self {
65 provider,
66 executor,
67 messages: Vec::new(),
68 options,
69 max_tool_rounds: 10,
70 pre_execute_hook: None,
71 cumulative_usage: Usage::default(),
72 }
73 }
74
75 pub fn with_max_tool_rounds(mut self, rounds: usize) -> Self {
77 self.max_tool_rounds = rounds;
78 self
79 }
80
81 pub fn with_pre_execute_hook(mut self, hook: Arc<dyn ToolPreHook>) -> Self {
83 self.pre_execute_hook = Some(hook);
84 self
85 }
86
87 pub fn with_system_prompt(mut self, prompt: &str) -> Self {
91 if let Some(first) = self.messages.first()
93 && first.role == Role::System
94 {
95 self.messages.remove(0);
96 }
97 self.messages.insert(0, Message::system(prompt));
98 self
99 }
100
101 pub async fn process_message(&mut self, input: &str) -> Result<String> {
110 self.messages.push(Message::user(input));
111 self.run_completion(None::<fn(&str)>).await
112 }
113
114 pub async fn process_message_streaming<F>(&mut self, input: &str, on_chunk: F) -> Result<String>
119 where
120 F: Fn(&str) + Send + Sync,
121 {
122 self.messages.push(Message::user(input));
123 self.run_completion(Some(on_chunk)).await
124 }
125
126 pub fn messages(&self) -> &[Message] {
128 &self.messages
129 }
130
131 pub fn restore_messages(&mut self, messages: Vec<Message>) {
136 self.messages = messages;
137 }
138
139 pub fn clear_history(&mut self) {
141 self.messages.clear();
142 }
143
144 pub fn trim_history(&mut self, max_messages: usize) {
147 if self.messages.len() <= max_messages {
148 return;
149 }
150
151 let has_system = self
152 .messages
153 .first()
154 .map(|m| m.role == Role::System)
155 .unwrap_or(false);
156
157 if has_system && max_messages > 0 {
158 let system = self.messages.remove(0);
159 let keep = max_messages.saturating_sub(1);
160 let start = self.messages.len().saturating_sub(keep);
161 self.messages = std::iter::once(system)
162 .chain(self.messages.drain(start..))
163 .collect();
164 } else {
165 let start = self.messages.len().saturating_sub(max_messages);
166 self.messages = self.messages.drain(start..).collect();
167 }
168 }
169
170 pub fn message_count(&self) -> usize {
172 self.messages.len()
173 }
174
175 pub fn cumulative_usage(&self) -> &Usage {
180 &self.cumulative_usage
181 }
182
183 pub fn reset_usage(&mut self) {
185 self.cumulative_usage = Usage::default();
186 }
187
188 pub async fn compact_history(&mut self) -> Result<()> {
194 self.trim_history(20);
196 Ok(())
197 }
198
199 async fn run_completion<F>(&mut self, on_chunk: Option<F>) -> Result<String>
202 where
203 F: Fn(&str) + Send + Sync,
204 {
205 let mut final_text = String::new();
206
207 for _ in 0..self.max_tool_rounds {
208 let tool_defs: Vec<Tool> = self.executor.tools();
209 let tools_opt = if tool_defs.is_empty() {
210 None
211 } else {
212 Some(tool_defs.as_slice())
213 };
214
215 let (text_buf, tool_uses, response_id, compaction) =
216 self.collect_stream(tools_opt, &on_chunk).await?;
217
218 if let Some((summary, tokens_freed)) = compaction {
222 tracing::info!(
223 tokens_freed = ?tokens_freed,
224 "context compaction triggered; replacing history with model summary"
225 );
226 let system_msg = self
227 .messages
228 .iter()
229 .find(|m| m.role == Role::System)
230 .cloned();
231 self.messages.clear();
232 if let Some(sys) = system_msg {
233 self.messages.push(sys);
234 }
235 self.messages.push(Message::assistant(&summary));
236 }
237
238 if tool_uses.is_empty() {
239 self.messages.push(Message::assistant(&text_buf));
241 final_text = text_buf;
242 break;
243 }
244
245 let mut blocks = Vec::new();
247 if !text_buf.is_empty() {
248 blocks.push(ContentBlock::Text {
249 text: text_buf.clone(),
250 });
251 }
252 for tu in &tool_uses {
253 blocks.push(ContentBlock::ToolUse {
254 id: tu.id.clone(),
255 name: tu.name.clone(),
256 input: tu.input.clone(),
257 });
258 }
259 let metadata = response_id.map(|rid| serde_json::json!({"response_id": rid}));
260 self.messages.push(Message {
261 role: Role::Assistant,
262 content: MessageContent::Blocks(blocks),
263 name: None,
264 metadata,
265 });
266
267 let mut result_blocks = Vec::new();
269 for tu in &tool_uses {
270 if let Some(ref hook) = self.pre_execute_hook {
272 let ctx = ToolContext::default();
273 match hook.before_execute(tu, &ctx).await {
274 Ok(PreHookDecision::Allow) => {}
275 Ok(PreHookDecision::Reject(reason)) => {
276 result_blocks.push(ContentBlock::ToolResult {
277 tool_use_id: tu.id.clone(),
278 content: reason,
279 is_error: Some(true),
280 });
281 continue;
282 }
283 Err(e) => {
284 tracing::warn!(tool = %tu.name, error = %e, "Pre-execute hook error");
285 }
286 }
287 }
288
289 let result = self
290 .executor
291 .execute_tool(&tu.name, &tu.id, &tu.input)
292 .await;
293 result_blocks.push(ContentBlock::ToolResult {
294 tool_use_id: tu.id.clone(),
295 content: result.content,
296 is_error: Some(result.is_error),
297 });
298 }
299
300 self.messages.push(Message {
301 role: Role::User,
302 content: MessageContent::Blocks(result_blocks),
303 name: None,
304 metadata: None,
305 });
306
307 final_text = text_buf;
309 }
310
311 Ok(final_text)
312 }
313
314 async fn collect_stream<F>(
322 &mut self,
323 tools_opt: Option<&[Tool]>,
324 on_chunk: &Option<F>,
325 ) -> Result<(
326 String,
327 Vec<ToolUse>,
328 Option<String>,
329 Option<(String, Option<u32>)>,
330 )>
331 where
332 F: Fn(&str) + Send + Sync,
333 {
334 let mut stream = self
335 .provider
336 .stream_chat(&self.messages, tools_opt, &self.options);
337
338 let mut text_buf = String::new();
339 let mut tool_uses: Vec<ToolUse> = Vec::new();
340 let mut current_tool_id = String::new();
341 let mut current_tool_name = String::new();
342 let mut current_tool_input = String::new();
343 let mut last_response_id: Option<String> = None;
344 let mut compaction: Option<(String, Option<u32>)> = None;
345
346 while let Some(chunk) = stream.next().await {
347 match chunk? {
348 StreamChunk::Text(t) => {
349 if let Some(cb) = on_chunk {
350 cb(&t);
351 }
352 text_buf.push_str(&t);
353 }
354 StreamChunk::ToolUse { id, name } => {
355 if !current_tool_id.is_empty() {
357 let input: serde_json::Value = serde_json::from_str(¤t_tool_input)
358 .unwrap_or(serde_json::Value::Null);
359 tool_uses.push(ToolUse {
360 id: std::mem::take(&mut current_tool_id),
361 name: std::mem::take(&mut current_tool_name),
362 input,
363 });
364 current_tool_input.clear();
365 }
366 current_tool_id = id;
367 current_tool_name = name;
368 }
369 StreamChunk::ToolInputDelta { partial_json, .. } => {
370 current_tool_input.push_str(&partial_json);
371 }
372 StreamChunk::ToolCall {
373 call_id,
374 response_id,
375 tool_name,
376 parameters,
377 ..
378 } => {
379 last_response_id = Some(response_id);
380 tool_uses.push(ToolUse {
381 id: call_id,
382 name: tool_name,
383 input: parameters,
384 });
385 }
386 StreamChunk::Usage(u) => {
387 self.cumulative_usage.prompt_tokens += u.prompt_tokens;
388 self.cumulative_usage.completion_tokens += u.completion_tokens;
389 self.cumulative_usage.total_tokens += u.total_tokens;
390 }
391 StreamChunk::Done => {}
392 StreamChunk::ContextCompacted {
393 summary,
394 tokens_freed,
395 } => {
396 compaction = Some((summary, tokens_freed));
399 }
400 }
401 }
402
403 if !current_tool_id.is_empty() {
405 let input: serde_json::Value =
406 serde_json::from_str(¤t_tool_input).unwrap_or(serde_json::Value::Null);
407 tool_uses.push(ToolUse {
408 id: current_tool_id,
409 name: current_tool_name,
410 input,
411 });
412 }
413
414 Ok((text_buf, tool_uses, last_response_id, compaction))
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421 use brainwires_core::{ToolContext, ToolInputSchema};
422 use brainwires_tools::ToolRegistry;
423 use futures::stream;
424 use std::collections::HashMap;
425
426 struct MockProvider {
428 response_text: String,
429 }
430
431 impl MockProvider {
432 fn new(text: &str) -> Self {
433 Self {
434 response_text: text.to_string(),
435 }
436 }
437 }
438
439 #[async_trait::async_trait]
440 impl Provider for MockProvider {
441 fn name(&self) -> &str {
442 "mock"
443 }
444
445 async fn chat(
446 &self,
447 _messages: &[Message],
448 _tools: Option<&[Tool]>,
449 _options: &ChatOptions,
450 ) -> Result<brainwires_core::ChatResponse> {
451 Ok(brainwires_core::ChatResponse {
452 message: Message::assistant(&self.response_text),
453 usage: brainwires_core::Usage::new(10, 20),
454 finish_reason: Some("stop".to_string()),
455 })
456 }
457
458 fn stream_chat<'a>(
459 &'a self,
460 _messages: &'a [Message],
461 _tools: Option<&'a [Tool]>,
462 _options: &'a ChatOptions,
463 ) -> futures::stream::BoxStream<'a, Result<StreamChunk>> {
464 let text = self.response_text.clone();
465 Box::pin(stream::iter(vec![
466 Ok(StreamChunk::Text(text)),
467 Ok(StreamChunk::Done),
468 ]))
469 }
470 }
471
472 fn make_executor() -> Arc<BuiltinToolExecutor> {
473 let mut registry = ToolRegistry::new();
474 registry.register(Tool {
475 name: "test_tool".to_string(),
476 description: "A test tool".to_string(),
477 input_schema: ToolInputSchema::object(HashMap::new(), vec![]),
478 ..Default::default()
479 });
480 let context = ToolContext::default();
481 Arc::new(BuiltinToolExecutor::new(registry, context))
482 }
483
484 fn make_agent() -> ChatAgent {
485 let provider = Arc::new(MockProvider::new("Hello from mock!"));
486 let executor = make_executor();
487 ChatAgent::new(provider, executor, ChatOptions::default())
488 }
489
490 #[test]
491 fn test_new_creates_successfully() {
492 let agent = make_agent();
493 assert_eq!(agent.message_count(), 0);
494 assert_eq!(agent.max_tool_rounds, 10);
495 }
496
497 #[test]
498 fn test_with_system_prompt_adds_system_message() {
499 let agent = make_agent().with_system_prompt("You are helpful.");
500 assert_eq!(agent.message_count(), 1);
501 assert_eq!(agent.messages()[0].role, Role::System);
502 assert_eq!(agent.messages()[0].text(), Some("You are helpful."));
503 }
504
505 #[test]
506 fn test_with_system_prompt_replaces_existing() {
507 let agent = make_agent()
508 .with_system_prompt("First prompt")
509 .with_system_prompt("Second prompt");
510 assert_eq!(agent.message_count(), 1);
511 assert_eq!(agent.messages()[0].text(), Some("Second prompt"));
512 }
513
514 #[test]
515 fn test_with_max_tool_rounds() {
516 let agent = make_agent().with_max_tool_rounds(5);
517 assert_eq!(agent.max_tool_rounds, 5);
518 }
519
520 #[test]
521 fn test_messages_returns_history() {
522 let mut agent = make_agent();
523 assert!(agent.messages().is_empty());
524 agent.messages.push(Message::user("test"));
526 assert_eq!(agent.messages().len(), 1);
527 }
528
529 #[test]
530 fn test_clear_history() {
531 let mut agent = make_agent().with_system_prompt("sys");
532 agent.messages.push(Message::user("hello"));
533 assert_eq!(agent.message_count(), 2);
534 agent.clear_history();
535 assert_eq!(agent.message_count(), 0);
536 }
537
538 #[test]
539 fn test_trim_history_no_system() {
540 let mut agent = make_agent();
541 for i in 0..10 {
542 agent.messages.push(Message::user(format!("msg {}", i)));
543 }
544 assert_eq!(agent.message_count(), 10);
545 agent.trim_history(3);
546 assert_eq!(agent.message_count(), 3);
547 assert_eq!(agent.messages()[0].text(), Some("msg 7"));
549 assert_eq!(agent.messages()[1].text(), Some("msg 8"));
550 assert_eq!(agent.messages()[2].text(), Some("msg 9"));
551 }
552
553 #[test]
554 fn test_trim_history_preserves_system() {
555 let mut agent = make_agent().with_system_prompt("system prompt");
556 for i in 0..10 {
557 agent.messages.push(Message::user(format!("msg {}", i)));
558 }
559 assert_eq!(agent.message_count(), 11); agent.trim_history(4);
561 assert_eq!(agent.message_count(), 4);
562 assert_eq!(agent.messages()[0].role, Role::System);
563 assert_eq!(agent.messages()[0].text(), Some("system prompt"));
564 assert_eq!(agent.messages()[1].text(), Some("msg 7"));
566 assert_eq!(agent.messages()[2].text(), Some("msg 8"));
567 assert_eq!(agent.messages()[3].text(), Some("msg 9"));
568 }
569
570 #[test]
571 fn test_trim_history_under_limit_is_noop() {
572 let mut agent = make_agent();
573 agent.messages.push(Message::user("only one"));
574 agent.trim_history(10);
575 assert_eq!(agent.message_count(), 1);
576 }
577
578 #[test]
579 fn test_message_count() {
580 let mut agent = make_agent();
581 assert_eq!(agent.message_count(), 0);
582 agent.messages.push(Message::user("a"));
583 assert_eq!(agent.message_count(), 1);
584 agent.messages.push(Message::assistant("b"));
585 assert_eq!(agent.message_count(), 2);
586 }
587
588 #[tokio::test]
589 async fn test_process_message_returns_text() {
590 let mut agent = make_agent();
591 let result = agent.process_message("Hi").await.unwrap();
592 assert_eq!(result, "Hello from mock!");
593 assert_eq!(agent.message_count(), 2);
595 assert_eq!(agent.messages()[0].role, Role::User);
596 assert_eq!(agent.messages()[1].role, Role::Assistant);
597 }
598
599 #[tokio::test]
600 async fn test_process_message_streaming() {
601 let mut agent = make_agent();
602 let chunks = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
603 let chunks_clone = chunks.clone();
604
605 let result = agent
606 .process_message_streaming("Hi", move |chunk| {
607 chunks_clone.lock().unwrap().push(chunk.to_string());
608 })
609 .await
610 .unwrap();
611
612 assert_eq!(result, "Hello from mock!");
613 let received = chunks.lock().unwrap();
614 assert_eq!(received.len(), 1);
615 assert_eq!(received[0], "Hello from mock!");
616 }
617}