struct_llm/
tool.rs

1//! Tool definition and parsing utilities
2
3use crate::{Error, Provider, Result, StructuredOutput};
4use serde::{Deserialize, Serialize};
5
6/// Tool definition for LLM function calling
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct ToolDefinition {
9    /// Name of the tool
10    pub name: String,
11
12    /// Description of what the tool does
13    pub description: String,
14
15    /// JSON Schema for the tool's parameters
16    pub parameters: serde_json::Value,
17}
18
19/// A tool call made by an LLM
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ToolCall {
22    /// Unique identifier for this tool call
23    pub id: String,
24
25    /// Name of the tool being called
26    pub name: String,
27
28    /// Arguments passed to the tool (JSON)
29    pub arguments: serde_json::Value,
30}
31
32/// Extract tool calls from an API response
33///
34/// This function parses the raw response text from an LLM API and extracts
35/// any tool calls that were made. The format varies by provider.
36pub fn extract_tool_calls(response: &str, provider: Provider) -> Result<Vec<ToolCall>> {
37    match provider {
38        Provider::OpenAI | Provider::Local => extract_openai_tool_calls(response),
39        Provider::Anthropic => extract_anthropic_tool_calls(response),
40    }
41}
42
43fn extract_openai_tool_calls(response: &str) -> Result<Vec<ToolCall>> {
44    #[derive(Deserialize)]
45    struct OpenAIResponse {
46        choices: Vec<OpenAIChoice>,
47    }
48
49    #[derive(Deserialize)]
50    struct OpenAIChoice {
51        message: OpenAIMessage,
52    }
53
54    #[derive(Deserialize)]
55    struct OpenAIMessage {
56        tool_calls: Option<Vec<OpenAIToolCall>>,
57    }
58
59    #[derive(Deserialize)]
60    struct OpenAIToolCall {
61        id: String,
62        function: OpenAIFunction,
63    }
64
65    #[derive(Deserialize)]
66    struct OpenAIFunction {
67        name: String,
68        arguments: String,
69    }
70
71    let parsed: OpenAIResponse = serde_json::from_str(response)?;
72
73    let choice = parsed
74        .choices
75        .first()
76        .ok_or_else(|| Error::InvalidResponseFormat("No choices in response".to_string()))?;
77
78    let tool_calls = match &choice.message.tool_calls {
79        Some(calls) => calls
80            .iter()
81            .map(|tc| {
82                let arguments: serde_json::Value =
83                    serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::json!({}));
84
85                ToolCall {
86                    id: tc.id.clone(),
87                    name: tc.function.name.clone(),
88                    arguments,
89                }
90            })
91            .collect(),
92        None => return Err(Error::NoToolCalls),
93    };
94
95    Ok(tool_calls)
96}
97
98fn extract_anthropic_tool_calls(response: &str) -> Result<Vec<ToolCall>> {
99    #[derive(Deserialize)]
100    struct AnthropicResponse {
101        content: Vec<AnthropicContent>,
102    }
103
104    #[derive(Deserialize)]
105    #[serde(tag = "type")]
106    enum AnthropicContent {
107        #[serde(rename = "tool_use")]
108        ToolUse {
109            id: String,
110            name: String,
111            input: serde_json::Value,
112        },
113        #[serde(other)]
114        Other,
115    }
116
117    let parsed: AnthropicResponse = serde_json::from_str(response)?;
118
119    let tool_calls: Vec<ToolCall> = parsed
120        .content
121        .into_iter()
122        .filter_map(|content| match content {
123            AnthropicContent::ToolUse { id, name, input } => Some(ToolCall {
124                id,
125                name,
126                arguments: input,
127            }),
128            AnthropicContent::Other => None,
129        })
130        .collect();
131
132    if tool_calls.is_empty() {
133        return Err(Error::NoToolCalls);
134    }
135
136    Ok(tool_calls)
137}
138
139/// Parse a tool call response into a structured type
140///
141/// This validates the tool call arguments against the expected schema
142/// and deserializes them into the target type.
143pub fn parse_tool_response<T: StructuredOutput>(tool_call: &ToolCall) -> Result<T> {
144    // Validate that the tool name matches
145    if tool_call.name != T::tool_name() {
146        return Err(Error::ToolMismatch(
147            tool_call.name.clone(),
148            T::tool_name().to_string(),
149        ));
150    }
151
152    // Validate arguments against the schema
153    let schema = T::json_schema();
154    crate::schema::validate(&tool_call.arguments, &schema)?;
155
156    // Deserialize the arguments
157    let result: T = serde_json::from_value(tool_call.arguments.clone())?;
158
159    Ok(result)
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    #[test]
167    fn test_extract_openai_tool_calls() {
168        let response = r#"{
169            "choices": [{
170                "message": {
171                    "tool_calls": [{
172                        "id": "call_123",
173                        "function": {
174                            "name": "final_answer",
175                            "arguments": "{\"response\": \"Hello\", \"confidence\": 0.95}"
176                        }
177                    }]
178                }
179            }]
180        }"#;
181
182        let tool_calls = extract_openai_tool_calls(response).unwrap();
183        assert_eq!(tool_calls.len(), 1);
184        assert_eq!(tool_calls[0].id, "call_123");
185        assert_eq!(tool_calls[0].name, "final_answer");
186        assert_eq!(tool_calls[0].arguments["response"], "Hello");
187    }
188
189    #[test]
190    fn test_extract_anthropic_tool_calls() {
191        let response = r#"{
192            "content": [{
193                "type": "tool_use",
194                "id": "call_123",
195                "name": "final_answer",
196                "input": {
197                    "response": "Hello",
198                    "confidence": 0.95
199                }
200            }]
201        }"#;
202
203        let tool_calls = extract_anthropic_tool_calls(response).unwrap();
204        assert_eq!(tool_calls.len(), 1);
205        assert_eq!(tool_calls[0].id, "call_123");
206        assert_eq!(tool_calls[0].name, "final_answer");
207        assert_eq!(tool_calls[0].arguments["response"], "Hello");
208    }
209
210    #[test]
211    fn test_no_tool_calls() {
212        let response = r#"{
213            "choices": [{
214                "message": {
215                    "content": "Just a regular response"
216                }
217            }]
218        }"#;
219
220        let result = extract_openai_tool_calls(response);
221        assert!(result.is_err());
222        assert!(matches!(result.unwrap_err(), Error::NoToolCalls));
223    }
224}