1use proc_macro::TokenStream;
6use quote::quote;
7use syn::{DeriveInput, Expr, Field, Meta, Type, parse_macro_input};
8
9#[proc_macro_derive(ConfigValidator, attributes(validate))]
47pub fn derive_config_validator(input: TokenStream) -> TokenStream {
48 let input = parse_macro_input!(input as DeriveInput);
49 impl_config_validator(&input)
50 .unwrap_or_else(|err| err.to_compile_error())
51 .into()
52}
53
54fn impl_config_validator(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
55 let name = &input.ident;
56
57 let fields = match &input.data {
58 syn::Data::Struct(data) => match &data.fields {
59 syn::Fields::Named(fields) => &fields.named,
60 _ => {
61 return Err(syn::Error::new_spanned(
62 input,
63 "ConfigValidator can only be derived for structs with named fields",
64 ));
65 }
66 },
67 _ => {
68 return Err(syn::Error::new_spanned(
69 input,
70 "ConfigValidator can only be derived for structs",
71 ));
72 }
73 };
74
75 let validations = fields
76 .iter()
77 .filter_map(|field| generate_field_validation(field).transpose())
78 .collect::<syn::Result<Vec<_>>>()?;
79
80 Ok(quote! {
81 impl crate::core::config::ConfigValidator for #name {
82 fn validate(&self) -> Result<(), crate::core::config::ConfigError> {
83 #(#validations)*
84 Ok(())
85 }
86
87 fn get_defaults() -> Self
88 where
89 Self: Sized,
90 {
91 Self::default()
92 }
93 }
94 })
95}
96
97fn generate_field_validation(field: &Field) -> syn::Result<Option<proc_macro2::TokenStream>> {
98 let field_name = field
99 .ident
100 .as_ref()
101 .ok_or_else(|| syn::Error::new_spanned(field, "Expected named field"))?;
102
103 let field_name_str = field_name.to_string();
104
105 let mut validations = Vec::new();
106
107 for attr in &field.attrs {
108 if !attr.path().is_ident("validate") {
109 continue;
110 }
111
112 let meta = attr.parse_args::<Meta>()?;
113 validations.push(generate_validation_code(
114 field_name,
115 &field_name_str,
116 &meta,
117 &field.ty,
118 )?);
119 }
120
121 if validations.is_empty() {
122 Ok(None)
123 } else {
124 Ok(Some(quote! { #(#validations)* }))
125 }
126}
127
128fn generate_validation_code(
129 field_name: &syn::Ident,
130 field_name_str: &str,
131 meta: &Meta,
132 _field_ty: &Type,
133) -> syn::Result<proc_macro2::TokenStream> {
134 match meta {
135 Meta::List(list) => {
136 let validator_name = list
137 .path
138 .get_ident()
139 .ok_or_else(|| syn::Error::new_spanned(&list.path, "Expected validator name"))?;
140 let validator_str = validator_name.to_string();
141
142 match validator_str.as_str() {
143 "range" => {
144 let args = parse_two_args(&list.tokens)?;
145 let (min_expr, max_expr) = args;
146 Ok(quote! {
147 if !(#min_expr..=#max_expr).contains(&self.#field_name) {
148 return Err(crate::core::config::ConfigError::InvalidConfig {
149 message: format!(
150 "{} must be between {} and {}",
151 #field_name_str,
152 #min_expr,
153 #max_expr
154 ),
155 });
156 }
157 })
158 }
159 "min" => {
160 let min_expr = parse_one_arg(&list.tokens)?;
161 Ok(quote! {
162 if self.#field_name < #min_expr {
163 return Err(crate::core::config::ConfigError::InvalidConfig {
164 message: format!("{} must be at least {}", #field_name_str, #min_expr),
165 });
166 }
167 })
168 }
169 "max" => {
170 let max_expr = parse_one_arg(&list.tokens)?;
171 Ok(quote! {
172 if self.#field_name > #max_expr {
173 return Err(crate::core::config::ConfigError::InvalidConfig {
174 message: format!("{} must be at most {}", #field_name_str, #max_expr),
175 });
176 }
177 })
178 }
179 "optional_range" => {
180 let args = parse_two_args(&list.tokens)?;
181 let (min_expr, max_expr) = args;
182 Ok(quote! {
183 if let Some(value) = self.#field_name {
184 if !(#min_expr..=#max_expr).contains(&value) {
185 return Err(crate::core::config::ConfigError::InvalidConfig {
186 message: format!(
187 "{} must be between {} and {}",
188 #field_name_str,
189 #min_expr,
190 #max_expr
191 ),
192 });
193 }
194 }
195 })
196 }
197 "path" => Ok(generate_path_validation(field_name, false)),
198 "optional_path" => Ok(generate_path_validation(field_name, true)),
199 other => Err(syn::Error::new_spanned(
200 validator_name,
201 format!("Unknown validator: {}", other),
202 )),
203 }
204 }
205 Meta::Path(path) => {
206 let validator_name = path
207 .get_ident()
208 .ok_or_else(|| syn::Error::new_spanned(path, "Expected validator name"))?;
209 let validator_str = validator_name.to_string();
210
211 match validator_str.as_str() {
212 "path" => Ok(generate_path_validation(field_name, false)),
213 "optional_path" => Ok(generate_path_validation(field_name, true)),
214 other => Err(syn::Error::new_spanned(
215 validator_name,
216 format!("Unknown validator without arguments: {}", other),
217 )),
218 }
219 }
220 _ => Err(syn::Error::new_spanned(meta, "Invalid validator format")),
221 }
222}
223
224fn generate_path_validation(field_name: &syn::Ident, optional: bool) -> proc_macro2::TokenStream {
225 if optional {
226 quote! {
227 if let Some(ref path) = self.#field_name {
228 self.validate_model_path(path)?;
229 }
230 }
231 } else {
232 quote! {
233 self.validate_model_path(&self.#field_name)?;
234 }
235 }
236}
237
238fn parse_one_arg(tokens: &proc_macro2::TokenStream) -> syn::Result<Expr> {
239 syn::parse2(tokens.clone())
240}
241
242fn parse_two_args(tokens: &proc_macro2::TokenStream) -> syn::Result<(Expr, Expr)> {
243 use syn::Token;
244 use syn::parse::Parser;
245 use syn::punctuated::Punctuated;
246
247 let parser = Punctuated::<Expr, Token![,]>::parse_terminated;
248 let args = parser.parse2(tokens.clone())?;
249 let mut iter = args.into_iter();
250
251 let first = iter
252 .next()
253 .ok_or_else(|| syn::Error::new_spanned(tokens, "Expected two arguments"))?;
254 let second = iter
255 .next()
256 .ok_or_else(|| syn::Error::new_spanned(tokens, "Expected two arguments"))?;
257
258 if iter.next().is_some() {
259 return Err(syn::Error::new_spanned(
260 tokens,
261 "Expected exactly two arguments",
262 ));
263 }
264
265 Ok((first, second))
266}
267
268#[proc_macro_derive(TaskPredictorBuilder, attributes(builder))]
291pub fn derive_task_predictor_builder(input: TokenStream) -> TokenStream {
292 let input = parse_macro_input!(input as DeriveInput);
293 impl_task_predictor_builder(&input)
294 .unwrap_or_else(|err| err.to_compile_error())
295 .into()
296}
297
298fn impl_task_predictor_builder(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
299 let name = &input.ident;
300
301 let config_type = find_builder_config_type(input)?;
303
304 verify_state_field(input)?;
306
307 Ok(quote! {
308 impl crate::predictors::builder::TaskPredictorBuilder for #name {
309 type Config = #config_type;
310
311 fn state_mut(
312 &mut self,
313 ) -> &mut crate::predictors::builder::PredictorBuilderState<Self::Config> {
314 &mut self.state
315 }
316 }
317
318 impl #name {
319 pub fn with_config(self, config: #config_type) -> Self {
321 <Self as crate::predictors::builder::TaskPredictorBuilder>::with_config(
322 self, config,
323 )
324 }
325
326 pub fn with_ort_config(self, config: crate::core::config::OrtSessionConfig) -> Self {
328 <Self as crate::predictors::builder::TaskPredictorBuilder>::with_ort_config(
329 self, config,
330 )
331 }
332 }
333 })
334}
335
336fn find_builder_config_type(input: &DeriveInput) -> syn::Result<Type> {
337 for attr in &input.attrs {
338 if !attr.path().is_ident("builder") {
339 continue;
340 }
341
342 let meta = attr.parse_args::<Meta>()?;
343
344 if let Meta::NameValue(nv) = meta
345 && nv.path.is_ident("config")
346 {
347 if let Expr::Path(expr_path) = nv.value {
348 return Ok(Type::Path(syn::TypePath {
349 qself: None,
350 path: expr_path.path,
351 }));
352 } else {
353 return Err(syn::Error::new_spanned(
354 nv.value,
355 "Expected a type path, e.g., #[builder(config = MyConfigType)]",
356 ));
357 }
358 }
359 }
360
361 Err(syn::Error::new_spanned(
362 input,
363 "Missing #[builder(config = ConfigType)] attribute",
364 ))
365}
366
367fn verify_state_field(input: &DeriveInput) -> syn::Result<()> {
368 let fields = match &input.data {
369 syn::Data::Struct(data) => match &data.fields {
370 syn::Fields::Named(fields) => &fields.named,
371 _ => {
372 return Err(syn::Error::new_spanned(
373 input,
374 "TaskPredictorBuilder can only be derived for structs with named fields",
375 ));
376 }
377 },
378 _ => {
379 return Err(syn::Error::new_spanned(
380 input,
381 "TaskPredictorBuilder can only be derived for structs",
382 ));
383 }
384 };
385
386 let state_field = fields
387 .iter()
388 .find(|f| f.ident.as_ref().is_some_and(|ident| ident == "state"));
389
390 let state_field = match state_field {
391 Some(field) => field,
392 None => {
393 return Err(syn::Error::new_spanned(
394 input,
395 "Struct must have a `state` field of type PredictorBuilderState<Config>",
396 ));
397 }
398 };
399
400 if !is_predictor_builder_state_type(&state_field.ty) {
402 return Err(syn::Error::new_spanned(
403 &state_field.ty,
404 "Field `state` must be of type PredictorBuilderState<Config>",
405 ));
406 }
407
408 Ok(())
409}
410
411fn is_predictor_builder_state_type(ty: &Type) -> bool {
412 let Type::Path(type_path) = ty else {
413 return false;
414 };
415
416 let last_segment = match type_path.path.segments.last() {
417 Some(seg) => seg,
418 None => return false,
419 };
420
421 if last_segment.ident != "PredictorBuilderState" {
422 return false;
423 }
424
425 matches!(
427 &last_segment.arguments,
428 syn::PathArguments::AngleBracketed(args) if args.args.len() == 1
429 )
430}