1use std::str::FromStr;
2
3use serde::{Deserialize, Serialize};
4
5use crate::openai::errors::ConversionError;
6
7#[derive(Debug, PartialEq, Serialize, Deserialize)]
8#[serde(rename_all = "lowercase")]
9pub enum ComparisonOperator {
10 Eq,
11 Ne,
12 Gt,
13 Gte,
14 Lt,
15 Lte,
16}
17
18impl FromStr for ComparisonOperator {
19 type Err = ConversionError;
20
21 fn from_str(s: &str) -> Result<Self, Self::Err> {
22 match s {
23 "eq" => Ok(ComparisonOperator::Eq),
24 "ne" => Ok(ComparisonOperator::Ne),
25 "gt" => Ok(ComparisonOperator::Gt),
26 "gte" => Ok(ComparisonOperator::Gte),
27 "lt" => Ok(ComparisonOperator::Lt),
28 "lte" => Ok(ComparisonOperator::Lte),
29 _ => Err(ConversionError::FromStr(s.to_string())),
30 }
31 }
32}
33
34#[derive(Debug, PartialEq, Serialize, Deserialize)]
35#[serde(untagged)]
36pub enum FilterValue {
37 String(String),
38 Boolean(bool),
39 Number(f64),
40}
41
42impl FilterValue {
43 pub fn string(filter: impl Into<String>) -> Self {
44 Self::String(filter.into())
45 }
46
47 pub fn boolean(filter: bool) -> Self {
48 Self::Boolean(filter)
49 }
50
51 pub fn number(filter: f64) -> Self {
52 Self::Number(filter)
53 }
54}
55
56impl From<String> for FilterValue {
57 fn from(value: String) -> Self {
58 FilterValue::String(value)
59 }
60}
61
62impl From<&str> for FilterValue {
63 fn from(value: &str) -> Self {
64 FilterValue::String(value.to_string())
65 }
66}
67
68impl From<bool> for FilterValue {
69 fn from(value: bool) -> Self {
70 FilterValue::Boolean(value)
71 }
72}
73
74impl From<f64> for FilterValue {
75 fn from(value: f64) -> Self {
76 FilterValue::Number(value)
77 }
78}
79
80#[derive(Debug, PartialEq, Serialize, Deserialize)]
81pub struct ComparisonFilter {
82 key: String,
83 #[serde(rename = "type")]
84 type_field: ComparisonOperator,
85 value: FilterValue,
86}
87
88impl ComparisonFilter {
89 pub fn build<V: Into<FilterValue>>(
90 key: impl Into<String>,
91 comparison_operator: impl AsRef<str>,
92 value: V,
93 ) -> Self {
94 Self {
95 key: key.into(),
96 type_field: ComparisonOperator::from_str(comparison_operator.as_ref()).unwrap(),
97 value: value.into(),
98 }
99 }
100}
101
102#[derive(Debug, PartialEq, Serialize, Deserialize)]
103#[serde(rename_all = "lowercase")]
104pub enum CompoundOperator {
105 And,
106 Or,
107}
108
109impl FromStr for CompoundOperator {
110 type Err = ConversionError;
111
112 fn from_str(s: &str) -> Result<Self, Self::Err> {
113 match s {
114 "and" => Ok(CompoundOperator::And),
115 "or" => Ok(CompoundOperator::Or),
116 _ => Err(ConversionError::FromStr(s.to_string())),
117 }
118 }
119}
120
121#[derive(Debug, PartialEq, Serialize, Deserialize)]
122pub struct CompoundFilter {
123 filters: Vec<FileSearchFilter>,
124 #[serde(rename = "type")]
125 type_field: CompoundOperator,
126}
127
128impl CompoundFilter {
129 pub fn build(filters: Vec<FileSearchFilter>, compound_operator: impl AsRef<str>) -> Self {
130 Self {
131 filters,
132 type_field: CompoundOperator::from_str(compound_operator.as_ref()).unwrap(),
133 }
134 }
135}
136
137#[derive(Debug, PartialEq, Serialize, Deserialize)]
138#[serde(untagged)]
139pub enum FileSearchFilter {
140 Comparison(ComparisonFilter),
141 Compound(CompoundFilter),
142}
143
144impl FileSearchFilter {
145 pub fn build_comparison_filter<V: Into<FilterValue>>(
146 key: impl Into<String>,
147 comparison_operator: impl AsRef<str>,
148 value: V,
149 ) -> Self {
150 Self::Comparison(ComparisonFilter::build(key, comparison_operator, value))
151 }
152
153 pub fn build_compound_filter(
154 filters: Vec<FileSearchFilter>,
155 compound_operator: impl AsRef<str>,
156 ) -> Self {
157 Self::Compound(CompoundFilter::build(filters, compound_operator))
158 }
159}
160
161#[derive(Debug, PartialEq, Serialize, Deserialize)]
162pub struct RankingOptions {
163 #[serde(skip_serializing_if = "Option::is_none")]
164 ranker: Option<String>,
165 #[serde(skip_serializing_if = "Option::is_none")]
166 score_threshold: Option<f32>,
167}
168
169impl RankingOptions {
170 pub fn new() -> Self {
171 Self {
172 ranker: None,
173 score_threshold: None,
174 }
175 }
176
177 pub fn ranker(mut self, value: impl Into<String>) -> Self {
178 self.ranker = Some(value.into());
179 self
180 }
181
182 pub fn score_threshold(mut self, value: f32) -> Self {
183 self.score_threshold = Some(value);
184 self
185 }
186}
187
188impl Default for RankingOptions {
189 fn default() -> Self {
190 Self::new()
191 }
192}
193
194#[derive(Debug, PartialEq, Serialize, Deserialize)]
195pub struct FileSearchTool {
196 #[serde(rename = "type")]
197 type_field: String,
198 vector_store_ids: Vec<String>,
199 #[serde(skip_serializing_if = "Option::is_none")]
200 filters: Option<FileSearchFilter>,
201 #[serde(skip_serializing_if = "Option::is_none")]
202 max_num_results: Option<u8>,
203 #[serde(skip_serializing_if = "Option::is_none")]
204 ranking_options: Option<RankingOptions>,
205}
206
207impl FileSearchTool {
208 pub fn new(vector_store_ids: Vec<impl Into<String>>) -> Self {
209 Self {
210 type_field: "file_search".to_string(),
211 vector_store_ids: vector_store_ids.into_iter().map(|id| id.into()).collect(),
212 filters: None,
213 max_num_results: None,
214 ranking_options: None,
215 }
216 }
217
218 pub fn filters(mut self, filters: FileSearchFilter) -> Self {
219 self.filters = Some(filters);
220 self
221 }
222
223 pub fn max_num_results(mut self, value: u8) -> Self {
224 self.max_num_results = Some(value);
225 self
226 }
227
228 pub fn ranking_options(mut self, value: RankingOptions) -> Self {
229 self.ranking_options = Some(value);
230 self
231 }
232}
233
234#[derive(Debug, PartialEq, Serialize, Deserialize)]
235pub struct FunctionTool {
236 name: String,
237 parameters: serde_json::Value,
238 strict: bool,
239 #[serde(rename = "type")]
240 type_field: String,
241 #[serde(skip_serializing_if = "Option::is_none")]
242 description: Option<String>,
243}
244
245impl FunctionTool {
246 pub fn new(name: impl Into<String>, parameters: serde_json::Value) -> Self {
247 Self {
248 name: name.into(),
249 parameters,
250 strict: true,
251 type_field: "function".to_string(),
252 description: None,
253 }
254 }
255
256 pub fn strict(mut self, value: bool) -> Self {
257 self.strict = value;
258 self
259 }
260
261 pub fn description(mut self, value: impl Into<String>) -> Self {
262 self.description = Some(value.into());
263 self
264 }
265}
266
267#[derive(Debug, PartialEq, Serialize, Deserialize)]
268pub struct ComputerUseTool {
269 display_height: f32,
270 display_width: f32,
271 environment: String,
272 #[serde(rename = "type")]
273 type_field: String,
274}
275
276impl ComputerUseTool {
277 pub fn new(display_height: f32, display_width: f32, environment: impl Into<String>) -> Self {
278 Self {
279 display_height,
280 display_width,
281 environment: environment.into(),
282 type_field: "computer_use_preview".to_string(),
283 }
284 }
285}
286
287#[derive(Debug, PartialEq, Serialize, Deserialize)]
288#[serde(rename_all = "lowercase")]
289pub enum SearchContextSize {
290 Low,
291 Medium,
292 High,
293}
294
295impl FromStr for SearchContextSize {
296 type Err = ConversionError;
297
298 fn from_str(s: &str) -> Result<Self, Self::Err> {
299 match s {
300 "low" => Ok(SearchContextSize::Low),
301 "medium" => Ok(SearchContextSize::Medium),
302 "high" => Ok(SearchContextSize::High),
303 _ => Err(ConversionError::FromStr(s.to_string())),
304 }
305 }
306}
307
308#[derive(Debug, PartialEq, Serialize, Deserialize)]
309pub struct UserLocation {
310 #[serde(rename = "type")]
311 type_field: String, #[serde(skip_serializing_if = "Option::is_none")]
313 city: Option<String>,
314 #[serde(skip_serializing_if = "Option::is_none")]
315 country: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
317 region: Option<String>,
318 #[serde(skip_serializing_if = "Option::is_none")]
319 timezone: Option<String>, }
321
322impl UserLocation {
323 pub fn new() -> Self {
324 Self {
325 type_field: "approximate".to_string(),
326 city: None,
327 country: None,
328 region: None,
329 timezone: None,
330 }
331 }
332
333 pub fn city(mut self, value: impl Into<String>) -> Self {
334 self.city = Some(value.into());
335 self
336 }
337
338 pub fn country(mut self, value: impl Into<String>) -> Self {
339 self.country = Some(value.into());
340 self
341 }
342
343 pub fn region(mut self, value: impl Into<String>) -> Self {
344 self.region = Some(value.into());
345 self
346 }
347
348 pub fn timezone(mut self, value: impl Into<String>) -> Self {
349 self.timezone = Some(value.into());
350 self
351 }
352}
353
354impl Default for UserLocation {
355 fn default() -> Self {
356 Self::new()
357 }
358}
359
360#[derive(Debug, PartialEq, Serialize, Deserialize)]
361pub struct WebSearchTool {
362 #[serde(rename = "type")]
363 type_field: String, #[serde(skip_serializing_if = "Option::is_none")]
365 search_context_size: Option<SearchContextSize>,
366 #[serde(skip_serializing_if = "Option::is_none")]
367 user_location: Option<UserLocation>,
368}
369
370impl WebSearchTool {
371 pub fn new(type_field: impl Into<String>) -> Self {
372 Self {
373 type_field: type_field.into(),
374 search_context_size: None,
375 user_location: None,
376 }
377 }
378
379 pub fn search_context_size(mut self, value: SearchContextSize) -> Self {
380 self.search_context_size = Some(value);
381 self
382 }
383
384 pub fn user_location(mut self, value: UserLocation) -> Self {
385 self.user_location = Some(value);
386 self
387 }
388}
389
390#[derive(Debug, PartialEq, Serialize, Deserialize)]
391#[serde(untagged)]
392pub enum Tool {
393 FileSearch(FileSearchTool),
394 Function(FunctionTool),
395 ComputerUse(ComputerUseTool),
396 WebSearch(WebSearchTool),
397}
398
399impl From<FileSearchTool> for Tool {
400 fn from(tool: FileSearchTool) -> Self {
401 Tool::FileSearch(tool)
402 }
403}
404
405impl TryFrom<Tool> for FileSearchTool {
406 type Error = ConversionError;
407
408 fn try_from(tool: Tool) -> Result<Self, Self::Error> {
409 match tool {
410 Tool::FileSearch(inner) => Ok(inner),
411 _ => Err(ConversionError::TryFrom("Tool".to_string())),
412 }
413 }
414}
415
416impl From<FunctionTool> for Tool {
417 fn from(tool: FunctionTool) -> Self {
418 Tool::Function(tool)
419 }
420}
421
422impl TryFrom<Tool> for FunctionTool {
423 type Error = ConversionError;
424
425 fn try_from(tool: Tool) -> Result<Self, Self::Error> {
426 match tool {
427 Tool::Function(inner) => Ok(inner),
428 _ => Err(ConversionError::TryFrom("Tool".to_string())),
429 }
430 }
431}
432
433impl From<ComputerUseTool> for Tool {
434 fn from(tool: ComputerUseTool) -> Self {
435 Tool::ComputerUse(tool)
436 }
437}
438
439impl TryFrom<Tool> for ComputerUseTool {
440 type Error = ConversionError;
441
442 fn try_from(tool: Tool) -> Result<Self, Self::Error> {
443 match tool {
444 Tool::ComputerUse(inner) => Ok(inner),
445 _ => Err(ConversionError::TryFrom("Tool".to_string())),
446 }
447 }
448}
449
450impl From<WebSearchTool> for Tool {
451 fn from(tool: WebSearchTool) -> Self {
452 Tool::WebSearch(tool)
453 }
454}
455
456impl TryFrom<Tool> for WebSearchTool {
457 type Error = ConversionError;
458
459 fn try_from(tool: Tool) -> Result<Self, Self::Error> {
460 match tool {
461 Tool::WebSearch(inner) => Ok(inner),
462 _ => Err(ConversionError::TryFrom("Tool".to_string())),
463 }
464 }
465}
466
467#[cfg(test)]
468mod tests {
469 use super::*;
470 use serde_json::json;
471
472 #[test]
473 fn it_creates_file_search_tool_with_comparison_operator() {
474 let vector_store_ids = vec![
475 "id_1".to_string(),
476 "id_2".to_string(),
477 "id_3".to_string(),
478 "id_4".to_string(),
479 ];
480 let tool: Tool = FileSearchTool::new(vector_store_ids.clone()).into();
481 let tool: Tool = FileSearchTool::try_from(tool)
482 .unwrap()
483 .ranking_options(
484 RankingOptions::new()
485 .ranker("test_ranker")
486 .score_threshold(1.0),
487 )
488 .filters(FileSearchFilter::build_comparison_filter(
489 "test_key",
490 "eq",
491 "test_value",
492 ))
493 .max_num_results(1)
494 .into();
495
496 let expected = Tool::FileSearch(FileSearchTool {
497 type_field: "file_search".to_string(),
498 vector_store_ids,
499 ranking_options: Some(RankingOptions {
500 ranker: Some("test_ranker".to_string()),
501 score_threshold: Some(1.0),
502 }),
503 filters: Some(FileSearchFilter::Comparison(ComparisonFilter {
504 key: "test_key".to_string(),
505 type_field: ComparisonOperator::Eq,
506 value: FilterValue::String("test_value".to_string()),
507 })),
508 max_num_results: Some(1),
509 });
510
511 assert_eq!(tool, expected);
512 }
513
514 #[test]
515 fn it_creates_file_search_tool_with_compound_operator() {
516 let vector_store_ids = vec![
517 "id_1".to_string(),
518 "id_2".to_string(),
519 "id_3".to_string(),
520 "id_4".to_string(),
521 ];
522 let tool: Tool = FileSearchTool::new(vector_store_ids.clone())
523 .filters(FileSearchFilter::build_compound_filter(
524 vec![FileSearchFilter::build_comparison_filter(
525 "test_key",
526 "eq",
527 "test_value",
528 )],
529 "and",
530 ))
531 .ranking_options(
532 RankingOptions::new()
533 .ranker("test_ranker")
534 .score_threshold(1.0),
535 )
536 .into();
537
538 let expected = Tool::FileSearch(FileSearchTool {
539 type_field: "file_search".to_string(),
540 vector_store_ids,
541 ranking_options: Some(RankingOptions {
542 ranker: Some("test_ranker".to_string()),
543 score_threshold: Some(1.0),
544 }),
545 filters: Some(FileSearchFilter::Compound(CompoundFilter {
546 type_field: CompoundOperator::And,
547 filters: vec![FileSearchFilter::Comparison(ComparisonFilter {
548 key: "test_key".to_string(),
549 type_field: ComparisonOperator::Eq,
550 value: FilterValue::String("test_value".to_string()),
551 })],
552 })),
553 max_num_results: None,
554 });
555
556 assert_eq!(tool, expected);
557 }
558
559 #[test]
560 fn it_creates_function_tool() {
561 let tool: Tool = FunctionTool::new(
562 "function_tool_test",
563 json!({
564 "name": "test"
565 }),
566 )
567 .description("this is description")
568 .into();
569
570 let expected = Tool::Function(FunctionTool {
571 description: Some("this is description".to_string()),
572 type_field: "function".to_string(),
573 strict: true,
574 parameters: json!({"name": "test"}),
575 name: "function_tool_test".to_string(),
576 });
577
578 assert_eq!(tool, expected);
579 }
580
581 #[test]
582 fn it_creates_computer_use_tool() {
583 let tool: Tool = ComputerUseTool::new(64.0, 64.0, "test_environment").into();
584
585 let expected = Tool::ComputerUse(ComputerUseTool {
586 type_field: "computer_use_preview".to_string(),
587 environment: "test_environment".to_string(),
588 display_width: 64.0,
589 display_height: 64.0,
590 });
591
592 assert_eq!(tool, expected);
593 }
594
595 #[test]
596 fn it_creates_web_search_tool() {
597 let tool: Tool = WebSearchTool::new("web_search_preview".to_string())
598 .search_context_size(SearchContextSize::Low)
599 .user_location(
600 UserLocation::new()
601 .city("Istanbul")
602 .country("TR")
603 .region("Marmara")
604 .timezone("Europe/Istanbul"),
605 )
606 .into();
607
608 let expected = Tool::WebSearch(WebSearchTool {
609 user_location: Some(UserLocation {
610 type_field: "approximate".to_string(),
611 city: Some("Istanbul".to_string()),
612 country: Some("TR".to_string()),
613 region: Some("Marmara".to_string()),
614 timezone: Some("Europe/Istanbul".to_string()),
615 }),
616 search_context_size: Some(SearchContextSize::Low),
617 type_field: "web_search_preview".to_string(),
618 });
619
620 assert_eq!(tool, expected);
621 }
622
623 #[test]
625 fn test_json_values() {
626 let tool: Tool = FileSearchTool::new(vec!["id_1", "id_2"])
628 .filters(FileSearchFilter::build_comparison_filter(
629 "test_key",
630 "eq",
631 "test_value".to_string(),
632 ))
633 .max_num_results(1)
634 .ranking_options(
635 RankingOptions::new()
636 .ranker("test_ranker")
637 .score_threshold(1.0),
638 )
639 .into();
640 let json_value = serde_json::to_value(&tool).unwrap();
641
642 assert_eq!(
643 json_value,
644 serde_json::json!({
645 "type": "file_search",
646 "vector_store_ids": ["id_1", "id_2"],
647 "filters": {
648 "type": "comparison",
649 "key": "test_key",
650 "type": "eq",
651 "value": "test_value"
652 },
653 "max_num_results": 1,
654 "ranking_options": {
655 "ranker": "test_ranker",
656 "score_threshold": 1.0
657 }
658 })
659 );
660
661 let tool: Tool = FunctionTool::new("test", json!({}))
663 .description("this is description")
664 .into();
665 let json_value = serde_json::to_value(&tool).unwrap();
666
667 assert_eq!(
668 json_value,
669 serde_json::json!({
670 "type": "function",
671 "name": "test",
672 "parameters": {},
673 "strict": true,
674 "description": "this is description"
675 })
676 );
677
678 let tool: Tool = ComputerUseTool::new(64.0, 64.0, "test_environment").into();
680 let json_value = serde_json::to_value(&tool).unwrap();
681
682 assert_eq!(
683 json_value,
684 serde_json::json!({
685 "type": "computer_use_preview",
686 "environment": "test_environment",
687 "display_width": 64.0,
688 "display_height": 64.0
689 })
690 );
691
692 let tool: Tool = WebSearchTool::new("web_search_preview".to_string())
694 .search_context_size(SearchContextSize::Low)
695 .user_location(
696 UserLocation::new()
697 .city("Istanbul")
698 .country("TR")
699 .region("Marmara")
700 .timezone("Europe/Istanbul"),
701 )
702 .into();
703 let json_value = serde_json::to_value(&tool).unwrap();
704
705 assert_eq!(
706 json_value,
707 serde_json::json!({
708 "type": "web_search_preview",
709 "search_context_size": "low",
710 "user_location": {
711 "type": "approximate",
712 "city": "Istanbul",
713 "country": "TR",
714 "region": "Marmara",
715 "timezone": "Europe/Istanbul"
716 }
717 })
718 );
719
720 let tool: Tool = WebSearchTool::new("web_search_preview_2025_03_11C".to_string())
722 .search_context_size(SearchContextSize::Low)
723 .user_location(
724 UserLocation::new()
725 .city("Istanbul")
726 .country("TR")
727 .region("Marmara")
728 .timezone("Europe/Istanbul"),
729 )
730 .into();
731 let json_value = serde_json::to_value(&tool).unwrap();
732
733 assert_eq!(
734 json_value,
735 serde_json::json!({
736 "type": "web_search_preview_2025_03_11C",
737 "search_context_size": "low",
738 "user_location": {
739 "type": "approximate",
740 "city": "Istanbul",
741 "country": "TR",
742 "region": "Marmara",
743 "timezone": "Europe/Istanbul"
744 }
745 })
746 );
747 }
748}