Skip to main content

rs_adk/
text_agent_tool.rs

1//! TextAgentTool — wraps a TextAgent as a ToolFunction for voice orchestration.
2//!
3//! When the live model calls this tool, the wrapped TextAgent runs via
4//! `BaseLlm::generate()` (request/response), not over a WebSocket. The agent's
5//! text output is returned as the tool result. State is shared with the parent
6//! session, so mutations are visible to watchers and phase transitions.
7//!
8//! This bridges live↔text: the voice model dispatches complex multi-step
9//! reasoning to specialist text agent pipelines.
10
11use std::sync::Arc;
12
13use async_trait::async_trait;
14use serde_json::json;
15
16use crate::error::ToolError;
17use crate::state::State;
18use crate::text::TextAgent;
19use crate::tool::ToolFunction;
20
21/// Wraps a [`TextAgent`] as a [`ToolFunction`] for live session tool dispatch.
22///
23/// Unlike [`AgentTool`](crate::AgentTool) (which wraps a live `Agent`),
24/// `TextAgentTool` wraps a text-mode agent that uses `BaseLlm::generate()`.
25/// This enables multi-step LLM reasoning pipelines to be invoked as tools
26/// from a voice session.
27///
28/// # State Sharing
29///
30/// The text agent operates on the **same shared `State`** as the voice session.
31/// This means:
32/// - The agent can read live-extracted values (emotional_state, risk_level)
33/// - Agent state mutations are visible to watchers and phase transitions
34/// - No explicit "promote state" step is needed
35///
36/// # Example
37///
38/// ```ignore
39/// let verifier = LlmTextAgent::new("verifier", flash)
40///     .instruction("Cross-reference identity against account record")
41///     .tools(Arc::new(db_dispatcher));
42///
43/// let tool = TextAgentTool::new("verify_identity", "Verify caller identity", verifier, state);
44/// dispatcher.register(tool);
45/// ```
46pub struct TextAgentTool {
47    name: String,
48    description: String,
49    agent: Arc<dyn TextAgent>,
50    parameters: serde_json::Value,
51    state: State,
52}
53
54impl TextAgentTool {
55    /// Create a new TextAgentTool wrapping the given text agent.
56    ///
57    /// `state` should be the session's shared State so mutations flow
58    /// bidirectionally between the voice session and the text agent.
59    pub fn new(
60        name: impl Into<String>,
61        description: impl Into<String>,
62        agent: impl TextAgent + 'static,
63        state: State,
64    ) -> Self {
65        Self {
66            name: name.into(),
67            description: description.into(),
68            agent: Arc::new(agent),
69            parameters: json!({
70                "type": "object",
71                "properties": {
72                    "request": {
73                        "type": "string",
74                        "description": "The request to process"
75                    }
76                },
77                "required": ["request"]
78            }),
79            state,
80        }
81    }
82
83    /// Create from an already-Arc'd text agent.
84    pub fn from_arc(
85        name: impl Into<String>,
86        description: impl Into<String>,
87        agent: Arc<dyn TextAgent>,
88        state: State,
89    ) -> Self {
90        Self {
91            name: name.into(),
92            description: description.into(),
93            agent,
94            parameters: json!({
95                "type": "object",
96                "properties": {
97                    "request": {
98                        "type": "string",
99                        "description": "The request to process"
100                    }
101                },
102                "required": ["request"]
103            }),
104            state,
105        }
106    }
107
108    /// Override the tool parameters schema.
109    pub fn with_parameters(mut self, params: serde_json::Value) -> Self {
110        self.parameters = params;
111        self
112    }
113}
114
115#[async_trait]
116impl ToolFunction for TextAgentTool {
117    fn name(&self) -> &str {
118        &self.name
119    }
120
121    fn description(&self) -> &str {
122        &self.description
123    }
124
125    fn parameters(&self) -> Option<serde_json::Value> {
126        Some(self.parameters.clone())
127    }
128
129    async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
130        // 1. Inject tool call args into state
131        if let Some(request) = args.get("request").and_then(|r| r.as_str()) {
132            self.state.set("input", request);
133        }
134        self.state.set("agent_tool_args", &args);
135
136        // 2. Run the text agent pipeline
137        let result = self
138            .agent
139            .run(&self.state)
140            .await
141            .map_err(|e| ToolError::ExecutionFailed(format!("{e}")))?;
142
143        // 3. Return result as tool response
144        Ok(json!({"result": result}))
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use crate::error::AgentError;
152
153    /// Echo agent: reads "input" from state, returns it prefixed.
154    struct EchoTextAgent;
155
156    #[async_trait]
157    impl TextAgent for EchoTextAgent {
158        fn name(&self) -> &str {
159            "echo"
160        }
161        async fn run(&self, state: &State) -> Result<String, AgentError> {
162            let input = state
163                .get::<String>("input")
164                .unwrap_or_else(|| "no input".into());
165            Ok(format!("Echo: {input}"))
166        }
167    }
168
169    /// Agent that reads and writes state.
170    struct StatefulAgent;
171
172    #[async_trait]
173    impl TextAgent for StatefulAgent {
174        fn name(&self) -> &str {
175            "stateful"
176        }
177        async fn run(&self, state: &State) -> Result<String, AgentError> {
178            // Read a value set by the parent
179            let parent_val = state
180                .get::<String>("parent_key")
181                .unwrap_or_else(|| "missing".into());
182
183            // Write a value visible to the parent
184            state.set("child_wrote", true);
185            state.set("child_output", "from child agent");
186
187            Ok(format!("Parent said: {parent_val}"))
188        }
189    }
190
191    /// Agent that always fails.
192    struct FailingTextAgent;
193
194    #[async_trait]
195    impl TextAgent for FailingTextAgent {
196        fn name(&self) -> &str {
197            "failing"
198        }
199        async fn run(&self, _state: &State) -> Result<String, AgentError> {
200            Err(AgentError::Other("intentional failure".into()))
201        }
202    }
203
204    #[tokio::test]
205    async fn basic_dispatch() {
206        let state = State::new();
207        let tool = TextAgentTool::new("echo", "Echo tool", EchoTextAgent, state);
208
209        let result = tool.call(json!({"request": "hello"})).await.unwrap();
210        assert_eq!(result["result"], "Echo: hello");
211    }
212
213    #[tokio::test]
214    async fn tool_metadata() {
215        let state = State::new();
216        let tool = TextAgentTool::new("my_tool", "Does things", EchoTextAgent, state);
217
218        assert_eq!(tool.name(), "my_tool");
219        assert_eq!(tool.description(), "Does things");
220        assert!(tool.parameters().is_some());
221        let params = tool.parameters().unwrap();
222        assert_eq!(params["type"], "object");
223        assert!(params["properties"]["request"].is_object());
224    }
225
226    #[tokio::test]
227    async fn state_shared_bidirectionally() {
228        let state = State::new();
229        state.set("parent_key", "hello from parent");
230
231        let tool = TextAgentTool::new("stateful", "Stateful tool", StatefulAgent, state.clone());
232
233        let result = tool.call(json!({"request": "test"})).await.unwrap();
234        assert_eq!(result["result"], "Parent said: hello from parent");
235
236        // Verify child's state mutations are visible to parent
237        assert_eq!(state.get::<bool>("child_wrote"), Some(true));
238        assert_eq!(
239            state.get::<String>("child_output"),
240            Some("from child agent".into())
241        );
242    }
243
244    #[tokio::test]
245    async fn error_propagation() {
246        let state = State::new();
247        let tool = TextAgentTool::new("failing", "Fails", FailingTextAgent, state);
248
249        let result = tool.call(json!({"request": "test"})).await;
250        assert!(result.is_err());
251        match result.unwrap_err() {
252            ToolError::ExecutionFailed(msg) => {
253                assert!(msg.contains("intentional failure"));
254            }
255            other => panic!("expected ExecutionFailed, got: {other:?}"),
256        }
257    }
258
259    #[tokio::test]
260    async fn custom_parameters() {
261        let state = State::new();
262        let params = json!({
263            "type": "object",
264            "properties": {
265                "query": { "type": "string" },
266                "limit": { "type": "integer" }
267            }
268        });
269        let tool = TextAgentTool::new("custom", "Custom params", EchoTextAgent, state)
270            .with_parameters(params.clone());
271
272        assert_eq!(tool.parameters().unwrap(), params);
273    }
274
275    #[tokio::test]
276    async fn args_injected_into_state() {
277        let state = State::new();
278        let tool = TextAgentTool::new("echo", "Echo", EchoTextAgent, state.clone());
279
280        let _ = tool.call(json!({"request": "injected"})).await.unwrap();
281
282        // Verify args were injected
283        assert_eq!(state.get::<String>("input"), Some("injected".into()));
284        let args = state.get::<serde_json::Value>("agent_tool_args").unwrap();
285        assert_eq!(args["request"], "injected");
286    }
287
288    #[tokio::test]
289    async fn from_arc_constructor() {
290        let state = State::new();
291        let agent: Arc<dyn TextAgent> = Arc::new(EchoTextAgent);
292        let tool = TextAgentTool::from_arc("echo", "Echo tool", agent, state);
293
294        let result = tool.call(json!({"request": "arc test"})).await.unwrap();
295        assert_eq!(result["result"], "Echo: arc test");
296    }
297}