use super::{
field::{Choice, FieldSchemaAttr},
json::JsonType,
};
use proc_macro::TokenStream;
use quote::quote;
use serde::Serialize;
use std::collections::HashMap;
use strum::{Display, EnumString};
use syn::{
Attribute, Data, DataStruct, DeriveInput, Error, Field, GenericArgument, Ident, LitStr,
PathArguments, Result, Type, parse_macro_input,
};
#[derive(EnumString, Display)]
enum InputAttrIdent {
#[strum(serialize = "input")]
Input,
}
#[derive(Debug, Serialize)]
pub(crate) struct InputToolProperty {
description: Option<String>,
#[serde(rename = "type")]
_type: String,
#[serde(rename = "enum", skip_serializing_if = "Option::is_none")]
_enum: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Default)]
pub(crate) struct InputToolParseData {
properties: HashMap<String, InputToolProperty>,
required: Vec<String>,
#[serde(rename = "type")]
arg_type: String,
}
impl InputToolParseData {
fn add_required_field(&mut self, field: String) {
self.required.push(field);
}
fn add_property(&mut self, name: String, property: InputToolProperty) {
self.properties.insert(name, property);
}
fn set_type(&mut self, arg_type: String) {
self.arg_type = arg_type;
}
}
#[derive(Debug, Default)]
pub(crate) struct InputParser {
tool_parse_data: InputToolParseData,
ident: Option<Ident>,
}
impl InputParser {
pub fn parse(&mut self, input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let struct_ident = input.ident.clone();
self.ident = Some(input.ident);
self.parse_data(input.data).unwrap();
let serialized_data =
serde_json::to_string::<InputToolParseData>(&self.tool_parse_data).unwrap();
let schema_literal = LitStr::new(&serialized_data, struct_ident.span());
let expanded = quote! {
impl ToolInputT for #struct_ident {
fn io_schema() -> &'static str {
#schema_literal
}
}
};
TokenStream::from(expanded)
}
fn parse_data(&mut self, input: Data) -> Result<()> {
match &input {
Data::Struct(struct_data) => self.parse_struct(struct_data)?,
_ => {
return Err(Error::new(
proc_macro2::Span::call_site(),
"Uninon or Enums not yet supported!",
));
}
};
Ok(())
}
fn parse_struct(&mut self, input: &DataStruct) -> Result<()> {
match &input.fields {
syn::Fields::Named(fields) => {
for field in fields.named.iter() {
let field_name = field
.ident
.as_ref()
.expect("Couldn't get the field name!")
.to_string();
let input_property = self.parse_field(field_name.clone(), field)?;
self.tool_parse_data
.add_property(field_name, input_property);
}
}
_ => {
return Err(Error::new(
proc_macro2::Span::call_site(),
"Uninon or Enums not yet supported!",
));
}
}
self.tool_parse_data.set_type(JsonType::Object.to_string());
Ok(())
}
fn parse_field(&mut self, name: String, field: &Field) -> Result<InputToolProperty> {
let (json_type, optional) = self.get_json_type(&field.ty)?;
if !optional {
self.tool_parse_data.add_required_field(name.clone());
}
let mut tool_property: Option<FieldSchemaAttr> = None;
for attr in &field.attrs {
if attr
.path()
.is_ident(InputAttrIdent::Input.to_string().as_str())
{
tool_property = Some(self.parse_macro_attributes(attr, &json_type)?);
}
}
if let Some(property) = tool_property {
Ok(InputToolProperty {
description: property
.description
.map_or_else(|| None, |f| Some(f.value())),
_enum: property.choice.map_or_else(
|| None,
|f| Some(f.iter().map(|f| f.to_string()).collect::<Vec<String>>()),
),
_type: json_type.to_string(),
})
} else {
Err(Error::new(
proc_macro2::Span::call_site(),
"Coudn't Create the tool arg property",
))
}
}
fn get_json_type(&mut self, field_type: &Type) -> Result<(JsonType, bool)> {
match field_type {
Type::Path(path) => {
let Some(segment) = path.path.segments.last() else {
return Err(Error::new(
proc_macro2::Span::call_site(),
"Invalid type path",
));
};
if segment.ident == "Option"
&& let PathArguments::AngleBracketed(args) = &segment.arguments
&& let Some(GenericArgument::Type(inner)) = args.args.first()
{
let (json_type, _) = self.get_json_type(inner)?;
return Ok((json_type, true));
}
if segment.ident == "Option" {
return Err(Error::new(
proc_macro2::Span::call_site(),
"Unsupported Option type",
));
}
let json_type = self.get_base_json_type(&segment.ident.to_string());
Ok((json_type, false))
}
Type::Reference(reference) => self.get_json_type(&reference.elem),
Type::Group(group) => self.get_json_type(&group.elem),
Type::Paren(paren) => self.get_json_type(&paren.elem),
_ => Ok((JsonType::String, false)),
}
}
fn get_base_json_type(&self, type_str: &str) -> JsonType {
match type_str {
"String" | "str" => JsonType::String,
"i32" | "u32" | "f64" | "f32" | "u8" | "i64" | "u16" | "usize" | "isize" => {
JsonType::Number
}
"bool" => JsonType::Boolean,
_ => JsonType::String,
}
}
fn parse_macro_attributes(
&mut self,
attribute: &Attribute,
field_type: &JsonType,
) -> Result<FieldSchemaAttr> {
let attributes = attribute.parse_args::<FieldSchemaAttr>()?;
if let Some(ref enum_vals) = attributes.choice {
let invalid_choice = enum_vals.iter().find(|c| match (c, field_type) {
(Choice::String(_), JsonType::String) => false,
(Choice::Number(_), JsonType::Number) => false,
_ => true, });
if invalid_choice.is_some() {
return Err(Error::new(
proc_macro2::Span::call_site(),
"Choices must be of the same type as the field",
));
}
}
Ok(attributes)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_input_struct_required_and_optional_fields() {
let input: DeriveInput = syn::parse_str(
r#"
struct ToolArgs {
#[input(description = "Id")]
id: String,
#[input(description = "Count")]
count: Option<u32>,
#[input(description = "Mode", choice = ["fast", "slow"])]
mode: String,
}
"#,
)
.unwrap();
let mut parser = InputParser::default();
parser.parse_data(input.data).unwrap();
assert_eq!(parser.tool_parse_data.arg_type, "object");
assert!(parser.tool_parse_data.required.contains(&"id".to_string()));
assert!(
!parser
.tool_parse_data
.required
.contains(&"count".to_string())
);
let mode = parser.tool_parse_data.properties.get("mode").unwrap();
assert_eq!(mode._type, "string");
assert_eq!(mode._enum.as_ref().unwrap().len(), 2);
}
#[test]
fn missing_input_attribute_errors() {
let input: DeriveInput = syn::parse_str(
r#"
struct ToolArgs {
id: String,
}
"#,
)
.unwrap();
let mut parser = InputParser::default();
let err = parser.parse_data(input.data).unwrap_err();
assert!(
err.to_string()
.contains("Coudn't Create the tool arg property")
);
}
#[test]
fn tuple_struct_errors() {
let input: DeriveInput = syn::parse_str(r#"struct ToolArgs(u32);"#).unwrap();
let mut parser = InputParser::default();
let err = parser.parse_data(input.data).unwrap_err();
assert!(
err.to_string()
.contains("Uninon or Enums not yet supported")
);
}
}