Skip to main content

aster/agents/
final_output_tool.rs

1use crate::agents::tool_execution::ToolCallResult;
2use crate::recipe::Response;
3use indoc::formatdoc;
4use rmcp::model::{CallToolRequestParam, Content, ErrorCode, ErrorData, Tool, ToolAnnotations};
5use serde_json::Value;
6use std::borrow::Cow;
7
8pub const FINAL_OUTPUT_TOOL_NAME: &str = "recipe__final_output";
9pub const FINAL_OUTPUT_CONTINUATION_MESSAGE: &str =
10    "You MUST call the `final_output` tool NOW with the final output for the user.";
11
12pub struct FinalOutputTool {
13    pub response: Response,
14    /// The final output collected for the user. It will be a single line string for easy script extraction from output.
15    pub final_output: Option<String>,
16}
17
18impl FinalOutputTool {
19    pub fn new(response: Response) -> Self {
20        if response.json_schema.is_none() {
21            panic!("Cannot create FinalOutputTool: json_schema is required");
22        }
23        let schema = response.json_schema.as_ref().unwrap();
24
25        if let Some(obj) = schema.as_object() {
26            if obj.is_empty() {
27                panic!("Cannot create FinalOutputTool: empty json_schema is not allowed");
28            }
29        }
30
31        jsonschema::meta::validate(schema).unwrap();
32        Self {
33            response,
34            final_output: None,
35        }
36    }
37
38    pub fn tool(&self) -> Tool {
39        let instructions = formatdoc! {r#"
40            The final_output tool collects the final output for the user and provides validation for structured JSON final output against a predefined schema.
41
42            This final_output tool MUST be called with the final output for the user.
43            
44            Purpose:
45            - Collects the final output for the user
46            - Ensures that final outputs conform to the expected JSON structure
47            - Provides clear validation feedback when outputs don't match the schema
48            
49            Usage:
50            - Call the `final_output` tool with your JSON final output passed as the argument.
51            
52            The expected JSON schema format is:
53
54            {}
55            
56            When validation fails, you'll receive:
57            - Specific validation errors
58            - The expected format
59        "#, serde_json::to_string_pretty(self.response.json_schema.as_ref().unwrap()).unwrap()};
60
61        Tool::new(
62            FINAL_OUTPUT_TOOL_NAME.to_string(),
63            instructions,
64            self.response
65                .json_schema
66                .as_ref()
67                .unwrap()
68                .as_object()
69                .unwrap()
70                .clone(),
71        )
72        .annotate(ToolAnnotations {
73            title: Some("Final Output".to_string()),
74            read_only_hint: Some(false),
75            destructive_hint: Some(false),
76            idempotent_hint: Some(true),
77            open_world_hint: Some(false),
78        })
79    }
80
81    pub fn system_prompt(&self) -> String {
82        formatdoc! {r#"
83            # Final Output Instructions
84
85            You MUST use the `final_output` tool to collect the final output for the user rather than providing the output directly in your response.
86            The final output MUST be a valid JSON object that is provided to the `final_output` tool when called and it must match the following schema:
87
88            {}
89
90            ----
91        "#, serde_json::to_string_pretty(self.response.json_schema.as_ref().unwrap()).unwrap()}
92    }
93
94    async fn validate_json_output(&self, output: &Value) -> Result<Value, String> {
95        let compiled_schema =
96            match jsonschema::validator_for(self.response.json_schema.as_ref().unwrap()) {
97                Ok(schema) => schema,
98                Err(e) => {
99                    return Err(format!("Internal error: Failed to compile schema: {}", e));
100                }
101            };
102
103        let validation_errors: Vec<String> = compiled_schema
104            .iter_errors(output)
105            .map(|error| format!("- {}: {}", error.instance_path, error))
106            .collect();
107
108        if validation_errors.is_empty() {
109            Ok(output.clone())
110        } else {
111            Err(format!(
112                "Validation failed:\n{}\n\nExpected format:\n{}\n\nPlease correct your output to match the expected JSON schema and try again.",
113                validation_errors.join("\n"),
114                serde_json::to_string_pretty(self.response.json_schema.as_ref().unwrap()).unwrap_or_else(|_| "Invalid schema".to_string())
115            ))
116        }
117    }
118
119    pub async fn execute_tool_call(&mut self, tool_call: CallToolRequestParam) -> ToolCallResult {
120        match tool_call.name.to_string().as_str() {
121            FINAL_OUTPUT_TOOL_NAME => {
122                let result = self.validate_json_output(&tool_call.arguments.into()).await;
123                match result {
124                    Ok(parsed_value) => {
125                        self.final_output = Some(Self::parsed_final_output_string(parsed_value));
126                        ToolCallResult::from(Ok(rmcp::model::CallToolResult {
127                            content: vec![Content::text(
128                                "Final output successfully collected.".to_string(),
129                            )],
130                            structured_content: None,
131                            is_error: Some(false),
132                            meta: None,
133                        }))
134                    }
135                    Err(error) => ToolCallResult::from(Err(ErrorData {
136                        code: ErrorCode::INVALID_PARAMS,
137                        message: Cow::from(error),
138                        data: None,
139                    })),
140                }
141            }
142            _ => ToolCallResult::from(Err(ErrorData {
143                code: ErrorCode::INVALID_REQUEST,
144                message: Cow::from(format!("Unknown tool: {}", tool_call.name)),
145                data: None,
146            })),
147        }
148    }
149
150    // Formats the parsed JSON as a single line string so its easy to extract from the output
151    fn parsed_final_output_string(parsed_json: Value) -> String {
152        serde_json::to_string(&parsed_json).unwrap()
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use crate::recipe::Response;
160    use rmcp::model::CallToolRequestParam;
161    use rmcp::object;
162    use serde_json::json;
163
164    fn create_complex_test_schema() -> Value {
165        json!({
166            "type": "object",
167            "properties": {
168                "user": {
169                    "type": "object",
170                    "properties": {
171                        "name": {"type": "string"},
172                        "age": {"type": "number"}
173                    },
174                    "required": ["name", "age"]
175                },
176                "tags": {
177                    "type": "array",
178                    "items": {"type": "string"}
179                }
180            },
181            "required": ["user", "tags"]
182        })
183    }
184
185    #[test]
186    #[should_panic(expected = "Cannot create FinalOutputTool: json_schema is required")]
187    fn test_new_with_missing_schema() {
188        let response = Response { json_schema: None };
189        FinalOutputTool::new(response);
190    }
191
192    #[test]
193    #[should_panic(expected = "Cannot create FinalOutputTool: empty json_schema is not allowed")]
194    fn test_new_with_empty_schema() {
195        let response = Response {
196            json_schema: Some(json!({})),
197        };
198        FinalOutputTool::new(response);
199    }
200
201    #[test]
202    #[should_panic]
203    fn test_new_with_invalid_schema() {
204        let response = Response {
205            json_schema: Some(json!({
206                "type": "invalid_type",
207                "properties": {
208                    "message": {
209                        "type": "unknown_type"
210                    }
211                }
212            })),
213        };
214        FinalOutputTool::new(response);
215    }
216
217    #[tokio::test]
218    async fn test_execute_tool_call_schema_validation_failure() {
219        let response = Response {
220            json_schema: Some(json!({
221                "type": "object",
222                "properties": {
223                    "message": {
224                        "type": "string"
225                    },
226                    "count": {
227                        "type": "number"
228                    }
229                },
230                "required": ["message", "count"]
231            })),
232        };
233
234        let mut tool = FinalOutputTool::new(response);
235        let tool_call = CallToolRequestParam {
236            name: FINAL_OUTPUT_TOOL_NAME.into(),
237            arguments: Some(object!({
238                "message": "Hello"  // Missing required "count" field
239            })),
240        };
241
242        let result = tool.execute_tool_call(tool_call).await;
243        let tool_result = result.result.await;
244        assert!(tool_result.is_err());
245        if let Err(error) = tool_result {
246            assert!(error.to_string().contains("Validation failed"));
247        }
248    }
249
250    #[tokio::test]
251    async fn test_execute_tool_call_complex_valid_json() {
252        let response = Response {
253            json_schema: Some(create_complex_test_schema()),
254        };
255
256        let mut tool = FinalOutputTool::new(response);
257        let tool_call = CallToolRequestParam {
258            name: FINAL_OUTPUT_TOOL_NAME.into(),
259            arguments: Some(object!({
260                "user": {
261                    "name": "John",
262                    "age": 30
263                },
264                "tags": ["developer", "rust"]
265            })),
266        };
267
268        let result = tool.execute_tool_call(tool_call).await;
269        let tool_result = result.result.await;
270        assert!(tool_result.is_ok());
271        assert!(tool.final_output.is_some());
272
273        let final_output = tool.final_output.unwrap();
274        assert!(serde_json::from_str::<Value>(&final_output).is_ok());
275        assert!(!final_output.contains('\n'));
276    }
277}