1use std::sync::Arc;
2
3use async_trait::async_trait;
4use rs_genai::prelude::{Content, FunctionCall, FunctionResponse, Part, Role};
5
6use super::TextAgent;
7use crate::error::AgentError;
8use crate::llm::{BaseLlm, LlmRequest};
9use crate::state::State;
10use crate::tool::ToolDispatcher;
11
12const MAX_TOOL_ROUNDS: usize = 10;
14
15pub struct LlmTextAgent {
18 name: String,
19 llm: Arc<dyn BaseLlm>,
20 instruction: Option<String>,
21 dispatcher: Option<Arc<ToolDispatcher>>,
22 temperature: Option<f32>,
23 max_output_tokens: Option<u32>,
24}
25
26impl LlmTextAgent {
27 pub fn new(name: impl Into<String>, llm: Arc<dyn BaseLlm>) -> Self {
29 Self {
30 name: name.into(),
31 llm,
32 instruction: None,
33 dispatcher: None,
34 temperature: None,
35 max_output_tokens: None,
36 }
37 }
38
39 pub fn instruction(mut self, inst: impl Into<String>) -> Self {
41 self.instruction = Some(inst.into());
42 self
43 }
44
45 pub fn tools(mut self, dispatcher: Arc<ToolDispatcher>) -> Self {
47 self.dispatcher = Some(dispatcher);
48 self
49 }
50
51 pub fn temperature(mut self, t: f32) -> Self {
53 self.temperature = Some(t);
54 self
55 }
56
57 pub fn max_output_tokens(mut self, n: u32) -> Self {
59 self.max_output_tokens = Some(n);
60 self
61 }
62
63 fn build_request(&self, contents: Vec<Content>) -> LlmRequest {
65 let mut req = LlmRequest::from_contents(contents);
66 req.system_instruction = self.instruction.clone();
67 req.temperature = self.temperature;
68 req.max_output_tokens = self.max_output_tokens;
69
70 if let Some(dispatcher) = &self.dispatcher {
71 req.tools = dispatcher.to_tool_declarations();
72 }
73
74 req
75 }
76
77 async fn dispatch_tools(&self, calls: &[FunctionCall]) -> Vec<FunctionResponse> {
79 let dispatcher = match &self.dispatcher {
80 Some(d) => d,
81 None => return Vec::new(),
82 };
83
84 let mut responses = Vec::with_capacity(calls.len());
85 for call in calls {
86 let result = dispatcher
87 .call_function(&call.name, call.args.clone())
88 .await;
89 responses.push(ToolDispatcher::build_response(call, result));
90 }
91 responses
92 }
93}
94
95#[async_trait]
96impl TextAgent for LlmTextAgent {
97 fn name(&self) -> &str {
98 &self.name
99 }
100
101 async fn run(&self, state: &State) -> Result<String, AgentError> {
102 let input = state.get::<String>("input").unwrap_or_default();
104
105 let mut contents = vec![Content::user(&input)];
106
107 for _round in 0..MAX_TOOL_ROUNDS {
108 let request = self.build_request(contents.clone());
109 let response = self
110 .llm
111 .generate(request)
112 .await
113 .map_err(|e| AgentError::Other(format!("LLM error: {e}")))?;
114
115 let calls: Vec<FunctionCall> = response.function_calls().into_iter().cloned().collect();
116
117 if calls.is_empty() {
118 let text = response.text();
120 state.set("output", &text);
121 return Ok(text);
122 }
123
124 contents.push(response.content);
126
127 let tool_responses = self.dispatch_tools(&calls).await;
129 let response_parts: Vec<Part> = tool_responses
130 .into_iter()
131 .map(|fr| Part::FunctionResponse {
132 function_response: fr,
133 })
134 .collect();
135
136 contents.push(Content {
137 role: Some(Role::User),
138 parts: response_parts,
139 });
140 }
141
142 Err(AgentError::Other(format!(
143 "Agent '{}' exceeded max tool rounds ({})",
144 self.name, MAX_TOOL_ROUNDS
145 )))
146 }
147}