use crate::agent;
use indexmap::IndexMap;
use serde::{Deserialize, Serialize};
use schemars::JsonSchema;
use super::InputValue;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(untagged)]
#[schemars(rename = "functions.expression.InputSchema")]
pub enum InputSchema {
#[schemars(title = "AnyOf")]
AnyOf(AnyOfInputSchema),
#[schemars(title = "Object")]
Object(ObjectInputSchema),
#[schemars(title = "Array")]
Array(ArrayInputSchema),
#[schemars(title = "String")]
String(StringInputSchema),
#[schemars(title = "Integer")]
Integer(IntegerInputSchema),
#[schemars(title = "Number")]
Number(NumberInputSchema),
#[schemars(title = "Boolean")]
Boolean(BooleanInputSchema),
#[schemars(title = "Image")]
Image(ImageInputSchema),
#[schemars(title = "Audio")]
Audio(AudioInputSchema),
#[schemars(title = "Video")]
Video(VideoInputSchema),
#[schemars(title = "File")]
File(FileInputSchema),
}
impl InputSchema {
pub fn modalities(&self) -> Modalities {
match self {
InputSchema::Image(_) => Modalities { image: true, ..Modalities::default() },
InputSchema::Audio(_) => Modalities { audio: true, ..Modalities::default() },
InputSchema::Video(_) => Modalities { video: true, ..Modalities::default() },
InputSchema::File(_) => Modalities { file: true, ..Modalities::default() },
InputSchema::Object(s) => s.modalities(),
InputSchema::Array(s) => s.modalities(),
InputSchema::AnyOf(s) => s.modalities(),
InputSchema::String(_) | InputSchema::Integer(_)
| InputSchema::Number(_) | InputSchema::Boolean(_) => Modalities::default(),
}
}
pub fn validate_input(&self, input: &InputValue) -> bool {
match self {
InputSchema::Object(schema) => schema.validate_input(input),
InputSchema::Array(schema) => schema.validate_input(input),
InputSchema::String(schema) => schema.validate_input(input),
InputSchema::Integer(schema) => schema.validate_input(input),
InputSchema::Number(schema) => schema.validate_input(input),
InputSchema::Boolean(schema) => schema.validate_input(input),
InputSchema::Image(schema) => schema.validate_input(input),
InputSchema::Audio(schema) => schema.validate_input(input),
InputSchema::Video(schema) => schema.validate_input(input),
InputSchema::File(schema) => schema.validate_input(input),
InputSchema::AnyOf(schema) => schema.validate_input(input),
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct Modalities {
pub image: bool,
pub audio: bool,
pub video: bool,
pub file: bool,
}
impl Modalities {
pub fn merge(self, other: Self) -> Self {
Self {
image: self.image || other.image,
audio: self.audio || other.audio,
video: self.video || other.video,
file: self.file || other.file,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(rename_all = "camelCase")]
#[schemars(rename = "functions.expression.AnyOfInputSchema")]
pub struct AnyOfInputSchema {
pub any_of: Vec<InputSchema>,
}
impl AnyOfInputSchema {
pub fn modalities(&self) -> Modalities {
self.any_of.iter().fold(Modalities::default(), |acc, s| acc.merge(s.modalities()))
}
pub fn validate_input(&self, input: &InputValue) -> bool {
self.any_of
.iter()
.any(|schema| schema.validate_input(input))
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(rename_all = "lowercase")]
#[schemars(rename = "functions.expression.ObjectInputSchemaType")]
pub enum ObjectInputSchemaType {
#[default]
Object,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(rename_all = "camelCase")]
#[schemars(rename = "functions.expression.ObjectInputSchema")]
pub struct ObjectInputSchema {
pub r#type: ObjectInputSchemaType,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub description: Option<String>,
#[arbitrary(with = crate::arbitrary_util::arbitrary_indexmap)]
pub properties: IndexMap<String, InputSchema>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub required: Option<Vec<String>>,
}
impl ObjectInputSchema {
pub fn modalities(&self) -> Modalities {
self.properties.values().fold(Modalities::default(), |acc, s| acc.merge(s.modalities()))
}
pub fn validate_input(&self, input: &InputValue) -> bool {
match input {
InputValue::Object(map) => {
let required = self.required.as_deref().unwrap_or(&[]);
self.properties.iter().all(|(key, schema)| {
match map.get(key) {
Some(value) => schema.validate_input(value),
None => !required.contains(key),
}
})
}
_ => false,
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(rename_all = "lowercase")]
#[schemars(rename = "functions.expression.ArrayInputSchemaType")]
pub enum ArrayInputSchemaType {
#[default]
Array,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(rename_all = "camelCase")]
#[schemars(rename = "functions.expression.ArrayInputSchema")]
pub struct ArrayInputSchema {
pub r#type: ArrayInputSchemaType,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
#[arbitrary(with = crate::arbitrary_util::arbitrary_option_u64)]
pub min_items: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
#[arbitrary(with = crate::arbitrary_util::arbitrary_option_u64)]
pub max_items: Option<u64>,
pub items: Box<InputSchema>,
}
impl ArrayInputSchema {
pub fn modalities(&self) -> Modalities {
self.items.modalities()
}
pub fn validate_input(&self, input: &InputValue) -> bool {
match input {
InputValue::Array(array) => {
if let Some(min_items) = self.min_items
&& (array.len() as u64) < min_items
{
false
} else if let Some(max_items) = self.max_items
&& (array.len() as u64) > max_items
{
false
} else {
array.iter().all(|item| self.items.validate_input(item))
}
}
_ => false,
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(rename_all = "lowercase")]
#[schemars(rename = "functions.expression.StringInputSchemaType")]
pub enum StringInputSchemaType {
#[default]
String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(rename_all = "camelCase")]
#[schemars(rename = "functions.expression.StringInputSchema")]
pub struct StringInputSchema {
pub r#type: StringInputSchemaType,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub r#enum: Option<Vec<String>>,
}
impl StringInputSchema {
pub fn validate_input(&self, input: &InputValue) -> bool {
match input {
InputValue::String(s) => {
if let Some(r#enum) = &self.r#enum {
r#enum.contains(s)
} else {
true
}
}
_ => false,
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(rename_all = "lowercase")]
#[schemars(rename = "functions.expression.IntegerInputSchemaType")]
pub enum IntegerInputSchemaType {
#[default]
Integer,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(rename_all = "camelCase")]
#[schemars(rename = "functions.expression.IntegerInputSchema")]
pub struct IntegerInputSchema {
pub r#type: IntegerInputSchemaType,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
#[arbitrary(with = crate::arbitrary_util::arbitrary_option_i64)]
pub minimum: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
#[arbitrary(with = crate::arbitrary_util::arbitrary_option_i64)]
pub maximum: Option<i64>,
}
impl IntegerInputSchema {
pub fn validate_input(&self, input: &InputValue) -> bool {
match input {
InputValue::Integer(integer) => {
if let Some(minimum) = self.minimum
&& *integer < minimum
{
false
} else if let Some(maximum) = self.maximum
&& *integer > maximum
{
false
} else {
true
}
}
InputValue::Number(number)
if number.is_finite() && number.fract() == 0.0 =>
{
let integer = *number as i64;
if let Some(minimum) = self.minimum
&& integer < minimum
{
false
} else if let Some(maximum) = self.maximum
&& integer > maximum
{
false
} else {
true
}
}
_ => false,
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(rename_all = "lowercase")]
#[schemars(rename = "functions.expression.NumberInputSchemaType")]
pub enum NumberInputSchemaType {
#[default]
Number,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(rename_all = "camelCase")]
#[schemars(rename = "functions.expression.NumberInputSchema")]
pub struct NumberInputSchema {
pub r#type: NumberInputSchemaType,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
#[arbitrary(with = crate::arbitrary_util::arbitrary_option_f64)]
pub minimum: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
#[arbitrary(with = crate::arbitrary_util::arbitrary_option_f64)]
pub maximum: Option<f64>,
}
impl NumberInputSchema {
pub fn validate_input(&self, input: &InputValue) -> bool {
match input {
InputValue::Integer(integer) => {
let number = *integer as f64;
if let Some(minimum) = self.minimum
&& number < minimum
{
false
} else if let Some(maximum) = self.maximum
&& number > maximum
{
false
} else {
true
}
}
InputValue::Number(number) => {
if let Some(minimum) = self.minimum
&& *number < minimum
{
false
} else if let Some(maximum) = self.maximum
&& *number > maximum
{
false
} else {
true
}
}
_ => false,
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(rename_all = "lowercase")]
#[schemars(rename = "functions.expression.BooleanInputSchemaType")]
pub enum BooleanInputSchemaType {
#[default]
Boolean,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(rename_all = "camelCase")]
#[schemars(rename = "functions.expression.BooleanInputSchema")]
pub struct BooleanInputSchema {
pub r#type: BooleanInputSchemaType,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub description: Option<String>,
}
impl BooleanInputSchema {
pub fn validate_input(&self, input: &InputValue) -> bool {
match input {
InputValue::Boolean(_) => true,
_ => false,
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(rename_all = "lowercase")]
#[schemars(rename = "functions.expression.ImageInputSchemaType")]
pub enum ImageInputSchemaType {
#[default]
Image,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(rename_all = "camelCase")]
#[schemars(rename = "functions.expression.ImageInputSchema")]
pub struct ImageInputSchema {
pub r#type: ImageInputSchemaType,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub description: Option<String>,
}
impl ImageInputSchema {
pub fn validate_input(&self, input: &InputValue) -> bool {
match input {
InputValue::RichContentPart(
agent::completions::message::RichContentPart::ImageUrl {
..
},
) => true,
_ => false,
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(rename_all = "lowercase")]
#[schemars(rename = "functions.expression.AudioInputSchemaType")]
pub enum AudioInputSchemaType {
#[default]
Audio,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(rename_all = "camelCase")]
#[schemars(rename = "functions.expression.AudioInputSchema")]
pub struct AudioInputSchema {
pub r#type: AudioInputSchemaType,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub description: Option<String>,
}
impl AudioInputSchema {
pub fn validate_input(&self, input: &InputValue) -> bool {
match input {
InputValue::RichContentPart(
agent::completions::message::RichContentPart::InputAudio {
..
},
) => true,
_ => false,
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(rename_all = "lowercase")]
#[schemars(rename = "functions.expression.VideoInputSchemaType")]
pub enum VideoInputSchemaType {
#[default]
Video,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(rename_all = "camelCase")]
#[schemars(rename = "functions.expression.VideoInputSchema")]
pub struct VideoInputSchema {
pub r#type: VideoInputSchemaType,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub description: Option<String>,
}
impl VideoInputSchema {
pub fn validate_input(&self, input: &InputValue) -> bool {
match input {
InputValue::RichContentPart(
agent::completions::message::RichContentPart::InputVideo {
..
},
) => true,
InputValue::RichContentPart(
agent::completions::message::RichContentPart::VideoUrl {
..
},
) => true,
_ => false,
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(rename_all = "lowercase")]
#[schemars(rename = "functions.expression.FileInputSchemaType")]
pub enum FileInputSchemaType {
#[default]
File,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(rename_all = "camelCase")]
#[schemars(rename = "functions.expression.FileInputSchema")]
pub struct FileInputSchema {
pub r#type: FileInputSchemaType,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub description: Option<String>,
}
impl FileInputSchema {
pub fn validate_input(&self, input: &InputValue) -> bool {
match input {
InputValue::RichContentPart(
agent::completions::message::RichContentPart::File { .. },
) => true,
_ => false,
}
}
}