use std::collections::BTreeMap;
use serde_json::{Map, Number, Value};
use crate::vllm_tool_parser::Tool;
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub(super) struct ToolSchemas {
tools: BTreeMap<String, ToolSchema>,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub(super) struct ToolSchema {
params: BTreeMap<String, JsonParamType>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) enum ParamInput {
Text(String),
#[allow(dead_code)]
Elements(Vec<ParamElement>),
}
impl From<String> for ParamInput {
fn from(value: String) -> Self {
Self::Text(value)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) struct ParamElement {
pub name: String,
pub value: ParamInput,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) enum JsonParamType {
String,
Integer,
Number,
Boolean,
Object {
properties: BTreeMap<String, JsonParamType>,
additional_properties: Option<Box<JsonParamType>>,
},
Array {
items: Option<Box<JsonParamType>>,
},
Null,
OneOf(Vec<JsonParamType>),
}
impl ToolSchemas {
pub(super) fn from_tools(tools: &[Tool]) -> Self {
let tools = tools
.iter()
.map(|tool| (tool.name.clone(), ToolSchema::from_schema(&tool.parameters)))
.collect();
Self { tools }
}
pub(super) fn convert_params_with_schema<P>(
&self,
function_name: &str,
params: Vec<(String, P)>,
) -> Map<String, Value>
where
P: Into<ParamInput>,
{
let tool_schema = self.tools.get(function_name).unwrap_or(ToolSchema::empty());
let mut converted = Map::with_capacity(params.len());
for (name, value) in params {
let value = tool_schema.convert(&name, value.into());
converted.insert(name, value);
}
converted
}
pub(super) fn convert_param_with_schema<P>(
&self,
function_name: &str,
name: &str,
value: P,
) -> Value
where
P: Into<ParamInput>,
{
let tool_schema = self.tools.get(function_name).unwrap_or(ToolSchema::empty());
tool_schema.convert(name, value.into())
}
}
impl ToolSchema {
const fn empty() -> &'static Self {
static EMPTY: ToolSchema = ToolSchema {
params: BTreeMap::new(),
};
&EMPTY
}
fn from_schema(parameters: &Value) -> Self {
let Some(properties) = parameters.get("properties").and_then(Value::as_object) else {
return Self::default();
};
let params = properties
.iter()
.filter_map(|(name, schema)| {
JsonParamType::from_schema(schema).map(|param_type| (name.clone(), param_type))
})
.collect();
Self { params }
}
fn convert(&self, name: &str, input: ParamInput) -> Value {
convert_with_optional_schema(self.params.get(name), &input)
}
}
impl JsonParamType {
fn from_schema(schema: &Value) -> Option<Self> {
let schema = schema.as_object()?;
if let Some(type_value) = schema.get("type") {
return Self::from_type_value(type_value, schema);
}
if let Some(composite) = schema.get("anyOf").or_else(|| schema.get("oneOf")) {
let param_type = composite
.as_array()
.map(|schemas| {
schemas
.iter()
.filter_map(Self::from_schema)
.collect::<Vec<_>>()
})
.filter(|types| !types.is_empty())
.map(Self::one_of)
.unwrap_or_else(|| Self::object_from_schema(Some(schema)));
return Some(param_type);
}
if schema.contains_key("enum") {
return Some(Self::String);
}
if schema.contains_key("items") {
return Some(Self::array_from_schema(Some(schema)));
}
if schema.contains_key("properties") || schema.contains_key("additionalProperties") {
return Some(Self::object_from_schema(Some(schema)));
}
None
}
fn from_type_value(type_value: &Value, schema: &Map<String, Value>) -> Option<Self> {
match type_value {
Value::String(kind) => Self::from_type_name(kind, Some(schema)),
Value::Array(kinds) => {
let types = kinds
.iter()
.filter_map(Value::as_str)
.filter_map(|kind| Self::from_type_name(kind, Some(schema)))
.collect::<Vec<_>>();
if types.is_empty() {
None
} else {
Some(Self::one_of(types))
}
}
_ => None,
}
}
fn from_type_name(kind: &str, schema: Option<&Map<String, Value>>) -> Option<Self> {
let kind = kind.trim().to_ascii_lowercase();
match kind.as_str() {
"string" | "str" | "text" | "varchar" | "char" | "enum" => Some(Self::String),
"integer" | "int" => Some(Self::Integer),
"number" | "float" | "double" => Some(Self::Number),
"boolean" | "bool" | "binary" => Some(Self::Boolean),
"object" | "dict" | "map" => Some(Self::object_from_schema(schema)),
"array" | "arr" | "list" | "sequence" => Some(Self::array_from_schema(schema)),
"null" => Some(Self::Null),
_ if kind.starts_with("int")
|| kind.starts_with("uint")
|| kind.starts_with("long")
|| kind.starts_with("short")
|| kind.starts_with("unsigned") =>
{
Some(Self::Integer)
}
_ if kind.starts_with("num") || kind.starts_with("float") => Some(Self::Number),
_ if kind.starts_with("dict") => Some(Self::object_from_schema(schema)),
_ if kind.starts_with("list") => Some(Self::array_from_schema(schema)),
_ => None,
}
}
fn object_from_schema(schema: Option<&Map<String, Value>>) -> Self {
let properties = schema
.and_then(|schema| schema.get("properties"))
.and_then(Value::as_object)
.map(|properties| {
properties
.iter()
.filter_map(|(name, schema)| {
Self::from_schema(schema).map(|param_type| (name.clone(), param_type))
})
.collect()
})
.unwrap_or_default();
let additional_properties = schema
.and_then(|schema| schema.get("additionalProperties"))
.and_then(|schema| {
if schema.is_object() {
Self::from_schema(schema).map(Box::new)
} else {
None
}
});
Self::Object {
properties,
additional_properties,
}
}
fn array_from_schema(schema: Option<&Map<String, Value>>) -> Self {
let items = schema
.and_then(|schema| schema.get("items"))
.and_then(Self::from_schema)
.map(Box::new);
Self::Array { items }
}
fn one_of(mut types: Vec<Self>) -> Self {
if types.len() == 1 {
types.remove(0)
} else {
Self::OneOf(types)
}
}
}
fn convert_with_optional_schema(param_type: Option<&JsonParamType>, input: &ParamInput) -> Value {
if let ParamInput::Text(value) = input
&& value.eq_ignore_ascii_case("null")
{
return Value::Null;
}
if let Some(param_type) = param_type
&& let Some(value) = try_convert_value(param_type, input)
{
return value;
}
match input {
ParamInput::Text(value) => Value::String(value.clone()),
ParamInput::Elements(elements) => {
Value::Object(convert_elements_to_object(elements, &BTreeMap::new(), None))
}
}
}
fn try_convert_value(param_type: &JsonParamType, input: &ParamInput) -> Option<Value> {
match input {
ParamInput::Text(value) => try_convert_text_value(param_type, value),
ParamInput::Elements(elements) => try_convert_elements_value(param_type, elements),
}
}
fn try_convert_text_value(param_type: &JsonParamType, value: &str) -> Option<Value> {
match param_type {
JsonParamType::String => Some(Value::String(value.to_string())),
JsonParamType::Integer => value
.parse::<i64>()
.ok()
.map(Number::from)
.map(Value::Number),
JsonParamType::Number => try_convert_number(value),
JsonParamType::Boolean => try_convert_boolean(value),
JsonParamType::Object { .. } if value.is_empty() => Some(Value::Object(Map::new())),
JsonParamType::Array { .. } if value.is_empty() => Some(Value::Array(Vec::new())),
JsonParamType::Object { .. } | JsonParamType::Array { .. } => {
serde_json::from_str(value).ok()
}
JsonParamType::Null => value.eq_ignore_ascii_case("null").then_some(Value::Null),
JsonParamType::OneOf(types) => types
.iter()
.find_map(|param_type| try_convert_text_value(param_type, value)),
}
}
fn try_convert_elements_value(
param_type: &JsonParamType,
elements: &[ParamElement],
) -> Option<Value> {
match param_type {
JsonParamType::Object {
properties,
additional_properties,
} => Some(Value::Object(convert_elements_to_object(
elements,
properties,
additional_properties.as_deref(),
))),
JsonParamType::Array { items } => Some(Value::Array(
elements
.iter()
.map(|element| convert_with_optional_schema(items.as_deref(), &element.value))
.collect(),
)),
JsonParamType::OneOf(types) => types
.iter()
.find_map(|param_type| try_convert_elements_value(param_type, elements)),
JsonParamType::String
| JsonParamType::Integer
| JsonParamType::Number
| JsonParamType::Boolean
| JsonParamType::Null => None,
}
}
fn convert_elements_to_object(
elements: &[ParamElement],
properties: &BTreeMap<String, JsonParamType>,
additional_properties: Option<&JsonParamType>,
) -> Map<String, Value> {
let mut object = Map::with_capacity(elements.len());
for element in elements {
let param_type = properties.get(&element.name).or(additional_properties);
let value = convert_with_optional_schema(param_type, &element.value);
insert_object_value(&mut object, element.name.clone(), value);
}
object
}
fn insert_object_value(object: &mut Map<String, Value>, key: String, value: Value) {
if let Some(existing) = object.get_mut(&key) {
match existing {
Value::Array(values) => values.push(value),
existing => {
let first = std::mem::replace(existing, Value::Null);
*existing = Value::Array(vec![first, value]);
}
}
} else {
object.insert(key, value);
}
}
fn try_convert_number(value: &str) -> Option<Value> {
serde_json::from_str::<Number>(value)
.or_else(|_| value.parse::<i64>().map(Number::from))
.or_else(|_| {
value
.parse::<f64>()
.ok()
.and_then(Number::from_f64)
.ok_or(())
})
.ok()
.map(Value::Number)
}
fn try_convert_boolean(value: &str) -> Option<Value> {
match value.trim().to_ascii_lowercase().as_str() {
"true" | "1" => Some(Value::Bool(true)),
"false" | "0" => Some(Value::Bool(false)),
_ => None,
}
}