oar_ocr_derive/
lib.rs

1//! Procedural derive macros for oar-ocr.
2//!
3//! This crate provides derive macros to reduce boilerplate in the oar-ocr library.
4
5use proc_macro::TokenStream;
6use quote::quote;
7use syn::{DeriveInput, Expr, Field, Meta, Type, parse_macro_input};
8
9/// Derive macro for implementing ConfigValidator trait.
10///
11/// This macro generates a `ConfigValidator` implementation for configuration structs.
12/// Validation rules are specified using the `#[validate(...)]` attribute on fields.
13///
14/// # Supported Validators
15///
16/// - `#[validate(range(min, max))]` - Validates that the field value is within the inclusive range [min, max]
17/// - `#[validate(min(value))]` - Validates that the field value is at least `value`
18/// - `#[validate(max(value))]` - Validates that the field value is at most `value`
19/// - `#[validate(optional_range(min, max))]` - Like `range`, but for `Option<T>` fields (only validates if Some)
20/// - `#[validate(path)]` - Validates that the path exists (for PathBuf fields)
21/// - `#[validate(optional_path)]` - Like `path`, but for `Option<PathBuf>` fields
22///
23/// # Example
24///
25/// ```rust,ignore
26/// use oar_ocr_derive::ConfigValidator;
27///
28/// #[derive(ConfigValidator, Default)]
29/// pub struct TextDetectionConfig {
30///     #[validate(range(0.0, 1.0))]
31///     pub score_threshold: f32,
32///
33///     #[validate(range(0.0, 1.0))]
34///     pub box_threshold: f32,
35///
36///     #[validate(min(0.0))]
37///     pub unclip_ratio: f32,
38///
39///     #[validate(min(1))]
40///     pub max_candidates: usize,
41///
42///     // Fields without #[validate] are not validated
43///     pub limit_side_len: Option<u32>,
44/// }
45/// ```
46#[proc_macro_derive(ConfigValidator, attributes(validate))]
47pub fn derive_config_validator(input: TokenStream) -> TokenStream {
48    let input = parse_macro_input!(input as DeriveInput);
49    impl_config_validator(&input)
50        .unwrap_or_else(|err| err.to_compile_error())
51        .into()
52}
53
54fn impl_config_validator(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
55    let name = &input.ident;
56
57    let fields = match &input.data {
58        syn::Data::Struct(data) => match &data.fields {
59            syn::Fields::Named(fields) => &fields.named,
60            _ => {
61                return Err(syn::Error::new_spanned(
62                    input,
63                    "ConfigValidator can only be derived for structs with named fields",
64                ));
65            }
66        },
67        _ => {
68            return Err(syn::Error::new_spanned(
69                input,
70                "ConfigValidator can only be derived for structs",
71            ));
72        }
73    };
74
75    let validations = fields
76        .iter()
77        .filter_map(|field| generate_field_validation(field).transpose())
78        .collect::<syn::Result<Vec<_>>>()?;
79
80    Ok(quote! {
81        impl crate::core::config::ConfigValidator for #name {
82            fn validate(&self) -> Result<(), crate::core::config::ConfigError> {
83                #(#validations)*
84                Ok(())
85            }
86
87            fn get_defaults() -> Self
88            where
89                Self: Sized,
90            {
91                Self::default()
92            }
93        }
94    })
95}
96
97fn generate_field_validation(field: &Field) -> syn::Result<Option<proc_macro2::TokenStream>> {
98    let field_name = field
99        .ident
100        .as_ref()
101        .ok_or_else(|| syn::Error::new_spanned(field, "Expected named field"))?;
102
103    let field_name_str = field_name.to_string();
104
105    let mut validations = Vec::new();
106
107    for attr in &field.attrs {
108        if !attr.path().is_ident("validate") {
109            continue;
110        }
111
112        let meta = attr.parse_args::<Meta>()?;
113        validations.push(generate_validation_code(
114            field_name,
115            &field_name_str,
116            &meta,
117            &field.ty,
118        )?);
119    }
120
121    if validations.is_empty() {
122        Ok(None)
123    } else {
124        Ok(Some(quote! { #(#validations)* }))
125    }
126}
127
128fn generate_validation_code(
129    field_name: &syn::Ident,
130    field_name_str: &str,
131    meta: &Meta,
132    _field_ty: &Type,
133) -> syn::Result<proc_macro2::TokenStream> {
134    match meta {
135        Meta::List(list) => {
136            let validator_name = list
137                .path
138                .get_ident()
139                .ok_or_else(|| syn::Error::new_spanned(&list.path, "Expected validator name"))?;
140            let validator_str = validator_name.to_string();
141
142            match validator_str.as_str() {
143                "range" => {
144                    let args = parse_two_args(&list.tokens)?;
145                    let (min_expr, max_expr) = args;
146                    Ok(quote! {
147                        if !(#min_expr..=#max_expr).contains(&self.#field_name) {
148                            return Err(crate::core::config::ConfigError::InvalidConfig {
149                                message: format!(
150                                    "{} must be between {} and {}",
151                                    #field_name_str,
152                                    #min_expr,
153                                    #max_expr
154                                ),
155                            });
156                        }
157                    })
158                }
159                "min" => {
160                    let min_expr = parse_one_arg(&list.tokens)?;
161                    Ok(quote! {
162                        if self.#field_name < #min_expr {
163                            return Err(crate::core::config::ConfigError::InvalidConfig {
164                                message: format!("{} must be at least {}", #field_name_str, #min_expr),
165                            });
166                        }
167                    })
168                }
169                "max" => {
170                    let max_expr = parse_one_arg(&list.tokens)?;
171                    Ok(quote! {
172                        if self.#field_name > #max_expr {
173                            return Err(crate::core::config::ConfigError::InvalidConfig {
174                                message: format!("{} must be at most {}", #field_name_str, #max_expr),
175                            });
176                        }
177                    })
178                }
179                "optional_range" => {
180                    let args = parse_two_args(&list.tokens)?;
181                    let (min_expr, max_expr) = args;
182                    Ok(quote! {
183                        if let Some(value) = self.#field_name {
184                            if !(#min_expr..=#max_expr).contains(&value) {
185                                return Err(crate::core::config::ConfigError::InvalidConfig {
186                                    message: format!(
187                                        "{} must be between {} and {}",
188                                        #field_name_str,
189                                        #min_expr,
190                                        #max_expr
191                                    ),
192                                });
193                            }
194                        }
195                    })
196                }
197                "path" => Ok(generate_path_validation(field_name, false)),
198                "optional_path" => Ok(generate_path_validation(field_name, true)),
199                other => Err(syn::Error::new_spanned(
200                    validator_name,
201                    format!("Unknown validator: {}", other),
202                )),
203            }
204        }
205        Meta::Path(path) => {
206            let validator_name = path
207                .get_ident()
208                .ok_or_else(|| syn::Error::new_spanned(path, "Expected validator name"))?;
209            let validator_str = validator_name.to_string();
210
211            match validator_str.as_str() {
212                "path" => Ok(generate_path_validation(field_name, false)),
213                "optional_path" => Ok(generate_path_validation(field_name, true)),
214                other => Err(syn::Error::new_spanned(
215                    validator_name,
216                    format!("Unknown validator without arguments: {}", other),
217                )),
218            }
219        }
220        _ => Err(syn::Error::new_spanned(meta, "Invalid validator format")),
221    }
222}
223
224fn generate_path_validation(field_name: &syn::Ident, optional: bool) -> proc_macro2::TokenStream {
225    if optional {
226        quote! {
227            if let Some(ref path) = self.#field_name {
228                self.validate_model_path(path)?;
229            }
230        }
231    } else {
232        quote! {
233            self.validate_model_path(&self.#field_name)?;
234        }
235    }
236}
237
238fn parse_one_arg(tokens: &proc_macro2::TokenStream) -> syn::Result<Expr> {
239    syn::parse2(tokens.clone())
240}
241
242fn parse_two_args(tokens: &proc_macro2::TokenStream) -> syn::Result<(Expr, Expr)> {
243    use syn::Token;
244    use syn::parse::Parser;
245    use syn::punctuated::Punctuated;
246
247    let parser = Punctuated::<Expr, Token![,]>::parse_terminated;
248    let args = parser.parse2(tokens.clone())?;
249    let mut iter = args.into_iter();
250
251    let first = iter
252        .next()
253        .ok_or_else(|| syn::Error::new_spanned(tokens, "Expected two arguments"))?;
254    let second = iter
255        .next()
256        .ok_or_else(|| syn::Error::new_spanned(tokens, "Expected two arguments"))?;
257
258    if iter.next().is_some() {
259        return Err(syn::Error::new_spanned(
260            tokens,
261            "Expected exactly two arguments",
262        ));
263    }
264
265    Ok((first, second))
266}
267
268/// Derive macro for implementing TaskPredictorBuilder trait.
269///
270/// This macro generates the `TaskPredictorBuilder` trait implementation and
271/// common builder methods (`with_config`, `with_ort_config`).
272///
273/// # Requirements
274///
275/// - The struct must have a field named `state` of type `PredictorBuilderState<Config>`
276/// - The config type must be specified using `#[builder(config = ConfigType)]`
277///
278/// # Example
279///
280/// ```rust,ignore
281/// use oar_ocr_derive::TaskPredictorBuilder;
282/// use oar_ocr::predictors::builder::PredictorBuilderState;
283///
284/// #[derive(TaskPredictorBuilder)]
285/// #[builder(config = TextDetectionConfig)]
286/// pub struct TextDetectionPredictorBuilder {
287///     state: PredictorBuilderState<TextDetectionConfig>,
288/// }
289/// ```
290#[proc_macro_derive(TaskPredictorBuilder, attributes(builder))]
291pub fn derive_task_predictor_builder(input: TokenStream) -> TokenStream {
292    let input = parse_macro_input!(input as DeriveInput);
293    impl_task_predictor_builder(&input)
294        .unwrap_or_else(|err| err.to_compile_error())
295        .into()
296}
297
298fn impl_task_predictor_builder(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
299    let name = &input.ident;
300
301    // Find the #[builder(config = Type)] attribute
302    let config_type = find_builder_config_type(input)?;
303
304    // Verify the struct has a `state` field
305    verify_state_field(input)?;
306
307    Ok(quote! {
308        impl crate::predictors::builder::TaskPredictorBuilder for #name {
309            type Config = #config_type;
310
311            fn state_mut(
312                &mut self,
313            ) -> &mut crate::predictors::builder::PredictorBuilderState<Self::Config> {
314                &mut self.state
315            }
316        }
317
318        impl #name {
319            /// Replace the full task configuration used by this builder.
320            pub fn with_config(self, config: #config_type) -> Self {
321                <Self as crate::predictors::builder::TaskPredictorBuilder>::with_config(
322                    self, config,
323                )
324            }
325
326            /// Configure ONNX Runtime session options.
327            pub fn with_ort_config(self, config: crate::core::config::OrtSessionConfig) -> Self {
328                <Self as crate::predictors::builder::TaskPredictorBuilder>::with_ort_config(
329                    self, config,
330                )
331            }
332        }
333    })
334}
335
336fn find_builder_config_type(input: &DeriveInput) -> syn::Result<Type> {
337    for attr in &input.attrs {
338        if !attr.path().is_ident("builder") {
339            continue;
340        }
341
342        let meta = attr.parse_args::<Meta>()?;
343
344        if let Meta::NameValue(nv) = meta
345            && nv.path.is_ident("config")
346        {
347            if let Expr::Path(expr_path) = nv.value {
348                return Ok(Type::Path(syn::TypePath {
349                    qself: None,
350                    path: expr_path.path,
351                }));
352            } else {
353                return Err(syn::Error::new_spanned(
354                    nv.value,
355                    "Expected a type path, e.g., #[builder(config = MyConfigType)]",
356                ));
357            }
358        }
359    }
360
361    Err(syn::Error::new_spanned(
362        input,
363        "Missing #[builder(config = ConfigType)] attribute",
364    ))
365}
366
367fn verify_state_field(input: &DeriveInput) -> syn::Result<()> {
368    let fields = match &input.data {
369        syn::Data::Struct(data) => match &data.fields {
370            syn::Fields::Named(fields) => &fields.named,
371            _ => {
372                return Err(syn::Error::new_spanned(
373                    input,
374                    "TaskPredictorBuilder can only be derived for structs with named fields",
375                ));
376            }
377        },
378        _ => {
379            return Err(syn::Error::new_spanned(
380                input,
381                "TaskPredictorBuilder can only be derived for structs",
382            ));
383        }
384    };
385
386    let state_field = fields
387        .iter()
388        .find(|f| f.ident.as_ref().is_some_and(|ident| ident == "state"));
389
390    let state_field = match state_field {
391        Some(field) => field,
392        None => {
393            return Err(syn::Error::new_spanned(
394                input,
395                "Struct must have a `state` field of type PredictorBuilderState<Config>",
396            ));
397        }
398    };
399
400    // Verify the type is PredictorBuilderState<...>
401    if !is_predictor_builder_state_type(&state_field.ty) {
402        return Err(syn::Error::new_spanned(
403            &state_field.ty,
404            "Field `state` must be of type PredictorBuilderState<Config>",
405        ));
406    }
407
408    Ok(())
409}
410
411fn is_predictor_builder_state_type(ty: &Type) -> bool {
412    let Type::Path(type_path) = ty else {
413        return false;
414    };
415
416    let last_segment = match type_path.path.segments.last() {
417        Some(seg) => seg,
418        None => return false,
419    };
420
421    if last_segment.ident != "PredictorBuilderState" {
422        return false;
423    }
424
425    // Verify it has exactly one generic argument
426    matches!(
427        &last_segment.arguments,
428        syn::PathArguments::AngleBracketed(args) if args.args.len() == 1
429    )
430}