Skip to main content

agentik_sdk/types/
tools.rs

1//! Tool use types for function calling with Claude.
2//!
3//! This module provides comprehensive support for tool use (function calling),
4//! allowing Claude to interact with external functions and APIs.
5//!
6//! # Examples
7//!
8//! ```rust
9//! use agentik_sdk::types::tools::{Tool, ToolChoice, ToolUse, ToolResult};
10//! use serde_json::json;
11//!
12//! // Define a tool
13//! let weather_tool = Tool::new("get_weather", "Get the current weather in a given location")
14//!     .parameter("location", "string", "The city and state, e.g. San Francisco, CA")
15//!     .required("location")
16//!     .build();
17//!
18//! // Tool use request from Claude
19//! let tool_use = ToolUse {
20//!     id: "toolu_123".to_string(),
21//!     name: "get_weather".to_string(),
22//!     input: json!({"location": "San Francisco, CA"}),
23//! };
24//!
25//! // Tool result after execution
26//! let tool_result = ToolResult::success(
27//!     tool_use.id.clone(),
28//!     "The weather in San Francisco is 72°F and sunny"
29//! );
30//! ```
31
32use serde::{Deserialize, Serialize};
33use serde_json::{Map, Value};
34
35/// A tool definition for function calling.
36///
37/// Tools allow Claude to call external functions with structured inputs.
38/// Each tool has a name, description, and JSON schema for input validation.
39#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
40pub struct Tool {
41    /// The name of the tool. Must be unique within the request.
42    pub name: String,
43
44    /// A detailed description of what the tool does.
45    pub description: String,
46
47    /// JSON schema definition for the tool's input parameters.
48    pub input_schema: ToolInputSchema,
49}
50
51/// JSON schema for tool input parameters.
52///
53/// Defines the structure, types, and constraints for tool parameters.
54#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
55pub struct ToolInputSchema {
56    /// The schema type (always "object" for tool inputs).
57    #[serde(rename = "type")]
58    pub schema_type: String,
59
60    /// Property definitions for the input parameters.
61    pub properties: Map<String, Value>,
62
63    /// List of required parameter names.
64    #[serde(skip_serializing_if = "Vec::is_empty")]
65    pub required: Vec<String>,
66
67    /// Additional schema properties.
68    #[serde(flatten)]
69    pub additional: Map<String, Value>,
70}
71
72/// Tool choice strategy for controlling which tools Claude can use.
73///
74/// This determines how Claude selects and uses available tools.
75#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
76#[serde(tag = "type")]
77pub enum ToolChoice {
78    /// Let Claude automatically decide whether and which tools to use.
79    #[serde(rename = "auto")]
80    Auto,
81
82    /// Claude must use one of the available tools.
83    #[serde(rename = "any")]
84    Any,
85
86    /// Force Claude to use a specific tool.
87    #[serde(rename = "tool")]
88    Tool {
89        /// The name of the tool that must be used.
90        name: String,
91    },
92}
93
94/// A tool use request from Claude.
95///
96/// When Claude decides to use a tool, it returns this structure with
97/// the tool name and input parameters.
98#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
99pub struct ToolUse {
100    /// Unique identifier for this tool use request.
101    pub id: String,
102
103    /// The name of the tool to call.
104    pub name: String,
105
106    /// Input parameters for the tool call.
107    pub input: Value,
108}
109
110/// Result of a tool execution.
111///
112/// After executing a tool, return the result using this structure
113/// so Claude can incorporate it into its response.
114#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
115pub struct ToolResult {
116    /// The ID of the tool use request this result corresponds to.
117    pub tool_use_id: String,
118
119    /// The result content from the tool execution.
120    pub content: ToolResultContent,
121
122    /// Whether the tool execution was successful.
123    #[serde(skip_serializing_if = "Option::is_none")]
124    pub is_error: Option<bool>,
125}
126
127/// Content of a tool result.
128#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
129#[serde(untagged)]
130pub enum ToolResultContent {
131    /// Simple text result.
132    Text(String),
133
134    /// Structured JSON result.
135    Json(Value),
136
137    /// Multiple content blocks (text, images, etc.).
138    Blocks(Vec<ToolResultBlock>),
139}
140
141/// A content block in a tool result.
142#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
143#[serde(tag = "type")]
144pub enum ToolResultBlock {
145    /// Text content block.
146    #[serde(rename = "text")]
147    Text {
148        /// The text content.
149        text: String,
150    },
151
152    /// Image content block.
153    #[serde(rename = "image")]
154    Image {
155        /// Image source information.
156        source: ImageSource,
157    },
158}
159
160/// Image source for tool results.
161#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
162#[serde(tag = "type")]
163pub enum ImageSource {
164    /// Base64-encoded image data.
165    #[serde(rename = "base64")]
166    Base64 {
167        /// MIME type of the image.
168        media_type: String,
169        /// Base64-encoded image data.
170        data: String,
171    },
172}
173
174/// Builder for creating tool definitions.
175///
176/// Provides a fluent API for constructing tools with parameters and validation.
177#[derive(Debug, Clone)]
178pub struct ToolBuilder {
179    name: String,
180    description: String,
181    properties: Map<String, Value>,
182    required: Vec<String>,
183    additional: Map<String, Value>,
184}
185
186impl ToolBuilder {
187    /// Create a new tool builder.
188    pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
189        Self {
190            name: name.into(),
191            description: description.into(),
192            properties: Map::new(),
193            required: Vec::new(),
194            additional: Map::new(),
195        }
196    }
197
198    /// Add a parameter to the tool.
199    ///
200    /// # Arguments
201    /// * `name` - Parameter name
202    /// * `param_type` - Parameter type (e.g., "string", "number", "boolean")
203    /// * `description` - Parameter description
204    pub fn parameter(
205        mut self,
206        name: impl Into<String>,
207        param_type: impl Into<String>,
208        description: impl Into<String>,
209    ) -> Self {
210        let param_name = name.into();
211        let param_schema = serde_json::json!({
212            "type": param_type.into(),
213            "description": description.into()
214        });
215        self.properties.insert(param_name, param_schema);
216        self
217    }
218
219    /// Add an enum parameter with specific allowed values.
220    pub fn enum_parameter(
221        mut self,
222        name: impl Into<String>,
223        description: impl Into<String>,
224        values: Vec<String>,
225    ) -> Self {
226        let param_name = name.into();
227        let param_schema = serde_json::json!({
228            "type": "string",
229            "description": description.into(),
230            "enum": values
231        });
232        self.properties.insert(param_name, param_schema);
233        self
234    }
235
236    /// Add an array parameter.
237    pub fn array_parameter(
238        mut self,
239        name: impl Into<String>,
240        description: impl Into<String>,
241        item_type: impl Into<String>,
242    ) -> Self {
243        let param_name = name.into();
244        let param_schema = serde_json::json!({
245            "type": "array",
246            "description": description.into(),
247            "items": {
248                "type": item_type.into()
249            }
250        });
251        self.properties.insert(param_name, param_schema);
252        self
253    }
254
255    /// Add an object parameter with nested properties.
256    pub fn object_parameter(
257        mut self,
258        name: impl Into<String>,
259        description: impl Into<String>,
260        properties: Map<String, Value>,
261    ) -> Self {
262        let param_name = name.into();
263        let param_schema = serde_json::json!({
264            "type": "object",
265            "description": description.into(),
266            "properties": properties
267        });
268        self.properties.insert(param_name, param_schema);
269        self
270    }
271
272    /// Mark a parameter as required.
273    pub fn required(mut self, name: impl Into<String>) -> Self {
274        let param_name = name.into();
275        if !self.required.contains(&param_name) {
276            self.required.push(param_name);
277        }
278        self
279    }
280
281    /// Add additional schema properties.
282    pub fn additional_property(mut self, key: impl Into<String>, value: Value) -> Self {
283        self.additional.insert(key.into(), value);
284        self
285    }
286
287    /// Build the tool definition.
288    pub fn build(self) -> Tool {
289        Tool {
290            name: self.name,
291            description: self.description,
292            input_schema: ToolInputSchema {
293                schema_type: "object".to_string(),
294                properties: self.properties,
295                required: self.required,
296                additional: self.additional,
297            },
298        }
299    }
300}
301
302impl Tool {
303    /// Create a new tool builder.
304    pub fn builder() -> ToolBuilder {
305        ToolBuilder {
306            name: String::new(),
307            description: String::new(),
308            properties: Map::new(),
309            required: Vec::new(),
310            additional: Map::new(),
311        }
312    }
313
314    /// Validate if the given input matches this tool's schema.
315    pub fn validate_input(&self, input: &Value) -> Result<(), ToolValidationError> {
316        // Basic validation - check required fields
317        if let Value::Object(input_obj) = input {
318            for required_field in &self.input_schema.required {
319                if !input_obj.contains_key(required_field) {
320                    return Err(ToolValidationError::MissingRequiredField {
321                        field: required_field.clone(),
322                        tool: self.name.clone(),
323                    });
324                }
325            }
326
327            // Check field types
328            for (field_name, field_value) in input_obj {
329                if let Some(property_schema) = self.input_schema.properties.get(field_name) {
330                    self.validate_field_type(field_name, field_value, property_schema)?;
331                }
332            }
333
334            Ok(())
335        } else {
336            Err(ToolValidationError::InvalidInputType {
337                expected: "object".to_string(),
338                actual: input.to_string(),
339                tool: self.name.clone(),
340            })
341        }
342    }
343
344    fn validate_field_type(
345        &self,
346        field_name: &str,
347        value: &Value,
348        schema: &Value,
349    ) -> Result<(), ToolValidationError> {
350        if let Some(expected_type) = schema.get("type").and_then(|t| t.as_str()) {
351            let actual_type = match value {
352                Value::Null => "null",
353                Value::Bool(_) => "boolean",
354                Value::Number(_) => "number",
355                Value::String(_) => "string",
356                Value::Array(_) => "array",
357                Value::Object(_) => "object",
358            };
359
360            if expected_type != actual_type {
361                return Err(ToolValidationError::InvalidFieldType {
362                    field: field_name.to_string(),
363                    expected: expected_type.to_string(),
364                    actual: actual_type.to_string(),
365                    tool: self.name.clone(),
366                });
367            }
368        }
369
370        Ok(())
371    }
372}
373
374impl ToolChoice {
375    /// Create an auto tool choice.
376    pub fn auto() -> Self {
377        Self::Auto
378    }
379
380    /// Create an any tool choice.
381    pub fn any() -> Self {
382        Self::Any
383    }
384
385    /// Create a specific tool choice.
386    pub fn tool(name: impl Into<String>) -> Self {
387        Self::Tool { name: name.into() }
388    }
389}
390
391impl ToolResult {
392    /// Create a successful tool result with text content.
393    pub fn success(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
394        Self {
395            tool_use_id: tool_use_id.into(),
396            content: ToolResultContent::Text(content.into()),
397            is_error: None,
398        }
399    }
400
401    /// Create a successful tool result with JSON content.
402    pub fn success_json(tool_use_id: impl Into<String>, content: Value) -> Self {
403        Self {
404            tool_use_id: tool_use_id.into(),
405            content: ToolResultContent::Json(content),
406            is_error: None,
407        }
408    }
409
410    /// Create an error tool result.
411    pub fn error(tool_use_id: impl Into<String>, error_message: impl Into<String>) -> Self {
412        Self {
413            tool_use_id: tool_use_id.into(),
414            content: ToolResultContent::Text(error_message.into()),
415            is_error: Some(true),
416        }
417    }
418
419    /// Create a tool result with multiple content blocks.
420    pub fn with_blocks(tool_use_id: impl Into<String>, blocks: Vec<ToolResultBlock>) -> Self {
421        Self {
422            tool_use_id: tool_use_id.into(),
423            content: ToolResultContent::Blocks(blocks),
424            is_error: None,
425        }
426    }
427}
428
429impl ToolResultBlock {
430    /// Create a text content block.
431    pub fn text(text: impl Into<String>) -> Self {
432        Self::Text { text: text.into() }
433    }
434
435    /// Create an image content block from base64 data.
436    pub fn image_base64(media_type: impl Into<String>, data: impl Into<String>) -> Self {
437        Self::Image {
438            source: ImageSource::Base64 {
439                media_type: media_type.into(),
440                data: data.into(),
441            },
442        }
443    }
444}
445
446/// Tool validation errors.
447#[derive(Debug, Clone, PartialEq, thiserror::Error)]
448pub enum ToolValidationError {
449    /// A required field is missing from the input.
450    #[error("Missing required field '{field}' for tool '{tool}'")]
451    MissingRequiredField { field: String, tool: String },
452
453    /// Invalid input type (expected object).
454    #[error("Invalid input type for tool '{tool}': expected {expected}, got {actual}")]
455    InvalidInputType {
456        expected: String,
457        actual: String,
458        tool: String,
459    },
460
461    /// Invalid field type.
462    #[error("Invalid type for field '{field}' in tool '{tool}': expected {expected}, got {actual}")]
463    InvalidFieldType {
464        field: String,
465        expected: String,
466        actual: String,
467        tool: String,
468    },
469}
470
471/// Server-side tools provided by Anthropic.
472///
473/// These tools are executed on Anthropic's servers and don't require
474/// client-side implementation.
475#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
476#[serde(tag = "type")]
477pub enum ServerTool {
478    /// Web search tool for retrieving current information.
479    #[serde(rename = "web_search_20250305")]
480    WebSearch {
481        /// Optional search parameters.
482        #[serde(skip_serializing_if = "Option::is_none")]
483        parameters: Option<WebSearchParameters>,
484    },
485}
486
487/// Parameters for the web search server tool.
488#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
489pub struct WebSearchParameters {
490    /// Maximum number of search results to return.
491    #[serde(skip_serializing_if = "Option::is_none")]
492    max_results: Option<u32>,
493
494    /// Search language preference.
495    #[serde(skip_serializing_if = "Option::is_none")]
496    language: Option<String>,
497
498    /// Geographic region for search results.
499    #[serde(skip_serializing_if = "Option::is_none")]
500    region: Option<String>,
501}
502
503impl ServerTool {
504    /// Create a web search tool with default parameters.
505    pub fn web_search() -> Self {
506        Self::WebSearch { parameters: None }
507    }
508
509    /// Create a web search tool with custom parameters.
510    pub fn web_search_with_params(parameters: WebSearchParameters) -> Self {
511        Self::WebSearch {
512            parameters: Some(parameters),
513        }
514    }
515}
516
517impl WebSearchParameters {
518    /// Create web search parameters with maximum results.
519    pub fn with_max_results(max_results: u32) -> Self {
520        Self {
521            max_results: Some(max_results),
522            language: None,
523            region: None,
524        }
525    }
526
527    /// Set the search language.
528    pub fn language(mut self, language: impl Into<String>) -> Self {
529        self.language = Some(language.into());
530        self
531    }
532
533    /// Set the search region.
534    pub fn region(mut self, region: impl Into<String>) -> Self {
535        self.region = Some(region.into());
536        self
537    }
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543    use serde_json::json;
544
545    #[test]
546    fn test_tool_builder() {
547        let tool = ToolBuilder::new("get_weather", "Get the current weather")
548            .parameter("location", "string", "The location to get weather for")
549            .parameter("unit", "string", "Temperature unit")
550            .enum_parameter(
551                "format",
552                "Response format",
553                vec!["json".to_string(), "text".to_string()],
554            )
555            .required("location")
556            .build();
557
558        assert_eq!(tool.name, "get_weather");
559        assert_eq!(tool.description, "Get the current weather");
560        assert_eq!(tool.input_schema.required, vec!["location"]);
561        assert_eq!(tool.input_schema.properties.len(), 3);
562    }
563
564    #[test]
565    fn test_tool_validation() {
566        let tool = ToolBuilder::new("test_tool", "Test tool")
567            .parameter("required_field", "string", "Required field")
568            .parameter("optional_field", "number", "Optional field")
569            .required("required_field")
570            .build();
571
572        // Valid input
573        let valid_input = json!({
574            "required_field": "test",
575            "optional_field": 42
576        });
577        assert!(tool.validate_input(&valid_input).is_ok());
578
579        // Missing required field
580        let invalid_input = json!({
581            "optional_field": 42
582        });
583        assert!(tool.validate_input(&invalid_input).is_err());
584
585        // Wrong type
586        let wrong_type_input = json!({
587            "required_field": 123
588        });
589        assert!(tool.validate_input(&wrong_type_input).is_err());
590    }
591
592    #[test]
593    fn test_tool_choice_serialization() {
594        let auto_choice = ToolChoice::auto();
595        let json = serde_json::to_value(&auto_choice).unwrap();
596        assert_eq!(json, json!({"type": "auto"}));
597
598        let tool_choice = ToolChoice::tool("get_weather");
599        let json = serde_json::to_value(&tool_choice).unwrap();
600        assert_eq!(json, json!({"type": "tool", "name": "get_weather"}));
601    }
602
603    #[test]
604    fn test_tool_result_creation() {
605        let success_result = ToolResult::success("tool_123", "Success message");
606        assert_eq!(success_result.tool_use_id, "tool_123");
607        assert!(success_result.is_error.is_none());
608
609        let error_result = ToolResult::error("tool_456", "Error message");
610        assert_eq!(error_result.tool_use_id, "tool_456");
611        assert_eq!(error_result.is_error, Some(true));
612
613        let json_result = ToolResult::success_json("tool_789", json!({"temperature": 72}));
614        if let ToolResultContent::Json(value) = json_result.content {
615            assert_eq!(value["temperature"], 72);
616        } else {
617            panic!("Expected JSON content");
618        }
619    }
620
621    #[test]
622    fn test_server_tool_creation() {
623        let web_search = ServerTool::web_search();
624        assert!(matches!(
625            web_search,
626            ServerTool::WebSearch { parameters: None }
627        ));
628
629        let params = WebSearchParameters::with_max_results(10)
630            .language("en")
631            .region("US");
632        let web_search_with_params = ServerTool::web_search_with_params(params);
633
634        if let ServerTool::WebSearch {
635            parameters: Some(p),
636        } = web_search_with_params
637        {
638            assert_eq!(p.max_results, Some(10));
639            assert_eq!(p.language, Some("en".to_string()));
640            assert_eq!(p.region, Some("US".to_string()));
641        } else {
642            panic!("Expected web search with parameters");
643        }
644    }
645
646    #[test]
647    fn test_tool_serialization() {
648        let tool = ToolBuilder::new("calculate", "Perform mathematical calculations")
649            .parameter(
650                "expression",
651                "string",
652                "Mathematical expression to evaluate",
653            )
654            .required("expression")
655            .build();
656
657        let json = serde_json::to_string(&tool).unwrap();
658        let deserialized: Tool = serde_json::from_str(&json).unwrap();
659        assert_eq!(tool, deserialized);
660    }
661
662    #[test]
663    fn test_tool_use_deserialization() {
664        let json = r#"
665        {
666            "id": "toolu_123456",
667            "name": "get_weather",
668            "input": {
669                "location": "San Francisco, CA",
670                "unit": "celsius"
671            }
672        }"#;
673
674        let tool_use: ToolUse = serde_json::from_str(json).unwrap();
675        assert_eq!(tool_use.id, "toolu_123456");
676        assert_eq!(tool_use.name, "get_weather");
677        assert_eq!(tool_use.input["location"], "San Francisco, CA");
678        assert_eq!(tool_use.input["unit"], "celsius");
679    }
680}