1use crate::backend_trait::LlmBackend;
30use crate::message::Message;
31use crate::observer::{NoOpObserver, Observer, StepContext, ToolResult};
32use crate::store_trait::{MessageStore, ToolLog};
33use crate::tool::Registry;
34use std::io::Write;
35use std::sync::Arc;
36use tracing::{debug, error, info_span, warn};
37
38pub const DEFAULT_MAX_TOOL_RESULT_BYTES: usize = 64 * 1024;
40
41pub struct Agent<B: LlmBackend> {
43 pub backend: B,
45 pub messages: Vec<Message>,
47 pub tools: Registry,
49 pub max_steps: usize,
51 pub max_window: usize,
55 pub max_tool_result_bytes: usize,
58 pub store: Option<Arc<dyn MessageStore>>,
60 pub session: String,
62 pub observer: Arc<dyn Observer>,
64 pub on_token: Option<Box<dyn FnMut(&str) + Send>>,
75 #[deprecated(
78 since = "0.2.0",
79 note = "Use `Agent::on_token` for a user-controlled token sink. `stream = true` still prints to stdout when `on_token` is None."
80 )]
81 pub stream: bool,
82}
83
84impl<B: LlmBackend> Agent<B> {
85 #[allow(deprecated)]
90 pub fn new(backend: B, system: &str) -> Self {
91 Self {
92 backend,
93 messages: vec![Message {
94 role: "system".into(),
95 content: Some(system.into()),
96 tool_calls: None,
97 tool_call_id: None,
98 name: None,
99 }],
100 tools: Registry::new(),
101 max_steps: 10,
102 max_window: 40,
103 max_tool_result_bytes: DEFAULT_MAX_TOOL_RESULT_BYTES,
104 store: None,
105 session: "default".into(),
106 observer: Arc::new(NoOpObserver),
107 on_token: None,
108 stream: true,
109 }
110 }
111
112 pub fn attach_store(
118 &mut self,
119 store: Arc<dyn MessageStore>,
120 session: &str,
121 ) -> Result<(), String> {
122 let loaded = store.load(session).map_err(|e| e.to_string())?;
123 if loaded.is_empty() {
124 for m in &self.messages {
125 store.append(session, m).map_err(|e| e.to_string())?;
126 }
127 } else {
128 self.messages = loaded;
129 }
130 self.store = Some(store);
131 self.session = session.into();
132 Ok(())
133 }
134
135 fn persist(&self, msg: &Message) {
136 if let Some(s) = &self.store {
137 if let Err(e) = s.append(&self.session, msg) {
138 eprintln!("persist: {}", e);
139 }
140 }
141 }
142
143 fn window_start(&self) -> Option<usize> {
152 if self.messages.len() <= self.max_window {
153 return None;
154 }
155 let n = self.max_window;
156 let mut start = self.messages.len() - (n - 1);
157 while start < self.messages.len() && self.messages[start].role != "user" {
158 start += 1;
159 }
160 Some(start)
161 }
162
163 fn windowed_truncated(&self, start: usize) -> Vec<Message> {
169 let mut out = Vec::with_capacity(self.messages.len() - start + 1);
170 out.push(self.messages[0].clone());
171 out.extend(self.messages[start..].iter().cloned());
172 out
173 }
174
175 #[cfg(test)]
178 fn windowed(&self) -> Vec<Message> {
179 match self.window_start() {
180 None => self.messages.clone(),
181 Some(start) => self.windowed_truncated(start),
182 }
183 }
184
185 fn frame_tool_output(&self, name: &str, id: &str, raw: &str) -> String {
189 let cap = self.max_tool_result_bytes;
190 let (body, truncated) = if raw.len() > cap {
191 let mut end = cap;
193 while end > 0 && !raw.is_char_boundary(end) {
194 end -= 1;
195 }
196 (&raw[..end], true)
197 } else {
198 (raw, false)
199 };
200 if truncated {
201 format!(
202 "<tool_output name=\"{}\" id=\"{}\" truncated=\"true\" raw_bytes=\"{}\">{}</tool_output>",
203 escape_attr(name),
204 escape_attr(id),
205 raw.len(),
206 body
207 )
208 } else {
209 format!(
210 "<tool_output name=\"{}\" id=\"{}\">{}</tool_output>",
211 escape_attr(name),
212 escape_attr(id),
213 body
214 )
215 }
216 }
217
218 #[allow(deprecated)]
227 pub fn step(&mut self, user_input: &str) -> Result<String, String> {
228 let _span = info_span!(
229 "agnt.step",
230 session = %self.session,
231 input_len = user_input.len(),
232 )
233 .entered();
234 debug!(user_input_len = user_input.len(), "agent.step start");
235
236 let ctx = StepContext {
237 session: self.session.clone(),
238 user_input: user_input.into(),
239 };
240 self.observer.on_step_start(&ctx);
241
242 let user = Message {
243 role: "user".into(),
244 content: Some(user_input.into()),
245 tool_calls: None,
246 tool_call_id: None,
247 name: None,
248 };
249 self.persist(&user);
250 self.messages.push(user);
251
252 let tools = self.tools.as_openai_tools();
253
254 for _ in 0..self.max_steps {
255 let window_start = self.window_start();
259 let truncated_buf: Vec<Message> = match window_start {
260 Some(start) => self.windowed_truncated(start),
261 None => Vec::new(),
262 };
263 let send: &[Message] = match window_start {
264 Some(_) => &truncated_buf,
265 None => &self.messages,
266 };
267
268 let use_on_token = self.on_token.is_some();
270 let use_legacy_stream = !use_on_token && self.stream;
271
272 let _backend_span = info_span!(
273 "agnt.backend.chat",
274 model = %self.backend.model(),
275 window_size = send.len(),
276 )
277 .entered();
278
279 let resp = if use_on_token {
280 let mut cb = self.on_token.take().expect("on_token is_some");
283 let mut sink = |s: &str| cb(s);
284 let r = self
285 .backend
286 .chat(send, &tools, Some(&mut sink))
287 .map_err(|e| {
288 let es = e.to_string();
289 error!(error = %es, "backend chat error");
290 self.observer.on_step_error(&es);
291 es
292 });
293 self.on_token = Some(cb);
294 r?
295 } else if use_legacy_stream {
296 let mut sink = |s: &str| {
297 print!("{}", s);
298 std::io::stdout().flush().ok();
299 };
300 let r = self
301 .backend
302 .chat(send, &tools, Some(&mut sink))
303 .map_err(|e| {
304 let es = e.to_string();
305 error!(error = %es, "backend chat error");
306 self.observer.on_step_error(&es);
307 es
308 })?;
309 println!();
310 r
311 } else {
312 self.backend
313 .chat(send, &tools, None)
314 .map_err(|e| {
315 let es = e.to_string();
316 error!(error = %es, "backend chat error");
317 self.observer.on_step_error(&es);
318 es
319 })?
320 };
321 drop(_backend_span);
322
323 self.persist(&resp);
326 let resp_idx = self.messages.len();
327 self.messages.push(resp);
328
329 let has_calls = self.messages[resp_idx]
334 .tool_calls
335 .as_ref()
336 .map(|c| !c.is_empty())
337 .unwrap_or(false);
338
339 if !has_calls {
340 let out = self.messages[resp_idx]
341 .content
342 .clone()
343 .unwrap_or_default();
344 let final_msg = Message {
345 role: "assistant".into(),
346 content: Some(out.clone()),
347 tool_calls: None,
348 tool_call_id: None,
349 name: None,
350 };
351 self.observer.on_step_end(&final_msg);
352 return Ok(out);
353 }
354
355 let calls = self.messages[resp_idx]
357 .tool_calls
358 .as_ref()
359 .expect("has_calls checked above")
360 .clone();
361
362 let registry = &self.tools;
363 let observer = self.observer.clone();
364 let results: Vec<(String, String, String, String, u64)> =
368 std::thread::scope(|s| {
369 let handles: Vec<_> = calls
370 .iter()
371 .map(|call| {
372 let name = call.function.name.clone();
373 let id = call.id.clone();
374 let args_str = call.function.arguments.clone();
375 let observer = observer.clone();
376 let call_clone = call.clone();
377 s.spawn(move || {
378 let _tool_span = info_span!(
379 "agnt.tool",
380 name = %name,
381 id = %id,
382 )
383 .entered();
384 observer.on_tool_start(&call_clone);
385 let args: serde_json::Value =
386 serde_json::from_str(&args_str)
387 .unwrap_or(serde_json::Value::Null);
388 let t0 = std::time::Instant::now();
389 let result = registry
390 .dispatch(&name, args)
391 .unwrap_or_else(|e| {
392 warn!(tool = %name, error = %e, "tool dispatch failed");
393 format!("error: {}", e)
394 });
395 let dur = t0.elapsed().as_micros() as u64;
396 debug!(tool = %name, duration_us = dur, "tool completed");
397 let tool_result = ToolResult {
398 name: name.clone(),
399 output: Ok(result.clone()),
400 duration_us: dur,
401 };
402 observer.on_tool_end(&call_clone, &tool_result);
403 (id, name, args_str, result, dur)
404 })
405 })
406 .collect();
407 handles
408 .into_iter()
409 .map(|h| {
410 h.join().unwrap_or_else(|panic_payload| {
411 let msg = panic_to_string(panic_payload);
412 (
413 String::new(),
414 "<panicked>".to_string(),
415 String::new(),
416 format!("error: tool thread panicked: {}", msg),
417 0,
418 )
419 })
420 })
421 .collect()
422 });
423
424 for (id, name, args_str, result, dur_us) in results {
425 if use_legacy_stream {
426 println!("[tool: {} ({:.2}ms)]", name, dur_us as f64 / 1000.0);
427 }
428 if let Some(s) = &self.store {
429 let log = ToolLog {
430 name: &name,
431 args: &args_str,
432 result: &result,
433 duration_us: dur_us,
434 };
435 if let Err(e) = s.log_tool(&self.session, &log) {
436 eprintln!("log_tool: {}", e);
437 }
438 }
439 let framed = self.frame_tool_output(&name, &id, &result);
441 let msg = Message {
442 role: "tool".into(),
443 content: Some(framed),
444 tool_calls: None,
445 tool_call_id: Some(id),
446 name: Some(name),
447 };
448 self.persist(&msg);
449 self.messages.push(msg);
450 }
451 }
452
453 let err = "max steps exceeded".to_string();
454 self.observer.on_step_error(&err);
455 Err(err)
456 }
457}
458
459fn panic_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
462 if let Some(s) = payload.downcast_ref::<&'static str>() {
463 (*s).to_string()
464 } else if let Some(s) = payload.downcast_ref::<String>() {
465 s.clone()
466 } else {
467 "unknown panic payload".to_string()
468 }
469}
470
471fn escape_attr(s: &str) -> String {
476 let mut out = String::with_capacity(s.len());
477 for c in s.chars() {
478 match c {
479 '&' => out.push_str("&"),
480 '"' => out.push_str("""),
481 '<' => out.push_str("<"),
482 '>' => out.push_str(">"),
483 _ => out.push(c),
484 }
485 }
486 out
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492 use crate::backend_trait::BackendError;
493 use crate::message::{FunctionCall, ToolCall};
494 use serde_json::Value;
495
496 struct MockBackend;
498 impl LlmBackend for MockBackend {
499 fn model(&self) -> &str {
500 "mock"
501 }
502 fn chat(
503 &self,
504 _messages: &[Message],
505 _tools: &Value,
506 _on_token: Option<&mut dyn FnMut(&str)>,
507 ) -> Result<Message, BackendError> {
508 Ok(Message {
509 role: "assistant".into(),
510 content: Some("mock response".into()),
511 tool_calls: None,
512 tool_call_id: None,
513 name: None,
514 })
515 }
516 }
517
518 fn msg(role: &str, content: &str) -> Message {
519 Message {
520 role: role.into(),
521 content: Some(content.into()),
522 tool_calls: None,
523 tool_call_id: None,
524 name: None,
525 }
526 }
527
528 #[test]
529 fn windowing_empty_session_returns_all() {
530 let mut a = Agent::new(MockBackend, "sys");
531 a.max_window = 10;
532 a.messages.push(msg("user", "hi"));
533 a.messages.push(msg("assistant", "hello"));
534 let w = a.windowed();
535 assert_eq!(w.len(), 3);
536 assert_eq!(w[0].role, "system");
537 }
538
539 #[test]
540 fn windowing_preserves_system_and_starts_at_user() {
541 let mut a = Agent::new(MockBackend, "sys");
542 a.max_window = 5;
543 for i in 0..20 {
544 let role = if i % 2 == 0 { "user" } else { "assistant" };
545 a.messages.push(msg(role, &format!("m{}", i)));
546 }
547 let w = a.windowed();
548 assert_eq!(w[0].role, "system", "system slot preserved");
549 assert!(w.len() <= 5, "window respects max_window: {}", w.len());
550 assert_eq!(w[1].role, "user", "first post-system must be user");
551 }
552
553 #[test]
554 fn windowing_skips_orphan_tool_results() {
555 let mut a = Agent::new(MockBackend, "sys");
556 a.max_window = 4;
557 a.messages.push(msg("user", "do thing"));
558 a.messages.push(Message {
559 role: "assistant".into(),
560 content: None,
561 tool_calls: Some(vec![ToolCall {
562 id: "c1".into(),
563 call_type: "function".into(),
564 function: FunctionCall {
565 name: "t".into(),
566 arguments: "{}".into(),
567 },
568 }]),
569 tool_call_id: None,
570 name: None,
571 });
572 a.messages.push(Message {
573 role: "tool".into(),
574 content: Some("result".into()),
575 tool_calls: None,
576 tool_call_id: Some("c1".into()),
577 name: Some("t".into()),
578 });
579 a.messages.push(msg("assistant", "done"));
580 a.messages.push(msg("user", "next"));
581 a.messages.push(msg("assistant", "ok"));
582 let w = a.windowed();
583 assert_eq!(w[0].role, "system");
584 assert_eq!(w[1].role, "user");
585 }
586
587 #[test]
588 fn window_start_is_none_when_history_fits() {
589 let mut a = Agent::new(MockBackend, "sys");
590 a.max_window = 10;
591 a.messages.push(msg("user", "hi"));
592 assert!(
593 a.window_start().is_none(),
594 "short history must not allocate a window vec"
595 );
596 }
597
598 #[test]
599 fn frame_tool_output_wraps_and_escapes() {
600 #[allow(deprecated)]
601 let a = Agent::new(MockBackend, "sys");
602 let framed = a.frame_tool_output("fetch", "call_1", "hello");
603 assert_eq!(
604 framed,
605 r#"<tool_output name="fetch" id="call_1">hello</tool_output>"#
606 );
607 }
608
609 #[test]
610 fn frame_tool_output_truncates_past_cap() {
611 #[allow(deprecated)]
612 let mut a = Agent::new(MockBackend, "sys");
613 a.max_tool_result_bytes = 8;
614 let framed = a.frame_tool_output("t", "id", "0123456789ABCDEF");
615 assert!(framed.contains("truncated=\"true\""));
616 assert!(framed.contains("raw_bytes=\"16\""));
617 assert!(framed.contains("01234567"));
618 assert!(!framed.contains("89ABCDEF"));
619 }
620
621 #[test]
622 fn frame_tool_output_respects_utf8_boundary() {
623 #[allow(deprecated)]
624 let mut a = Agent::new(MockBackend, "sys");
625 a.max_tool_result_bytes = 3; let framed = a.frame_tool_output("t", "id", "é中");
628 assert!(framed.contains("truncated=\"true\""));
630 }
631
632 #[test]
633 fn frame_tool_output_escapes_attrs() {
634 #[allow(deprecated)]
635 let a = Agent::new(MockBackend, "sys");
636 let framed = a.frame_tool_output("na\"me", "id&1", "x");
637 assert!(framed.contains("name=\"na"me\""));
638 assert!(framed.contains("id=\"id&1\""));
639 }
640}