1use serde::{Deserialize, Serialize};
33use serde_json::{Map, Value};
34
35#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
40pub struct Tool {
41 pub name: String,
43
44 pub description: String,
46
47 pub input_schema: ToolInputSchema,
49}
50
51#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
55pub struct ToolInputSchema {
56 #[serde(rename = "type")]
58 pub schema_type: String,
59
60 pub properties: Map<String, Value>,
62
63 #[serde(skip_serializing_if = "Vec::is_empty")]
65 pub required: Vec<String>,
66
67 #[serde(flatten)]
69 pub additional: Map<String, Value>,
70}
71
72#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
76#[serde(tag = "type")]
77pub enum ToolChoice {
78 #[serde(rename = "auto")]
80 Auto,
81
82 #[serde(rename = "any")]
84 Any,
85
86 #[serde(rename = "tool")]
88 Tool {
89 name: String,
91 },
92}
93
94#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
99pub struct ToolUse {
100 pub id: String,
102
103 pub name: String,
105
106 pub input: Value,
108}
109
110#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
115pub struct ToolResult {
116 pub tool_use_id: String,
118
119 pub content: ToolResultContent,
121
122 #[serde(skip_serializing_if = "Option::is_none")]
124 pub is_error: Option<bool>,
125}
126
127#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
129#[serde(untagged)]
130pub enum ToolResultContent {
131 Text(String),
133
134 Json(Value),
136
137 Blocks(Vec<ToolResultBlock>),
139}
140
141#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
143#[serde(tag = "type")]
144pub enum ToolResultBlock {
145 #[serde(rename = "text")]
147 Text {
148 text: String,
150 },
151
152 #[serde(rename = "image")]
154 Image {
155 source: ImageSource,
157 },
158}
159
160#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
162#[serde(tag = "type")]
163pub enum ImageSource {
164 #[serde(rename = "base64")]
166 Base64 {
167 media_type: String,
169 data: String,
171 },
172}
173
174#[derive(Debug, Clone)]
178pub struct ToolBuilder {
179 name: String,
180 description: String,
181 properties: Map<String, Value>,
182 required: Vec<String>,
183 additional: Map<String, Value>,
184}
185
186impl ToolBuilder {
187 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
189 Self {
190 name: name.into(),
191 description: description.into(),
192 properties: Map::new(),
193 required: Vec::new(),
194 additional: Map::new(),
195 }
196 }
197
198 pub fn parameter(
205 mut self,
206 name: impl Into<String>,
207 param_type: impl Into<String>,
208 description: impl Into<String>,
209 ) -> Self {
210 let param_name = name.into();
211 let param_schema = serde_json::json!({
212 "type": param_type.into(),
213 "description": description.into()
214 });
215 self.properties.insert(param_name, param_schema);
216 self
217 }
218
219 pub fn enum_parameter(
221 mut self,
222 name: impl Into<String>,
223 description: impl Into<String>,
224 values: Vec<String>,
225 ) -> Self {
226 let param_name = name.into();
227 let param_schema = serde_json::json!({
228 "type": "string",
229 "description": description.into(),
230 "enum": values
231 });
232 self.properties.insert(param_name, param_schema);
233 self
234 }
235
236 pub fn array_parameter(
238 mut self,
239 name: impl Into<String>,
240 description: impl Into<String>,
241 item_type: impl Into<String>,
242 ) -> Self {
243 let param_name = name.into();
244 let param_schema = serde_json::json!({
245 "type": "array",
246 "description": description.into(),
247 "items": {
248 "type": item_type.into()
249 }
250 });
251 self.properties.insert(param_name, param_schema);
252 self
253 }
254
255 pub fn object_parameter(
257 mut self,
258 name: impl Into<String>,
259 description: impl Into<String>,
260 properties: Map<String, Value>,
261 ) -> Self {
262 let param_name = name.into();
263 let param_schema = serde_json::json!({
264 "type": "object",
265 "description": description.into(),
266 "properties": properties
267 });
268 self.properties.insert(param_name, param_schema);
269 self
270 }
271
272 pub fn required(mut self, name: impl Into<String>) -> Self {
274 let param_name = name.into();
275 if !self.required.contains(¶m_name) {
276 self.required.push(param_name);
277 }
278 self
279 }
280
281 pub fn additional_property(mut self, key: impl Into<String>, value: Value) -> Self {
283 self.additional.insert(key.into(), value);
284 self
285 }
286
287 pub fn build(self) -> Tool {
289 Tool {
290 name: self.name,
291 description: self.description,
292 input_schema: ToolInputSchema {
293 schema_type: "object".to_string(),
294 properties: self.properties,
295 required: self.required,
296 additional: self.additional,
297 },
298 }
299 }
300}
301
302impl Tool {
303 pub fn builder() -> ToolBuilder {
305 ToolBuilder {
306 name: String::new(),
307 description: String::new(),
308 properties: Map::new(),
309 required: Vec::new(),
310 additional: Map::new(),
311 }
312 }
313
314 pub fn validate_input(&self, input: &Value) -> Result<(), ToolValidationError> {
316 if let Value::Object(input_obj) = input {
318 for required_field in &self.input_schema.required {
319 if !input_obj.contains_key(required_field) {
320 return Err(ToolValidationError::MissingRequiredField {
321 field: required_field.clone(),
322 tool: self.name.clone(),
323 });
324 }
325 }
326
327 for (field_name, field_value) in input_obj {
329 if let Some(property_schema) = self.input_schema.properties.get(field_name) {
330 self.validate_field_type(field_name, field_value, property_schema)?;
331 }
332 }
333
334 Ok(())
335 } else {
336 Err(ToolValidationError::InvalidInputType {
337 expected: "object".to_string(),
338 actual: input.to_string(),
339 tool: self.name.clone(),
340 })
341 }
342 }
343
344 fn validate_field_type(
345 &self,
346 field_name: &str,
347 value: &Value,
348 schema: &Value,
349 ) -> Result<(), ToolValidationError> {
350 if let Some(expected_type) = schema.get("type").and_then(|t| t.as_str()) {
351 let actual_type = match value {
352 Value::Null => "null",
353 Value::Bool(_) => "boolean",
354 Value::Number(_) => "number",
355 Value::String(_) => "string",
356 Value::Array(_) => "array",
357 Value::Object(_) => "object",
358 };
359
360 if expected_type != actual_type {
361 return Err(ToolValidationError::InvalidFieldType {
362 field: field_name.to_string(),
363 expected: expected_type.to_string(),
364 actual: actual_type.to_string(),
365 tool: self.name.clone(),
366 });
367 }
368 }
369
370 Ok(())
371 }
372}
373
374impl ToolChoice {
375 pub fn auto() -> Self {
377 Self::Auto
378 }
379
380 pub fn any() -> Self {
382 Self::Any
383 }
384
385 pub fn tool(name: impl Into<String>) -> Self {
387 Self::Tool { name: name.into() }
388 }
389}
390
391impl ToolResult {
392 pub fn success(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
394 Self {
395 tool_use_id: tool_use_id.into(),
396 content: ToolResultContent::Text(content.into()),
397 is_error: None,
398 }
399 }
400
401 pub fn success_json(tool_use_id: impl Into<String>, content: Value) -> Self {
403 Self {
404 tool_use_id: tool_use_id.into(),
405 content: ToolResultContent::Json(content),
406 is_error: None,
407 }
408 }
409
410 pub fn error(tool_use_id: impl Into<String>, error_message: impl Into<String>) -> Self {
412 Self {
413 tool_use_id: tool_use_id.into(),
414 content: ToolResultContent::Text(error_message.into()),
415 is_error: Some(true),
416 }
417 }
418
419 pub fn with_blocks(tool_use_id: impl Into<String>, blocks: Vec<ToolResultBlock>) -> Self {
421 Self {
422 tool_use_id: tool_use_id.into(),
423 content: ToolResultContent::Blocks(blocks),
424 is_error: None,
425 }
426 }
427}
428
429impl ToolResultBlock {
430 pub fn text(text: impl Into<String>) -> Self {
432 Self::Text { text: text.into() }
433 }
434
435 pub fn image_base64(media_type: impl Into<String>, data: impl Into<String>) -> Self {
437 Self::Image {
438 source: ImageSource::Base64 {
439 media_type: media_type.into(),
440 data: data.into(),
441 },
442 }
443 }
444}
445
446#[derive(Debug, Clone, PartialEq, thiserror::Error)]
448pub enum ToolValidationError {
449 #[error("Missing required field '{field}' for tool '{tool}'")]
451 MissingRequiredField { field: String, tool: String },
452
453 #[error("Invalid input type for tool '{tool}': expected {expected}, got {actual}")]
455 InvalidInputType {
456 expected: String,
457 actual: String,
458 tool: String,
459 },
460
461 #[error("Invalid type for field '{field}' in tool '{tool}': expected {expected}, got {actual}")]
463 InvalidFieldType {
464 field: String,
465 expected: String,
466 actual: String,
467 tool: String,
468 },
469}
470
471#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
476#[serde(tag = "type")]
477pub enum ServerTool {
478 #[serde(rename = "web_search_20250305")]
480 WebSearch {
481 #[serde(skip_serializing_if = "Option::is_none")]
483 parameters: Option<WebSearchParameters>,
484 },
485}
486
487#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
489pub struct WebSearchParameters {
490 #[serde(skip_serializing_if = "Option::is_none")]
492 max_results: Option<u32>,
493
494 #[serde(skip_serializing_if = "Option::is_none")]
496 language: Option<String>,
497
498 #[serde(skip_serializing_if = "Option::is_none")]
500 region: Option<String>,
501}
502
503impl ServerTool {
504 pub fn web_search() -> Self {
506 Self::WebSearch { parameters: None }
507 }
508
509 pub fn web_search_with_params(parameters: WebSearchParameters) -> Self {
511 Self::WebSearch {
512 parameters: Some(parameters),
513 }
514 }
515}
516
517impl WebSearchParameters {
518 pub fn with_max_results(max_results: u32) -> Self {
520 Self {
521 max_results: Some(max_results),
522 language: None,
523 region: None,
524 }
525 }
526
527 pub fn language(mut self, language: impl Into<String>) -> Self {
529 self.language = Some(language.into());
530 self
531 }
532
533 pub fn region(mut self, region: impl Into<String>) -> Self {
535 self.region = Some(region.into());
536 self
537 }
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543 use serde_json::json;
544
545 #[test]
546 fn test_tool_builder() {
547 let tool = ToolBuilder::new("get_weather", "Get the current weather")
548 .parameter("location", "string", "The location to get weather for")
549 .parameter("unit", "string", "Temperature unit")
550 .enum_parameter(
551 "format",
552 "Response format",
553 vec!["json".to_string(), "text".to_string()],
554 )
555 .required("location")
556 .build();
557
558 assert_eq!(tool.name, "get_weather");
559 assert_eq!(tool.description, "Get the current weather");
560 assert_eq!(tool.input_schema.required, vec!["location"]);
561 assert_eq!(tool.input_schema.properties.len(), 3);
562 }
563
564 #[test]
565 fn test_tool_validation() {
566 let tool = ToolBuilder::new("test_tool", "Test tool")
567 .parameter("required_field", "string", "Required field")
568 .parameter("optional_field", "number", "Optional field")
569 .required("required_field")
570 .build();
571
572 let valid_input = json!({
574 "required_field": "test",
575 "optional_field": 42
576 });
577 assert!(tool.validate_input(&valid_input).is_ok());
578
579 let invalid_input = json!({
581 "optional_field": 42
582 });
583 assert!(tool.validate_input(&invalid_input).is_err());
584
585 let wrong_type_input = json!({
587 "required_field": 123
588 });
589 assert!(tool.validate_input(&wrong_type_input).is_err());
590 }
591
592 #[test]
593 fn test_tool_choice_serialization() {
594 let auto_choice = ToolChoice::auto();
595 let json = serde_json::to_value(&auto_choice).unwrap();
596 assert_eq!(json, json!({"type": "auto"}));
597
598 let tool_choice = ToolChoice::tool("get_weather");
599 let json = serde_json::to_value(&tool_choice).unwrap();
600 assert_eq!(json, json!({"type": "tool", "name": "get_weather"}));
601 }
602
603 #[test]
604 fn test_tool_result_creation() {
605 let success_result = ToolResult::success("tool_123", "Success message");
606 assert_eq!(success_result.tool_use_id, "tool_123");
607 assert!(success_result.is_error.is_none());
608
609 let error_result = ToolResult::error("tool_456", "Error message");
610 assert_eq!(error_result.tool_use_id, "tool_456");
611 assert_eq!(error_result.is_error, Some(true));
612
613 let json_result = ToolResult::success_json("tool_789", json!({"temperature": 72}));
614 if let ToolResultContent::Json(value) = json_result.content {
615 assert_eq!(value["temperature"], 72);
616 } else {
617 panic!("Expected JSON content");
618 }
619 }
620
621 #[test]
622 fn test_server_tool_creation() {
623 let web_search = ServerTool::web_search();
624 assert!(matches!(
625 web_search,
626 ServerTool::WebSearch { parameters: None }
627 ));
628
629 let params = WebSearchParameters::with_max_results(10)
630 .language("en")
631 .region("US");
632 let web_search_with_params = ServerTool::web_search_with_params(params);
633
634 if let ServerTool::WebSearch {
635 parameters: Some(p),
636 } = web_search_with_params
637 {
638 assert_eq!(p.max_results, Some(10));
639 assert_eq!(p.language, Some("en".to_string()));
640 assert_eq!(p.region, Some("US".to_string()));
641 } else {
642 panic!("Expected web search with parameters");
643 }
644 }
645
646 #[test]
647 fn test_tool_serialization() {
648 let tool = ToolBuilder::new("calculate", "Perform mathematical calculations")
649 .parameter(
650 "expression",
651 "string",
652 "Mathematical expression to evaluate",
653 )
654 .required("expression")
655 .build();
656
657 let json = serde_json::to_string(&tool).unwrap();
658 let deserialized: Tool = serde_json::from_str(&json).unwrap();
659 assert_eq!(tool, deserialized);
660 }
661
662 #[test]
663 fn test_tool_use_deserialization() {
664 let json = r#"
665 {
666 "id": "toolu_123456",
667 "name": "get_weather",
668 "input": {
669 "location": "San Francisco, CA",
670 "unit": "celsius"
671 }
672 }"#;
673
674 let tool_use: ToolUse = serde_json::from_str(json).unwrap();
675 assert_eq!(tool_use.id, "toolu_123456");
676 assert_eq!(tool_use.name, "get_weather");
677 assert_eq!(tool_use.input["location"], "San Francisco, CA");
678 assert_eq!(tool_use.input["unit"], "celsius");
679 }
680}