modelfox_macro 0.8.0

ModelFox makes it easy for programmers to train, deploy, and monitor machine learning models.
Documentation
use quote::quote;
use syn::spanned::Spanned;

#[proc_macro]
pub fn predict_input(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
	predict_input_impl(input.into())
		.unwrap_or_else(|e| e.to_compile_error())
		.into()
}

struct PredictInput {
	entries: Vec<PredictInputEntry>,
}

impl syn::parse::Parse for PredictInput {
	fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
		let entries =
			<syn::punctuated::Punctuated<PredictInputEntry, syn::Token![,]>>::parse_terminated(
				input,
			)?;
		let entries = entries.into_iter().collect();
		Ok(PredictInput { entries })
	}
}

struct PredictInputEntry {
	column_name: syn::LitStr,
	value: syn::Expr,
}

impl syn::parse::Parse for PredictInputEntry {
	fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
		let column_name = input.parse()?;
		input.parse::<syn::Token![:]>()?;
		let value = input.parse()?;
		Ok(PredictInputEntry { column_name, value })
	}
}

fn predict_input_impl(input: proc_macro2::TokenStream) -> syn::Result<proc_macro2::TokenStream> {
	let input: PredictInput = syn::parse2(input)?;
	let column_names = input
		.entries
		.iter()
		.map(|entry| &entry.column_name)
		.collect::<Vec<_>>();
	let values = input
		.entries
		.iter()
		.map(|entry| &entry.value)
		.collect::<Vec<_>>();
	let code = quote! {{
		let mut map = std::collections::BTreeMap::new();
		#(
			map.insert(#column_names.to_owned(), #values.into());
		)*
		modelfox::PredictInput(map)
	}};
	Ok(code)
}

#[proc_macro_derive(PredictInput, attributes(modelfox))]
pub fn predict_input_derive_macro(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
	predict_input_derive_macro_impl(input.into())
		.unwrap_or_else(|e| e.to_compile_error())
		.into()
}

fn predict_input_derive_macro_impl(
	input: proc_macro2::TokenStream,
) -> syn::Result<proc_macro2::TokenStream> {
	let input: syn::DeriveInput = syn::parse2(input)?;
	let ident = &input.ident;
	let data = match &input.data {
		syn::Data::Struct(data) => data,
		_ => {
			return Err(syn::Error::new_spanned(
				input,
				"this macro can only be used on a struct",
			))
		}
	};
	let insert_statements = data
		.fields
		.iter()
		.map(|field| {
			let field_ident = field
				.ident
				.as_ref()
				.ok_or_else(|| syn::Error::new(field.span(), "field must have ident"))?;
			let column_name =
				predict_input_field_rename(field)?.unwrap_or_else(|| field_ident.to_string());
			let code = quote! {
				map.insert(#column_name.to_owned(), value.#field_ident.into());
			};
			Ok(code)
		})
		.collect::<syn::Result<Vec<_>>>()?;
	let code = quote! {
		impl From<#ident> for modelfox::PredictInput {
			fn from(value: #ident) -> modelfox::PredictInput {
				let mut map = std::collections::BTreeMap::new();
				#(#insert_statements)*
				modelfox::PredictInput(map)
			}
		}
	};
	Ok(code)
}

fn predict_input_field_rename(field: &syn::Field) -> syn::Result<Option<String>> {
	let attr = field
		.attrs
		.iter()
		.find(|attr| attr.path.is_ident("modelfox"));
	let attr = if let Some(attr) = attr {
		attr
	} else {
		return Ok(None);
	};
	let meta = attr.parse_meta()?;
	let list = match meta {
		syn::Meta::List(list) => list,
		_ => {
			return Err(syn::Error::new_spanned(
				attr,
				"modelfox attribute must contain a list",
			))
		}
	};
	let mut rename = None;
	for item in list.nested.iter() {
		match item {
			syn::NestedMeta::Meta(syn::Meta::NameValue(item)) if item.path.is_ident("rename") => {
				let value = if let syn::Lit::Str(value) = &item.lit {
					Some(value)
				} else {
					None
				};
				let value = value.ok_or_else(|| {
					syn::Error::new_spanned(&item, "value for attribute \"value\" must be a string")
				})?;
				rename = Some(value);
			}
			_ => {}
		}
	}
	let rename = rename.ok_or_else(|| {
		syn::Error::new_spanned(&list.nested, "an attribute with key \"value\" is required")
	})?;
	let rename = rename.value();
	Ok(Some(rename))
}

#[proc_macro_derive(PredictInputValue, attributes(modelfox))]
pub fn predict_input_value_derive_macro(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
	predict_input_value_derive_macro_impl(input.into())
		.unwrap_or_else(|e| e.to_compile_error())
		.into()
}

fn predict_input_value_derive_macro_impl(
	input: proc_macro2::TokenStream,
) -> syn::Result<proc_macro2::TokenStream> {
	let input: syn::DeriveInput = syn::parse2(input)?;
	let ident = &input.ident;
	let data = match &input.data {
		syn::Data::Enum(data) => data,
		_ => {
			return Err(syn::Error::new(
				input.span(),
				"this macro can only be used on an enum",
			))
		}
	};
	let match_arms = data
		.variants
		.iter()
		.map(|variant| {
			let variant_ident = &variant.ident;
			let variant_value = predict_input_value_variant_value(variant)?
				.unwrap_or_else(|| variant_ident.to_string());
			let code = quote! { #ident::#variant_ident => #variant_value };
			Ok(code)
		})
		.collect::<syn::Result<Vec<_>>>()?;
	let code = quote! {
		impl From<#ident> for modelfox::PredictInputValue {
			fn from(value: #ident) -> modelfox::PredictInputValue {
				let value = match value {
					#(#match_arms,)*
				};
				modelfox::PredictInputValue::String(value.to_owned())
			}
		}
	};
	Ok(code)
}

fn predict_input_value_variant_value(variant: &syn::Variant) -> syn::Result<Option<String>> {
	let attr = variant
		.attrs
		.iter()
		.find(|attr| attr.path.is_ident("modelfox"));
	let attr = if let Some(attr) = attr {
		attr
	} else {
		return Ok(None);
	};
	let meta = attr.parse_meta()?;
	let list = match meta {
		syn::Meta::List(list) => list,
		_ => {
			return Err(syn::Error::new_spanned(
				attr,
				"modelfox attribute must contain a list",
			))
		}
	};
	let mut input_value = None;
	for item in list.nested.iter() {
		match item {
			syn::NestedMeta::Meta(syn::Meta::NameValue(item)) if item.path.is_ident("value") => {
				let value = if let syn::Lit::Str(value) = &item.lit {
					Some(value)
				} else {
					None
				};
				let value = value.ok_or_else(|| {
					syn::Error::new_spanned(&item, "value for attribute \"value\" must be a string")
				})?;
				input_value = Some(value);
			}
			_ => {}
		}
	}
	let input_value = input_value.ok_or_else(|| {
		syn::Error::new_spanned(&list.nested, "an attribute with key \"value\" is required")
	})?;
	let input_value = input_value.value();
	Ok(Some(input_value))
}

#[proc_macro_derive(ClassificationOutputValue, attributes(modelfox))]
pub fn classification_output_value_derive_macro(
	input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
	classification_output_value_derive_macro_impl(input.into())
		.unwrap_or_else(|e| e.to_compile_error())
		.into()
}

fn classification_output_value_derive_macro_impl(
	input: proc_macro2::TokenStream,
) -> syn::Result<proc_macro2::TokenStream> {
	let input: syn::DeriveInput = syn::parse2(input)?;
	let ident = &input.ident;
	let data = match &input.data {
		syn::Data::Enum(data) => data,
		_ => {
			return Err(syn::Error::new(
				input.span(),
				"this macro can only be used on an enum",
			))
		}
	};
	let from_str_match_arms = data
		.variants
		.iter()
		.map(|variant| {
			let variant_ident = &variant.ident;
			let variant_value = classification_output_value_variant_value(variant)?
				.unwrap_or_else(|| variant_ident.to_string());
			let code = quote! { #variant_value => #ident::#variant_ident };
			Ok(code)
		})
		.collect::<syn::Result<Vec<_>>>()?;
	let as_str_match_arms = data
		.variants
		.iter()
		.map(|variant| {
			let variant_ident = &variant.ident;
			let variant_value = classification_output_value_variant_value(variant)?
				.unwrap_or_else(|| variant_ident.to_string());
			let code = quote! { #ident::#variant_ident => #variant_value };
			Ok(code)
		})
		.collect::<syn::Result<Vec<_>>>()?;
	let code = quote! {
		impl modelfox::ClassificationOutputValue for #ident {
			fn from_str(value: &str) -> Self {
				match value {
					#(#from_str_match_arms,)*
					_ => panic!("unexpected value"),
				}
			}
			fn as_str(&self) -> &str {
				match self {
					#(#as_str_match_arms,)*
				}
			}
		}
	};
	Ok(code)
}

fn classification_output_value_variant_value(
	variant: &syn::Variant,
) -> syn::Result<Option<String>> {
	let attr = variant
		.attrs
		.iter()
		.find(|attr| attr.path.is_ident("modelfox"));
	let attr = if let Some(attr) = attr {
		attr
	} else {
		return Ok(None);
	};
	let meta = attr.parse_meta()?;
	let list = match meta {
		syn::Meta::List(list) => list,
		_ => {
			return Err(syn::Error::new_spanned(
				attr,
				"modelfox attribute must contain a list",
			))
		}
	};
	let mut input_value = None;
	for item in list.nested.iter() {
		match item {
			syn::NestedMeta::Meta(syn::Meta::NameValue(item)) if item.path.is_ident("value") => {
				let value = if let syn::Lit::Str(value) = &item.lit {
					Some(value)
				} else {
					None
				};
				let value = value.ok_or_else(|| {
					syn::Error::new_spanned(&item, "value for attribute \"value\" must be a string")
				})?;
				input_value = Some(value);
			}
			_ => {}
		}
	}
	let input_value = input_value.ok_or_else(|| {
		syn::Error::new_spanned(&list.nested, "an attribute with key \"value\" is required")
	})?;
	let input_value = input_value.value();
	Ok(Some(input_value))
}