Skip to main content

rs_adk/text/
llm.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use rs_genai::prelude::{Content, FunctionCall, FunctionResponse, Part, Role};
5
6use super::TextAgent;
7use crate::error::AgentError;
8use crate::llm::{BaseLlm, LlmRequest};
9use crate::state::State;
10use crate::tool::ToolDispatcher;
11
12/// Maximum number of tool-dispatch round-trips before giving up.
13const MAX_TOOL_ROUNDS: usize = 10;
14
15/// Core text agent — calls `BaseLlm::generate()`, dispatches tools, loops
16/// until the model produces a final text response.
17pub struct LlmTextAgent {
18    name: String,
19    llm: Arc<dyn BaseLlm>,
20    instruction: Option<String>,
21    dispatcher: Option<Arc<ToolDispatcher>>,
22    temperature: Option<f32>,
23    max_output_tokens: Option<u32>,
24}
25
26impl LlmTextAgent {
27    /// Create a new LLM text agent.
28    pub fn new(name: impl Into<String>, llm: Arc<dyn BaseLlm>) -> Self {
29        Self {
30            name: name.into(),
31            llm,
32            instruction: None,
33            dispatcher: None,
34            temperature: None,
35            max_output_tokens: None,
36        }
37    }
38
39    /// Set the system instruction.
40    pub fn instruction(mut self, inst: impl Into<String>) -> Self {
41        self.instruction = Some(inst.into());
42        self
43    }
44
45    /// Set the tool dispatcher.
46    pub fn tools(mut self, dispatcher: Arc<ToolDispatcher>) -> Self {
47        self.dispatcher = Some(dispatcher);
48        self
49    }
50
51    /// Set temperature.
52    pub fn temperature(mut self, t: f32) -> Self {
53        self.temperature = Some(t);
54        self
55    }
56
57    /// Set max output tokens.
58    pub fn max_output_tokens(mut self, n: u32) -> Self {
59        self.max_output_tokens = Some(n);
60        self
61    }
62
63    /// Build an LlmRequest, taking ownership of contents to avoid cloning.
64    fn build_request(&self, contents: Vec<Content>) -> LlmRequest {
65        let mut req = LlmRequest::from_contents(contents);
66        req.system_instruction = self.instruction.clone();
67        req.temperature = self.temperature;
68        req.max_output_tokens = self.max_output_tokens;
69
70        if let Some(dispatcher) = &self.dispatcher {
71            req.tools = dispatcher.to_tool_declarations();
72        }
73
74        req
75    }
76
77    /// Dispatch function calls and return function responses.
78    async fn dispatch_tools(&self, calls: &[FunctionCall]) -> Vec<FunctionResponse> {
79        let dispatcher = match &self.dispatcher {
80            Some(d) => d,
81            None => return Vec::new(),
82        };
83
84        let mut responses = Vec::with_capacity(calls.len());
85        for call in calls {
86            let result = dispatcher
87                .call_function(&call.name, call.args.clone())
88                .await;
89            responses.push(ToolDispatcher::build_response(call, result));
90        }
91        responses
92    }
93}
94
95#[async_trait]
96impl TextAgent for LlmTextAgent {
97    fn name(&self) -> &str {
98        &self.name
99    }
100
101    async fn run(&self, state: &State) -> Result<String, AgentError> {
102        // Build initial contents from state "input" key, or empty user message.
103        let input = state.get::<String>("input").unwrap_or_default();
104
105        let mut contents = vec![Content::user(&input)];
106
107        for _round in 0..MAX_TOOL_ROUNDS {
108            let request = self.build_request(contents.clone());
109            let response = self
110                .llm
111                .generate(request)
112                .await
113                .map_err(|e| AgentError::Other(format!("LLM error: {e}")))?;
114
115            let calls: Vec<FunctionCall> = response.function_calls().into_iter().cloned().collect();
116
117            if calls.is_empty() {
118                // No tool calls — we have a final text response.
119                let text = response.text();
120                state.set("output", &text);
121                return Ok(text);
122            }
123
124            // Move model response into conversation (no clone needed).
125            contents.push(response.content);
126
127            // Dispatch tools and append responses.
128            let tool_responses = self.dispatch_tools(&calls).await;
129            let response_parts: Vec<Part> = tool_responses
130                .into_iter()
131                .map(|fr| Part::FunctionResponse {
132                    function_response: fr,
133                })
134                .collect();
135
136            contents.push(Content {
137                role: Some(Role::User),
138                parts: response_parts,
139            });
140        }
141
142        Err(AgentError::Other(format!(
143            "Agent '{}' exceeded max tool rounds ({})",
144            self.name, MAX_TOOL_ROUNDS
145        )))
146    }
147}