use crate::chat;
use indexmap::IndexMap;
use serde::de::Error as _;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum InputMaps {
One(super::Expression),
Many(Vec<super::Expression>),
}
impl InputMaps {
pub fn compile(
self,
params: &super::Params,
) -> Result<Vec<Vec<Input>>, super::ExpressionError> {
match self {
InputMaps::One(expression) => {
match expression.compile_one_or_many::<Vec<Input>>(params)? {
super::OneOrMany::One(one) => Ok(vec![one]),
super::OneOrMany::Many(many) => Ok(many),
}
}
InputMaps::Many(expressions) => {
let mut compiled = Vec::with_capacity(expressions.len());
for expression in expressions {
match expression
.compile_one_or_many::<Vec<Input>>(params)?
{
super::OneOrMany::One(one) => compiled.push(one),
super::OneOrMany::Many(many) => {
for item in many {
compiled.push(item);
}
}
}
}
Ok(compiled)
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Input {
RichContentPart(chat::completions::request::RichContentPart),
Object(IndexMap<String, Input>),
Array(Vec<Input>),
String(String),
Integer(i64),
Number(f64),
Boolean(bool),
}
impl Input {
pub fn to_rich_content_parts(
self,
depth: usize,
) -> impl Iterator<Item = chat::completions::request::RichContentPart> {
enum Iter {
RichContentPart(RichContentPartIter),
Object(Box<ObjectIter>),
Array(Box<ArrayIter>),
Primitive(Option<String>),
}
impl Iter {
pub fn new(input: Input, depth: usize) -> Self {
match input {
Input::RichContentPart(rich_content_part) => {
Iter::RichContentPart(RichContentPartIter {
first: true,
part: Some(rich_content_part),
last: true,
})
}
Input::Object(object) => {
Iter::Object(Box::new(ObjectIter {
object: object.into_iter(),
first: true,
child: None,
depth,
}))
}
Input::Array(array) => Iter::Array(Box::new(ArrayIter {
array: array.into_iter(),
first: true,
child: None,
depth,
})),
Input::String(string) => Iter::Primitive(Some(format!(
"\"{}\"",
json_escape::escape_str(&string),
))),
Input::Integer(integer) => {
Iter::Primitive(Some(integer.to_string()))
}
Input::Number(number) => {
Iter::Primitive(Some(number.to_string()))
}
Input::Boolean(boolean) => {
Iter::Primitive(Some(boolean.to_string()))
}
}
}
}
impl Iterator for Iter {
type Item = chat::completions::request::RichContentPart;
fn next(&mut self) -> Option<Self::Item> {
match self {
Iter::RichContentPart(rich_content_part_iter) => {
rich_content_part_iter.next()
}
Iter::Object(object_iter) => object_iter.next(),
Iter::Array(array_iter) => array_iter.next(),
Iter::Primitive(primitive_option) => {
primitive_option.take().map(|text| {
chat::completions::request::RichContentPart::Text {
text,
}
})
}
}
}
}
struct RichContentPartIter {
first: bool,
part: Option<chat::completions::request::RichContentPart>,
last: bool,
}
impl Iterator for RichContentPartIter {
type Item = chat::completions::request::RichContentPart;
fn next(&mut self) -> Option<Self::Item> {
if self.first {
self.first = false;
Some(chat::completions::request::RichContentPart::Text {
text: '"'.to_string(),
})
} else if let Some(part) = self.part.take() {
Some(part)
} else if self.last {
self.last = false;
Some(chat::completions::request::RichContentPart::Text {
text: '"'.to_string(),
})
} else {
None
}
}
}
struct ObjectIter {
object: indexmap::map::IntoIter<String, Input>,
first: bool,
child: Option<Iter>,
depth: usize,
}
impl Iterator for ObjectIter {
type Item = chat::completions::request::RichContentPart;
fn next(&mut self) -> Option<Self::Item> {
if self.first {
self.first = false;
if let Some((key, input)) = self.object.next() {
self.child = Some(Iter::new(input, self.depth + 1));
Some(
chat::completions::request::RichContentPart::Text {
text: format!(
"{{\n{}\"{}\": ",
" ".repeat(self.depth + 1),
key,
),
},
)
} else {
Some(
chat::completions::request::RichContentPart::Text {
text: format!("{{}}"),
},
)
}
} else if let Some(child) = &mut self.child {
if let Some(part) = child.next() {
Some(part)
} else if let Some((key, input)) = self.object.next() {
self.child = Some(Iter::new(input, self.depth + 1));
Some(
chat::completions::request::RichContentPart::Text {
text: format!(
",\n{}\"{}\": ",
" ".repeat(self.depth + 1),
key,
),
},
)
} else {
self.child = None;
Some(
chat::completions::request::RichContentPart::Text {
text: format!(
"\n{}}}",
" ".repeat(self.depth)
),
},
)
}
} else {
None
}
}
}
struct ArrayIter {
array: std::vec::IntoIter<Input>,
first: bool,
child: Option<Iter>,
depth: usize,
}
impl Iterator for ArrayIter {
type Item = chat::completions::request::RichContentPart;
fn next(&mut self) -> Option<Self::Item> {
if self.first {
self.first = false;
if let Some(input) = self.array.next() {
self.child = Some(Iter::new(input, self.depth + 1));
Some(
chat::completions::request::RichContentPart::Text {
text: format!(
"[\n{}",
" ".repeat(self.depth + 1)
),
},
)
} else {
Some(
chat::completions::request::RichContentPart::Text {
text: format!("[]"),
},
)
}
} else if let Some(child) = &mut self.child {
if let Some(part) = child.next() {
Some(part)
} else if let Some(input) = self.array.next() {
self.child = Some(Iter::new(input, self.depth + 1));
Some(
chat::completions::request::RichContentPart::Text {
text: format!(
",\n{}",
" ".repeat(self.depth + 1),
),
},
)
} else {
self.child = None;
Some(
chat::completions::request::RichContentPart::Text {
text: format!(
"\n{}]",
" ".repeat(self.depth)
),
},
)
}
} else {
None
}
}
}
Iter::new(self, depth)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum InputExpression {
RichContentPart(chat::completions::request::RichContentPart),
Object(IndexMap<String, super::WithExpression<InputExpression>>),
Array(Vec<super::WithExpression<InputExpression>>),
String(String),
Integer(i64),
Number(f64),
Boolean(bool),
}
impl InputExpression {
pub fn compile(
self,
params: &super::Params,
) -> Result<Input, super::ExpressionError> {
match self {
InputExpression::RichContentPart(rich_content_part) => {
Ok(Input::RichContentPart(rich_content_part))
}
InputExpression::Object(object) => {
let mut compiled_object = IndexMap::with_capacity(object.len());
for (key, value) in object {
compiled_object.insert(
key,
value.compile_one(params)?.compile(params)?,
);
}
Ok(Input::Object(compiled_object))
}
InputExpression::Array(array) => {
let mut compiled_array = Vec::with_capacity(array.len());
for item in array {
match item.compile_one_or_many(params)? {
super::OneOrMany::One(one_item) => {
compiled_array.push(one_item.compile(params)?);
}
super::OneOrMany::Many(many_items) => {
for item in many_items {
compiled_array.push(item.compile(params)?);
}
}
}
}
Ok(Input::Array(compiled_array))
}
InputExpression::String(string) => Ok(Input::String(string)),
InputExpression::Integer(integer) => Ok(Input::Integer(integer)),
InputExpression::Number(number) => Ok(Input::Number(number)),
InputExpression::Boolean(boolean) => Ok(Input::Boolean(boolean)),
}
}
}
#[derive(Debug, Clone)]
pub enum InputSchema {
Object(ObjectInputSchema),
Array(ArrayInputSchema),
String(StringInputSchema),
Integer(IntegerInputSchema),
Number(NumberInputSchema),
Boolean(BooleanInputSchema),
Image(ImageInputSchema),
Audio(AudioInputSchema),
Video(VideoInputSchema),
File(FileInputSchema),
AnyOf(AnyOfInputSchema),
}
impl InputSchema {
pub fn validate_input(&self, input: &Input) -> bool {
match self {
InputSchema::Object(schema) => schema.validate_input(input),
InputSchema::Array(schema) => schema.validate_input(input),
InputSchema::String(schema) => schema.validate_input(input),
InputSchema::Integer(schema) => schema.validate_input(input),
InputSchema::Number(schema) => schema.validate_input(input),
InputSchema::Boolean(schema) => schema.validate_input(input),
InputSchema::Image(schema) => schema.validate_input(input),
InputSchema::Audio(schema) => schema.validate_input(input),
InputSchema::Video(schema) => schema.validate_input(input),
InputSchema::File(schema) => schema.validate_input(input),
InputSchema::AnyOf(schema) => schema.validate_input(input),
}
}
}
#[derive(Deserialize)]
#[serde(tag = "type", rename_all = "camelCase")]
enum TypedInputSchema {
Object(ObjectInputSchema),
Array(ArrayInputSchema),
String(StringInputSchema),
Integer(IntegerInputSchema),
Number(NumberInputSchema),
Boolean(BooleanInputSchema),
Image(ImageInputSchema),
Audio(AudioInputSchema),
Video(VideoInputSchema),
File(FileInputSchema),
}
impl From<TypedInputSchema> for InputSchema {
fn from(typed: TypedInputSchema) -> Self {
match typed {
TypedInputSchema::Object(s) => InputSchema::Object(s),
TypedInputSchema::Array(s) => InputSchema::Array(s),
TypedInputSchema::String(s) => InputSchema::String(s),
TypedInputSchema::Integer(s) => InputSchema::Integer(s),
TypedInputSchema::Number(s) => InputSchema::Number(s),
TypedInputSchema::Boolean(s) => InputSchema::Boolean(s),
TypedInputSchema::Image(s) => InputSchema::Image(s),
TypedInputSchema::Audio(s) => InputSchema::Audio(s),
TypedInputSchema::Video(s) => InputSchema::Video(s),
TypedInputSchema::File(s) => InputSchema::File(s),
}
}
}
impl<'de> Deserialize<'de> for InputSchema {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let value = serde_json::Value::deserialize(deserializer)?;
if value.get("anyOf").is_some() {
let schema: AnyOfInputSchema =
serde_json::from_value(value).map_err(D::Error::custom)?;
Ok(InputSchema::AnyOf(schema))
} else {
let typed: TypedInputSchema =
serde_json::from_value(value).map_err(D::Error::custom)?;
Ok(typed.into())
}
}
}
impl Serialize for InputSchema {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
InputSchema::AnyOf(schema) => schema.serialize(serializer),
InputSchema::Object(schema) => {
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct Tagged<'a> {
r#type: &'static str,
#[serde(flatten)]
schema: &'a ObjectInputSchema,
}
Tagged {
r#type: "object",
schema,
}
.serialize(serializer)
}
InputSchema::Array(schema) => {
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct Tagged<'a> {
r#type: &'static str,
#[serde(flatten)]
schema: &'a ArrayInputSchema,
}
Tagged {
r#type: "array",
schema,
}
.serialize(serializer)
}
InputSchema::String(schema) => {
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct Tagged<'a> {
r#type: &'static str,
#[serde(flatten)]
schema: &'a StringInputSchema,
}
Tagged {
r#type: "string",
schema,
}
.serialize(serializer)
}
InputSchema::Integer(schema) => {
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct Tagged<'a> {
r#type: &'static str,
#[serde(flatten)]
schema: &'a IntegerInputSchema,
}
Tagged {
r#type: "integer",
schema,
}
.serialize(serializer)
}
InputSchema::Number(schema) => {
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct Tagged<'a> {
r#type: &'static str,
#[serde(flatten)]
schema: &'a NumberInputSchema,
}
Tagged {
r#type: "number",
schema,
}
.serialize(serializer)
}
InputSchema::Boolean(schema) => {
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct Tagged<'a> {
r#type: &'static str,
#[serde(flatten)]
schema: &'a BooleanInputSchema,
}
Tagged {
r#type: "boolean",
schema,
}
.serialize(serializer)
}
InputSchema::Image(schema) => {
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct Tagged<'a> {
r#type: &'static str,
#[serde(flatten)]
schema: &'a ImageInputSchema,
}
Tagged {
r#type: "image",
schema,
}
.serialize(serializer)
}
InputSchema::Audio(schema) => {
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct Tagged<'a> {
r#type: &'static str,
#[serde(flatten)]
schema: &'a AudioInputSchema,
}
Tagged {
r#type: "audio",
schema,
}
.serialize(serializer)
}
InputSchema::Video(schema) => {
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct Tagged<'a> {
r#type: &'static str,
#[serde(flatten)]
schema: &'a VideoInputSchema,
}
Tagged {
r#type: "video",
schema,
}
.serialize(serializer)
}
InputSchema::File(schema) => {
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct Tagged<'a> {
r#type: &'static str,
#[serde(flatten)]
schema: &'a FileInputSchema,
}
Tagged {
r#type: "file",
schema,
}
.serialize(serializer)
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AnyOfInputSchema {
pub any_of: Vec<InputSchema>,
}
impl AnyOfInputSchema {
pub fn validate_input(&self, input: &Input) -> bool {
self.any_of
.iter()
.any(|schema| schema.validate_input(input))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ObjectInputSchema {
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub properties: IndexMap<String, InputSchema>,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<Vec<String>>,
}
impl ObjectInputSchema {
pub fn validate_input(&self, input: &Input) -> bool {
match input {
Input::Object(map) => {
self.properties.iter().all(|(key, schema)| {
map.get(key)
.map(|value| schema.validate_input(value))
.unwrap_or(false)
}) && {
if let Some(required) = &self.required {
required.iter().all(|key| map.contains_key(key))
} else {
true
}
}
}
_ => false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ArrayInputSchema {
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_items: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_items: Option<u64>,
pub items: Box<InputSchema>,
}
impl ArrayInputSchema {
pub fn validate_input(&self, input: &Input) -> bool {
match input {
Input::Array(array) => {
if let Some(min_items) = self.min_items
&& (array.len() as u64) < min_items
{
false
} else if let Some(max_items) = self.max_items
&& (array.len() as u64) > max_items
{
false
} else {
array.iter().all(|item| self.items.validate_input(item))
}
}
_ => false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct StringInputSchema {
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub r#enum: Option<Vec<String>>,
}
impl StringInputSchema {
pub fn validate_input(&self, input: &Input) -> bool {
match input {
Input::String(s) => {
if let Some(r#enum) = &self.r#enum {
r#enum.contains(s)
} else {
true
}
}
_ => false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct IntegerInputSchema {
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub minimum: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub maximum: Option<i64>,
}
impl IntegerInputSchema {
pub fn validate_input(&self, input: &Input) -> bool {
match input {
Input::Integer(integer) => {
if let Some(minimum) = self.minimum
&& *integer < minimum
{
false
} else if let Some(maximum) = self.maximum
&& *integer > maximum
{
false
} else {
true
}
}
Input::Number(number)
if number.is_finite() && number.fract() == 0.0 =>
{
let integer = *number as i64;
if let Some(minimum) = self.minimum
&& integer < minimum
{
false
} else if let Some(maximum) = self.maximum
&& integer > maximum
{
false
} else {
true
}
}
_ => false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct NumberInputSchema {
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub minimum: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub maximum: Option<f64>,
}
impl NumberInputSchema {
pub fn validate_input(&self, input: &Input) -> bool {
match input {
Input::Integer(integer) => {
let number = *integer as f64;
if let Some(minimum) = self.minimum
&& number < minimum
{
false
} else if let Some(maximum) = self.maximum
&& number > maximum
{
false
} else {
true
}
}
Input::Number(number) => {
if let Some(minimum) = self.minimum
&& *number < minimum
{
false
} else if let Some(maximum) = self.maximum
&& *number > maximum
{
false
} else {
true
}
}
_ => false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BooleanInputSchema {
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
impl BooleanInputSchema {
pub fn validate_input(&self, input: &Input) -> bool {
match input {
Input::Boolean(_) => true,
_ => false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ImageInputSchema {
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
impl ImageInputSchema {
pub fn validate_input(&self, input: &Input) -> bool {
match input {
Input::RichContentPart(
chat::completions::request::RichContentPart::ImageUrl {
..
},
) => true,
_ => false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AudioInputSchema {
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
impl AudioInputSchema {
pub fn validate_input(&self, input: &Input) -> bool {
match input {
Input::RichContentPart(
chat::completions::request::RichContentPart::InputAudio {
..
},
) => true,
_ => false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct VideoInputSchema {
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
impl VideoInputSchema {
pub fn validate_input(&self, input: &Input) -> bool {
match input {
Input::RichContentPart(
chat::completions::request::RichContentPart::InputVideo {
..
},
) => true,
Input::RichContentPart(
chat::completions::request::RichContentPart::VideoUrl {
..
},
) => true,
_ => false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FileInputSchema {
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
impl FileInputSchema {
pub fn validate_input(&self, input: &Input) -> bool {
match input {
Input::RichContentPart(
chat::completions::request::RichContentPart::File { .. },
) => true,
_ => false,
}
}
}