1use super::InputValue;
7use crate::agent;
8use indexmap::IndexMap;
9use schemars::JsonSchema;
10use serde::{Deserialize, Serialize};
11
12#[derive(
17 Debug,
18 Clone,
19 PartialEq,
20 Serialize,
21 Deserialize,
22 JsonSchema,
23 arbitrary::Arbitrary,
24)]
25#[serde(untagged)]
26#[schemars(rename = "functions.expression.InputSchema")]
27pub enum InputSchema {
28 #[schemars(title = "AnyOf")]
30 AnyOf(AnyOfInputSchema),
31 #[schemars(title = "Object")]
33 Object(ObjectInputSchema),
34 #[schemars(title = "Array")]
36 Array(ArrayInputSchema),
37 #[schemars(title = "String")]
39 String(StringInputSchema),
40 #[schemars(title = "Integer")]
42 Integer(IntegerInputSchema),
43 #[schemars(title = "Number")]
45 Number(NumberInputSchema),
46 #[schemars(title = "Boolean")]
48 Boolean(BooleanInputSchema),
49 #[schemars(title = "Image")]
51 Image(ImageInputSchema),
52 #[schemars(title = "Audio")]
54 Audio(AudioInputSchema),
55 #[schemars(title = "Video")]
57 Video(VideoInputSchema),
58 #[schemars(title = "File")]
60 File(FileInputSchema),
61}
62
63impl InputSchema {
64 pub fn modalities(&self) -> Modalities {
66 match self {
67 InputSchema::Image(_) => Modalities {
68 image: true,
69 ..Modalities::default()
70 },
71 InputSchema::Audio(_) => Modalities {
72 audio: true,
73 ..Modalities::default()
74 },
75 InputSchema::Video(_) => Modalities {
76 video: true,
77 ..Modalities::default()
78 },
79 InputSchema::File(_) => Modalities {
80 file: true,
81 ..Modalities::default()
82 },
83 InputSchema::Object(s) => s.modalities(),
84 InputSchema::Array(s) => s.modalities(),
85 InputSchema::AnyOf(s) => s.modalities(),
86 InputSchema::String(_)
87 | InputSchema::Integer(_)
88 | InputSchema::Number(_)
89 | InputSchema::Boolean(_) => Modalities::default(),
90 }
91 }
92
93 pub fn validate_input(&self, input: &InputValue) -> bool {
95 match self {
96 InputSchema::Object(schema) => schema.validate_input(input),
97 InputSchema::Array(schema) => schema.validate_input(input),
98 InputSchema::String(schema) => schema.validate_input(input),
99 InputSchema::Integer(schema) => schema.validate_input(input),
100 InputSchema::Number(schema) => schema.validate_input(input),
101 InputSchema::Boolean(schema) => schema.validate_input(input),
102 InputSchema::Image(schema) => schema.validate_input(input),
103 InputSchema::Audio(schema) => schema.validate_input(input),
104 InputSchema::Video(schema) => schema.validate_input(input),
105 InputSchema::File(schema) => schema.validate_input(input),
106 InputSchema::AnyOf(schema) => schema.validate_input(input),
107 }
108 }
109}
110
111#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
113pub struct Modalities {
114 pub image: bool,
115 pub audio: bool,
116 pub video: bool,
117 pub file: bool,
118}
119
120impl Modalities {
121 pub fn merge(self, other: Self) -> Self {
123 Self {
124 image: self.image || other.image,
125 audio: self.audio || other.audio,
126 video: self.video || other.video,
127 file: self.file || other.file,
128 }
129 }
130}
131
132#[derive(
134 Debug,
135 Clone,
136 PartialEq,
137 Serialize,
138 Deserialize,
139 JsonSchema,
140 arbitrary::Arbitrary,
141)]
142#[serde(rename_all = "camelCase")]
143#[schemars(rename = "functions.expression.AnyOfInputSchema")]
144pub struct AnyOfInputSchema {
145 pub any_of: Vec<InputSchema>,
147}
148
149impl AnyOfInputSchema {
150 pub fn modalities(&self) -> Modalities {
152 self.any_of
153 .iter()
154 .fold(Modalities::default(), |acc, s| acc.merge(s.modalities()))
155 }
156
157 pub fn validate_input(&self, input: &InputValue) -> bool {
159 self.any_of
160 .iter()
161 .any(|schema| schema.validate_input(input))
162 }
163}
164
165#[derive(
166 Debug,
167 Clone,
168 Copy,
169 Default,
170 PartialEq,
171 Serialize,
172 Deserialize,
173 JsonSchema,
174 arbitrary::Arbitrary,
175)]
176#[serde(rename_all = "lowercase")]
177#[schemars(rename = "functions.expression.ObjectInputSchemaType")]
178pub enum ObjectInputSchemaType {
179 #[default]
180 Object,
181}
182
183#[derive(
185 Debug,
186 Clone,
187 PartialEq,
188 Serialize,
189 Deserialize,
190 JsonSchema,
191 arbitrary::Arbitrary,
192)]
193#[serde(rename_all = "camelCase")]
194#[schemars(rename = "functions.expression.ObjectInputSchema")]
195pub struct ObjectInputSchema {
196 pub r#type: ObjectInputSchemaType,
197 #[serde(skip_serializing_if = "Option::is_none")]
199 #[schemars(extend("omitempty" = true))]
200 pub description: Option<String>,
201 #[arbitrary(with = crate::arbitrary_util::arbitrary_indexmap)]
203 pub properties: IndexMap<String, InputSchema>,
204 #[serde(skip_serializing_if = "Option::is_none")]
206 #[schemars(extend("omitempty" = true))]
207 pub required: Option<Vec<String>>,
208}
209
210impl ObjectInputSchema {
211 pub fn modalities(&self) -> Modalities {
213 self.properties
214 .values()
215 .fold(Modalities::default(), |acc, s| acc.merge(s.modalities()))
216 }
217
218 pub fn validate_input(&self, input: &InputValue) -> bool {
220 match input {
221 InputValue::Object(map) => {
222 let required = self.required.as_deref().unwrap_or(&[]);
223 self.properties
224 .iter()
225 .all(|(key, schema)| match map.get(key) {
226 Some(value) => schema.validate_input(value),
227 None => !required.contains(key),
228 })
229 }
230 _ => false,
231 }
232 }
233}
234
235#[derive(
236 Debug,
237 Clone,
238 Copy,
239 Default,
240 PartialEq,
241 Serialize,
242 Deserialize,
243 JsonSchema,
244 arbitrary::Arbitrary,
245)]
246#[serde(rename_all = "lowercase")]
247#[schemars(rename = "functions.expression.ArrayInputSchemaType")]
248pub enum ArrayInputSchemaType {
249 #[default]
250 Array,
251}
252
253#[derive(
255 Debug,
256 Clone,
257 PartialEq,
258 Serialize,
259 Deserialize,
260 JsonSchema,
261 arbitrary::Arbitrary,
262)]
263#[serde(rename_all = "camelCase")]
264#[schemars(rename = "functions.expression.ArrayInputSchema")]
265pub struct ArrayInputSchema {
266 pub r#type: ArrayInputSchemaType,
267 #[serde(skip_serializing_if = "Option::is_none")]
269 #[schemars(extend("omitempty" = true))]
270 pub description: Option<String>,
271 #[serde(skip_serializing_if = "Option::is_none")]
273 #[schemars(extend("omitempty" = true))]
274 #[arbitrary(with = crate::arbitrary_util::arbitrary_option_u64)]
275 pub min_items: Option<u64>,
276 #[serde(skip_serializing_if = "Option::is_none")]
278 #[schemars(extend("omitempty" = true))]
279 #[arbitrary(with = crate::arbitrary_util::arbitrary_option_u64)]
280 pub max_items: Option<u64>,
281 pub items: Box<InputSchema>,
283}
284
285impl ArrayInputSchema {
286 pub fn modalities(&self) -> Modalities {
288 self.items.modalities()
289 }
290
291 pub fn validate_input(&self, input: &InputValue) -> bool {
293 match input {
294 InputValue::Array(array) => {
295 if let Some(min_items) = self.min_items
296 && (array.len() as u64) < min_items
297 {
298 false
299 } else if let Some(max_items) = self.max_items
300 && (array.len() as u64) > max_items
301 {
302 false
303 } else {
304 array.iter().all(|item| self.items.validate_input(item))
305 }
306 }
307 _ => false,
308 }
309 }
310}
311
312#[derive(
313 Debug,
314 Clone,
315 Copy,
316 Default,
317 PartialEq,
318 Serialize,
319 Deserialize,
320 JsonSchema,
321 arbitrary::Arbitrary,
322)]
323#[serde(rename_all = "lowercase")]
324#[schemars(rename = "functions.expression.StringInputSchemaType")]
325pub enum StringInputSchemaType {
326 #[default]
327 String,
328}
329
330#[derive(
332 Debug,
333 Clone,
334 PartialEq,
335 Serialize,
336 Deserialize,
337 JsonSchema,
338 arbitrary::Arbitrary,
339)]
340#[serde(rename_all = "camelCase")]
341#[schemars(rename = "functions.expression.StringInputSchema")]
342pub struct StringInputSchema {
343 pub r#type: StringInputSchemaType,
344 #[serde(skip_serializing_if = "Option::is_none")]
346 #[schemars(extend("omitempty" = true))]
347 pub description: Option<String>,
348 #[serde(skip_serializing_if = "Option::is_none")]
350 #[schemars(extend("omitempty" = true))]
351 pub r#enum: Option<Vec<String>>,
352}
353
354impl StringInputSchema {
355 pub fn validate_input(&self, input: &InputValue) -> bool {
357 match input {
358 InputValue::String(s) => {
359 if let Some(r#enum) = &self.r#enum {
360 r#enum.contains(s)
361 } else {
362 true
363 }
364 }
365 _ => false,
366 }
367 }
368}
369
370#[derive(
371 Debug,
372 Clone,
373 Copy,
374 Default,
375 PartialEq,
376 Serialize,
377 Deserialize,
378 JsonSchema,
379 arbitrary::Arbitrary,
380)]
381#[serde(rename_all = "lowercase")]
382#[schemars(rename = "functions.expression.IntegerInputSchemaType")]
383pub enum IntegerInputSchemaType {
384 #[default]
385 Integer,
386}
387
388#[derive(
390 Debug,
391 Clone,
392 PartialEq,
393 Serialize,
394 Deserialize,
395 JsonSchema,
396 arbitrary::Arbitrary,
397)]
398#[serde(rename_all = "camelCase")]
399#[schemars(rename = "functions.expression.IntegerInputSchema")]
400pub struct IntegerInputSchema {
401 pub r#type: IntegerInputSchemaType,
402 #[serde(skip_serializing_if = "Option::is_none")]
404 #[schemars(extend("omitempty" = true))]
405 pub description: Option<String>,
406 #[serde(skip_serializing_if = "Option::is_none")]
408 #[schemars(extend("omitempty" = true))]
409 #[arbitrary(with = crate::arbitrary_util::arbitrary_option_i64)]
410 pub minimum: Option<i64>,
411 #[serde(skip_serializing_if = "Option::is_none")]
413 #[schemars(extend("omitempty" = true))]
414 #[arbitrary(with = crate::arbitrary_util::arbitrary_option_i64)]
415 pub maximum: Option<i64>,
416}
417
418impl IntegerInputSchema {
419 pub fn validate_input(&self, input: &InputValue) -> bool {
421 match input {
422 InputValue::Integer(integer) => {
423 if let Some(minimum) = self.minimum
424 && *integer < minimum
425 {
426 false
427 } else if let Some(maximum) = self.maximum
428 && *integer > maximum
429 {
430 false
431 } else {
432 true
433 }
434 }
435 InputValue::Number(number)
436 if number.is_finite() && number.fract() == 0.0 =>
437 {
438 let integer = *number as i64;
439 if let Some(minimum) = self.minimum
440 && integer < minimum
441 {
442 false
443 } else if let Some(maximum) = self.maximum
444 && integer > maximum
445 {
446 false
447 } else {
448 true
449 }
450 }
451 _ => false,
452 }
453 }
454}
455
456#[derive(
457 Debug,
458 Clone,
459 Copy,
460 Default,
461 PartialEq,
462 Serialize,
463 Deserialize,
464 JsonSchema,
465 arbitrary::Arbitrary,
466)]
467#[serde(rename_all = "lowercase")]
468#[schemars(rename = "functions.expression.NumberInputSchemaType")]
469pub enum NumberInputSchemaType {
470 #[default]
471 Number,
472}
473
474#[derive(
476 Debug,
477 Clone,
478 PartialEq,
479 Serialize,
480 Deserialize,
481 JsonSchema,
482 arbitrary::Arbitrary,
483)]
484#[serde(rename_all = "camelCase")]
485#[schemars(rename = "functions.expression.NumberInputSchema")]
486pub struct NumberInputSchema {
487 pub r#type: NumberInputSchemaType,
488 #[serde(skip_serializing_if = "Option::is_none")]
490 #[schemars(extend("omitempty" = true))]
491 pub description: Option<String>,
492 #[serde(skip_serializing_if = "Option::is_none")]
494 #[schemars(extend("omitempty" = true))]
495 #[arbitrary(with = crate::arbitrary_util::arbitrary_option_f64)]
496 pub minimum: Option<f64>,
497 #[serde(skip_serializing_if = "Option::is_none")]
499 #[schemars(extend("omitempty" = true))]
500 #[arbitrary(with = crate::arbitrary_util::arbitrary_option_f64)]
501 pub maximum: Option<f64>,
502}
503
504impl NumberInputSchema {
505 pub fn validate_input(&self, input: &InputValue) -> bool {
507 match input {
508 InputValue::Integer(integer) => {
509 let number = *integer as f64;
510 if let Some(minimum) = self.minimum
511 && number < minimum
512 {
513 false
514 } else if let Some(maximum) = self.maximum
515 && number > maximum
516 {
517 false
518 } else {
519 true
520 }
521 }
522 InputValue::Number(number) => {
523 if let Some(minimum) = self.minimum
524 && *number < minimum
525 {
526 false
527 } else if let Some(maximum) = self.maximum
528 && *number > maximum
529 {
530 false
531 } else {
532 true
533 }
534 }
535 _ => false,
536 }
537 }
538}
539
540#[derive(
541 Debug,
542 Clone,
543 Copy,
544 Default,
545 PartialEq,
546 Serialize,
547 Deserialize,
548 JsonSchema,
549 arbitrary::Arbitrary,
550)]
551#[serde(rename_all = "lowercase")]
552#[schemars(rename = "functions.expression.BooleanInputSchemaType")]
553pub enum BooleanInputSchemaType {
554 #[default]
555 Boolean,
556}
557
558#[derive(
560 Debug,
561 Clone,
562 PartialEq,
563 Serialize,
564 Deserialize,
565 JsonSchema,
566 arbitrary::Arbitrary,
567)]
568#[serde(rename_all = "camelCase")]
569#[schemars(rename = "functions.expression.BooleanInputSchema")]
570pub struct BooleanInputSchema {
571 pub r#type: BooleanInputSchemaType,
572 #[serde(skip_serializing_if = "Option::is_none")]
574 #[schemars(extend("omitempty" = true))]
575 pub description: Option<String>,
576}
577
578impl BooleanInputSchema {
579 pub fn validate_input(&self, input: &InputValue) -> bool {
581 match input {
582 InputValue::Boolean(_) => true,
583 _ => false,
584 }
585 }
586}
587
588#[derive(
589 Debug,
590 Clone,
591 Copy,
592 Default,
593 PartialEq,
594 Serialize,
595 Deserialize,
596 JsonSchema,
597 arbitrary::Arbitrary,
598)]
599#[serde(rename_all = "lowercase")]
600#[schemars(rename = "functions.expression.ImageInputSchemaType")]
601pub enum ImageInputSchemaType {
602 #[default]
603 Image,
604}
605
606#[derive(
608 Debug,
609 Clone,
610 PartialEq,
611 Serialize,
612 Deserialize,
613 JsonSchema,
614 arbitrary::Arbitrary,
615)]
616#[serde(rename_all = "camelCase")]
617#[schemars(rename = "functions.expression.ImageInputSchema")]
618pub struct ImageInputSchema {
619 pub r#type: ImageInputSchemaType,
620 #[serde(skip_serializing_if = "Option::is_none")]
622 #[schemars(extend("omitempty" = true))]
623 pub description: Option<String>,
624}
625
626impl ImageInputSchema {
627 pub fn validate_input(&self, input: &InputValue) -> bool {
629 match input {
630 InputValue::RichContentPart(
631 agent::completions::message::RichContentPart::ImageUrl {
632 ..
633 },
634 ) => true,
635 _ => false,
636 }
637 }
638}
639
640#[derive(
641 Debug,
642 Clone,
643 Copy,
644 Default,
645 PartialEq,
646 Serialize,
647 Deserialize,
648 JsonSchema,
649 arbitrary::Arbitrary,
650)]
651#[serde(rename_all = "lowercase")]
652#[schemars(rename = "functions.expression.AudioInputSchemaType")]
653pub enum AudioInputSchemaType {
654 #[default]
655 Audio,
656}
657
658#[derive(
660 Debug,
661 Clone,
662 PartialEq,
663 Serialize,
664 Deserialize,
665 JsonSchema,
666 arbitrary::Arbitrary,
667)]
668#[serde(rename_all = "camelCase")]
669#[schemars(rename = "functions.expression.AudioInputSchema")]
670pub struct AudioInputSchema {
671 pub r#type: AudioInputSchemaType,
672 #[serde(skip_serializing_if = "Option::is_none")]
674 #[schemars(extend("omitempty" = true))]
675 pub description: Option<String>,
676}
677
678impl AudioInputSchema {
679 pub fn validate_input(&self, input: &InputValue) -> bool {
681 match input {
682 InputValue::RichContentPart(
683 agent::completions::message::RichContentPart::InputAudio {
684 ..
685 },
686 ) => true,
687 _ => false,
688 }
689 }
690}
691
692#[derive(
693 Debug,
694 Clone,
695 Copy,
696 Default,
697 PartialEq,
698 Serialize,
699 Deserialize,
700 JsonSchema,
701 arbitrary::Arbitrary,
702)]
703#[serde(rename_all = "lowercase")]
704#[schemars(rename = "functions.expression.VideoInputSchemaType")]
705pub enum VideoInputSchemaType {
706 #[default]
707 Video,
708}
709
710#[derive(
712 Debug,
713 Clone,
714 PartialEq,
715 Serialize,
716 Deserialize,
717 JsonSchema,
718 arbitrary::Arbitrary,
719)]
720#[serde(rename_all = "camelCase")]
721#[schemars(rename = "functions.expression.VideoInputSchema")]
722pub struct VideoInputSchema {
723 pub r#type: VideoInputSchemaType,
724 #[serde(skip_serializing_if = "Option::is_none")]
726 #[schemars(extend("omitempty" = true))]
727 pub description: Option<String>,
728}
729
730impl VideoInputSchema {
731 pub fn validate_input(&self, input: &InputValue) -> bool {
733 match input {
734 InputValue::RichContentPart(
735 agent::completions::message::RichContentPart::InputVideo {
736 ..
737 },
738 ) => true,
739 InputValue::RichContentPart(
740 agent::completions::message::RichContentPart::VideoUrl {
741 ..
742 },
743 ) => true,
744 _ => false,
745 }
746 }
747}
748
749#[derive(
750 Debug,
751 Clone,
752 Copy,
753 Default,
754 PartialEq,
755 Serialize,
756 Deserialize,
757 JsonSchema,
758 arbitrary::Arbitrary,
759)]
760#[serde(rename_all = "lowercase")]
761#[schemars(rename = "functions.expression.FileInputSchemaType")]
762pub enum FileInputSchemaType {
763 #[default]
764 File,
765}
766
767#[derive(
769 Debug,
770 Clone,
771 PartialEq,
772 Serialize,
773 Deserialize,
774 JsonSchema,
775 arbitrary::Arbitrary,
776)]
777#[serde(rename_all = "camelCase")]
778#[schemars(rename = "functions.expression.FileInputSchema")]
779pub struct FileInputSchema {
780 pub r#type: FileInputSchemaType,
781 #[serde(skip_serializing_if = "Option::is_none")]
783 #[schemars(extend("omitempty" = true))]
784 pub description: Option<String>,
785}
786
787impl FileInputSchema {
788 pub fn validate_input(&self, input: &InputValue) -> bool {
790 match input {
791 InputValue::RichContentPart(
792 agent::completions::message::RichContentPart::File { .. },
793 ) => true,
794 _ => false,
795 }
796 }
797}