1use crate::agent;
7use indexmap::IndexMap;
8use serde::{Deserialize, Serialize};
9use schemars::JsonSchema;
10use super::InputValue;
11
12#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
17#[serde(untagged)]
18#[schemars(rename = "functions.expression.InputSchema")]
19pub enum InputSchema {
20 #[schemars(title = "AnyOf")]
22 AnyOf(AnyOfInputSchema),
23 #[schemars(title = "Object")]
25 Object(ObjectInputSchema),
26 #[schemars(title = "Array")]
28 Array(ArrayInputSchema),
29 #[schemars(title = "String")]
31 String(StringInputSchema),
32 #[schemars(title = "Integer")]
34 Integer(IntegerInputSchema),
35 #[schemars(title = "Number")]
37 Number(NumberInputSchema),
38 #[schemars(title = "Boolean")]
40 Boolean(BooleanInputSchema),
41 #[schemars(title = "Image")]
43 Image(ImageInputSchema),
44 #[schemars(title = "Audio")]
46 Audio(AudioInputSchema),
47 #[schemars(title = "Video")]
49 Video(VideoInputSchema),
50 #[schemars(title = "File")]
52 File(FileInputSchema),
53}
54
55impl InputSchema {
56 pub fn modalities(&self) -> Modalities {
58 match self {
59 InputSchema::Image(_) => Modalities { image: true, ..Modalities::default() },
60 InputSchema::Audio(_) => Modalities { audio: true, ..Modalities::default() },
61 InputSchema::Video(_) => Modalities { video: true, ..Modalities::default() },
62 InputSchema::File(_) => Modalities { file: true, ..Modalities::default() },
63 InputSchema::Object(s) => s.modalities(),
64 InputSchema::Array(s) => s.modalities(),
65 InputSchema::AnyOf(s) => s.modalities(),
66 InputSchema::String(_) | InputSchema::Integer(_)
67 | InputSchema::Number(_) | InputSchema::Boolean(_) => Modalities::default(),
68 }
69 }
70
71 pub fn validate_input(&self, input: &InputValue) -> bool {
73 match self {
74 InputSchema::Object(schema) => schema.validate_input(input),
75 InputSchema::Array(schema) => schema.validate_input(input),
76 InputSchema::String(schema) => schema.validate_input(input),
77 InputSchema::Integer(schema) => schema.validate_input(input),
78 InputSchema::Number(schema) => schema.validate_input(input),
79 InputSchema::Boolean(schema) => schema.validate_input(input),
80 InputSchema::Image(schema) => schema.validate_input(input),
81 InputSchema::Audio(schema) => schema.validate_input(input),
82 InputSchema::Video(schema) => schema.validate_input(input),
83 InputSchema::File(schema) => schema.validate_input(input),
84 InputSchema::AnyOf(schema) => schema.validate_input(input),
85 }
86 }
87}
88
89
90#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
92pub struct Modalities {
93 pub image: bool,
94 pub audio: bool,
95 pub video: bool,
96 pub file: bool,
97}
98
99impl Modalities {
100 pub fn merge(self, other: Self) -> Self {
102 Self {
103 image: self.image || other.image,
104 audio: self.audio || other.audio,
105 video: self.video || other.video,
106 file: self.file || other.file,
107 }
108 }
109}
110
111#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
113#[serde(rename_all = "camelCase")]
114#[schemars(rename = "functions.expression.AnyOfInputSchema")]
115pub struct AnyOfInputSchema {
116 pub any_of: Vec<InputSchema>,
118}
119
120impl AnyOfInputSchema {
121 pub fn modalities(&self) -> Modalities {
123 self.any_of.iter().fold(Modalities::default(), |acc, s| acc.merge(s.modalities()))
124 }
125
126 pub fn validate_input(&self, input: &InputValue) -> bool {
128 self.any_of
129 .iter()
130 .any(|schema| schema.validate_input(input))
131 }
132}
133
134#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
135#[serde(rename_all = "lowercase")]
136#[schemars(rename = "functions.expression.ObjectInputSchemaType")]
137pub enum ObjectInputSchemaType {
138 #[default]
139 Object,
140}
141
142#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
144#[serde(rename_all = "camelCase")]
145#[schemars(rename = "functions.expression.ObjectInputSchema")]
146pub struct ObjectInputSchema {
147 pub r#type: ObjectInputSchemaType,
148 #[serde(skip_serializing_if = "Option::is_none")]
150 #[schemars(extend("omitempty" = true))]
151 pub description: Option<String>,
152 #[arbitrary(with = crate::arbitrary_util::arbitrary_indexmap)]
154 pub properties: IndexMap<String, InputSchema>,
155 #[serde(skip_serializing_if = "Option::is_none")]
157 #[schemars(extend("omitempty" = true))]
158 pub required: Option<Vec<String>>,
159}
160
161impl ObjectInputSchema {
162 pub fn modalities(&self) -> Modalities {
164 self.properties.values().fold(Modalities::default(), |acc, s| acc.merge(s.modalities()))
165 }
166
167 pub fn validate_input(&self, input: &InputValue) -> bool {
169 match input {
170 InputValue::Object(map) => {
171 let required = self.required.as_deref().unwrap_or(&[]);
172 self.properties.iter().all(|(key, schema)| {
173 match map.get(key) {
174 Some(value) => schema.validate_input(value),
175 None => !required.contains(key),
176 }
177 })
178 }
179 _ => false,
180 }
181 }
182}
183
184#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
185#[serde(rename_all = "lowercase")]
186#[schemars(rename = "functions.expression.ArrayInputSchemaType")]
187pub enum ArrayInputSchemaType {
188 #[default]
189 Array,
190}
191
192#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
194#[serde(rename_all = "camelCase")]
195#[schemars(rename = "functions.expression.ArrayInputSchema")]
196pub struct ArrayInputSchema {
197 pub r#type: ArrayInputSchemaType,
198 #[serde(skip_serializing_if = "Option::is_none")]
200 #[schemars(extend("omitempty" = true))]
201 pub description: Option<String>,
202 #[serde(skip_serializing_if = "Option::is_none")]
204 #[schemars(extend("omitempty" = true))]
205 #[arbitrary(with = crate::arbitrary_util::arbitrary_option_u64)]
206 pub min_items: Option<u64>,
207 #[serde(skip_serializing_if = "Option::is_none")]
209 #[schemars(extend("omitempty" = true))]
210 #[arbitrary(with = crate::arbitrary_util::arbitrary_option_u64)]
211 pub max_items: Option<u64>,
212 pub items: Box<InputSchema>,
214}
215
216impl ArrayInputSchema {
217 pub fn modalities(&self) -> Modalities {
219 self.items.modalities()
220 }
221
222 pub fn validate_input(&self, input: &InputValue) -> bool {
224 match input {
225 InputValue::Array(array) => {
226 if let Some(min_items) = self.min_items
227 && (array.len() as u64) < min_items
228 {
229 false
230 } else if let Some(max_items) = self.max_items
231 && (array.len() as u64) > max_items
232 {
233 false
234 } else {
235 array.iter().all(|item| self.items.validate_input(item))
236 }
237 }
238 _ => false,
239 }
240 }
241}
242
243#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
244#[serde(rename_all = "lowercase")]
245#[schemars(rename = "functions.expression.StringInputSchemaType")]
246pub enum StringInputSchemaType {
247 #[default]
248 String,
249}
250
251#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
253#[serde(rename_all = "camelCase")]
254#[schemars(rename = "functions.expression.StringInputSchema")]
255pub struct StringInputSchema {
256 pub r#type: StringInputSchemaType,
257 #[serde(skip_serializing_if = "Option::is_none")]
259 #[schemars(extend("omitempty" = true))]
260 pub description: Option<String>,
261 #[serde(skip_serializing_if = "Option::is_none")]
263 #[schemars(extend("omitempty" = true))]
264 pub r#enum: Option<Vec<String>>,
265}
266
267impl StringInputSchema {
268 pub fn validate_input(&self, input: &InputValue) -> bool {
270 match input {
271 InputValue::String(s) => {
272 if let Some(r#enum) = &self.r#enum {
273 r#enum.contains(s)
274 } else {
275 true
276 }
277 }
278 _ => false,
279 }
280 }
281}
282
283#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
284#[serde(rename_all = "lowercase")]
285#[schemars(rename = "functions.expression.IntegerInputSchemaType")]
286pub enum IntegerInputSchemaType {
287 #[default]
288 Integer,
289}
290
291#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
293#[serde(rename_all = "camelCase")]
294#[schemars(rename = "functions.expression.IntegerInputSchema")]
295pub struct IntegerInputSchema {
296 pub r#type: IntegerInputSchemaType,
297 #[serde(skip_serializing_if = "Option::is_none")]
299 #[schemars(extend("omitempty" = true))]
300 pub description: Option<String>,
301 #[serde(skip_serializing_if = "Option::is_none")]
303 #[schemars(extend("omitempty" = true))]
304 #[arbitrary(with = crate::arbitrary_util::arbitrary_option_i64)]
305 pub minimum: Option<i64>,
306 #[serde(skip_serializing_if = "Option::is_none")]
308 #[schemars(extend("omitempty" = true))]
309 #[arbitrary(with = crate::arbitrary_util::arbitrary_option_i64)]
310 pub maximum: Option<i64>,
311}
312
313impl IntegerInputSchema {
314 pub fn validate_input(&self, input: &InputValue) -> bool {
316 match input {
317 InputValue::Integer(integer) => {
318 if let Some(minimum) = self.minimum
319 && *integer < minimum
320 {
321 false
322 } else if let Some(maximum) = self.maximum
323 && *integer > maximum
324 {
325 false
326 } else {
327 true
328 }
329 }
330 InputValue::Number(number)
331 if number.is_finite() && number.fract() == 0.0 =>
332 {
333 let integer = *number as i64;
334 if let Some(minimum) = self.minimum
335 && integer < minimum
336 {
337 false
338 } else if let Some(maximum) = self.maximum
339 && integer > maximum
340 {
341 false
342 } else {
343 true
344 }
345 }
346 _ => false,
347 }
348 }
349}
350
351#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
352#[serde(rename_all = "lowercase")]
353#[schemars(rename = "functions.expression.NumberInputSchemaType")]
354pub enum NumberInputSchemaType {
355 #[default]
356 Number,
357}
358
359#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
361#[serde(rename_all = "camelCase")]
362#[schemars(rename = "functions.expression.NumberInputSchema")]
363pub struct NumberInputSchema {
364 pub r#type: NumberInputSchemaType,
365 #[serde(skip_serializing_if = "Option::is_none")]
367 #[schemars(extend("omitempty" = true))]
368 pub description: Option<String>,
369 #[serde(skip_serializing_if = "Option::is_none")]
371 #[schemars(extend("omitempty" = true))]
372 #[arbitrary(with = crate::arbitrary_util::arbitrary_option_f64)]
373 pub minimum: Option<f64>,
374 #[serde(skip_serializing_if = "Option::is_none")]
376 #[schemars(extend("omitempty" = true))]
377 #[arbitrary(with = crate::arbitrary_util::arbitrary_option_f64)]
378 pub maximum: Option<f64>,
379}
380
381impl NumberInputSchema {
382 pub fn validate_input(&self, input: &InputValue) -> bool {
384 match input {
385 InputValue::Integer(integer) => {
386 let number = *integer as f64;
387 if let Some(minimum) = self.minimum
388 && number < minimum
389 {
390 false
391 } else if let Some(maximum) = self.maximum
392 && number > maximum
393 {
394 false
395 } else {
396 true
397 }
398 }
399 InputValue::Number(number) => {
400 if let Some(minimum) = self.minimum
401 && *number < minimum
402 {
403 false
404 } else if let Some(maximum) = self.maximum
405 && *number > maximum
406 {
407 false
408 } else {
409 true
410 }
411 }
412 _ => false,
413 }
414 }
415}
416
417#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
418#[serde(rename_all = "lowercase")]
419#[schemars(rename = "functions.expression.BooleanInputSchemaType")]
420pub enum BooleanInputSchemaType {
421 #[default]
422 Boolean,
423}
424
425#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
427#[serde(rename_all = "camelCase")]
428#[schemars(rename = "functions.expression.BooleanInputSchema")]
429pub struct BooleanInputSchema {
430 pub r#type: BooleanInputSchemaType,
431 #[serde(skip_serializing_if = "Option::is_none")]
433 #[schemars(extend("omitempty" = true))]
434 pub description: Option<String>,
435}
436
437impl BooleanInputSchema {
438 pub fn validate_input(&self, input: &InputValue) -> bool {
440 match input {
441 InputValue::Boolean(_) => true,
442 _ => false,
443 }
444 }
445}
446
447#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
448#[serde(rename_all = "lowercase")]
449#[schemars(rename = "functions.expression.ImageInputSchemaType")]
450pub enum ImageInputSchemaType {
451 #[default]
452 Image,
453}
454
455#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
457#[serde(rename_all = "camelCase")]
458#[schemars(rename = "functions.expression.ImageInputSchema")]
459pub struct ImageInputSchema {
460 pub r#type: ImageInputSchemaType,
461 #[serde(skip_serializing_if = "Option::is_none")]
463 #[schemars(extend("omitempty" = true))]
464 pub description: Option<String>,
465}
466
467impl ImageInputSchema {
468 pub fn validate_input(&self, input: &InputValue) -> bool {
470 match input {
471 InputValue::RichContentPart(
472 agent::completions::message::RichContentPart::ImageUrl {
473 ..
474 },
475 ) => true,
476 _ => false,
477 }
478 }
479}
480
481#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
482#[serde(rename_all = "lowercase")]
483#[schemars(rename = "functions.expression.AudioInputSchemaType")]
484pub enum AudioInputSchemaType {
485 #[default]
486 Audio,
487}
488
489#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
491#[serde(rename_all = "camelCase")]
492#[schemars(rename = "functions.expression.AudioInputSchema")]
493pub struct AudioInputSchema {
494 pub r#type: AudioInputSchemaType,
495 #[serde(skip_serializing_if = "Option::is_none")]
497 #[schemars(extend("omitempty" = true))]
498 pub description: Option<String>,
499}
500
501impl AudioInputSchema {
502 pub fn validate_input(&self, input: &InputValue) -> bool {
504 match input {
505 InputValue::RichContentPart(
506 agent::completions::message::RichContentPart::InputAudio {
507 ..
508 },
509 ) => true,
510 _ => false,
511 }
512 }
513}
514
515#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
516#[serde(rename_all = "lowercase")]
517#[schemars(rename = "functions.expression.VideoInputSchemaType")]
518pub enum VideoInputSchemaType {
519 #[default]
520 Video,
521}
522
523#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
525#[serde(rename_all = "camelCase")]
526#[schemars(rename = "functions.expression.VideoInputSchema")]
527pub struct VideoInputSchema {
528 pub r#type: VideoInputSchemaType,
529 #[serde(skip_serializing_if = "Option::is_none")]
531 #[schemars(extend("omitempty" = true))]
532 pub description: Option<String>,
533}
534
535impl VideoInputSchema {
536 pub fn validate_input(&self, input: &InputValue) -> bool {
538 match input {
539 InputValue::RichContentPart(
540 agent::completions::message::RichContentPart::InputVideo {
541 ..
542 },
543 ) => true,
544 InputValue::RichContentPart(
545 agent::completions::message::RichContentPart::VideoUrl {
546 ..
547 },
548 ) => true,
549 _ => false,
550 }
551 }
552}
553
554#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
555#[serde(rename_all = "lowercase")]
556#[schemars(rename = "functions.expression.FileInputSchemaType")]
557pub enum FileInputSchemaType {
558 #[default]
559 File,
560}
561
562#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
564#[serde(rename_all = "camelCase")]
565#[schemars(rename = "functions.expression.FileInputSchema")]
566pub struct FileInputSchema {
567 pub r#type: FileInputSchemaType,
568 #[serde(skip_serializing_if = "Option::is_none")]
570 #[schemars(extend("omitempty" = true))]
571 pub description: Option<String>,
572}
573
574impl FileInputSchema {
575 pub fn validate_input(&self, input: &InputValue) -> bool {
577 match input {
578 InputValue::RichContentPart(
579 agent::completions::message::RichContentPart::File { .. },
580 ) => true,
581 _ => false,
582 }
583 }
584}