use crate::backend_trait::LlmBackend;
use crate::message::Message;
use crate::observer::{Disposition, NoOpObserver, Observer, StepContext, ToolResult};
use crate::store_trait::{MessageStore, ToolLog};
use crate::tool::Registry;
use std::collections::HashMap;
use std::io::Write;
use std::sync::Arc;
use tracing::{debug, error, info_span, warn};
pub const DEFAULT_MAX_TOOL_RESULT_BYTES: usize = 64 * 1024;
#[derive(Debug, Clone, Default)]
pub struct ToolQuota {
pub max_calls: Option<u32>,
pub max_duration_us: Option<u64>,
pub max_result_bytes: Option<usize>,
}
#[derive(Default)]
struct QuotaUsage {
calls: u32,
duration_us: u64,
}
pub struct Agent<B: LlmBackend> {
pub backend: B,
pub messages: Vec<Message>,
pub tools: Registry,
pub max_steps: usize,
pub max_window: usize,
pub max_tool_result_bytes: usize,
pub tool_quotas: HashMap<String, ToolQuota>,
pub max_step_duration: Option<std::time::Duration>,
pub store: Option<Arc<dyn MessageStore>>,
pub session: String,
pub observer: Arc<dyn Observer>,
pub on_token: Option<Box<dyn FnMut(&str) + Send>>,
#[deprecated(
since = "0.2.0",
note = "Use `Agent::on_token` for a user-controlled token sink. `stream = true` still prints to stdout when `on_token` is None."
)]
pub stream: bool,
}
impl<B: LlmBackend> Agent<B> {
#[allow(deprecated)]
pub fn new(backend: B, system: &str) -> Self {
Self {
backend,
messages: vec![Message {
role: "system".into(),
content: Some(system.into()),
tool_calls: None,
tool_call_id: None,
name: None,
}],
tools: Registry::new(),
max_steps: 10,
max_window: 40,
max_tool_result_bytes: DEFAULT_MAX_TOOL_RESULT_BYTES,
tool_quotas: HashMap::new(),
max_step_duration: None,
store: None,
session: "default".into(),
observer: Arc::new(NoOpObserver),
on_token: None,
stream: true,
}
}
pub fn attach_store(
&mut self,
store: Arc<dyn MessageStore>,
session: &str,
) -> Result<(), String> {
let loaded = store.load(session).map_err(|e| e.to_string())?;
if loaded.is_empty() {
for m in &self.messages {
store.append(session, m).map_err(|e| e.to_string())?;
}
} else {
self.messages = loaded;
}
self.store = Some(store);
self.session = session.into();
Ok(())
}
fn persist(&self, msg: &Message) {
if let Some(s) = &self.store {
if let Err(e) = s.append(&self.session, msg) {
eprintln!("persist: {}", e);
}
}
}
fn window_start(&self) -> Option<usize> {
if self.messages.len() <= self.max_window {
return None;
}
let n = self.max_window;
let mut start = self.messages.len() - (n - 1);
while start < self.messages.len() && self.messages[start].role != "user" {
start += 1;
}
Some(start)
}
fn windowed_truncated(&self, start: usize) -> Vec<Message> {
let mut out = Vec::with_capacity(self.messages.len() - start + 1);
out.push(self.messages[0].clone());
out.extend(self.messages[start..].iter().cloned());
out
}
#[cfg(test)]
fn windowed(&self) -> Vec<Message> {
match self.window_start() {
None => self.messages.clone(),
Some(start) => self.windowed_truncated(start),
}
}
fn frame_tool_output(&self, name: &str, id: &str, raw: &str) -> String {
let cap = self.max_tool_result_bytes;
let (body, truncated) = if raw.len() > cap {
let mut end = cap;
while end > 0 && !raw.is_char_boundary(end) {
end -= 1;
}
(&raw[..end], true)
} else {
(raw, false)
};
if truncated {
format!(
"<tool_output name=\"{}\" id=\"{}\" truncated=\"true\" raw_bytes=\"{}\">{}</tool_output>",
escape_attr(name),
escape_attr(id),
raw.len(),
body
)
} else {
format!(
"<tool_output name=\"{}\" id=\"{}\">{}</tool_output>",
escape_attr(name),
escape_attr(id),
body
)
}
}
#[allow(deprecated)]
pub fn step(&mut self, user_input: &str) -> Result<String, String> {
let _span = info_span!(
"agnt.step",
session = %self.session,
input_len = user_input.len(),
)
.entered();
debug!(user_input_len = user_input.len(), "agent.step start");
let ctx = StepContext {
session: self.session.clone(),
user_input: user_input.into(),
};
self.observer.on_step_start(&ctx);
let user = Message {
role: "user".into(),
content: Some(user_input.into()),
tool_calls: None,
tool_call_id: None,
name: None,
};
self.persist(&user);
self.messages.push(user);
let tools = self.tools.as_openai_tools();
let mut quota_usage: HashMap<String, QuotaUsage> = HashMap::new();
let step_started = std::time::Instant::now();
let deadline_check = |stage: &str| -> Result<(), String> {
if let Some(limit) = self.max_step_duration {
if step_started.elapsed() >= limit {
return Err(format!(
"step deadline exceeded at {}: {}ms >= {}ms",
stage,
step_started.elapsed().as_millis(),
limit.as_millis()
));
}
}
Ok(())
};
for _ in 0..self.max_steps {
if let Err(e) = deadline_check("turn_start") {
self.observer.on_step_error(&e);
return Err(e);
}
let window_start = self.window_start();
let truncated_buf: Vec<Message> = match window_start {
Some(start) => self.windowed_truncated(start),
None => Vec::new(),
};
let send: &[Message] = match window_start {
Some(_) => &truncated_buf,
None => &self.messages,
};
let use_on_token = self.on_token.is_some();
let use_legacy_stream = !use_on_token && self.stream;
let _backend_span = info_span!(
"agnt.backend.chat",
model = %self.backend.model(),
window_size = send.len(),
)
.entered();
let resp = if use_on_token {
let mut cb = self.on_token.take().expect("on_token is_some");
let mut sink = |s: &str| cb(s);
let r = self
.backend
.chat(send, &tools, Some(&mut sink))
.map_err(|e| {
let es = e.to_string();
error!(error = %es, "backend chat error");
self.observer.on_step_error(&es);
es
});
self.on_token = Some(cb);
r?
} else if use_legacy_stream {
let mut sink = |s: &str| {
print!("{}", s);
std::io::stdout().flush().ok();
};
let r = self
.backend
.chat(send, &tools, Some(&mut sink))
.map_err(|e| {
let es = e.to_string();
error!(error = %es, "backend chat error");
self.observer.on_step_error(&es);
es
})?;
println!();
r
} else {
self.backend
.chat(send, &tools, None)
.map_err(|e| {
let es = e.to_string();
error!(error = %es, "backend chat error");
self.observer.on_step_error(&es);
es
})?
};
drop(_backend_span);
self.persist(&resp);
let resp_idx = self.messages.len();
self.messages.push(resp);
let has_calls = self.messages[resp_idx]
.tool_calls
.as_ref()
.map(|c| !c.is_empty())
.unwrap_or(false);
if !has_calls {
let out = self.messages[resp_idx]
.content
.clone()
.unwrap_or_default();
let final_msg = Message {
role: "assistant".into(),
content: Some(out.clone()),
tool_calls: None,
tool_call_id: None,
name: None,
};
self.observer.on_step_end(&final_msg);
return Ok(out);
}
let calls = self.messages[resp_idx]
.tool_calls
.as_ref()
.expect("has_calls checked above")
.clone();
if let Err(e) = deadline_check("pre_dispatch") {
self.observer.on_step_error(&e);
return Err(e);
}
enum CallDecision {
Allow,
Refused(String),
}
let observer_clone = self.observer.clone();
let decisions: Vec<CallDecision> = calls
.iter()
.map(|call| {
if let Disposition::Refused(msg) = observer_clone.should_dispatch(call) {
warn!(tool = %call.function.name, reason = %msg, "observer refused dispatch");
return CallDecision::Refused(format!(
"refused by observer: {}",
msg
));
}
if let Some(quota) = self.tool_quotas.get(&call.function.name) {
let usage = quota_usage
.entry(call.function.name.clone())
.or_default();
if let Some(max) = quota.max_calls {
if usage.calls >= max {
warn!(
tool = %call.function.name,
max = max,
"tool call quota exceeded"
);
return CallDecision::Refused(format!(
"quota exceeded: {} reached max {} calls this step",
call.function.name, max
));
}
}
if let Some(max_us) = quota.max_duration_us {
if usage.duration_us >= max_us {
warn!(
tool = %call.function.name,
max_us = max_us,
"tool duration quota exceeded"
);
return CallDecision::Refused(format!(
"quota exceeded: {} reached max {}µs wall time this step",
call.function.name, max_us
));
}
}
usage.calls += 1;
}
CallDecision::Allow
})
.collect();
let registry = &self.tools;
let observer = self.observer.clone();
type ToolOutcome = (String, String, String, String, u64);
let results: Vec<ToolOutcome> =
std::thread::scope(|s| {
type Handle<'s> = (
std::thread::ScopedJoinHandle<'s, ToolOutcome>,
String,
String,
String,
);
let handles: Vec<Handle<'_>> = calls
.iter()
.zip(decisions.into_iter())
.map(|(call, decision)| {
let name = call.function.name.clone();
let id = call.id.clone();
let args_str = call.function.arguments.clone();
let sidecar_id = id.clone();
let sidecar_name = name.clone();
let sidecar_args = args_str.clone();
let observer = observer.clone();
let call_clone = call.clone();
let handle = s.spawn(move || {
let _tool_span = info_span!(
"agnt.tool",
name = %name,
id = %id,
)
.entered();
observer.on_tool_start(&call_clone);
let (result, dur) = match decision {
CallDecision::Refused(msg) => (msg, 0u64),
CallDecision::Allow => {
let args: serde_json::Value =
serde_json::from_str(&args_str)
.unwrap_or(serde_json::Value::Null);
let t0 = std::time::Instant::now();
let result = registry
.dispatch(&name, args)
.unwrap_or_else(|e| {
warn!(tool = %name, error = %e, "tool dispatch failed");
format!("error: {}", e)
});
let dur = t0.elapsed().as_micros() as u64;
debug!(
tool = %name,
duration_us = dur,
"tool completed"
);
(result, dur)
}
};
let tool_result = ToolResult {
name: name.clone(),
output: Ok(result.clone()),
duration_us: dur,
};
observer.on_tool_end(&call_clone, &tool_result);
(id, name, args_str, result, dur)
});
(handle, sidecar_id, sidecar_name, sidecar_args)
})
.collect();
handles
.into_iter()
.map(|(h, id, name, args_str)| {
h.join().unwrap_or_else(|panic_payload| {
let msg = panic_to_string(panic_payload);
warn!(
tool = %name,
id = %id,
panic = %msg,
"tool thread panicked"
);
(
id,
name,
args_str,
format!("error: tool thread panicked: {}", msg),
0,
)
})
})
.collect()
});
for (_id, name, _args, _result, dur) in &results {
if self.tool_quotas.contains_key(name) {
let u = quota_usage.entry(name.clone()).or_default();
u.duration_us = u.duration_us.saturating_add(*dur);
}
}
for (id, name, args_str, result, dur_us) in results {
if use_legacy_stream {
println!("[tool: {} ({:.2}ms)]", name, dur_us as f64 / 1000.0);
}
if let Some(s) = &self.store {
let log = ToolLog {
name: &name,
args: &args_str,
result: &result,
duration_us: dur_us,
};
if let Err(e) = s.log_tool(&self.session, &log) {
eprintln!("log_tool: {}", e);
}
}
let result = match self
.tool_quotas
.get(&name)
.and_then(|q| q.max_result_bytes)
{
Some(cap) if result.len() > cap => {
let mut end = cap;
while end > 0 && !result.is_char_boundary(end) {
end -= 1;
}
result[..end].to_string()
}
_ => result,
};
let framed = self.frame_tool_output(&name, &id, &result);
let msg = Message {
role: "tool".into(),
content: Some(framed),
tool_calls: None,
tool_call_id: Some(id),
name: Some(name),
};
self.persist(&msg);
self.messages.push(msg);
}
}
let err = "max steps exceeded".to_string();
self.observer.on_step_error(&err);
Err(err)
}
}
fn panic_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
if let Some(s) = payload.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"unknown panic payload".to_string()
}
}
fn escape_attr(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for c in s.chars() {
match c {
'&' => out.push_str("&"),
'"' => out.push_str("""),
'<' => out.push_str("<"),
'>' => out.push_str(">"),
_ => out.push(c),
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend_trait::BackendError;
use crate::message::{FunctionCall, ToolCall};
use serde_json::Value;
struct MockBackend;
impl LlmBackend for MockBackend {
fn model(&self) -> &str {
"mock"
}
fn chat(
&self,
_messages: &[Message],
_tools: &Value,
_on_token: Option<&mut dyn FnMut(&str)>,
) -> Result<Message, BackendError> {
Ok(Message {
role: "assistant".into(),
content: Some("mock response".into()),
tool_calls: None,
tool_call_id: None,
name: None,
})
}
}
fn msg(role: &str, content: &str) -> Message {
Message {
role: role.into(),
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
#[test]
fn windowing_empty_session_returns_all() {
let mut a = Agent::new(MockBackend, "sys");
a.max_window = 10;
a.messages.push(msg("user", "hi"));
a.messages.push(msg("assistant", "hello"));
let w = a.windowed();
assert_eq!(w.len(), 3);
assert_eq!(w[0].role, "system");
}
#[test]
fn windowing_preserves_system_and_starts_at_user() {
let mut a = Agent::new(MockBackend, "sys");
a.max_window = 5;
for i in 0..20 {
let role = if i % 2 == 0 { "user" } else { "assistant" };
a.messages.push(msg(role, &format!("m{}", i)));
}
let w = a.windowed();
assert_eq!(w[0].role, "system", "system slot preserved");
assert!(w.len() <= 5, "window respects max_window: {}", w.len());
assert_eq!(w[1].role, "user", "first post-system must be user");
}
#[test]
fn windowing_skips_orphan_tool_results() {
let mut a = Agent::new(MockBackend, "sys");
a.max_window = 4;
a.messages.push(msg("user", "do thing"));
a.messages.push(Message {
role: "assistant".into(),
content: None,
tool_calls: Some(vec![ToolCall {
id: "c1".into(),
call_type: "function".into(),
function: FunctionCall {
name: "t".into(),
arguments: "{}".into(),
},
}]),
tool_call_id: None,
name: None,
});
a.messages.push(Message {
role: "tool".into(),
content: Some("result".into()),
tool_calls: None,
tool_call_id: Some("c1".into()),
name: Some("t".into()),
});
a.messages.push(msg("assistant", "done"));
a.messages.push(msg("user", "next"));
a.messages.push(msg("assistant", "ok"));
let w = a.windowed();
assert_eq!(w[0].role, "system");
assert_eq!(w[1].role, "user");
}
#[test]
fn window_start_is_none_when_history_fits() {
let mut a = Agent::new(MockBackend, "sys");
a.max_window = 10;
a.messages.push(msg("user", "hi"));
assert!(
a.window_start().is_none(),
"short history must not allocate a window vec"
);
}
#[test]
fn frame_tool_output_wraps_and_escapes() {
#[allow(deprecated)]
let a = Agent::new(MockBackend, "sys");
let framed = a.frame_tool_output("fetch", "call_1", "hello");
assert_eq!(
framed,
r#"<tool_output name="fetch" id="call_1">hello</tool_output>"#
);
}
#[test]
fn frame_tool_output_truncates_past_cap() {
#[allow(deprecated)]
let mut a = Agent::new(MockBackend, "sys");
a.max_tool_result_bytes = 8;
let framed = a.frame_tool_output("t", "id", "0123456789ABCDEF");
assert!(framed.contains("truncated=\"true\""));
assert!(framed.contains("raw_bytes=\"16\""));
assert!(framed.contains("01234567"));
assert!(!framed.contains("89ABCDEF"));
}
#[test]
fn frame_tool_output_respects_utf8_boundary() {
#[allow(deprecated)]
let mut a = Agent::new(MockBackend, "sys");
a.max_tool_result_bytes = 3; let framed = a.frame_tool_output("t", "id", "é中");
assert!(framed.contains("truncated=\"true\""));
}
#[test]
fn frame_tool_output_escapes_attrs() {
#[allow(deprecated)]
let a = Agent::new(MockBackend, "sys");
let framed = a.frame_tool_output("na\"me", "id&1", "x");
assert!(framed.contains("name=\"na"me\""));
assert!(framed.contains("id=\"id&1\""));
}
use crate::tool::Tool;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Mutex;
struct ScriptedBackend {
script: Mutex<std::collections::VecDeque<Message>>,
}
impl ScriptedBackend {
fn new(script: Vec<Message>) -> Self {
Self { script: Mutex::new(script.into()) }
}
}
impl LlmBackend for ScriptedBackend {
fn model(&self) -> &str { "scripted" }
fn chat(
&self,
_messages: &[Message],
_tools: &Value,
_on_token: Option<&mut dyn FnMut(&str)>,
) -> Result<Message, BackendError> {
let m = self.script.lock().unwrap().pop_front().unwrap_or_else(|| Message {
role: "assistant".into(),
content: Some("done".into()),
tool_calls: None,
tool_call_id: None,
name: None,
});
Ok(m)
}
}
struct CountingTool {
hits: Arc<AtomicUsize>,
payload: String,
}
impl Tool for CountingTool {
fn name(&self) -> &str { "counter" }
fn description(&self) -> &str { "test counter" }
fn schema(&self) -> Value {
serde_json::json!({"type":"object","properties":{}})
}
fn call(&self, _args: Value) -> Result<String, String> {
self.hits.fetch_add(1, Ordering::SeqCst);
Ok(self.payload.clone())
}
}
fn tool_call(id: &str, name: &str) -> Message {
Message {
role: "assistant".into(),
content: None,
tool_calls: Some(vec![ToolCall {
id: id.into(),
call_type: "function".into(),
function: FunctionCall {
name: name.into(),
arguments: "{}".into(),
},
}]),
tool_call_id: None,
name: None,
}
}
#[test]
fn quota_max_calls_refuses_after_limit_within_single_step() {
let script = vec![
tool_call("c1", "counter"),
tool_call("c2", "counter"),
];
let hits = Arc::new(AtomicUsize::new(0));
#[allow(deprecated)]
let mut a = Agent::new(ScriptedBackend::new(script), "sys");
a.tools.register(Box::new(CountingTool {
hits: hits.clone(),
payload: "ok".into(),
}));
a.tool_quotas.insert(
"counter".into(),
ToolQuota { max_calls: Some(1), ..Default::default() },
);
let out = a.step("go").unwrap();
assert_eq!(hits.load(Ordering::SeqCst), 1, "tool must run exactly once");
assert_eq!(out, "done");
let refused = a.messages.iter().any(|m| {
m.role == "tool"
&& m.content.as_deref().map(|c| c.contains("quota exceeded")).unwrap_or(false)
});
assert!(refused, "second call must produce a quota-refused tool message");
}
#[test]
fn quota_max_result_bytes_truncates_before_framing() {
let script = vec![tool_call("c1", "counter")];
let hits = Arc::new(AtomicUsize::new(0));
#[allow(deprecated)]
let mut a = Agent::new(ScriptedBackend::new(script), "sys");
a.tools.register(Box::new(CountingTool {
hits,
payload: "0123456789ABCDEF".into(),
}));
a.tool_quotas.insert(
"counter".into(),
ToolQuota { max_result_bytes: Some(4), ..Default::default() },
);
a.step("go").unwrap();
let tool_msg = a
.messages
.iter()
.find(|m| m.role == "tool")
.expect("tool message present");
let body = tool_msg.content.as_deref().unwrap();
assert!(body.contains("0123"), "kept prefix");
assert!(!body.contains("456789"), "truncated tail");
}
#[test]
fn observer_refuses_dispatch_and_tool_never_runs() {
struct DenyObserver;
impl Observer for DenyObserver {
fn should_dispatch(&self, _call: &ToolCall) -> Disposition {
Disposition::Refused("policy".into())
}
}
let script = vec![tool_call("c1", "counter")];
let hits = Arc::new(AtomicUsize::new(0));
#[allow(deprecated)]
let mut a = Agent::new(ScriptedBackend::new(script), "sys");
a.observer = Arc::new(DenyObserver);
a.tools.register(Box::new(CountingTool {
hits: hits.clone(),
payload: "should not run".into(),
}));
a.step("go").unwrap();
assert_eq!(hits.load(Ordering::SeqCst), 0, "observer must block dispatch");
let refused = a.messages.iter().any(|m| {
m.role == "tool"
&& m.content.as_deref().map(|c| c.contains("refused by observer")).unwrap_or(false)
});
assert!(refused);
}
struct SleepyTool {
dur: std::time::Duration,
}
impl Tool for SleepyTool {
fn name(&self) -> &str { "sleepy" }
fn description(&self) -> &str { "sleeps" }
fn schema(&self) -> Value {
serde_json::json!({"type":"object","properties":{}})
}
fn call(&self, _args: Value) -> Result<String, String> {
std::thread::sleep(self.dur);
Ok("awake".into())
}
}
#[test]
fn max_step_duration_terminates_runaway_loop() {
let script = vec![
tool_call("c1", "sleepy"),
tool_call("c2", "sleepy"),
];
#[allow(deprecated)]
let mut a = Agent::new(ScriptedBackend::new(script), "sys");
a.tools.register(Box::new(SleepyTool {
dur: std::time::Duration::from_millis(80),
}));
a.max_step_duration = Some(std::time::Duration::from_millis(100));
let err = a.step("go").expect_err("deadline must fire");
assert!(err.contains("step deadline"), "got: {}", err);
}
#[test]
fn no_deadline_means_no_deadline() {
let script = vec![tool_call("c1", "sleepy")];
#[allow(deprecated)]
let mut a = Agent::new(ScriptedBackend::new(script), "sys");
a.tools.register(Box::new(SleepyTool {
dur: std::time::Duration::from_millis(20),
}));
let out = a.step("go").unwrap();
assert_eq!(out, "done");
}
struct PanickingTool;
impl Tool for PanickingTool {
fn name(&self) -> &str { "panicker" }
fn description(&self) -> &str { "always panics" }
fn schema(&self) -> Value {
serde_json::json!({"type":"object","properties":{}})
}
fn call(&self, _args: Value) -> Result<String, String> {
panic!("deliberate test panic");
}
}
#[test]
fn panicked_tool_preserves_attribution_in_transcript() {
let script = vec![tool_call("pc1", "panicker")];
#[allow(deprecated)]
let mut a = Agent::new(ScriptedBackend::new(script), "sys");
a.tools.register(Box::new(PanickingTool));
a.step("go").unwrap();
let tool_msg = a
.messages
.iter()
.find(|m| m.role == "tool")
.expect("tool message present");
assert_eq!(tool_msg.name.as_deref(), Some("panicker"));
assert_eq!(tool_msg.tool_call_id.as_deref(), Some("pc1"));
let body = tool_msg.content.as_deref().unwrap();
assert!(
body.contains("panicked") && body.contains("deliberate test panic"),
"panic body: {}",
body
);
}
}