serdes_ai_toolsets/
external.rs1use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use serdes_ai_tools::{RunContext, ToolDefinition, ToolError, ToolReturn};
9use std::collections::HashMap;
10use std::marker::PhantomData;
11
12use crate::{AbstractToolset, ToolsetTool};
13
14pub struct ExternalToolset<Deps = ()> {
31 id: Option<String>,
32 definitions: Vec<ToolDefinition>,
33 max_retries: u32,
34 _phantom: PhantomData<fn() -> Deps>,
35}
36
37impl<Deps> ExternalToolset<Deps> {
38 #[must_use]
40 pub fn new() -> Self {
41 Self {
42 id: None,
43 definitions: Vec::new(),
44 max_retries: 3,
45 _phantom: PhantomData,
46 }
47 }
48
49 #[must_use]
51 pub fn with_id(mut self, id: impl Into<String>) -> Self {
52 self.id = Some(id.into());
53 self
54 }
55
56 #[must_use]
58 pub fn with_max_retries(mut self, retries: u32) -> Self {
59 self.max_retries = retries;
60 self
61 }
62
63 #[must_use]
65 pub fn definition(mut self, def: ToolDefinition) -> Self {
66 self.definitions.push(def);
67 self
68 }
69
70 #[must_use]
72 pub fn definitions(mut self, defs: impl IntoIterator<Item = ToolDefinition>) -> Self {
73 self.definitions.extend(defs);
74 self
75 }
76
77 #[must_use]
79 pub fn len(&self) -> usize {
80 self.definitions.len()
81 }
82
83 #[must_use]
85 pub fn is_empty(&self) -> bool {
86 self.definitions.is_empty()
87 }
88}
89
90impl<Deps> Default for ExternalToolset<Deps> {
91 fn default() -> Self {
92 Self::new()
93 }
94}
95
96#[async_trait]
97impl<Deps: Send + Sync> AbstractToolset<Deps> for ExternalToolset<Deps> {
98 fn id(&self) -> Option<&str> {
99 self.id.as_deref()
100 }
101
102 fn type_name(&self) -> &'static str {
103 "ExternalToolset"
104 }
105
106 async fn get_tools(
107 &self,
108 _ctx: &RunContext<Deps>,
109 ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
110 Ok(self
111 .definitions
112 .iter()
113 .map(|def| {
114 (
115 def.name.clone(),
116 ToolsetTool {
117 toolset_id: self.id.clone(),
118 tool_def: def.clone(),
119 max_retries: self.max_retries,
120 },
121 )
122 })
123 .collect())
124 }
125
126 async fn call_tool(
127 &self,
128 name: &str,
129 args: JsonValue,
130 _ctx: &RunContext<Deps>,
131 _tool: &ToolsetTool,
132 ) -> Result<ToolReturn, ToolError> {
133 Err(ToolError::CallDeferred {
135 tool_name: name.to_string(),
136 args,
137 })
138 }
139}
140
141impl<Deps> std::fmt::Debug for ExternalToolset<Deps> {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 f.debug_struct("ExternalToolset")
144 .field("id", &self.id)
145 .field("definitions", &self.definitions.len())
146 .field("max_retries", &self.max_retries)
147 .finish()
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn test_external_toolset_new() {
157 let toolset = ExternalToolset::<()>::new();
158 assert!(toolset.is_empty());
159 assert!(toolset.id().is_none());
160 }
161
162 #[test]
163 fn test_external_toolset_with_definitions() {
164 let toolset = ExternalToolset::<()>::new()
165 .with_id("external")
166 .definition(ToolDefinition::new("api_call", "Call API"))
167 .definition(ToolDefinition::new("webhook", "Send webhook"));
168
169 assert_eq!(toolset.len(), 2);
170 assert_eq!(toolset.id(), Some("external"));
171 }
172
173 #[tokio::test]
174 async fn test_external_toolset_get_tools() {
175 let toolset =
176 ExternalToolset::<()>::new().definition(ToolDefinition::new("test", "Test tool"));
177
178 let ctx = RunContext::minimal("test");
179 let tools = toolset.get_tools(&ctx).await.unwrap();
180
181 assert_eq!(tools.len(), 1);
182 assert!(tools.contains_key("test"));
183 }
184
185 #[tokio::test]
186 async fn test_external_toolset_call_deferred() {
187 let toolset =
188 ExternalToolset::<()>::new().definition(ToolDefinition::new("api_call", "Call API"));
189
190 let ctx = RunContext::minimal("test");
191 let tools = toolset.get_tools(&ctx).await.unwrap();
192 let tool = tools.get("api_call").unwrap();
193
194 let result = toolset
195 .call_tool(
196 "api_call",
197 serde_json::json!({"endpoint": "/test"}),
198 &ctx,
199 tool,
200 )
201 .await;
202
203 assert!(matches!(result, Err(ToolError::CallDeferred { .. })));
204
205 if let Err(ToolError::CallDeferred { tool_name, args }) = result {
206 assert_eq!(tool_name, "api_call");
207 assert_eq!(args["endpoint"], "/test");
208 }
209 }
210}