use darling::{FromDeriveInput, FromField, FromMeta, ast};
use proc_macro::TokenStream;
use quote::quote;
use syn::{DeriveInput, Expr, Type, parse_macro_input};
#[derive(Debug, FromMeta)]
struct RangeArgs {
min: Expr,
max: Expr,
}
#[derive(Debug, Default, FromMeta)]
struct Validators {
#[darling(default)]
range: Option<RangeArgs>,
#[darling(default)]
min: Option<Expr>,
#[darling(default)]
max: Option<Expr>,
#[darling(default)]
optional_range: Option<RangeArgs>,
#[darling(default)]
path: bool,
#[darling(default)]
optional_path: bool,
}
#[derive(Debug, FromField)]
#[darling(attributes(validate))]
struct ValidatedField {
ident: Option<syn::Ident>,
#[allow(dead_code)]
ty: Type,
#[darling(flatten)]
validators: Validators,
}
#[derive(Debug, FromDeriveInput)]
#[darling(attributes(validate), supports(struct_named))]
struct ConfigValidatorInput {
ident: syn::Ident,
data: ast::Data<(), ValidatedField>,
}
#[derive(Debug, FromMeta)]
struct BuilderAttr {
config: syn::Path,
}
#[derive(Debug, FromField)]
struct BuilderField {
ident: Option<syn::Ident>,
ty: Type,
}
#[derive(Debug, FromDeriveInput)]
#[darling(attributes(builder), supports(struct_named))]
struct TaskPredictorBuilderInput {
ident: syn::Ident,
data: ast::Data<(), BuilderField>,
#[darling(flatten)]
builder: BuilderAttr,
}
#[proc_macro_derive(ConfigValidator, attributes(validate))]
pub fn derive_config_validator(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
ConfigValidatorInput::from_derive_input(&input)
.map(|parsed| generate_config_validator(&parsed))
.unwrap_or_else(|err| err.write_errors())
.into()
}
fn generate_config_validator(input: &ConfigValidatorInput) -> proc_macro2::TokenStream {
let name = &input.ident;
let fields = input
.data
.as_ref()
.take_struct()
.expect("Only structs are supported");
let validations: Vec<_> = fields
.iter()
.filter_map(|field| generate_field_validation(field))
.collect();
quote! {
impl crate::core::config::ConfigValidator for #name {
fn validate(&self) -> Result<(), crate::core::config::ConfigError> {
#(#validations)*
Ok(())
}
fn get_defaults() -> Self
where
Self: Sized,
{
Self::default()
}
}
}
}
fn generate_field_validation(field: &ValidatedField) -> Option<proc_macro2::TokenStream> {
let field_name = field.ident.as_ref()?;
let field_name_str = field_name.to_string();
let validators = &field.validators;
let mut validations = Vec::new();
if let Some(range) = &validators.range {
let min_expr = &range.min;
let max_expr = &range.max;
validations.push(quote! {
if !(#min_expr..=#max_expr).contains(&self.#field_name) {
return Err(crate::core::config::ConfigError::InvalidConfig {
message: format!(
"{} must be between {} and {}",
#field_name_str,
#min_expr,
#max_expr
),
});
}
});
}
if let Some(min_expr) = &validators.min {
validations.push(quote! {
if self.#field_name < #min_expr {
return Err(crate::core::config::ConfigError::InvalidConfig {
message: format!("{} must be at least {}", #field_name_str, #min_expr),
});
}
});
}
if let Some(max_expr) = &validators.max {
validations.push(quote! {
if self.#field_name > #max_expr {
return Err(crate::core::config::ConfigError::InvalidConfig {
message: format!("{} must be at most {}", #field_name_str, #max_expr),
});
}
});
}
if let Some(range) = &validators.optional_range {
let min_expr = &range.min;
let max_expr = &range.max;
validations.push(quote! {
if let Some(value) = self.#field_name {
if !(#min_expr..=#max_expr).contains(&value) {
return Err(crate::core::config::ConfigError::InvalidConfig {
message: format!(
"{} must be between {} and {}",
#field_name_str,
#min_expr,
#max_expr
),
});
}
}
});
}
if validators.path {
validations.push(quote! {
self.validate_model_path(&self.#field_name)?;
});
}
if validators.optional_path {
validations.push(quote! {
if let Some(ref path) = self.#field_name {
self.validate_model_path(path)?;
}
});
}
if validations.is_empty() {
None
} else {
Some(quote! { #(#validations)* })
}
}
#[proc_macro_derive(TaskPredictorBuilder, attributes(builder))]
pub fn derive_task_predictor_builder(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
TaskPredictorBuilderInput::from_derive_input(&input)
.and_then(|parsed| generate_task_predictor_builder(&parsed))
.unwrap_or_else(|err| err.write_errors())
.into()
}
fn generate_task_predictor_builder(
input: &TaskPredictorBuilderInput,
) -> darling::Result<proc_macro2::TokenStream> {
let name = &input.ident;
let config_type = &input.builder.config;
verify_state_field(input)?;
Ok(quote! {
impl crate::predictors::builder::TaskPredictorBuilder for #name {
type Config = #config_type;
fn state_mut(
&mut self,
) -> &mut crate::predictors::builder::PredictorBuilderState<Self::Config> {
&mut self.state
}
}
impl #name {
pub fn with_config(self, config: #config_type) -> Self {
<Self as crate::predictors::builder::TaskPredictorBuilder>::with_config(
self, config,
)
}
pub fn with_ort_config(self, config: crate::core::config::OrtSessionConfig) -> Self {
<Self as crate::predictors::builder::TaskPredictorBuilder>::with_ort_config(
self, config,
)
}
}
})
}
fn verify_state_field(input: &TaskPredictorBuilderInput) -> darling::Result<()> {
let fields = input
.data
.as_ref()
.take_struct()
.expect("Only structs are supported");
let state_field = fields
.iter()
.find(|f| f.ident.as_ref().is_some_and(|ident| ident == "state"));
let state_field = match state_field {
Some(field) => field,
None => {
return Err(darling::Error::custom(
"Struct must have a `state` field of type PredictorBuilderState<Config>",
));
}
};
if !is_predictor_builder_state_type(&state_field.ty) {
return Err(darling::Error::custom(
"Field `state` must be of type PredictorBuilderState<Config>",
)
.with_span(&state_field.ty));
}
Ok(())
}
fn is_predictor_builder_state_type(ty: &Type) -> bool {
let Type::Path(type_path) = ty else {
return false;
};
let Some(last_segment) = type_path.path.segments.last() else {
return false;
};
if last_segment.ident != "PredictorBuilderState" {
return false;
}
matches!(
&last_segment.arguments,
syn::PathArguments::AngleBracketed(args) if args.args.len() == 1
)
}