1use darling::{FromDeriveInput, FromField, FromMeta, ast};
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{DeriveInput, Expr, Type, parse_macro_input};
9
10#[derive(Debug, FromMeta)]
12struct RangeArgs {
13 min: Expr,
14 max: Expr,
15}
16
17#[derive(Debug, Default, FromMeta)]
19struct Validators {
20 #[darling(default)]
22 range: Option<RangeArgs>,
23
24 #[darling(default)]
26 min: Option<Expr>,
27
28 #[darling(default)]
30 max: Option<Expr>,
31
32 #[darling(default)]
34 optional_range: Option<RangeArgs>,
35
36 #[darling(default)]
38 path: bool,
39
40 #[darling(default)]
42 optional_path: bool,
43}
44
45#[derive(Debug, FromField)]
47#[darling(attributes(validate))]
48struct ValidatedField {
49 ident: Option<syn::Ident>,
50 #[allow(dead_code)]
51 ty: Type,
52 #[darling(flatten)]
53 validators: Validators,
54}
55
56#[derive(Debug, FromDeriveInput)]
58#[darling(attributes(validate), supports(struct_named))]
59struct ConfigValidatorInput {
60 ident: syn::Ident,
61 data: ast::Data<(), ValidatedField>,
62}
63
64#[derive(Debug, FromMeta)]
66struct BuilderAttr {
67 config: syn::Path,
68}
69
70#[derive(Debug, FromField)]
72struct BuilderField {
73 ident: Option<syn::Ident>,
74 ty: Type,
75}
76
77#[derive(Debug, FromDeriveInput)]
79#[darling(attributes(builder), supports(struct_named))]
80struct TaskPredictorBuilderInput {
81 ident: syn::Ident,
82 data: ast::Data<(), BuilderField>,
83 #[darling(flatten)]
84 builder: BuilderAttr,
85}
86
87#[proc_macro_derive(ConfigValidator, attributes(validate))]
125pub fn derive_config_validator(input: TokenStream) -> TokenStream {
126 let input = parse_macro_input!(input as DeriveInput);
127
128 ConfigValidatorInput::from_derive_input(&input)
129 .map(|parsed| generate_config_validator(&parsed))
130 .unwrap_or_else(|err| err.write_errors())
131 .into()
132}
133
134fn generate_config_validator(input: &ConfigValidatorInput) -> proc_macro2::TokenStream {
135 let name = &input.ident;
136
137 let fields = input
138 .data
139 .as_ref()
140 .take_struct()
141 .expect("Only structs are supported");
142
143 let validations: Vec<_> = fields
144 .iter()
145 .filter_map(|field| generate_field_validation(field))
146 .collect();
147
148 quote! {
149 impl crate::core::config::ConfigValidator for #name {
150 fn validate(&self) -> Result<(), crate::core::config::ConfigError> {
151 #(#validations)*
152 Ok(())
153 }
154
155 fn get_defaults() -> Self
156 where
157 Self: Sized,
158 {
159 Self::default()
160 }
161 }
162 }
163}
164
165fn generate_field_validation(field: &ValidatedField) -> Option<proc_macro2::TokenStream> {
166 let field_name = field.ident.as_ref()?;
167 let field_name_str = field_name.to_string();
168 let validators = &field.validators;
169
170 let mut validations = Vec::new();
171
172 if let Some(range) = &validators.range {
174 let min_expr = &range.min;
175 let max_expr = &range.max;
176 validations.push(quote! {
177 if !(#min_expr..=#max_expr).contains(&self.#field_name) {
178 return Err(crate::core::config::ConfigError::InvalidConfig {
179 message: format!(
180 "{} must be between {} and {}",
181 #field_name_str,
182 #min_expr,
183 #max_expr
184 ),
185 });
186 }
187 });
188 }
189
190 if let Some(min_expr) = &validators.min {
192 validations.push(quote! {
193 if self.#field_name < #min_expr {
194 return Err(crate::core::config::ConfigError::InvalidConfig {
195 message: format!("{} must be at least {}", #field_name_str, #min_expr),
196 });
197 }
198 });
199 }
200
201 if let Some(max_expr) = &validators.max {
203 validations.push(quote! {
204 if self.#field_name > #max_expr {
205 return Err(crate::core::config::ConfigError::InvalidConfig {
206 message: format!("{} must be at most {}", #field_name_str, #max_expr),
207 });
208 }
209 });
210 }
211
212 if let Some(range) = &validators.optional_range {
214 let min_expr = &range.min;
215 let max_expr = &range.max;
216 validations.push(quote! {
217 if let Some(value) = self.#field_name {
218 if !(#min_expr..=#max_expr).contains(&value) {
219 return Err(crate::core::config::ConfigError::InvalidConfig {
220 message: format!(
221 "{} must be between {} and {}",
222 #field_name_str,
223 #min_expr,
224 #max_expr
225 ),
226 });
227 }
228 }
229 });
230 }
231
232 if validators.path {
234 validations.push(quote! {
235 self.validate_model_path(&self.#field_name)?;
236 });
237 }
238
239 if validators.optional_path {
241 validations.push(quote! {
242 if let Some(ref path) = self.#field_name {
243 self.validate_model_path(path)?;
244 }
245 });
246 }
247
248 if validations.is_empty() {
249 None
250 } else {
251 Some(quote! { #(#validations)* })
252 }
253}
254
255#[proc_macro_derive(TaskPredictorBuilder, attributes(builder))]
278pub fn derive_task_predictor_builder(input: TokenStream) -> TokenStream {
279 let input = parse_macro_input!(input as DeriveInput);
280
281 TaskPredictorBuilderInput::from_derive_input(&input)
282 .and_then(|parsed| generate_task_predictor_builder(&parsed))
283 .unwrap_or_else(|err| err.write_errors())
284 .into()
285}
286
287fn generate_task_predictor_builder(
288 input: &TaskPredictorBuilderInput,
289) -> darling::Result<proc_macro2::TokenStream> {
290 let name = &input.ident;
291 let config_type = &input.builder.config;
292
293 verify_state_field(input)?;
295
296 Ok(quote! {
297 impl crate::predictors::builder::TaskPredictorBuilder for #name {
298 type Config = #config_type;
299
300 fn state_mut(
301 &mut self,
302 ) -> &mut crate::predictors::builder::PredictorBuilderState<Self::Config> {
303 &mut self.state
304 }
305 }
306
307 impl #name {
308 pub fn with_config(self, config: #config_type) -> Self {
310 <Self as crate::predictors::builder::TaskPredictorBuilder>::with_config(
311 self, config,
312 )
313 }
314
315 pub fn with_ort_config(self, config: crate::core::config::OrtSessionConfig) -> Self {
317 <Self as crate::predictors::builder::TaskPredictorBuilder>::with_ort_config(
318 self, config,
319 )
320 }
321 }
322 })
323}
324
325fn verify_state_field(input: &TaskPredictorBuilderInput) -> darling::Result<()> {
326 let fields = input
327 .data
328 .as_ref()
329 .take_struct()
330 .expect("Only structs are supported");
331
332 let state_field = fields
333 .iter()
334 .find(|f| f.ident.as_ref().is_some_and(|ident| ident == "state"));
335
336 let state_field = match state_field {
337 Some(field) => field,
338 None => {
339 return Err(darling::Error::custom(
340 "Struct must have a `state` field of type PredictorBuilderState<Config>",
341 ));
342 }
343 };
344
345 if !is_predictor_builder_state_type(&state_field.ty) {
347 return Err(darling::Error::custom(
348 "Field `state` must be of type PredictorBuilderState<Config>",
349 )
350 .with_span(&state_field.ty));
351 }
352
353 Ok(())
354}
355
356fn is_predictor_builder_state_type(ty: &Type) -> bool {
357 let Type::Path(type_path) = ty else {
358 return false;
359 };
360
361 let Some(last_segment) = type_path.path.segments.last() else {
362 return false;
363 };
364
365 if last_segment.ident != "PredictorBuilderState" {
366 return false;
367 }
368
369 matches!(
371 &last_segment.arguments,
372 syn::PathArguments::AngleBracketed(args) if args.args.len() == 1
373 )
374}