Skip to main content

brainwires_agents/
chat_agent.rs

1//! A simple chat agent that processes messages through an LLM provider with tool support.
2//!
3//! [`ChatAgent`] is the framework's ready-to-use agent for text message to response
4//! flows, including automatic tool call dispatch via [`BuiltinToolExecutor`].
5
6use std::sync::Arc;
7
8use anyhow::Result;
9use futures::StreamExt;
10
11use brainwires_core::{
12    ChatOptions, ContentBlock, Message, MessageContent, Provider, Role, StreamChunk, Tool,
13    ToolContext, ToolUse, Usage,
14};
15use brainwires_tools::{BuiltinToolExecutor, PreHookDecision, ToolPreHook};
16
17/// A simple chat agent that processes messages through an LLM provider with tool support.
18///
19/// This is the framework's ready-to-use agent for text message -> response flows.
20/// It manages conversation history, streams responses from the provider, and
21/// automatically dispatches tool calls through a [`BuiltinToolExecutor`].
22///
23/// # Example
24///
25/// ```rust,ignore
26/// use brainwires_agents::ChatAgent;
27/// use brainwires_tools::{BuiltinToolExecutor, ToolRegistry};
28/// use brainwires_core::{ChatOptions, ToolContext};
29/// use std::sync::Arc;
30///
31/// let provider = /* create a provider */;
32/// let registry = ToolRegistry::with_builtins();
33/// let context = ToolContext::default();
34/// let executor = Arc::new(BuiltinToolExecutor::new(registry, context));
35/// let options = ChatOptions::default();
36///
37/// let mut agent = ChatAgent::new(provider, executor, options)
38///     .with_system_prompt("You are a helpful assistant.")
39///     .with_max_tool_rounds(5);
40///
41/// let response = agent.process_message("Hello!").await?;
42/// println!("{}", response);
43/// ```
44pub struct ChatAgent {
45    provider: Arc<dyn Provider>,
46    executor: Arc<BuiltinToolExecutor>,
47    messages: Vec<Message>,
48    options: ChatOptions,
49    max_tool_rounds: usize,
50    pre_execute_hook: Option<Arc<dyn ToolPreHook>>,
51    /// Accumulated token usage across all completions in this session.
52    cumulative_usage: Usage,
53}
54
55impl ChatAgent {
56    /// Create a new `ChatAgent`.
57    ///
58    /// Defaults `max_tool_rounds` to 10.
59    pub fn new(
60        provider: Arc<dyn Provider>,
61        executor: Arc<BuiltinToolExecutor>,
62        options: ChatOptions,
63    ) -> Self {
64        Self {
65            provider,
66            executor,
67            messages: Vec::new(),
68            options,
69            max_tool_rounds: 10,
70            pre_execute_hook: None,
71            cumulative_usage: Usage::default(),
72        }
73    }
74
75    /// Set the maximum number of tool-call rounds before the agent stops.
76    pub fn with_max_tool_rounds(mut self, rounds: usize) -> Self {
77        self.max_tool_rounds = rounds;
78        self
79    }
80
81    /// Attach a pre-execution hook that can allow or reject tool calls before they run.
82    pub fn with_pre_execute_hook(mut self, hook: Arc<dyn ToolPreHook>) -> Self {
83        self.pre_execute_hook = Some(hook);
84        self
85    }
86
87    /// Add a system prompt as the first message in the conversation.
88    ///
89    /// If messages already exist, the system message is inserted at position 0.
90    pub fn with_system_prompt(mut self, prompt: &str) -> Self {
91        // Remove any existing system message at position 0
92        if let Some(first) = self.messages.first()
93            && first.role == Role::System
94        {
95            self.messages.remove(0);
96        }
97        self.messages.insert(0, Message::system(prompt));
98        self
99    }
100
101    /// Process a user message and return the final assistant text response.
102    ///
103    /// This is the core completion loop:
104    /// 1. Adds the user message to history
105    /// 2. Streams the provider response, collecting text and tool calls
106    /// 3. If tool calls are present, executes them and loops
107    /// 4. Returns the final accumulated text once no more tool calls remain
108    ///    (or `max_tool_rounds` is reached)
109    pub async fn process_message(&mut self, input: &str) -> Result<String> {
110        self.messages.push(Message::user(input));
111        self.run_completion(None::<fn(&str)>).await
112    }
113
114    /// Process a user message with streaming — calls `on_chunk` for each text
115    /// fragment as it arrives from the provider.
116    ///
117    /// Returns the full accumulated text once the completion loop finishes.
118    pub async fn process_message_streaming<F>(&mut self, input: &str, on_chunk: F) -> Result<String>
119    where
120        F: Fn(&str) + Send + Sync,
121    {
122        self.messages.push(Message::user(input));
123        self.run_completion(Some(on_chunk)).await
124    }
125
126    /// Access the conversation history.
127    pub fn messages(&self) -> &[Message] {
128        &self.messages
129    }
130
131    /// Replace the entire message history with the provided messages.
132    ///
133    /// This is used by session persistence to restore a previously saved
134    /// conversation when an agent session is recreated.
135    pub fn restore_messages(&mut self, messages: Vec<Message>) {
136        self.messages = messages;
137    }
138
139    /// Clear all messages (including any system prompt).
140    pub fn clear_history(&mut self) {
141        self.messages.clear();
142    }
143
144    /// Keep only the last `max_messages` messages, preserving the system prompt
145    /// at position 0 if one exists.
146    pub fn trim_history(&mut self, max_messages: usize) {
147        if self.messages.len() <= max_messages {
148            return;
149        }
150
151        let has_system = self
152            .messages
153            .first()
154            .map(|m| m.role == Role::System)
155            .unwrap_or(false);
156
157        if has_system && max_messages > 0 {
158            let system = self.messages.remove(0);
159            let keep = max_messages.saturating_sub(1);
160            let start = self.messages.len().saturating_sub(keep);
161            self.messages = std::iter::once(system)
162                .chain(self.messages.drain(start..))
163                .collect();
164        } else {
165            let start = self.messages.len().saturating_sub(max_messages);
166            self.messages = self.messages.drain(start..).collect();
167        }
168    }
169
170    /// Return the number of messages in the conversation.
171    pub fn message_count(&self) -> usize {
172        self.messages.len()
173    }
174
175    /// Return the accumulated token usage for this agent session.
176    ///
177    /// Counts prompt + completion tokens across all completions. Updated
178    /// whenever the provider emits a `StreamChunk::Usage` event.
179    pub fn cumulative_usage(&self) -> &Usage {
180        &self.cumulative_usage
181    }
182
183    /// Reset the cumulative token usage counter.
184    pub fn reset_usage(&mut self) {
185        self.cumulative_usage = Usage::default();
186    }
187
188    /// Compact conversation history by trimming older messages.
189    ///
190    /// This is a simple, LLM-free compaction that keeps the system prompt
191    /// (if any) and the most recent `keep` messages. For LLM-powered
192    /// summarisation, use the `DreamSummarizer` from `brainwires-autonomy`.
193    pub async fn compact_history(&mut self) -> Result<()> {
194        // Default: keep system prompt + last 20 messages
195        self.trim_history(20);
196        Ok(())
197    }
198
199    // ── Internal completion loop ─────────────────────────────────────────
200
201    async fn run_completion<F>(&mut self, on_chunk: Option<F>) -> Result<String>
202    where
203        F: Fn(&str) + Send + Sync,
204    {
205        let mut final_text = String::new();
206
207        for _ in 0..self.max_tool_rounds {
208            let tool_defs: Vec<Tool> = self.executor.tools();
209            let tools_opt = if tool_defs.is_empty() {
210                None
211            } else {
212                Some(tool_defs.as_slice())
213            };
214
215            let (text_buf, tool_uses, response_id, compaction) =
216                self.collect_stream(tools_opt, &on_chunk).await?;
217
218            // Apply context compaction if the model summarised the history.
219            // Must happen after collect_stream returns so the stream's borrow
220            // on self.messages is released.
221            if let Some((summary, tokens_freed)) = compaction {
222                tracing::info!(
223                    tokens_freed = ?tokens_freed,
224                    "context compaction triggered; replacing history with model summary"
225                );
226                let system_msg = self
227                    .messages
228                    .iter()
229                    .find(|m| m.role == Role::System)
230                    .cloned();
231                self.messages.clear();
232                if let Some(sys) = system_msg {
233                    self.messages.push(sys);
234                }
235                self.messages.push(Message::assistant(&summary));
236            }
237
238            if tool_uses.is_empty() {
239                // No tool calls — this is the final response
240                self.messages.push(Message::assistant(&text_buf));
241                final_text = text_buf;
242                break;
243            }
244
245            // Build assistant message with text + tool use blocks
246            let mut blocks = Vec::new();
247            if !text_buf.is_empty() {
248                blocks.push(ContentBlock::Text {
249                    text: text_buf.clone(),
250                });
251            }
252            for tu in &tool_uses {
253                blocks.push(ContentBlock::ToolUse {
254                    id: tu.id.clone(),
255                    name: tu.name.clone(),
256                    input: tu.input.clone(),
257                });
258            }
259            let metadata = response_id.map(|rid| serde_json::json!({"response_id": rid}));
260            self.messages.push(Message {
261                role: Role::Assistant,
262                content: MessageContent::Blocks(blocks),
263                name: None,
264                metadata,
265            });
266
267            // Execute each tool call and add results as a user message
268            let mut result_blocks = Vec::new();
269            for tu in &tool_uses {
270                // Run pre-execute hook if configured
271                if let Some(ref hook) = self.pre_execute_hook {
272                    let ctx = ToolContext::default();
273                    match hook.before_execute(tu, &ctx).await {
274                        Ok(PreHookDecision::Allow) => {}
275                        Ok(PreHookDecision::Reject(reason)) => {
276                            result_blocks.push(ContentBlock::ToolResult {
277                                tool_use_id: tu.id.clone(),
278                                content: reason,
279                                is_error: Some(true),
280                            });
281                            continue;
282                        }
283                        Err(e) => {
284                            tracing::warn!(tool = %tu.name, error = %e, "Pre-execute hook error");
285                        }
286                    }
287                }
288
289                let result = self
290                    .executor
291                    .execute_tool(&tu.name, &tu.id, &tu.input)
292                    .await;
293                result_blocks.push(ContentBlock::ToolResult {
294                    tool_use_id: tu.id.clone(),
295                    content: result.content,
296                    is_error: Some(result.is_error),
297                });
298            }
299
300            self.messages.push(Message {
301                role: Role::User,
302                content: MessageContent::Blocks(result_blocks),
303                name: None,
304                metadata: None,
305            });
306
307            // Keep the last text in case we hit max rounds
308            final_text = text_buf;
309        }
310
311        Ok(final_text)
312    }
313
314    /// Collect the stream into accumulated text + tool uses.
315    ///
316    /// Returns `(text, tool_uses, response_id, compaction)`.
317    /// `compaction` is `Some((summary, tokens_freed))` when the model emitted a
318    /// `context_window_management_event` during the stream.  The caller is
319    /// responsible for applying compaction to `self.messages` after the borrow
320    /// on `self.messages` (held by the stream) is released.
321    async fn collect_stream<F>(
322        &mut self,
323        tools_opt: Option<&[Tool]>,
324        on_chunk: &Option<F>,
325    ) -> Result<(
326        String,
327        Vec<ToolUse>,
328        Option<String>,
329        Option<(String, Option<u32>)>,
330    )>
331    where
332        F: Fn(&str) + Send + Sync,
333    {
334        let mut stream = self
335            .provider
336            .stream_chat(&self.messages, tools_opt, &self.options);
337
338        let mut text_buf = String::new();
339        let mut tool_uses: Vec<ToolUse> = Vec::new();
340        let mut current_tool_id = String::new();
341        let mut current_tool_name = String::new();
342        let mut current_tool_input = String::new();
343        let mut last_response_id: Option<String> = None;
344        let mut compaction: Option<(String, Option<u32>)> = None;
345
346        while let Some(chunk) = stream.next().await {
347            match chunk? {
348                StreamChunk::Text(t) => {
349                    if let Some(cb) = on_chunk {
350                        cb(&t);
351                    }
352                    text_buf.push_str(&t);
353                }
354                StreamChunk::ToolUse { id, name } => {
355                    // Flush previous tool if any
356                    if !current_tool_id.is_empty() {
357                        let input: serde_json::Value = serde_json::from_str(&current_tool_input)
358                            .unwrap_or(serde_json::Value::Null);
359                        tool_uses.push(ToolUse {
360                            id: std::mem::take(&mut current_tool_id),
361                            name: std::mem::take(&mut current_tool_name),
362                            input,
363                        });
364                        current_tool_input.clear();
365                    }
366                    current_tool_id = id;
367                    current_tool_name = name;
368                }
369                StreamChunk::ToolInputDelta { partial_json, .. } => {
370                    current_tool_input.push_str(&partial_json);
371                }
372                StreamChunk::ToolCall {
373                    call_id,
374                    response_id,
375                    tool_name,
376                    parameters,
377                    ..
378                } => {
379                    last_response_id = Some(response_id);
380                    tool_uses.push(ToolUse {
381                        id: call_id,
382                        name: tool_name,
383                        input: parameters,
384                    });
385                }
386                StreamChunk::Usage(u) => {
387                    self.cumulative_usage.prompt_tokens += u.prompt_tokens;
388                    self.cumulative_usage.completion_tokens += u.completion_tokens;
389                    self.cumulative_usage.total_tokens += u.total_tokens;
390                }
391                StreamChunk::Done => {}
392                StreamChunk::ContextCompacted {
393                    summary,
394                    tokens_freed,
395                } => {
396                    // Record compaction info; applied to self.messages after the stream
397                    // borrow is released (see run_completion).
398                    compaction = Some((summary, tokens_freed));
399                }
400            }
401        }
402
403        // Flush last tool if any
404        if !current_tool_id.is_empty() {
405            let input: serde_json::Value =
406                serde_json::from_str(&current_tool_input).unwrap_or(serde_json::Value::Null);
407            tool_uses.push(ToolUse {
408                id: current_tool_id,
409                name: current_tool_name,
410                input,
411            });
412        }
413
414        Ok((text_buf, tool_uses, last_response_id, compaction))
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421    use brainwires_core::{ToolContext, ToolInputSchema};
422    use brainwires_tools::ToolRegistry;
423    use futures::stream;
424    use std::collections::HashMap;
425
426    /// A mock provider that returns a simple text response.
427    struct MockProvider {
428        response_text: String,
429    }
430
431    impl MockProvider {
432        fn new(text: &str) -> Self {
433            Self {
434                response_text: text.to_string(),
435            }
436        }
437    }
438
439    #[async_trait::async_trait]
440    impl Provider for MockProvider {
441        fn name(&self) -> &str {
442            "mock"
443        }
444
445        async fn chat(
446            &self,
447            _messages: &[Message],
448            _tools: Option<&[Tool]>,
449            _options: &ChatOptions,
450        ) -> Result<brainwires_core::ChatResponse> {
451            Ok(brainwires_core::ChatResponse {
452                message: Message::assistant(&self.response_text),
453                usage: brainwires_core::Usage::new(10, 20),
454                finish_reason: Some("stop".to_string()),
455            })
456        }
457
458        fn stream_chat<'a>(
459            &'a self,
460            _messages: &'a [Message],
461            _tools: Option<&'a [Tool]>,
462            _options: &'a ChatOptions,
463        ) -> futures::stream::BoxStream<'a, Result<StreamChunk>> {
464            let text = self.response_text.clone();
465            Box::pin(stream::iter(vec![
466                Ok(StreamChunk::Text(text)),
467                Ok(StreamChunk::Done),
468            ]))
469        }
470    }
471
472    fn make_executor() -> Arc<BuiltinToolExecutor> {
473        let mut registry = ToolRegistry::new();
474        registry.register(Tool {
475            name: "test_tool".to_string(),
476            description: "A test tool".to_string(),
477            input_schema: ToolInputSchema::object(HashMap::new(), vec![]),
478            ..Default::default()
479        });
480        let context = ToolContext::default();
481        Arc::new(BuiltinToolExecutor::new(registry, context))
482    }
483
484    fn make_agent() -> ChatAgent {
485        let provider = Arc::new(MockProvider::new("Hello from mock!"));
486        let executor = make_executor();
487        ChatAgent::new(provider, executor, ChatOptions::default())
488    }
489
490    #[test]
491    fn test_new_creates_successfully() {
492        let agent = make_agent();
493        assert_eq!(agent.message_count(), 0);
494        assert_eq!(agent.max_tool_rounds, 10);
495    }
496
497    #[test]
498    fn test_with_system_prompt_adds_system_message() {
499        let agent = make_agent().with_system_prompt("You are helpful.");
500        assert_eq!(agent.message_count(), 1);
501        assert_eq!(agent.messages()[0].role, Role::System);
502        assert_eq!(agent.messages()[0].text(), Some("You are helpful."));
503    }
504
505    #[test]
506    fn test_with_system_prompt_replaces_existing() {
507        let agent = make_agent()
508            .with_system_prompt("First prompt")
509            .with_system_prompt("Second prompt");
510        assert_eq!(agent.message_count(), 1);
511        assert_eq!(agent.messages()[0].text(), Some("Second prompt"));
512    }
513
514    #[test]
515    fn test_with_max_tool_rounds() {
516        let agent = make_agent().with_max_tool_rounds(5);
517        assert_eq!(agent.max_tool_rounds, 5);
518    }
519
520    #[test]
521    fn test_messages_returns_history() {
522        let mut agent = make_agent();
523        assert!(agent.messages().is_empty());
524        // Manually push to test accessor
525        agent.messages.push(Message::user("test"));
526        assert_eq!(agent.messages().len(), 1);
527    }
528
529    #[test]
530    fn test_clear_history() {
531        let mut agent = make_agent().with_system_prompt("sys");
532        agent.messages.push(Message::user("hello"));
533        assert_eq!(agent.message_count(), 2);
534        agent.clear_history();
535        assert_eq!(agent.message_count(), 0);
536    }
537
538    #[test]
539    fn test_trim_history_no_system() {
540        let mut agent = make_agent();
541        for i in 0..10 {
542            agent.messages.push(Message::user(format!("msg {}", i)));
543        }
544        assert_eq!(agent.message_count(), 10);
545        agent.trim_history(3);
546        assert_eq!(agent.message_count(), 3);
547        // Should keep the last 3
548        assert_eq!(agent.messages()[0].text(), Some("msg 7"));
549        assert_eq!(agent.messages()[1].text(), Some("msg 8"));
550        assert_eq!(agent.messages()[2].text(), Some("msg 9"));
551    }
552
553    #[test]
554    fn test_trim_history_preserves_system() {
555        let mut agent = make_agent().with_system_prompt("system prompt");
556        for i in 0..10 {
557            agent.messages.push(Message::user(format!("msg {}", i)));
558        }
559        assert_eq!(agent.message_count(), 11); // 1 system + 10 user
560        agent.trim_history(4);
561        assert_eq!(agent.message_count(), 4);
562        assert_eq!(agent.messages()[0].role, Role::System);
563        assert_eq!(agent.messages()[0].text(), Some("system prompt"));
564        // Last 3 user messages
565        assert_eq!(agent.messages()[1].text(), Some("msg 7"));
566        assert_eq!(agent.messages()[2].text(), Some("msg 8"));
567        assert_eq!(agent.messages()[3].text(), Some("msg 9"));
568    }
569
570    #[test]
571    fn test_trim_history_under_limit_is_noop() {
572        let mut agent = make_agent();
573        agent.messages.push(Message::user("only one"));
574        agent.trim_history(10);
575        assert_eq!(agent.message_count(), 1);
576    }
577
578    #[test]
579    fn test_message_count() {
580        let mut agent = make_agent();
581        assert_eq!(agent.message_count(), 0);
582        agent.messages.push(Message::user("a"));
583        assert_eq!(agent.message_count(), 1);
584        agent.messages.push(Message::assistant("b"));
585        assert_eq!(agent.message_count(), 2);
586    }
587
588    #[tokio::test]
589    async fn test_process_message_returns_text() {
590        let mut agent = make_agent();
591        let result = agent.process_message("Hi").await.unwrap();
592        assert_eq!(result, "Hello from mock!");
593        // Should have user message + assistant response
594        assert_eq!(agent.message_count(), 2);
595        assert_eq!(agent.messages()[0].role, Role::User);
596        assert_eq!(agent.messages()[1].role, Role::Assistant);
597    }
598
599    #[tokio::test]
600    async fn test_process_message_streaming() {
601        let mut agent = make_agent();
602        let chunks = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
603        let chunks_clone = chunks.clone();
604
605        let result = agent
606            .process_message_streaming("Hi", move |chunk| {
607                chunks_clone.lock().unwrap().push(chunk.to_string());
608            })
609            .await
610            .unwrap();
611
612        assert_eq!(result, "Hello from mock!");
613        let received = chunks.lock().unwrap();
614        assert_eq!(received.len(), 1);
615        assert_eq!(received[0], "Hello from mock!");
616    }
617}