ai_providers/openai/common/
tool.rs

1use std::str::FromStr;
2
3use serde::{Deserialize, Serialize};
4
5use crate::openai::errors::ConversionError;
6
7#[derive(Debug, PartialEq, Serialize, Deserialize)]
8#[serde(rename_all = "lowercase")]
9pub enum ComparisonOperator {
10    Eq,
11    Ne,
12    Gt,
13    Gte,
14    Lt,
15    Lte,
16}
17
18impl FromStr for ComparisonOperator {
19    type Err = ConversionError;
20
21    fn from_str(s: &str) -> Result<Self, Self::Err> {
22        match s {
23            "eq" => Ok(ComparisonOperator::Eq),
24            "ne" => Ok(ComparisonOperator::Ne),
25            "gt" => Ok(ComparisonOperator::Gt),
26            "gte" => Ok(ComparisonOperator::Gte),
27            "lt" => Ok(ComparisonOperator::Lt),
28            "lte" => Ok(ComparisonOperator::Lte),
29            _ => Err(ConversionError::FromStr(s.to_string())),
30        }
31    }
32}
33
34#[derive(Debug, PartialEq, Serialize, Deserialize)]
35#[serde(untagged)]
36pub enum FilterValue {
37    String(String),
38    Boolean(bool),
39    Number(f64),
40}
41
42impl FilterValue {
43    pub fn string(filter: impl Into<String>) -> Self {
44        Self::String(filter.into())
45    }
46
47    pub fn boolean(filter: bool) -> Self {
48        Self::Boolean(filter)
49    }
50
51    pub fn number(filter: f64) -> Self {
52        Self::Number(filter)
53    }
54}
55
56impl From<String> for FilterValue {
57    fn from(value: String) -> Self {
58        FilterValue::String(value)
59    }
60}
61
62impl From<&str> for FilterValue {
63    fn from(value: &str) -> Self {
64        FilterValue::String(value.to_string())
65    }
66}
67
68impl From<bool> for FilterValue {
69    fn from(value: bool) -> Self {
70        FilterValue::Boolean(value)
71    }
72}
73
74impl From<f64> for FilterValue {
75    fn from(value: f64) -> Self {
76        FilterValue::Number(value)
77    }
78}
79
80#[derive(Debug, PartialEq, Serialize, Deserialize)]
81pub struct ComparisonFilter {
82    key: String,
83    #[serde(rename = "type")]
84    type_field: ComparisonOperator,
85    value: FilterValue,
86}
87
88impl ComparisonFilter {
89    pub fn build<V: Into<FilterValue>>(
90        key: impl Into<String>,
91        comparison_operator: impl AsRef<str>,
92        value: V,
93    ) -> Self {
94        Self {
95            key: key.into(),
96            type_field: ComparisonOperator::from_str(comparison_operator.as_ref()).unwrap(),
97            value: value.into(),
98        }
99    }
100}
101
102#[derive(Debug, PartialEq, Serialize, Deserialize)]
103#[serde(rename_all = "lowercase")]
104pub enum CompoundOperator {
105    And,
106    Or,
107}
108
109impl FromStr for CompoundOperator {
110    type Err = ConversionError;
111
112    fn from_str(s: &str) -> Result<Self, Self::Err> {
113        match s {
114            "and" => Ok(CompoundOperator::And),
115            "or" => Ok(CompoundOperator::Or),
116            _ => Err(ConversionError::FromStr(s.to_string())),
117        }
118    }
119}
120
121#[derive(Debug, PartialEq, Serialize, Deserialize)]
122pub struct CompoundFilter {
123    filters: Vec<FileSearchFilter>,
124    #[serde(rename = "type")]
125    type_field: CompoundOperator,
126}
127
128impl CompoundFilter {
129    pub fn build(filters: Vec<FileSearchFilter>, compound_operator: impl AsRef<str>) -> Self {
130        Self {
131            filters,
132            type_field: CompoundOperator::from_str(compound_operator.as_ref()).unwrap(),
133        }
134    }
135}
136
137#[derive(Debug, PartialEq, Serialize, Deserialize)]
138#[serde(untagged)]
139pub enum FileSearchFilter {
140    Comparison(ComparisonFilter),
141    Compound(CompoundFilter),
142}
143
144impl FileSearchFilter {
145    pub fn build_comparison_filter<V: Into<FilterValue>>(
146        key: impl Into<String>,
147        comparison_operator: impl AsRef<str>,
148        value: V,
149    ) -> Self {
150        Self::Comparison(ComparisonFilter::build(key, comparison_operator, value))
151    }
152
153    pub fn build_compound_filter(
154        filters: Vec<FileSearchFilter>,
155        compound_operator: impl AsRef<str>,
156    ) -> Self {
157        Self::Compound(CompoundFilter::build(filters, compound_operator))
158    }
159}
160
161#[derive(Debug, PartialEq, Serialize, Deserialize)]
162pub struct RankingOptions {
163    #[serde(skip_serializing_if = "Option::is_none")]
164    ranker: Option<String>,
165    #[serde(skip_serializing_if = "Option::is_none")]
166    score_threshold: Option<f32>,
167}
168
169impl RankingOptions {
170    pub fn new() -> Self {
171        Self {
172            ranker: None,
173            score_threshold: None,
174        }
175    }
176
177    pub fn ranker(mut self, value: impl Into<String>) -> Self {
178        self.ranker = Some(value.into());
179        self
180    }
181
182    pub fn score_threshold(mut self, value: f32) -> Self {
183        self.score_threshold = Some(value);
184        self
185    }
186}
187
188impl Default for RankingOptions {
189    fn default() -> Self {
190        Self::new()
191    }
192}
193
194#[derive(Debug, PartialEq, Serialize, Deserialize)]
195pub struct FileSearchTool {
196    #[serde(rename = "type")]
197    type_field: String,
198    vector_store_ids: Vec<String>,
199    #[serde(skip_serializing_if = "Option::is_none")]
200    filters: Option<FileSearchFilter>,
201    #[serde(skip_serializing_if = "Option::is_none")]
202    max_num_results: Option<u8>,
203    #[serde(skip_serializing_if = "Option::is_none")]
204    ranking_options: Option<RankingOptions>,
205}
206
207impl FileSearchTool {
208    pub fn new(vector_store_ids: Vec<impl Into<String>>) -> Self {
209        Self {
210            type_field: "file_search".to_string(),
211            vector_store_ids: vector_store_ids.into_iter().map(|id| id.into()).collect(),
212            filters: None,
213            max_num_results: None,
214            ranking_options: None,
215        }
216    }
217
218    pub fn filters(mut self, filters: FileSearchFilter) -> Self {
219        self.filters = Some(filters);
220        self
221    }
222
223    pub fn max_num_results(mut self, value: u8) -> Self {
224        self.max_num_results = Some(value);
225        self
226    }
227
228    pub fn ranking_options(mut self, value: RankingOptions) -> Self {
229        self.ranking_options = Some(value);
230        self
231    }
232}
233
234#[derive(Debug, PartialEq, Serialize, Deserialize)]
235pub struct FunctionTool {
236    name: String,
237    parameters: serde_json::Value,
238    strict: bool,
239    #[serde(rename = "type")]
240    type_field: String,
241    #[serde(skip_serializing_if = "Option::is_none")]
242    description: Option<String>,
243}
244
245impl FunctionTool {
246    pub fn new(name: impl Into<String>, parameters: serde_json::Value) -> Self {
247        Self {
248            name: name.into(),
249            parameters,
250            strict: true,
251            type_field: "function".to_string(),
252            description: None,
253        }
254    }
255
256    pub fn strict(mut self, value: bool) -> Self {
257        self.strict = value;
258        self
259    }
260
261    pub fn description(mut self, value: impl Into<String>) -> Self {
262        self.description = Some(value.into());
263        self
264    }
265}
266
267#[derive(Debug, PartialEq, Serialize, Deserialize)]
268pub struct ComputerUseTool {
269    display_height: f32,
270    display_width: f32,
271    environment: String,
272    #[serde(rename = "type")]
273    type_field: String,
274}
275
276impl ComputerUseTool {
277    pub fn new(display_height: f32, display_width: f32, environment: impl Into<String>) -> Self {
278        Self {
279            display_height,
280            display_width,
281            environment: environment.into(),
282            type_field: "computer_use_preview".to_string(),
283        }
284    }
285}
286
287#[derive(Debug, PartialEq, Serialize, Deserialize)]
288#[serde(rename_all = "lowercase")]
289pub enum SearchContextSize {
290    Low,
291    Medium,
292    High,
293}
294
295impl FromStr for SearchContextSize {
296    type Err = ConversionError;
297
298    fn from_str(s: &str) -> Result<Self, Self::Err> {
299        match s {
300            "low" => Ok(SearchContextSize::Low),
301            "medium" => Ok(SearchContextSize::Medium),
302            "high" => Ok(SearchContextSize::High),
303            _ => Err(ConversionError::FromStr(s.to_string())),
304        }
305    }
306}
307
308#[derive(Debug, PartialEq, Serialize, Deserialize)]
309pub struct UserLocation {
310    #[serde(rename = "type")]
311    type_field: String, // NOTE: this is always "approximate" value
312    #[serde(skip_serializing_if = "Option::is_none")]
313    city: Option<String>,
314    #[serde(skip_serializing_if = "Option::is_none")]
315    country: Option<String>, // NOTE: this is ISO-3166 country code
316    #[serde(skip_serializing_if = "Option::is_none")]
317    region: Option<String>,
318    #[serde(skip_serializing_if = "Option::is_none")]
319    timezone: Option<String>, // NOTE: this is IANA timezone
320}
321
322impl UserLocation {
323    pub fn new() -> Self {
324        Self {
325            type_field: "approximate".to_string(),
326            city: None,
327            country: None,
328            region: None,
329            timezone: None,
330        }
331    }
332
333    pub fn city(mut self, value: impl Into<String>) -> Self {
334        self.city = Some(value.into());
335        self
336    }
337
338    pub fn country(mut self, value: impl Into<String>) -> Self {
339        self.country = Some(value.into());
340        self
341    }
342
343    pub fn region(mut self, value: impl Into<String>) -> Self {
344        self.region = Some(value.into());
345        self
346    }
347
348    pub fn timezone(mut self, value: impl Into<String>) -> Self {
349        self.timezone = Some(value.into());
350        self
351    }
352}
353
354impl Default for UserLocation {
355    fn default() -> Self {
356        Self::new()
357    }
358}
359
360#[derive(Debug, PartialEq, Serialize, Deserialize)]
361pub struct WebSearchTool {
362    #[serde(rename = "type")]
363    type_field: String, // NOTE: this is either web_search_preview or web_search_preview_2025_03_11C
364    #[serde(skip_serializing_if = "Option::is_none")]
365    search_context_size: Option<SearchContextSize>,
366    #[serde(skip_serializing_if = "Option::is_none")]
367    user_location: Option<UserLocation>,
368}
369
370impl WebSearchTool {
371    pub fn new(type_field: impl Into<String>) -> Self {
372        Self {
373            type_field: type_field.into(),
374            search_context_size: None,
375            user_location: None,
376        }
377    }
378
379    pub fn search_context_size(mut self, value: SearchContextSize) -> Self {
380        self.search_context_size = Some(value);
381        self
382    }
383
384    pub fn user_location(mut self, value: UserLocation) -> Self {
385        self.user_location = Some(value);
386        self
387    }
388}
389
390#[derive(Debug, PartialEq, Serialize, Deserialize)]
391#[serde(untagged)]
392pub enum Tool {
393    FileSearch(FileSearchTool),
394    Function(FunctionTool),
395    ComputerUse(ComputerUseTool),
396    WebSearch(WebSearchTool),
397}
398
399impl From<FileSearchTool> for Tool {
400    fn from(tool: FileSearchTool) -> Self {
401        Tool::FileSearch(tool)
402    }
403}
404
405impl TryFrom<Tool> for FileSearchTool {
406    type Error = ConversionError;
407
408    fn try_from(tool: Tool) -> Result<Self, Self::Error> {
409        match tool {
410            Tool::FileSearch(inner) => Ok(inner),
411            _ => Err(ConversionError::TryFrom("Tool".to_string())),
412        }
413    }
414}
415
416impl From<FunctionTool> for Tool {
417    fn from(tool: FunctionTool) -> Self {
418        Tool::Function(tool)
419    }
420}
421
422impl TryFrom<Tool> for FunctionTool {
423    type Error = ConversionError;
424
425    fn try_from(tool: Tool) -> Result<Self, Self::Error> {
426        match tool {
427            Tool::Function(inner) => Ok(inner),
428            _ => Err(ConversionError::TryFrom("Tool".to_string())),
429        }
430    }
431}
432
433impl From<ComputerUseTool> for Tool {
434    fn from(tool: ComputerUseTool) -> Self {
435        Tool::ComputerUse(tool)
436    }
437}
438
439impl TryFrom<Tool> for ComputerUseTool {
440    type Error = ConversionError;
441
442    fn try_from(tool: Tool) -> Result<Self, Self::Error> {
443        match tool {
444            Tool::ComputerUse(inner) => Ok(inner),
445            _ => Err(ConversionError::TryFrom("Tool".to_string())),
446        }
447    }
448}
449
450impl From<WebSearchTool> for Tool {
451    fn from(tool: WebSearchTool) -> Self {
452        Tool::WebSearch(tool)
453    }
454}
455
456impl TryFrom<Tool> for WebSearchTool {
457    type Error = ConversionError;
458
459    fn try_from(tool: Tool) -> Result<Self, Self::Error> {
460        match tool {
461            Tool::WebSearch(inner) => Ok(inner),
462            _ => Err(ConversionError::TryFrom("Tool".to_string())),
463        }
464    }
465}
466
467#[cfg(test)]
468mod tests {
469    use super::*;
470    use serde_json::json;
471
472    #[test]
473    fn it_creates_file_search_tool_with_comparison_operator() {
474        let vector_store_ids = vec![
475            "id_1".to_string(),
476            "id_2".to_string(),
477            "id_3".to_string(),
478            "id_4".to_string(),
479        ];
480        let tool: Tool = FileSearchTool::new(vector_store_ids.clone()).into();
481        let tool: Tool = FileSearchTool::try_from(tool)
482            .unwrap()
483            .ranking_options(
484                RankingOptions::new()
485                    .ranker("test_ranker")
486                    .score_threshold(1.0),
487            )
488            .filters(FileSearchFilter::build_comparison_filter(
489                "test_key",
490                "eq",
491                "test_value",
492            ))
493            .max_num_results(1)
494            .into();
495
496        let expected = Tool::FileSearch(FileSearchTool {
497            type_field: "file_search".to_string(),
498            vector_store_ids,
499            ranking_options: Some(RankingOptions {
500                ranker: Some("test_ranker".to_string()),
501                score_threshold: Some(1.0),
502            }),
503            filters: Some(FileSearchFilter::Comparison(ComparisonFilter {
504                key: "test_key".to_string(),
505                type_field: ComparisonOperator::Eq,
506                value: FilterValue::String("test_value".to_string()),
507            })),
508            max_num_results: Some(1),
509        });
510
511        assert_eq!(tool, expected);
512    }
513
514    #[test]
515    fn it_creates_file_search_tool_with_compound_operator() {
516        let vector_store_ids = vec![
517            "id_1".to_string(),
518            "id_2".to_string(),
519            "id_3".to_string(),
520            "id_4".to_string(),
521        ];
522        let tool: Tool = FileSearchTool::new(vector_store_ids.clone())
523            .filters(FileSearchFilter::build_compound_filter(
524                vec![FileSearchFilter::build_comparison_filter(
525                    "test_key",
526                    "eq",
527                    "test_value",
528                )],
529                "and",
530            ))
531            .ranking_options(
532                RankingOptions::new()
533                    .ranker("test_ranker")
534                    .score_threshold(1.0),
535            )
536            .into();
537
538        let expected = Tool::FileSearch(FileSearchTool {
539            type_field: "file_search".to_string(),
540            vector_store_ids,
541            ranking_options: Some(RankingOptions {
542                ranker: Some("test_ranker".to_string()),
543                score_threshold: Some(1.0),
544            }),
545            filters: Some(FileSearchFilter::Compound(CompoundFilter {
546                type_field: CompoundOperator::And,
547                filters: vec![FileSearchFilter::Comparison(ComparisonFilter {
548                    key: "test_key".to_string(),
549                    type_field: ComparisonOperator::Eq,
550                    value: FilterValue::String("test_value".to_string()),
551                })],
552            })),
553            max_num_results: None,
554        });
555
556        assert_eq!(tool, expected);
557    }
558
559    #[test]
560    fn it_creates_function_tool() {
561        let tool: Tool = FunctionTool::new(
562            "function_tool_test",
563            json!({
564                "name": "test"
565            }),
566        )
567        .description("this is description")
568        .into();
569
570        let expected = Tool::Function(FunctionTool {
571            description: Some("this is description".to_string()),
572            type_field: "function".to_string(),
573            strict: true,
574            parameters: json!({"name": "test"}),
575            name: "function_tool_test".to_string(),
576        });
577
578        assert_eq!(tool, expected);
579    }
580
581    #[test]
582    fn it_creates_computer_use_tool() {
583        let tool: Tool = ComputerUseTool::new(64.0, 64.0, "test_environment").into();
584
585        let expected = Tool::ComputerUse(ComputerUseTool {
586            type_field: "computer_use_preview".to_string(),
587            environment: "test_environment".to_string(),
588            display_width: 64.0,
589            display_height: 64.0,
590        });
591
592        assert_eq!(tool, expected);
593    }
594
595    #[test]
596    fn it_creates_web_search_tool() {
597        let tool: Tool = WebSearchTool::new("web_search_preview".to_string())
598            .search_context_size(SearchContextSize::Low)
599            .user_location(
600                UserLocation::new()
601                    .city("Istanbul")
602                    .country("TR")
603                    .region("Marmara")
604                    .timezone("Europe/Istanbul"),
605            )
606            .into();
607
608        let expected = Tool::WebSearch(WebSearchTool {
609            user_location: Some(UserLocation {
610                type_field: "approximate".to_string(),
611                city: Some("Istanbul".to_string()),
612                country: Some("TR".to_string()),
613                region: Some("Marmara".to_string()),
614                timezone: Some("Europe/Istanbul".to_string()),
615            }),
616            search_context_size: Some(SearchContextSize::Low),
617            type_field: "web_search_preview".to_string(),
618        });
619
620        assert_eq!(tool, expected);
621    }
622
623    // test the json values of the tool
624    #[test]
625    fn test_json_values() {
626        // FileSearchTool test
627        let tool: Tool = FileSearchTool::new(vec!["id_1", "id_2"])
628            .filters(FileSearchFilter::build_comparison_filter(
629                "test_key",
630                "eq",
631                "test_value".to_string(),
632            ))
633            .max_num_results(1)
634            .ranking_options(
635                RankingOptions::new()
636                    .ranker("test_ranker")
637                    .score_threshold(1.0),
638            )
639            .into();
640        let json_value = serde_json::to_value(&tool).unwrap();
641
642        assert_eq!(
643            json_value,
644            serde_json::json!({
645                "type": "file_search",
646                "vector_store_ids": ["id_1", "id_2"],
647                "filters": {
648                    "type": "comparison",
649                    "key": "test_key",
650                    "type": "eq",
651                    "value": "test_value"
652                },
653                "max_num_results": 1,
654                "ranking_options": {
655                    "ranker": "test_ranker",
656                    "score_threshold": 1.0
657                }
658            })
659        );
660
661        // FunctionTool test
662        let tool: Tool = FunctionTool::new("test", json!({}))
663            .description("this is description")
664            .into();
665        let json_value = serde_json::to_value(&tool).unwrap();
666
667        assert_eq!(
668            json_value,
669            serde_json::json!({
670                "type": "function",
671                "name": "test",
672                "parameters": {},
673                "strict": true,
674                "description": "this is description"
675            })
676        );
677
678        // ComputerUseTool test
679        let tool: Tool = ComputerUseTool::new(64.0, 64.0, "test_environment").into();
680        let json_value = serde_json::to_value(&tool).unwrap();
681
682        assert_eq!(
683            json_value,
684            serde_json::json!({
685                "type": "computer_use_preview",
686                "environment": "test_environment",
687                "display_width": 64.0,
688                "display_height": 64.0
689            })
690        );
691
692        // WebSearchTool test with web_search_preview
693        let tool: Tool = WebSearchTool::new("web_search_preview".to_string())
694            .search_context_size(SearchContextSize::Low)
695            .user_location(
696                UserLocation::new()
697                    .city("Istanbul")
698                    .country("TR")
699                    .region("Marmara")
700                    .timezone("Europe/Istanbul"),
701            )
702            .into();
703        let json_value = serde_json::to_value(&tool).unwrap();
704
705        assert_eq!(
706            json_value,
707            serde_json::json!({
708                "type": "web_search_preview",
709                "search_context_size": "low",
710                "user_location": {
711                    "type": "approximate",
712                    "city": "Istanbul",
713                    "country": "TR",
714                    "region": "Marmara",
715                    "timezone": "Europe/Istanbul"
716                }
717            })
718        );
719
720        // WebSearchTool test with web_search_preview_2025_03_11C
721        let tool: Tool = WebSearchTool::new("web_search_preview_2025_03_11C".to_string())
722            .search_context_size(SearchContextSize::Low)
723            .user_location(
724                UserLocation::new()
725                    .city("Istanbul")
726                    .country("TR")
727                    .region("Marmara")
728                    .timezone("Europe/Istanbul"),
729            )
730            .into();
731        let json_value = serde_json::to_value(&tool).unwrap();
732
733        assert_eq!(
734            json_value,
735            serde_json::json!({
736                "type": "web_search_preview_2025_03_11C",
737                "search_context_size": "low",
738                "user_location": {
739                    "type": "approximate",
740                    "city": "Istanbul",
741                    "country": "TR",
742                    "region": "Marmara",
743                    "timezone": "Europe/Istanbul"
744                }
745            })
746        );
747    }
748}