use crate::backend_trait::LlmBackend;
use crate::message::Message;
use crate::observer::{NoOpObserver, Observer, StepContext, ToolResult};
use crate::store_trait::{MessageStore, ToolLog};
use crate::tool::Registry;
use std::io::Write;
use std::sync::Arc;
use tracing::{debug, error, info_span, warn};
pub const DEFAULT_MAX_TOOL_RESULT_BYTES: usize = 64 * 1024;
pub struct Agent<B: LlmBackend> {
pub backend: B,
pub messages: Vec<Message>,
pub tools: Registry,
pub max_steps: usize,
pub max_window: usize,
pub max_tool_result_bytes: usize,
pub store: Option<Arc<dyn MessageStore>>,
pub session: String,
pub observer: Arc<dyn Observer>,
pub on_token: Option<Box<dyn FnMut(&str) + Send>>,
#[deprecated(
since = "0.2.0",
note = "Use `Agent::on_token` for a user-controlled token sink. `stream = true` still prints to stdout when `on_token` is None."
)]
pub stream: bool,
}
impl<B: LlmBackend> Agent<B> {
#[allow(deprecated)]
pub fn new(backend: B, system: &str) -> Self {
Self {
backend,
messages: vec![Message {
role: "system".into(),
content: Some(system.into()),
tool_calls: None,
tool_call_id: None,
name: None,
}],
tools: Registry::new(),
max_steps: 10,
max_window: 40,
max_tool_result_bytes: DEFAULT_MAX_TOOL_RESULT_BYTES,
store: None,
session: "default".into(),
observer: Arc::new(NoOpObserver),
on_token: None,
stream: true,
}
}
pub fn attach_store(
&mut self,
store: Arc<dyn MessageStore>,
session: &str,
) -> Result<(), String> {
let loaded = store.load(session).map_err(|e| e.to_string())?;
if loaded.is_empty() {
for m in &self.messages {
store.append(session, m).map_err(|e| e.to_string())?;
}
} else {
self.messages = loaded;
}
self.store = Some(store);
self.session = session.into();
Ok(())
}
fn persist(&self, msg: &Message) {
if let Some(s) = &self.store {
if let Err(e) = s.append(&self.session, msg) {
eprintln!("persist: {}", e);
}
}
}
fn window_start(&self) -> Option<usize> {
if self.messages.len() <= self.max_window {
return None;
}
let n = self.max_window;
let mut start = self.messages.len() - (n - 1);
while start < self.messages.len() && self.messages[start].role != "user" {
start += 1;
}
Some(start)
}
fn windowed_truncated(&self, start: usize) -> Vec<Message> {
let mut out = Vec::with_capacity(self.messages.len() - start + 1);
out.push(self.messages[0].clone());
out.extend(self.messages[start..].iter().cloned());
out
}
#[cfg(test)]
fn windowed(&self) -> Vec<Message> {
match self.window_start() {
None => self.messages.clone(),
Some(start) => self.windowed_truncated(start),
}
}
fn frame_tool_output(&self, name: &str, id: &str, raw: &str) -> String {
let cap = self.max_tool_result_bytes;
let (body, truncated) = if raw.len() > cap {
let mut end = cap;
while end > 0 && !raw.is_char_boundary(end) {
end -= 1;
}
(&raw[..end], true)
} else {
(raw, false)
};
if truncated {
format!(
"<tool_output name=\"{}\" id=\"{}\" truncated=\"true\" raw_bytes=\"{}\">{}</tool_output>",
escape_attr(name),
escape_attr(id),
raw.len(),
body
)
} else {
format!(
"<tool_output name=\"{}\" id=\"{}\">{}</tool_output>",
escape_attr(name),
escape_attr(id),
body
)
}
}
#[allow(deprecated)]
pub fn step(&mut self, user_input: &str) -> Result<String, String> {
let _span = info_span!(
"agnt.step",
session = %self.session,
input_len = user_input.len(),
)
.entered();
debug!(user_input_len = user_input.len(), "agent.step start");
let ctx = StepContext {
session: self.session.clone(),
user_input: user_input.into(),
};
self.observer.on_step_start(&ctx);
let user = Message {
role: "user".into(),
content: Some(user_input.into()),
tool_calls: None,
tool_call_id: None,
name: None,
};
self.persist(&user);
self.messages.push(user);
let tools = self.tools.as_openai_tools();
for _ in 0..self.max_steps {
let window_start = self.window_start();
let truncated_buf: Vec<Message> = match window_start {
Some(start) => self.windowed_truncated(start),
None => Vec::new(),
};
let send: &[Message] = match window_start {
Some(_) => &truncated_buf,
None => &self.messages,
};
let use_on_token = self.on_token.is_some();
let use_legacy_stream = !use_on_token && self.stream;
let _backend_span = info_span!(
"agnt.backend.chat",
model = %self.backend.model(),
window_size = send.len(),
)
.entered();
let resp = if use_on_token {
let mut cb = self.on_token.take().expect("on_token is_some");
let mut sink = |s: &str| cb(s);
let r = self
.backend
.chat(send, &tools, Some(&mut sink))
.map_err(|e| {
let es = e.to_string();
error!(error = %es, "backend chat error");
self.observer.on_step_error(&es);
es
});
self.on_token = Some(cb);
r?
} else if use_legacy_stream {
let mut sink = |s: &str| {
print!("{}", s);
std::io::stdout().flush().ok();
};
let r = self
.backend
.chat(send, &tools, Some(&mut sink))
.map_err(|e| {
let es = e.to_string();
error!(error = %es, "backend chat error");
self.observer.on_step_error(&es);
es
})?;
println!();
r
} else {
self.backend
.chat(send, &tools, None)
.map_err(|e| {
let es = e.to_string();
error!(error = %es, "backend chat error");
self.observer.on_step_error(&es);
es
})?
};
drop(_backend_span);
self.persist(&resp);
let resp_idx = self.messages.len();
self.messages.push(resp);
let has_calls = self.messages[resp_idx]
.tool_calls
.as_ref()
.map(|c| !c.is_empty())
.unwrap_or(false);
if !has_calls {
let out = self.messages[resp_idx]
.content
.clone()
.unwrap_or_default();
let final_msg = Message {
role: "assistant".into(),
content: Some(out.clone()),
tool_calls: None,
tool_call_id: None,
name: None,
};
self.observer.on_step_end(&final_msg);
return Ok(out);
}
let calls = self.messages[resp_idx]
.tool_calls
.as_ref()
.expect("has_calls checked above")
.clone();
let registry = &self.tools;
let observer = self.observer.clone();
let results: Vec<(String, String, String, String, u64)> =
std::thread::scope(|s| {
let handles: Vec<_> = calls
.iter()
.map(|call| {
let name = call.function.name.clone();
let id = call.id.clone();
let args_str = call.function.arguments.clone();
let observer = observer.clone();
let call_clone = call.clone();
s.spawn(move || {
let _tool_span = info_span!(
"agnt.tool",
name = %name,
id = %id,
)
.entered();
observer.on_tool_start(&call_clone);
let args: serde_json::Value =
serde_json::from_str(&args_str)
.unwrap_or(serde_json::Value::Null);
let t0 = std::time::Instant::now();
let result = registry
.dispatch(&name, args)
.unwrap_or_else(|e| {
warn!(tool = %name, error = %e, "tool dispatch failed");
format!("error: {}", e)
});
let dur = t0.elapsed().as_micros() as u64;
debug!(tool = %name, duration_us = dur, "tool completed");
let tool_result = ToolResult {
name: name.clone(),
output: Ok(result.clone()),
duration_us: dur,
};
observer.on_tool_end(&call_clone, &tool_result);
(id, name, args_str, result, dur)
})
})
.collect();
handles
.into_iter()
.map(|h| {
h.join().unwrap_or_else(|panic_payload| {
let msg = panic_to_string(panic_payload);
(
String::new(),
"<panicked>".to_string(),
String::new(),
format!("error: tool thread panicked: {}", msg),
0,
)
})
})
.collect()
});
for (id, name, args_str, result, dur_us) in results {
if use_legacy_stream {
println!("[tool: {} ({:.2}ms)]", name, dur_us as f64 / 1000.0);
}
if let Some(s) = &self.store {
let log = ToolLog {
name: &name,
args: &args_str,
result: &result,
duration_us: dur_us,
};
if let Err(e) = s.log_tool(&self.session, &log) {
eprintln!("log_tool: {}", e);
}
}
let framed = self.frame_tool_output(&name, &id, &result);
let msg = Message {
role: "tool".into(),
content: Some(framed),
tool_calls: None,
tool_call_id: Some(id),
name: Some(name),
};
self.persist(&msg);
self.messages.push(msg);
}
}
let err = "max steps exceeded".to_string();
self.observer.on_step_error(&err);
Err(err)
}
}
fn panic_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
if let Some(s) = payload.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"unknown panic payload".to_string()
}
}
fn escape_attr(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for c in s.chars() {
match c {
'&' => out.push_str("&"),
'"' => out.push_str("""),
'<' => out.push_str("<"),
'>' => out.push_str(">"),
_ => out.push(c),
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend_trait::BackendError;
use crate::message::{FunctionCall, ToolCall};
use serde_json::Value;
struct MockBackend;
impl LlmBackend for MockBackend {
fn model(&self) -> &str {
"mock"
}
fn chat(
&self,
_messages: &[Message],
_tools: &Value,
_on_token: Option<&mut dyn FnMut(&str)>,
) -> Result<Message, BackendError> {
Ok(Message {
role: "assistant".into(),
content: Some("mock response".into()),
tool_calls: None,
tool_call_id: None,
name: None,
})
}
}
fn msg(role: &str, content: &str) -> Message {
Message {
role: role.into(),
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
#[test]
fn windowing_empty_session_returns_all() {
let mut a = Agent::new(MockBackend, "sys");
a.max_window = 10;
a.messages.push(msg("user", "hi"));
a.messages.push(msg("assistant", "hello"));
let w = a.windowed();
assert_eq!(w.len(), 3);
assert_eq!(w[0].role, "system");
}
#[test]
fn windowing_preserves_system_and_starts_at_user() {
let mut a = Agent::new(MockBackend, "sys");
a.max_window = 5;
for i in 0..20 {
let role = if i % 2 == 0 { "user" } else { "assistant" };
a.messages.push(msg(role, &format!("m{}", i)));
}
let w = a.windowed();
assert_eq!(w[0].role, "system", "system slot preserved");
assert!(w.len() <= 5, "window respects max_window: {}", w.len());
assert_eq!(w[1].role, "user", "first post-system must be user");
}
#[test]
fn windowing_skips_orphan_tool_results() {
let mut a = Agent::new(MockBackend, "sys");
a.max_window = 4;
a.messages.push(msg("user", "do thing"));
a.messages.push(Message {
role: "assistant".into(),
content: None,
tool_calls: Some(vec![ToolCall {
id: "c1".into(),
call_type: "function".into(),
function: FunctionCall {
name: "t".into(),
arguments: "{}".into(),
},
}]),
tool_call_id: None,
name: None,
});
a.messages.push(Message {
role: "tool".into(),
content: Some("result".into()),
tool_calls: None,
tool_call_id: Some("c1".into()),
name: Some("t".into()),
});
a.messages.push(msg("assistant", "done"));
a.messages.push(msg("user", "next"));
a.messages.push(msg("assistant", "ok"));
let w = a.windowed();
assert_eq!(w[0].role, "system");
assert_eq!(w[1].role, "user");
}
#[test]
fn window_start_is_none_when_history_fits() {
let mut a = Agent::new(MockBackend, "sys");
a.max_window = 10;
a.messages.push(msg("user", "hi"));
assert!(
a.window_start().is_none(),
"short history must not allocate a window vec"
);
}
#[test]
fn frame_tool_output_wraps_and_escapes() {
#[allow(deprecated)]
let a = Agent::new(MockBackend, "sys");
let framed = a.frame_tool_output("fetch", "call_1", "hello");
assert_eq!(
framed,
r#"<tool_output name="fetch" id="call_1">hello</tool_output>"#
);
}
#[test]
fn frame_tool_output_truncates_past_cap() {
#[allow(deprecated)]
let mut a = Agent::new(MockBackend, "sys");
a.max_tool_result_bytes = 8;
let framed = a.frame_tool_output("t", "id", "0123456789ABCDEF");
assert!(framed.contains("truncated=\"true\""));
assert!(framed.contains("raw_bytes=\"16\""));
assert!(framed.contains("01234567"));
assert!(!framed.contains("89ABCDEF"));
}
#[test]
fn frame_tool_output_respects_utf8_boundary() {
#[allow(deprecated)]
let mut a = Agent::new(MockBackend, "sys");
a.max_tool_result_bytes = 3; let framed = a.frame_tool_output("t", "id", "é中");
assert!(framed.contains("truncated=\"true\""));
}
#[test]
fn frame_tool_output_escapes_attrs() {
#[allow(deprecated)]
let a = Agent::new(MockBackend, "sys");
let framed = a.frame_tool_output("na\"me", "id&1", "x");
assert!(framed.contains("name=\"na"me\""));
assert!(framed.contains("id=\"id&1\""));
}
}