use std::collections::HashMap;
use std::sync::Arc;
use schemars::JsonSchema;
use serde_json::Value;
use super::structured_parser::{build_structured_output_instructions, extract_structured_output};
use crate::errors::{AgentError, AgentResult};
use crate::models::LLMOutputTrait;
use crate::models::{BaseLlm, ContentPart, Event, Thread};
use crate::tools::{
BaseTool, BaseToolset, CombinedToolset, DefaultExecutionState, SimpleToolset, ToolContext,
ToolResponse,
};
use crate::{compat::MaybeSend, compat::MaybeSync};
const DEFAULT_MAX_TOOL_ITERATIONS: usize = 20;
pub struct LlmWorker<T> {
model: Arc<dyn BaseLlm>,
system_instructions: Option<String>,
toolset: Option<Arc<dyn BaseToolset>>,
max_iterations: usize,
_phantom: std::marker::PhantomData<T>,
}
impl<T> LlmWorker<T>
where
T: LLMOutputTrait + JsonSchema + MaybeSend + MaybeSync + 'static,
{
pub fn builder(model: impl BaseLlm + 'static) -> LlmWorkerBuilder<T> {
LlmWorkerBuilder::new(model)
}
pub fn builder_shared(model: Arc<dyn BaseLlm>) -> LlmWorkerBuilder<T> {
LlmWorkerBuilder::from_shared(model)
}
pub async fn run<IT>(&self, input: IT) -> AgentResult<T>
where
IT: Into<Thread>,
{
let thread = self.apply_defaults(input.into())?;
let outcome = self.execute(thread).await?;
Ok(outcome.value)
}
pub async fn run_and_continue<IT>(&self, input: IT) -> AgentResult<(T, Thread)>
where
IT: Into<Thread>,
{
let thread = self.apply_defaults(input.into())?;
let outcome = self.execute(thread).await?;
Ok((outcome.value, outcome.thread))
}
#[must_use]
pub fn has_tools(&self) -> bool {
self.toolset.is_some()
}
#[must_use]
pub fn toolset(&self) -> Option<&Arc<dyn BaseToolset>> {
self.toolset.as_ref()
}
fn apply_defaults(&self, mut thread: Thread) -> AgentResult<Thread> {
let structured_instructions = build_structured_output_instructions::<T>()?;
let combined_instructions = if let Some(user_instructions) = &self.system_instructions {
format!("{user_instructions}\n\n{structured_instructions}")
} else {
structured_instructions
};
thread = thread.with_system(combined_instructions);
Ok(thread)
}
async fn execute(&self, thread: Thread) -> AgentResult<WorkerOutcome<T>> {
let toolset = self.toolset.clone();
let tool_cache = if let Some(ref ts) = toolset {
load_tool_map(ts).await?
} else {
HashMap::new()
};
let execution_state = DefaultExecutionState::new();
let tool_context = ToolContext::builder()
.with_state(&execution_state)
.build()
.map_err(|err| AgentError::ToolSetupFailed {
tool_name: "tool_context".to_string(),
reason: err.to_string(),
})?;
let result = self
.run_tool_loop(
thread,
toolset.clone(),
&tool_cache,
&tool_context,
self.max_iterations,
)
.await;
if let Some(ts) = toolset {
ts.close().await;
}
result
}
async fn run_tool_loop(
&self,
mut thread: Thread,
toolset: Option<Arc<dyn BaseToolset>>,
tool_cache: &HashMap<String, &dyn BaseTool>,
tool_context: &ToolContext<'_>,
max_iterations: usize,
) -> AgentResult<WorkerOutcome<T>> {
let mut iterations = 0usize;
loop {
iterations += 1;
if iterations > max_iterations {
return Err(AgentError::Internal {
component: "llm_worker".to_string(),
reason: format!("Exceeded tool interaction iterations (max: {max_iterations})"),
});
}
let response = self
.model
.generate_content(thread.clone(), toolset.clone())
.await?;
let content = response.into_content();
let tool_calls: Vec<_> = content
.parts()
.iter()
.filter_map(|part| match part {
ContentPart::ToolCall(call) => Some(call.clone()),
_ => None,
})
.collect();
if tool_calls.is_empty() {
if let Ok(value) = extract_structured_output::<T>(&content) {
thread = thread.add_event(Event::assistant(content));
return Ok(WorkerOutcome { value, thread });
}
thread = thread.add_event(Event::assistant(content));
continue;
}
thread = thread.add_event(Event::assistant(content));
for call in tool_calls {
let tool =
*tool_cache
.get(call.name())
.ok_or_else(|| AgentError::ToolNotFound {
tool_name: call.name().to_string(),
})?;
let args = value_to_arguments(call.name(), call.arguments())?;
let result = tool.run_async(args, tool_context).await;
let response = ToolResponse::new(call.id().to_string(), result);
thread = thread.add_event(Event::from(response));
}
}
}
}
struct WorkerOutcome<T> {
value: T,
thread: Thread,
}
async fn load_tool_map(
toolset: &Arc<dyn BaseToolset>,
) -> AgentResult<HashMap<String, &dyn BaseTool>> {
let tools = toolset.get_tools().await;
let mut map = HashMap::with_capacity(tools.len());
for tool in tools {
map.insert(tool.name().to_string(), tool);
}
Ok(map)
}
fn value_to_arguments(tool_name: &str, value: &Value) -> AgentResult<HashMap<String, Value>> {
match value {
Value::Null => Ok(HashMap::new()),
Value::Object(map) => Ok(map.iter().map(|(k, v)| (k.clone(), v.clone())).collect()),
_ => Err(AgentError::ToolValidationError {
tool_name: tool_name.to_string(),
reason: "Tool arguments must be a JSON object".to_string(),
}),
}
}
pub struct LlmWorkerBuilder<T> {
model: Arc<dyn BaseLlm>,
system_instructions: Option<String>,
tools: Vec<Box<dyn BaseTool>>,
toolsets: Vec<Arc<dyn BaseToolset>>,
max_iterations: usize,
_phantom: std::marker::PhantomData<T>,
}
impl<T> LlmWorkerBuilder<T>
where
T: LLMOutputTrait + JsonSchema + MaybeSend + MaybeSync + 'static,
{
pub fn new(model: impl BaseLlm + 'static) -> Self {
Self {
model: Arc::new(model) as Arc<dyn BaseLlm>,
system_instructions: None,
tools: Vec::new(),
toolsets: Vec::new(),
max_iterations: DEFAULT_MAX_TOOL_ITERATIONS,
_phantom: std::marker::PhantomData,
}
}
pub fn from_shared(model: Arc<dyn BaseLlm>) -> Self {
Self {
model,
system_instructions: None,
tools: Vec::new(),
toolsets: Vec::new(),
max_iterations: DEFAULT_MAX_TOOL_ITERATIONS,
_phantom: std::marker::PhantomData,
}
}
#[must_use]
pub fn with_system_instructions(mut self, instructions: impl Into<String>) -> Self {
self.system_instructions = Some(instructions.into());
self
}
#[must_use]
pub fn with_tool<U>(mut self, tool: U) -> Self
where
U: BaseTool + 'static,
{
self.tools.push(Box::new(tool));
self
}
#[must_use]
pub fn with_tools<I, U>(mut self, tools: I) -> Self
where
I: IntoIterator<Item = U>,
U: BaseTool + 'static,
{
for tool in tools {
self.tools.push(Box::new(tool));
}
self
}
#[must_use]
pub fn with_toolset(mut self, toolset: Arc<dyn BaseToolset>) -> Self {
self.toolsets.push(toolset);
self
}
#[must_use]
pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
self.max_iterations = max_iterations.max(1);
self
}
#[must_use]
pub fn with_toolsets<I>(mut self, toolsets: I) -> Self
where
I: IntoIterator<Item = Arc<dyn BaseToolset>>,
{
self.toolsets.extend(toolsets);
self
}
#[must_use]
pub fn build(self) -> LlmWorker<T> {
let mut all_toolsets = self.toolsets;
if !self.tools.is_empty() {
let simple_toolset = Arc::new(SimpleToolset::new(self.tools)) as Arc<dyn BaseToolset>;
all_toolsets.push(simple_toolset);
}
let combined_toolset = match all_toolsets.len() {
0 => None,
1 => Some(all_toolsets.into_iter().next().unwrap()),
_ => {
let mut iter = all_toolsets.into_iter();
let first = iter.next().unwrap();
let combined = iter.fold(first, |acc, toolset| {
Arc::new(CombinedToolset::new(acc, toolset)) as Arc<dyn BaseToolset>
});
Some(combined)
}
};
LlmWorker {
model: self.model,
system_instructions: self.system_instructions,
toolset: combined_toolset,
max_iterations: self.max_iterations,
_phantom: std::marker::PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::errors::{AgentError, AgentResult};
use crate::macros::LLMOutput;
use crate::models::{Content, ContentPart, LlmResponse};
use crate::test_support::{FakeLlm, RecordingTool};
use crate::tools::tool::ToolCall;
use crate::tools::ToolResult;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::VecDeque;
#[derive(Debug, Deserialize, LLMOutput, Serialize, JsonSchema, PartialEq)]
struct Sample {
value: i32,
}
fn structured_response(value: i32) -> AgentResult<LlmResponse> {
let json_str = format!(r#"{{"value": {value}}}"#);
FakeLlm::content_response(Content::from_text(json_str))
}
#[tokio::test(flavor = "current_thread")]
async fn run_with_structured_output_returns_value() {
let model = FakeLlm::with_responses("fake-model", [structured_response(7)]);
let worker = LlmWorker::<Sample>::builder(model).build();
let result = worker.run("Hello").await.expect("worker result");
assert_eq!(result, Sample { value: 7 });
}
#[tokio::test(flavor = "current_thread")]
async fn run_with_tool_executes_before_final_response() {
let initial_tool_call =
ToolCall::new("call-1", "recording_tool", json!({ "input": "ping" }));
let tool_request = Content::from_parts(vec![ContentPart::ToolCall(initial_tool_call)]);
let final_response = structured_response(42);
let model =
FakeLlm::with_responses("fake-model", [FakeLlm::content_response(tool_request)]);
model.push_response(Ok(final_response.expect("final response")));
let results = VecDeque::from([ToolResult::success(json!({"ok": true}))]);
let recorder = RecordingTool::new("recording_tool", "Records usage", results);
let worker = LlmWorker::<Sample>::builder(model)
.with_tool(recorder.clone())
.build();
let (result, thread) = worker
.run_and_continue("Need help")
.await
.expect("worker call");
assert_eq!(result, Sample { value: 42 });
assert!(
thread.events().len() >= 2,
"thread should capture tool exchange"
);
assert_eq!(
recorder.call_count(),
1,
"tool should have been called once"
);
}
#[tokio::test(flavor = "current_thread")]
async fn run_fails_when_tool_missing() {
let tool_call = ToolCall::new("call-1", "unknown_tool", json!({ "value": "data" }));
let model = FakeLlm::with_responses(
"fake-model",
[FakeLlm::content_response(Content::from_parts(vec![
ContentPart::ToolCall(tool_call),
]))],
);
let worker = LlmWorker::<Sample>::builder(model).build();
let err = worker.run("Test").await.expect_err("should fail");
assert!(matches!(err, AgentError::ToolNotFound { .. }));
}
#[tokio::test(flavor = "current_thread")]
async fn builder_composes_tools_and_toolsets() {
let model = FakeLlm::with_responses("fake", [structured_response(1)]);
let tool = Box::new(RecordingTool::default()) as Box<dyn BaseTool>;
let toolset = Arc::new(SimpleToolset::new(vec![tool])) as Arc<dyn BaseToolset>;
let worker = LlmWorker::<Sample>::builder(model)
.with_tool(RecordingTool::default())
.with_toolset(toolset.clone())
.with_max_iterations(5)
.build();
assert!(worker.has_tools());
assert!(worker.toolset().is_some());
assert_eq!(worker.max_iterations, 5);
}
#[tokio::test(flavor = "current_thread")]
async fn exceeding_iteration_budget_returns_error() {
let model = FakeLlm::with_responses(
"fake-model",
[FakeLlm::content_response(Content::from_parts(vec![
ContentPart::ToolCall(ToolCall::new("call-1", "recording_tool", json!({}))),
]))],
);
let worker = LlmWorker::<Sample>::builder(model)
.with_tool(RecordingTool::default())
.with_max_iterations(1)
.build();
let err = worker
.run("loop")
.await
.expect_err("should fail on iteration cap");
assert!(matches!(err, AgentError::Internal { .. }));
}
}