Skip to main content

sgr_agent/
union_schema.rs

1//! Dynamic discriminated union schema builder — generates oneOf JSON Schema
2//! from tool definitions at runtime. Used by SgrAgent for structured output.
3
4use crate::tool::ToolDef;
5use crate::types::ToolCall;
6use serde_json::Value;
7
8/// Build a JSON Schema `oneOf` from tool definitions.
9/// Each variant has a `tool_name` const discriminator.
10pub fn build_action_schema(tools: &[ToolDef]) -> Value {
11    let variants: Vec<Value> = tools
12        .iter()
13        .map(|t| {
14            let mut properties = serde_json::Map::new();
15
16            // Discriminator
17            properties.insert(
18                "tool_name".to_string(),
19                serde_json::json!({ "type": "string", "const": t.name }),
20            );
21
22            // Merge tool parameters into properties
23            if let Some(props) = t.parameters.get("properties").and_then(|p| p.as_object()) {
24                for (k, v) in props {
25                    properties.insert(k.clone(), v.clone());
26                }
27            }
28
29            // Required fields: tool_name + tool's required
30            let mut required = vec![serde_json::json!("tool_name")];
31            if let Some(req) = t.parameters.get("required").and_then(|r| r.as_array()) {
32                required.extend(req.iter().cloned());
33            }
34
35            serde_json::json!({
36                "type": "object",
37                "properties": properties,
38                "required": required,
39            })
40        })
41        .collect();
42
43    serde_json::json!({
44        "type": "object",
45        "properties": {
46            "situation": { "type": "string", "description": "Current assessment" },
47            "task": {
48                "type": "array",
49                "items": { "type": "string" },
50                "description": "Reasoning steps"
51            },
52            "actions": {
53                "type": "array",
54                "items": { "oneOf": variants },
55                "description": "Tool calls to execute"
56            }
57        },
58        "required": ["situation", "task", "actions"]
59    })
60}
61
62/// Known wrapper keys that Gemini uses to wrap tool arguments.
63const WRAPPER_KEYS: &[&str] = &["parameters", "params", "args", "arguments"];
64
65/// Parse raw LLM output into tool calls using flexible_parser.
66/// Extracts `actions` array and maps each to a ToolCall.
67pub fn parse_action(raw: &str, _tools: &[ToolDef]) -> Result<(String, Vec<ToolCall>), ParseError> {
68    // Try to parse as JSON via flexible parser, fall back to direct serde
69    let value: Value = match crate::flexible_parser::parse_flexible::<Value>(raw) {
70        Ok(r) => r.value,
71        Err(_) => serde_json::from_str::<Value>(raw).map_err(|e| ParseError(e.to_string()))?,
72    };
73
74    let situation = match value.get("situation") {
75        Some(Value::String(s)) => s.clone(),
76        _ => String::new(),
77    };
78
79    let actions: Vec<Value> = match value.get("actions") {
80        Some(Value::Array(arr)) => arr.clone(),
81        _ => Vec::new(),
82    };
83
84    let mut tool_calls: Vec<ToolCall> = Vec::new();
85    for (i, action) in actions.into_iter().enumerate() {
86        let name = match action.get("tool_name") {
87            Some(Value::String(s)) => s.clone(),
88            _ => continue,
89        };
90
91        // Remove tool_name from args, unwrap "parameters" wrapper if present
92        let arguments = if let Value::Object(mut obj) = action {
93            obj.remove("tool_name");
94            // Gemini sometimes wraps args: {"parameters": {...}}, {"args": {...}}, etc.
95            // Unwrap only known wrapper keys that contain an object value.
96            if obj.len() == 1 {
97                let key = obj.keys().next().unwrap().clone();
98                if WRAPPER_KEYS.contains(&key.as_str()) && obj[&key].is_object() {
99                    obj.remove(&key).unwrap()
100                } else {
101                    Value::Object(obj)
102                }
103            } else {
104                Value::Object(obj)
105            }
106        } else {
107            action
108        };
109
110        tool_calls.push(ToolCall {
111            id: format!("call_{}", i),
112            name,
113            arguments,
114        });
115    }
116
117    Ok((situation, tool_calls))
118}
119
120/// Parse error for action extraction.
121#[derive(Debug, thiserror::Error)]
122#[error("{0}")]
123pub struct ParseError(pub String);
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use crate::tool::ToolDef;
129
130    fn mock_tools() -> Vec<ToolDef> {
131        vec![
132            ToolDef {
133                name: "read_file".into(),
134                description: "Read a file".into(),
135                parameters: serde_json::json!({
136                    "type": "object",
137                    "properties": {
138                        "path": { "type": "string" }
139                    },
140                    "required": ["path"]
141                }),
142            },
143            ToolDef {
144                name: "bash".into(),
145                description: "Run command".into(),
146                parameters: serde_json::json!({
147                    "type": "object",
148                    "properties": {
149                        "command": { "type": "string" }
150                    },
151                    "required": ["command"]
152                }),
153            },
154        ]
155    }
156
157    #[test]
158    fn build_schema_has_one_of() {
159        let schema = build_action_schema(&mock_tools());
160        let items = &schema["properties"]["actions"]["items"];
161        let one_of = items["oneOf"].as_array().unwrap();
162        assert_eq!(one_of.len(), 2);
163
164        // First variant has tool_name const
165        let first = &one_of[0];
166        assert_eq!(first["properties"]["tool_name"]["const"], "read_file");
167        assert!(first["properties"]["path"].is_object());
168    }
169
170    #[test]
171    fn build_schema_has_situation_and_task() {
172        let schema = build_action_schema(&mock_tools());
173        assert!(schema["properties"]["situation"].is_object());
174        assert!(schema["properties"]["task"].is_object());
175        let required = schema["required"].as_array().unwrap();
176        assert!(required.contains(&serde_json::json!("situation")));
177    }
178
179    #[test]
180    fn parse_action_extracts_calls() {
181        let raw = r#"{
182            "situation": "need to read a file",
183            "task": ["read main.rs"],
184            "actions": [
185                {"tool_name": "read_file", "path": "/src/main.rs"},
186                {"tool_name": "bash", "command": "ls -la"}
187            ]
188        }"#;
189        let (situation, calls) = parse_action(raw, &mock_tools()).unwrap();
190        assert_eq!(situation, "need to read a file");
191        assert_eq!(calls.len(), 2);
192        assert_eq!(calls[0].name, "read_file");
193        assert_eq!(calls[0].arguments["path"], "/src/main.rs");
194        assert_eq!(calls[1].name, "bash");
195        // tool_name should be stripped from args
196        assert!(calls[0].arguments.get("tool_name").is_none());
197    }
198
199    #[test]
200    fn parse_action_empty_actions() {
201        let raw = r#"{"situation": "done", "task": [], "actions": []}"#;
202        let (_, calls) = parse_action(raw, &mock_tools()).unwrap();
203        assert!(calls.is_empty());
204    }
205
206    #[test]
207    fn parse_action_markdown_wrapped() {
208        let raw = "```json\n{\"situation\": \"ok\", \"task\": [], \"actions\": [{\"tool_name\": \"bash\", \"command\": \"pwd\"}]}\n```";
209        let (_, calls) = parse_action(raw, &mock_tools()).unwrap();
210        assert_eq!(calls.len(), 1);
211        assert_eq!(calls[0].name, "bash");
212    }
213
214    #[test]
215    fn parse_action_unwraps_parameters_wrapper() {
216        // Gemini wraps args in {"parameters": {...}}
217        let raw = r#"{"situation": "reading", "task": [], "actions": [
218            {"tool_name": "read_file", "parameters": {"path": "/main.rs"}},
219            {"tool_name": "bash", "params": {"command": "ls"}}
220        ]}"#;
221        let (_, calls) = parse_action(raw, &mock_tools()).unwrap();
222        assert_eq!(calls[0].arguments["path"], "/main.rs");
223        assert_eq!(calls[1].arguments["command"], "ls");
224    }
225
226    #[test]
227    fn parse_action_keeps_single_real_arg() {
228        // Single real arg should NOT be unwrapped
229        let raw = r#"{"situation": "ok", "task": [], "actions": [
230            {"tool_name": "bash", "command": "ls"}
231        ]}"#;
232        let (_, calls) = parse_action(raw, &mock_tools()).unwrap();
233        assert_eq!(calls[0].arguments["command"], "ls");
234    }
235
236    #[test]
237    fn parse_action_skips_non_object_and_missing_tool_name() {
238        // Non-objects and missing tool_name are silently skipped
239        let raw = r#"{"situation": "ok", "task": [], "actions": [
240            "just a string",
241            42,
242            null,
243            {"no_tool_name": true},
244            {"tool_name": "bash", "command": "ls"}
245        ]}"#;
246        let (_, calls) = parse_action(raw, &mock_tools()).unwrap();
247        // Only the valid one should survive
248        assert_eq!(calls.len(), 1);
249        assert_eq!(calls[0].name, "bash");
250    }
251
252    #[test]
253    fn parse_action_missing_situation_defaults_empty() {
254        let raw = r#"{"actions": [{"tool_name": "bash", "command": "ls"}]}"#;
255        let (situation, calls) = parse_action(raw, &mock_tools()).unwrap();
256        assert_eq!(situation, "");
257        assert_eq!(calls.len(), 1);
258    }
259
260    #[test]
261    fn parse_action_missing_actions_returns_empty() {
262        let raw = r#"{"situation": "thinking"}"#;
263        let (situation, calls) = parse_action(raw, &mock_tools()).unwrap();
264        assert_eq!(situation, "thinking");
265        assert!(calls.is_empty());
266    }
267}