Skip to main content

oar_ocr_derive/
lib.rs

1//! Procedural derive macros for oar-ocr.
2//!
3//! This crate provides derive macros to reduce boilerplate in the oar-ocr library.
4
5use darling::{FromDeriveInput, FromField, FromMeta, ast};
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{DeriveInput, Expr, Type, parse_macro_input};
9
10/// Parsed arguments for range validators: `range(min, max)` or `optional_range(min, max)`
11#[derive(Debug, FromMeta)]
12struct RangeArgs {
13    min: Expr,
14    max: Expr,
15}
16
17/// All supported validators that can be applied to a field.
18#[derive(Debug, Default, FromMeta)]
19struct Validators {
20    /// `#[validate(range(min = expr, max = expr))]` - value must be in [min, max]
21    #[darling(default)]
22    range: Option<RangeArgs>,
23
24    /// `#[validate(min = expr)]` - value must be >= expr
25    #[darling(default)]
26    min: Option<Expr>,
27
28    /// `#[validate(max = expr)]` - value must be <= expr
29    #[darling(default)]
30    max: Option<Expr>,
31
32    /// `#[validate(optional_range(min = expr, max = expr))]` - for Option<T> fields
33    #[darling(default)]
34    optional_range: Option<RangeArgs>,
35
36    /// `#[validate(path)]` - validates path exists
37    #[darling(default)]
38    path: bool,
39
40    /// `#[validate(optional_path)]` - validates path exists for Option<PathBuf>
41    #[darling(default)]
42    optional_path: bool,
43}
44
45/// A single field with its validation rules.
46#[derive(Debug, FromField)]
47#[darling(attributes(validate))]
48struct ValidatedField {
49    ident: Option<syn::Ident>,
50    #[allow(dead_code)]
51    ty: Type,
52    #[darling(flatten)]
53    validators: Validators,
54}
55
56/// The input struct for ConfigValidator derive.
57#[derive(Debug, FromDeriveInput)]
58#[darling(attributes(validate), supports(struct_named))]
59struct ConfigValidatorInput {
60    ident: syn::Ident,
61    data: ast::Data<(), ValidatedField>,
62}
63
64/// Builder attribute: `#[builder(config = ConfigType)]`
65#[derive(Debug, FromMeta)]
66struct BuilderAttr {
67    config: syn::Path,
68}
69
70/// A field in the builder struct.
71#[derive(Debug, FromField)]
72struct BuilderField {
73    ident: Option<syn::Ident>,
74    ty: Type,
75}
76
77/// The input struct for TaskPredictorBuilder derive.
78#[derive(Debug, FromDeriveInput)]
79#[darling(attributes(builder), supports(struct_named))]
80struct TaskPredictorBuilderInput {
81    ident: syn::Ident,
82    data: ast::Data<(), BuilderField>,
83    #[darling(flatten)]
84    builder: BuilderAttr,
85}
86
87/// Derive macro for implementing ConfigValidator trait.
88///
89/// This macro generates a `ConfigValidator` implementation for configuration structs.
90/// Validation rules are specified using the `#[validate(...)]` attribute on fields.
91///
92/// # Supported Validators
93///
94/// - `#[validate(range(min = value, max = value))]` - Validates that the field value is within [min, max]
95/// - `#[validate(min = value)]` - Validates that the field value is at least `value`
96/// - `#[validate(max = value)]` - Validates that the field value is at most `value`
97/// - `#[validate(optional_range(min = value, max = value))]` - Like `range`, but for `Option<T>` fields
98/// - `#[validate(path)]` - Validates that the path exists (for PathBuf fields)
99/// - `#[validate(optional_path)]` - Like `path`, but for `Option<PathBuf>` fields
100///
101/// # Example
102///
103/// ```rust,ignore
104/// use oar_ocr_derive::ConfigValidator;
105///
106/// #[derive(ConfigValidator, Default)]
107/// pub struct TextDetectionConfig {
108///     #[validate(range(min = 0.0, max = 1.0))]
109///     pub score_threshold: f32,
110///
111///     #[validate(range(min = 0.0, max = 1.0))]
112///     pub box_threshold: f32,
113///
114///     #[validate(min = 0.0)]
115///     pub unclip_ratio: f32,
116///
117///     #[validate(min = 1)]
118///     pub max_candidates: usize,
119///
120///     // Fields without #[validate] are not validated
121///     pub limit_side_len: Option<u32>,
122/// }
123/// ```
124#[proc_macro_derive(ConfigValidator, attributes(validate))]
125pub fn derive_config_validator(input: TokenStream) -> TokenStream {
126    let input = parse_macro_input!(input as DeriveInput);
127
128    ConfigValidatorInput::from_derive_input(&input)
129        .map(|parsed| generate_config_validator(&parsed))
130        .unwrap_or_else(|err| err.write_errors())
131        .into()
132}
133
134fn generate_config_validator(input: &ConfigValidatorInput) -> proc_macro2::TokenStream {
135    let name = &input.ident;
136
137    let fields = input
138        .data
139        .as_ref()
140        .take_struct()
141        .expect("Only structs are supported");
142
143    let validations: Vec<_> = fields
144        .iter()
145        .filter_map(|field| generate_field_validation(field))
146        .collect();
147
148    quote! {
149        impl crate::core::config::ConfigValidator for #name {
150            fn validate(&self) -> Result<(), crate::core::config::ConfigError> {
151                #(#validations)*
152                Ok(())
153            }
154
155            fn get_defaults() -> Self
156            where
157                Self: Sized,
158            {
159                Self::default()
160            }
161        }
162    }
163}
164
165fn generate_field_validation(field: &ValidatedField) -> Option<proc_macro2::TokenStream> {
166    let field_name = field.ident.as_ref()?;
167    let field_name_str = field_name.to_string();
168    let validators = &field.validators;
169
170    let mut validations = Vec::new();
171
172    // Range validation
173    if let Some(range) = &validators.range {
174        let min_expr = &range.min;
175        let max_expr = &range.max;
176        validations.push(quote! {
177            if !(#min_expr..=#max_expr).contains(&self.#field_name) {
178                return Err(crate::core::config::ConfigError::InvalidConfig {
179                    message: format!(
180                        "{} must be between {} and {}",
181                        #field_name_str,
182                        #min_expr,
183                        #max_expr
184                    ),
185                });
186            }
187        });
188    }
189
190    // Min validation
191    if let Some(min_expr) = &validators.min {
192        validations.push(quote! {
193            if self.#field_name < #min_expr {
194                return Err(crate::core::config::ConfigError::InvalidConfig {
195                    message: format!("{} must be at least {}", #field_name_str, #min_expr),
196                });
197            }
198        });
199    }
200
201    // Max validation
202    if let Some(max_expr) = &validators.max {
203        validations.push(quote! {
204            if self.#field_name > #max_expr {
205                return Err(crate::core::config::ConfigError::InvalidConfig {
206                    message: format!("{} must be at most {}", #field_name_str, #max_expr),
207                });
208            }
209        });
210    }
211
212    // Optional range validation
213    if let Some(range) = &validators.optional_range {
214        let min_expr = &range.min;
215        let max_expr = &range.max;
216        validations.push(quote! {
217            if let Some(value) = self.#field_name {
218                if !(#min_expr..=#max_expr).contains(&value) {
219                    return Err(crate::core::config::ConfigError::InvalidConfig {
220                        message: format!(
221                            "{} must be between {} and {}",
222                            #field_name_str,
223                            #min_expr,
224                            #max_expr
225                        ),
226                    });
227                }
228            }
229        });
230    }
231
232    // Path validation
233    if validators.path {
234        validations.push(quote! {
235            self.validate_model_path(&self.#field_name)?;
236        });
237    }
238
239    // Optional path validation
240    if validators.optional_path {
241        validations.push(quote! {
242            if let Some(ref path) = self.#field_name {
243                self.validate_model_path(path)?;
244            }
245        });
246    }
247
248    if validations.is_empty() {
249        None
250    } else {
251        Some(quote! { #(#validations)* })
252    }
253}
254
255/// Derive macro for implementing TaskPredictorBuilder trait.
256///
257/// This macro generates the `TaskPredictorBuilder` trait implementation and
258/// common builder methods (`with_config`, `with_ort_config`).
259///
260/// # Requirements
261///
262/// - The struct must have a field named `state` of type `PredictorBuilderState<Config>`
263/// - The config type must be specified using `#[builder(config = ConfigType)]`
264///
265/// # Example
266///
267/// ```rust,ignore
268/// use oar_ocr_derive::TaskPredictorBuilder;
269/// use oar_ocr::predictors::builder::PredictorBuilderState;
270///
271/// #[derive(TaskPredictorBuilder)]
272/// #[builder(config = TextDetectionConfig)]
273/// pub struct TextDetectionPredictorBuilder {
274///     state: PredictorBuilderState<TextDetectionConfig>,
275/// }
276/// ```
277#[proc_macro_derive(TaskPredictorBuilder, attributes(builder))]
278pub fn derive_task_predictor_builder(input: TokenStream) -> TokenStream {
279    let input = parse_macro_input!(input as DeriveInput);
280
281    TaskPredictorBuilderInput::from_derive_input(&input)
282        .and_then(|parsed| generate_task_predictor_builder(&parsed))
283        .unwrap_or_else(|err| err.write_errors())
284        .into()
285}
286
287fn generate_task_predictor_builder(
288    input: &TaskPredictorBuilderInput,
289) -> darling::Result<proc_macro2::TokenStream> {
290    let name = &input.ident;
291    let config_type = &input.builder.config;
292
293    // Verify the struct has a `state` field with correct type
294    verify_state_field(input)?;
295
296    Ok(quote! {
297        impl crate::predictors::builder::TaskPredictorBuilder for #name {
298            type Config = #config_type;
299
300            fn state_mut(
301                &mut self,
302            ) -> &mut crate::predictors::builder::PredictorBuilderState<Self::Config> {
303                &mut self.state
304            }
305        }
306
307        impl #name {
308            /// Replace the full task configuration used by this builder.
309            pub fn with_config(self, config: #config_type) -> Self {
310                <Self as crate::predictors::builder::TaskPredictorBuilder>::with_config(
311                    self, config,
312                )
313            }
314
315            /// Configure ONNX Runtime session options.
316            pub fn with_ort_config(self, config: crate::core::config::OrtSessionConfig) -> Self {
317                <Self as crate::predictors::builder::TaskPredictorBuilder>::with_ort_config(
318                    self, config,
319                )
320            }
321        }
322    })
323}
324
325fn verify_state_field(input: &TaskPredictorBuilderInput) -> darling::Result<()> {
326    let fields = input
327        .data
328        .as_ref()
329        .take_struct()
330        .expect("Only structs are supported");
331
332    let state_field = fields
333        .iter()
334        .find(|f| f.ident.as_ref().is_some_and(|ident| ident == "state"));
335
336    let state_field = match state_field {
337        Some(field) => field,
338        None => {
339            return Err(darling::Error::custom(
340                "Struct must have a `state` field of type PredictorBuilderState<Config>",
341            ));
342        }
343    };
344
345    // Verify the type is PredictorBuilderState<...>
346    if !is_predictor_builder_state_type(&state_field.ty) {
347        return Err(darling::Error::custom(
348            "Field `state` must be of type PredictorBuilderState<Config>",
349        )
350        .with_span(&state_field.ty));
351    }
352
353    Ok(())
354}
355
356fn is_predictor_builder_state_type(ty: &Type) -> bool {
357    let Type::Path(type_path) = ty else {
358        return false;
359    };
360
361    let Some(last_segment) = type_path.path.segments.last() else {
362        return false;
363    };
364
365    if last_segment.ident != "PredictorBuilderState" {
366        return false;
367    }
368
369    // Verify it has exactly one generic argument
370    matches!(
371        &last_segment.arguments,
372        syn::PathArguments::AngleBracketed(args) if args.args.len() == 1
373    )
374}