use std::sync::Arc;
use async_trait::async_trait;
use rs_genai::prelude::{Content, FunctionCall, FunctionResponse, Part, Role};
use super::TextAgent;
use crate::error::AgentError;
use crate::llm::{BaseLlm, LlmRequest};
use crate::state::State;
use crate::tool::ToolDispatcher;
const MAX_TOOL_ROUNDS: usize = 10;
pub struct LlmTextAgent {
name: String,
llm: Arc<dyn BaseLlm>,
instruction: Option<String>,
dispatcher: Option<Arc<ToolDispatcher>>,
temperature: Option<f32>,
max_output_tokens: Option<u32>,
}
impl LlmTextAgent {
pub fn new(name: impl Into<String>, llm: Arc<dyn BaseLlm>) -> Self {
Self {
name: name.into(),
llm,
instruction: None,
dispatcher: None,
temperature: None,
max_output_tokens: None,
}
}
pub fn instruction(mut self, inst: impl Into<String>) -> Self {
self.instruction = Some(inst.into());
self
}
pub fn tools(mut self, dispatcher: Arc<ToolDispatcher>) -> Self {
self.dispatcher = Some(dispatcher);
self
}
pub fn temperature(mut self, t: f32) -> Self {
self.temperature = Some(t);
self
}
pub fn max_output_tokens(mut self, n: u32) -> Self {
self.max_output_tokens = Some(n);
self
}
fn build_request(&self, contents: Vec<Content>) -> LlmRequest {
let mut req = LlmRequest::from_contents(contents);
req.system_instruction = self.instruction.clone();
req.temperature = self.temperature;
req.max_output_tokens = self.max_output_tokens;
if let Some(dispatcher) = &self.dispatcher {
req.tools = dispatcher.to_tool_declarations();
}
req
}
async fn dispatch_tools(&self, calls: &[FunctionCall]) -> Vec<FunctionResponse> {
let dispatcher = match &self.dispatcher {
Some(d) => d,
None => return Vec::new(),
};
let mut responses = Vec::with_capacity(calls.len());
for call in calls {
let result = dispatcher
.call_function(&call.name, call.args.clone())
.await;
responses.push(ToolDispatcher::build_response(call, result));
}
responses
}
}
#[async_trait]
impl TextAgent for LlmTextAgent {
fn name(&self) -> &str {
&self.name
}
async fn run(&self, state: &State) -> Result<String, AgentError> {
let input = state.get::<String>("input").unwrap_or_default();
let mut contents = vec![Content::user(&input)];
for _round in 0..MAX_TOOL_ROUNDS {
let request = self.build_request(contents.clone());
let response = self
.llm
.generate(request)
.await
.map_err(|e| AgentError::Other(format!("LLM error: {e}")))?;
let calls: Vec<FunctionCall> = response.function_calls().into_iter().cloned().collect();
if calls.is_empty() {
let text = response.text();
state.set("output", &text);
return Ok(text);
}
contents.push(response.content);
let tool_responses = self.dispatch_tools(&calls).await;
let response_parts: Vec<Part> = tool_responses
.into_iter()
.map(|fr| Part::FunctionResponse {
function_response: fr,
})
.collect();
contents.push(Content {
role: Some(Role::User),
parts: response_parts,
});
}
Err(AgentError::Other(format!(
"Agent '{}' exceeded max tool rounds ({})",
self.name, MAX_TOOL_ROUNDS
)))
}
}