serdes_ai_output/
toolset.rs1use async_trait::async_trait;
7use parking_lot::RwLock;
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10use serde_json::Value as JsonValue;
11use serdes_ai_tools::{RunContext, ToolError, ToolReturn};
12use serdes_ai_toolsets::{AbstractToolset, ToolsetTool};
13use std::collections::HashMap;
14use std::marker::PhantomData;
15use std::sync::Arc;
16
17use crate::schema::OutputSchema;
18use crate::structured::StructuredOutputSchema;
19
20pub struct OutputToolset<T, Deps = ()>
41where
42 T: DeserializeOwned + Send + Sync + 'static,
43{
44 schema: StructuredOutputSchema<T>,
45 captured: Arc<RwLock<Option<T>>>,
46 _phantom: PhantomData<Deps>,
47}
48
49impl<T, Deps> OutputToolset<T, Deps>
50where
51 T: DeserializeOwned + Send + Sync + 'static,
52{
53 pub fn new(schema: StructuredOutputSchema<T>) -> Self {
55 Self {
56 schema,
57 captured: Arc::new(RwLock::new(None)),
58 _phantom: PhantomData,
59 }
60 }
61
62 #[must_use]
64 pub fn has_output(&self) -> bool {
65 self.captured.read().is_some()
66 }
67
68 pub fn take_output(&self) -> Option<T> {
70 self.captured.write().take()
71 }
72
73 pub fn get_output(&self) -> Option<T>
75 where
76 T: Clone,
77 {
78 self.captured.read().clone()
79 }
80
81 pub fn clear(&self) {
83 *self.captured.write() = None;
84 }
85
86 #[must_use]
88 pub fn schema(&self) -> &StructuredOutputSchema<T> {
89 &self.schema
90 }
91
92 #[must_use]
94 pub fn tool_name(&self) -> &str {
95 &self.schema.tool_name
96 }
97}
98
99impl<T, Deps> std::fmt::Debug for OutputToolset<T, Deps>
100where
101 T: DeserializeOwned + Send + Sync + 'static,
102{
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 f.debug_struct("OutputToolset")
105 .field("tool_name", &self.schema.tool_name)
106 .field("has_output", &self.has_output())
107 .finish()
108 }
109}
110
111#[async_trait]
112impl<T, Deps> AbstractToolset<Deps> for OutputToolset<T, Deps>
113where
114 T: DeserializeOwned + Serialize + Send + Sync + 'static,
115 Deps: Send + Sync + 'static,
116{
117 fn id(&self) -> Option<&str> {
118 Some("__output__")
119 }
120
121 fn type_name(&self) -> &'static str {
122 "OutputToolset"
123 }
124
125 async fn get_tools(
126 &self,
127 _ctx: &RunContext<Deps>,
128 ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
129 let defs = self.schema.tool_definitions();
130 let mut tools = HashMap::with_capacity(defs.len());
131
132 for def in defs {
133 let name = def.name.clone();
134 tools.insert(name, ToolsetTool::new(def).with_toolset_id("__output__"));
135 }
136
137 Ok(tools)
138 }
139
140 async fn call_tool(
141 &self,
142 name: &str,
143 args: JsonValue,
144 _ctx: &RunContext<Deps>,
145 _tool: &ToolsetTool,
146 ) -> Result<ToolReturn, ToolError> {
147 let output: T = self
149 .schema
150 .parse_tool_call(name, &args)
151 .map_err(|e| ToolError::execution_failed(e.to_string()))?;
152
153 *self.captured.write() = Some(output);
155
156 Ok(ToolReturn::json(args))
158 }
159}
160
161#[derive(Debug, Clone)]
163pub struct OutputCaptured<T> {
164 pub value: T,
166 pub tool_name: String,
168}
169
170impl<T> OutputCaptured<T> {
171 pub fn new(value: T, tool_name: impl Into<String>) -> Self {
173 Self {
174 value,
175 tool_name: tool_name.into(),
176 }
177 }
178
179 pub fn into_inner(self) -> T {
181 self.value
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use serde::Deserialize;
189 use serdes_ai_tools::{ObjectJsonSchema, PropertySchema};
190
191 #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
192 struct TestOutput {
193 message: String,
194 count: i32,
195 }
196
197 fn test_schema() -> StructuredOutputSchema<TestOutput> {
198 let json_schema = ObjectJsonSchema::new()
199 .with_property(
200 "message",
201 PropertySchema::string("The message").build(),
202 true,
203 )
204 .with_property("count", PropertySchema::integer("The count").build(), true);
205
206 StructuredOutputSchema::new(json_schema)
207 }
208
209 #[test]
210 fn test_output_toolset_new() {
211 let toolset: OutputToolset<TestOutput> = OutputToolset::new(test_schema());
212 assert!(!toolset.has_output());
213 assert_eq!(toolset.id(), Some("__output__"));
214 }
215
216 #[tokio::test]
217 async fn test_output_toolset_get_tools() {
218 let toolset: OutputToolset<TestOutput, ()> = OutputToolset::new(test_schema());
219 let ctx = RunContext::minimal("test");
220
221 let tools = toolset.get_tools(&ctx).await.unwrap();
222 assert_eq!(tools.len(), 1);
223 assert!(tools.contains_key("final_result"));
224 }
225
226 #[tokio::test]
227 async fn test_output_toolset_call_and_capture() {
228 let toolset: OutputToolset<TestOutput, ()> = OutputToolset::new(test_schema());
229 let ctx = RunContext::minimal("test");
230
231 let tools = toolset.get_tools(&ctx).await.unwrap();
232 let tool = tools.get("final_result").unwrap();
233
234 let args = serde_json::json!({
235 "message": "Hello, World!",
236 "count": 42
237 });
238
239 let result = toolset.call_tool("final_result", args, &ctx, tool).await;
240 assert!(result.is_ok());
241
242 assert!(toolset.has_output());
244 let output = toolset.take_output().unwrap();
245 assert_eq!(output.message, "Hello, World!");
246 assert_eq!(output.count, 42);
247
248 assert!(!toolset.has_output());
250 }
251
252 #[tokio::test]
253 async fn test_output_toolset_clear() {
254 let toolset: OutputToolset<TestOutput, ()> = OutputToolset::new(test_schema());
255 let ctx = RunContext::minimal("test");
256
257 let tools = toolset.get_tools(&ctx).await.unwrap();
258 let tool = tools.get("final_result").unwrap();
259
260 let args = serde_json::json!({
261 "message": "Test",
262 "count": 1
263 });
264
265 toolset
266 .call_tool("final_result", args, &ctx, tool)
267 .await
268 .unwrap();
269
270 assert!(toolset.has_output());
271 toolset.clear();
272 assert!(!toolset.has_output());
273 }
274
275 #[tokio::test]
276 async fn test_output_toolset_wrong_tool_name() {
277 let toolset: OutputToolset<TestOutput, ()> = OutputToolset::new(test_schema());
278 let ctx = RunContext::minimal("test");
279
280 let tools = toolset.get_tools(&ctx).await.unwrap();
281 let tool = tools.get("final_result").unwrap();
282
283 let args = serde_json::json!({"message": "Test", "count": 1});
284
285 let result = toolset.call_tool("wrong_name", args, &ctx, tool).await;
286
287 assert!(result.is_err());
288 }
289
290 #[test]
291 fn test_output_captured() {
292 let captured = OutputCaptured::new("test".to_string(), "my_tool");
293 assert_eq!(captured.tool_name, "my_tool");
294 assert_eq!(captured.into_inner(), "test");
295 }
296}