1use proc_macro::TokenStream;
64use quote::quote;
65use syn::{
66 DeriveInput, Expr, ExprArray, ExprLit, FnArg, GenericArgument, ImplItem, ItemImpl, Lit, LitStr,
67 PathArguments, ReturnType, Token, Type, TypePath,
68 parse::{Parse, ParseStream},
69 parse_macro_input,
70};
71
72#[proc_macro_derive(DerefMacro)]
85pub fn derive_deref(input: TokenStream) -> TokenStream {
86 let input = parse_macro_input!(input as DeriveInput);
87 let name = &input.ident;
88 if let Err(error) = reject_generics(&input, "DerefMacro") {
89 return error.to_compile_error().into();
90 }
91 if let Err(error) = require_tuple_wrapper(&input, "DerefMacro") {
92 return error.to_compile_error().into();
93 }
94
95 let expanded = quote! {
96 impl std::ops::Deref for #name {
97 type Target = <#name as DerefTarget>::Target; fn deref(&self) -> &Self::Target {
106 &self.0
107 }
108 }
109 };
110 TokenStream::from(expanded)
111}
112
113#[proc_macro_derive(DerefMutMacro)]
127pub fn derive_deref_mut(input: TokenStream) -> TokenStream {
128 let input = parse_macro_input!(input as DeriveInput);
129 let name = &input.ident;
130 if let Err(error) = reject_generics(&input, "DerefMutMacro") {
131 return error.to_compile_error().into();
132 }
133 if let Err(error) = require_tuple_wrapper(&input, "DerefMutMacro") {
134 return error.to_compile_error().into();
135 }
136
137 let expanded = quote! {
138 impl std::ops::DerefMut for #name {
139 fn deref_mut(&mut self) -> &mut Self::Target {
140 &mut self.0
141 }
142 }
143 };
144
145 TokenStream::from(expanded)
146}
147
148#[proc_macro_attribute]
149pub fn genja_task(args: TokenStream, input: TokenStream) -> TokenStream {
150 let args = parse_macro_input!(args as GenjaTaskArgs);
151 let item_impl = parse_macro_input!(input as ItemImpl);
152
153 match expand_genja_task(args, item_impl) {
154 Ok(tokens) => tokens.into(),
155 Err(error) => error.to_compile_error().into(),
156 }
157}
158
159#[derive(Default)]
160struct GenjaTaskArgs {
161 name: Option<LitStr>,
162 connection_plugin_name: Option<LitStr>,
163 processors: Vec<LitStr>,
164}
165
166impl Parse for GenjaTaskArgs {
167 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
168 let mut args = Self::default();
169
170 while !input.is_empty() {
171 let key: syn::Ident = input.parse()?;
172 input.parse::<Token![=]>()?;
173
174 match key.to_string().as_str() {
175 "name" => {
176 if args.name.is_some() {
177 return Err(syn::Error::new_spanned(key, "duplicate `name`"));
178 }
179 args.name = Some(input.parse()?);
180 }
181 "connection_plugin_name" => {
182 if args.connection_plugin_name.is_some() {
183 return Err(syn::Error::new_spanned(
184 key,
185 "duplicate `connection_plugin_name`",
186 ));
187 }
188 args.connection_plugin_name = Some(input.parse()?);
189 }
190 "processors" => {
191 if !args.processors.is_empty() {
192 return Err(syn::Error::new_spanned(key, "duplicate `processors`"));
193 }
194 let array: ExprArray = input.parse()?;
195 args.processors = parse_processor_exprs(&array)?;
196 }
197 _ => {
198 return Err(syn::Error::new_spanned(
199 key,
200 "unsupported key; expected `name`, `connection_plugin_name`, or `processors`",
201 ));
202 }
203 }
204
205 if input.is_empty() {
206 break;
207 }
208
209 input.parse::<Token![,]>()?;
210 }
211
212 if args.name.is_none() {
213 return Err(syn::Error::new(
214 proc_macro2::Span::call_site(),
215 "`name = \"...\"` is required",
216 ));
217 }
218
219 Ok(args)
220 }
221}
222
223fn expand_genja_task(
224 args: GenjaTaskArgs,
225 item_impl: ItemImpl,
226) -> syn::Result<proc_macro2::TokenStream> {
227 if item_impl.trait_.is_some() {
228 return Err(syn::Error::new_spanned(
229 &item_impl.self_ty,
230 "`#[genja_task(...)]` can only be applied to inherent impl blocks",
231 ));
232 }
233
234 if !item_impl.generics.params.is_empty() || item_impl.generics.where_clause.is_some() {
235 return Err(syn::Error::new_spanned(
236 &item_impl.generics,
237 "`genja_task` does not support generic parameters or where clauses",
238 ));
239 }
240
241 let self_ty = &item_impl.self_ty;
242 let mut has_start = false;
243 let mut has_start_async = false;
244 let mut has_options = false;
245 let mut has_sub_tasks = false;
246
247 for item in &item_impl.items {
248 let ImplItem::Fn(method) = item else {
249 continue;
250 };
251
252 match method.sig.ident.to_string().as_str() {
253 "start" => {
254 validate_start_method(method, false)?;
255 has_start = true;
256 }
257 "start_async" => {
258 validate_start_method(method, true)?;
259 has_start_async = true;
260 }
261 "options" => {
262 validate_options_method(method)?;
263 has_options = true;
264 }
265 "sub_tasks" => {
266 validate_sub_tasks_method(method)?;
267 has_sub_tasks = true;
268 }
269 _ => {}
270 }
271 }
272
273 if has_start == has_start_async {
274 return Err(syn::Error::new_spanned(
275 &item_impl.self_ty,
276 if has_start {
277 "define exactly one of `fn start(...)` or `async fn start_async(...)`"
278 } else {
279 "define one of `fn start(...)` or `async fn start_async(...)`"
280 },
281 ));
282 }
283
284 let name = args.name.expect("validated above");
285 let connection_plugin_name = args.connection_plugin_name;
286 let processors = args.processors;
287
288 let connection_impl = match connection_plugin_name {
289 Some(plugin_name) => quote! { Some(#plugin_name) },
290 None => quote! { None },
291 };
292
293 let options_impl = if has_options {
294 quote! {
295 fn options(&self) -> Option<&serde_json::Value> {
296 #self_ty::options(self)
297 }
298 }
299 } else {
300 quote! {}
301 };
302
303 let sub_tasks_impl = if has_sub_tasks {
304 quote! {
305 fn sub_tasks(&self) -> Vec<std::sync::Arc<dyn genja_core::task::Task>> {
306 #self_ty::sub_tasks(self)
307 }
308 }
309 } else {
310 quote! {}
311 };
312
313 let processor_names_impl = if processors.is_empty() {
314 quote! {}
315 } else {
316 quote! {
317 fn processor_names(&self) -> Vec<&str> {
318 vec![#(#processors),*]
319 }
320 }
321 };
322
323 let task_impl = if has_start {
324 quote! {
325 #[genja_core::async_trait]
326 impl genja_core::task::Task for #self_ty {
327 fn start(
328 &self,
329 host: &genja_core::inventory::Host,
330 context: &genja_core::task::BlockingTaskRuntimeContext,
331 ) -> Result<genja_core::task::HostTaskResult, genja_core::task::TaskError> {
332 #self_ty::start(self, host, context)
333 }
334
335 #sub_tasks_impl
336
337 fn execution_mode(&self) -> genja_core::task::TaskExecutionMode {
338 genja_core::task::TaskExecutionMode::Blocking
339 }
340 }
341 }
342 } else {
343 quote! {
344 #[genja_core::async_trait]
345 impl genja_core::task::Task for #self_ty {
346 async fn start_async(
347 &self,
348 host: &genja_core::inventory::Host,
349 context: &genja_core::task::TaskRuntimeContext,
350 ) -> Result<genja_core::task::HostTaskResult, genja_core::task::TaskError> {
351 #self_ty::start_async(self, host, context).await
352 }
353
354 #sub_tasks_impl
355
356 fn execution_mode(&self) -> genja_core::task::TaskExecutionMode {
357 genja_core::task::TaskExecutionMode::Async
358 }
359 }
360 }
361 };
362
363 Ok(quote! {
364 #item_impl
365
366 impl genja_core::task::TaskInfo for #self_ty {
367 fn name(&self) -> &str {
368 #name
369 }
370
371 fn connection_plugin_name(&self) -> Option<&str> {
372 #connection_impl
373 }
374
375 #options_impl
376
377 #processor_names_impl
378 }
379
380 #task_impl
381 })
382}
383
384fn reject_generics(input: &DeriveInput, macro_name: &str) -> syn::Result<()> {
385 if input.generics.params.is_empty() && input.generics.where_clause.is_none() {
386 return Ok(());
387 }
388
389 Err(syn::Error::new_spanned(
390 &input.generics,
391 format!("`{macro_name}` does not support generic parameters or where clauses"),
392 ))
393}
394
395fn parse_processor_exprs(array: &ExprArray) -> syn::Result<Vec<LitStr>> {
396 array
397 .elems
398 .iter()
399 .map(|expr| match expr {
400 Expr::Lit(ExprLit {
401 lit: Lit::Str(value),
402 ..
403 }) => Ok(value.clone()),
404 _ => Err(syn::Error::new_spanned(
405 expr,
406 "`processors` must be an array of string literals",
407 )),
408 })
409 .collect()
410}
411
412fn validate_start_method(method: &syn::ImplItemFn, is_async: bool) -> syn::Result<()> {
413 if method.sig.asyncness.is_some() != is_async {
414 let expected = if is_async {
415 "`start_async` must be declared as `async fn`"
416 } else {
417 "`start` must be declared as `fn`, not `async fn`"
418 };
419 return Err(syn::Error::new_spanned(&method.sig.ident, expected));
420 }
421
422 validate_shared_method_shape(method)?;
423
424 if method.sig.inputs.len() != 3 {
425 return Err(syn::Error::new_spanned(
426 &method.sig.inputs,
427 "task start methods must take `&self`, `host`, and `context`",
428 ));
429 }
430
431 let mut inputs = method.sig.inputs.iter();
432 validate_receiver(inputs.next().unwrap())?;
433 validate_typed_arg(
434 inputs.next().unwrap(),
435 is_host_ref,
436 "`host` must be `&Host`",
437 )?;
438 validate_typed_arg(
439 inputs.next().unwrap(),
440 if is_async {
441 is_async_context_ref
442 } else {
443 is_blocking_context_ref
444 },
445 if is_async {
446 "`context` must be `&TaskRuntimeContext`"
447 } else {
448 "`context` must be `&BlockingTaskRuntimeContext`"
449 },
450 )?;
451
452 validate_return_type(
453 &method.sig.output,
454 is_result_host_task_error,
455 if is_async {
456 "`start_async` must return `Result<HostTaskResult, TaskError>`"
457 } else {
458 "`start` must return `Result<HostTaskResult, TaskError>`"
459 },
460 )
461}
462
463fn validate_options_method(method: &syn::ImplItemFn) -> syn::Result<()> {
464 if method.sig.asyncness.is_some() {
465 return Err(syn::Error::new_spanned(
466 &method.sig.ident,
467 "`options` must not be async",
468 ));
469 }
470
471 validate_shared_method_shape(method)?;
472
473 if method.sig.inputs.len() != 1 {
474 return Err(syn::Error::new_spanned(
475 &method.sig.inputs,
476 "`options` must take only `&self`",
477 ));
478 }
479
480 validate_receiver(method.sig.inputs.first().unwrap())?;
481 validate_return_type(
482 &method.sig.output,
483 is_option_value_ref,
484 "`options` must return `Option<&serde_json::Value>`",
485 )
486}
487
488fn validate_sub_tasks_method(method: &syn::ImplItemFn) -> syn::Result<()> {
489 if method.sig.asyncness.is_some() {
490 return Err(syn::Error::new_spanned(
491 &method.sig.ident,
492 "`sub_tasks` must not be async",
493 ));
494 }
495
496 validate_shared_method_shape(method)?;
497
498 if method.sig.inputs.len() != 1 {
499 return Err(syn::Error::new_spanned(
500 &method.sig.inputs,
501 "`sub_tasks` must take only `&self`",
502 ));
503 }
504
505 validate_receiver(method.sig.inputs.first().unwrap())?;
506 validate_return_type(
507 &method.sig.output,
508 is_vec_of_task_arcs,
509 "`sub_tasks` must return `Vec<Arc<dyn Task>>`",
510 )
511}
512
513fn validate_shared_method_shape(method: &syn::ImplItemFn) -> syn::Result<()> {
514 if method.sig.constness.is_some()
515 || method.sig.unsafety.is_some()
516 || method.sig.abi.is_some()
517 || method.sig.variadic.is_some()
518 || !method.sig.generics.params.is_empty()
519 || method.sig.generics.where_clause.is_some()
520 {
521 return Err(syn::Error::new_spanned(
522 &method.sig,
523 "Genja task hook methods cannot be const, unsafe, generic, extern, or variadic",
524 ));
525 }
526
527 Ok(())
528}
529
530fn validate_receiver(arg: &FnArg) -> syn::Result<()> {
531 match arg {
532 FnArg::Receiver(receiver)
533 if receiver.reference.is_some() && receiver.mutability.is_none() =>
534 {
535 Ok(())
536 }
537 _ => Err(syn::Error::new_spanned(
538 arg,
539 "first argument must be `&self`",
540 )),
541 }
542}
543
544fn validate_typed_arg(arg: &FnArg, predicate: fn(&Type) -> bool, message: &str) -> syn::Result<()> {
545 match arg {
546 FnArg::Typed(typed) if predicate(&typed.ty) => Ok(()),
547 FnArg::Typed(typed) => Err(syn::Error::new_spanned(&typed.ty, message)),
548 FnArg::Receiver(_) => Err(syn::Error::new_spanned(arg, message)),
549 }
550}
551
552fn validate_return_type(
553 output: &ReturnType,
554 predicate: fn(&Type) -> bool,
555 message: &str,
556) -> syn::Result<()> {
557 match output {
558 ReturnType::Type(_, ty) if predicate(ty) => Ok(()),
559 ReturnType::Type(_, ty) => Err(syn::Error::new_spanned(ty, message)),
560 ReturnType::Default => Err(syn::Error::new(proc_macro2::Span::call_site(), message)),
561 }
562}
563
564fn is_result_host_task_error(ty: &Type) -> bool {
565 let Type::Path(TypePath { path, .. }) = ty else {
566 return false;
567 };
568 let Some(seg) = path.segments.last() else {
569 return false;
570 };
571 if seg.ident != "Result" {
572 return false;
573 }
574 let PathArguments::AngleBracketed(args) = &seg.arguments else {
575 return false;
576 };
577 if args.args.len() != 2 {
578 return false;
579 }
580
581 let mut args_iter = args.args.iter();
582 let ok = match args_iter.next() {
583 Some(GenericArgument::Type(ty)) => type_ends_with(ty, "HostTaskResult"),
584 _ => false,
585 };
586 let err = match args_iter.next() {
587 Some(GenericArgument::Type(ty)) => type_ends_with(ty, "TaskError"),
588 _ => false,
589 };
590 ok && err
591}
592
593fn is_option_value_ref(ty: &Type) -> bool {
594 let Type::Path(TypePath { path, .. }) = ty else {
595 return false;
596 };
597 let Some(seg) = path.segments.last() else {
598 return false;
599 };
600 if seg.ident != "Option" {
601 return false;
602 }
603 let PathArguments::AngleBracketed(args) = &seg.arguments else {
604 return false;
605 };
606 if args.args.len() != 1 {
607 return false;
608 }
609 match args.args.first() {
610 Some(GenericArgument::Type(Type::Reference(reference))) => {
611 type_ends_with(&reference.elem, "Value")
612 }
613 _ => false,
614 }
615}
616
617fn is_vec_of_task_arcs(ty: &Type) -> bool {
618 let Type::Path(TypePath { path, .. }) = ty else {
619 return false;
620 };
621 let Some(seg) = path.segments.last() else {
622 return false;
623 };
624 if seg.ident != "Vec" {
625 return false;
626 }
627 let PathArguments::AngleBracketed(args) = &seg.arguments else {
628 return false;
629 };
630 if args.args.len() != 1 {
631 return false;
632 }
633 match args.args.first() {
634 Some(GenericArgument::Type(inner)) => is_arc_task(inner),
635 _ => false,
636 }
637}
638
639fn is_arc_task(ty: &Type) -> bool {
640 match ty {
641 Type::Path(TypePath { path, .. }) => {
642 let Some(seg) = path.segments.last() else {
643 return false;
644 };
645 if seg.ident != "Arc" {
646 return false;
647 }
648 match &seg.arguments {
649 PathArguments::AngleBracketed(args) => args
650 .args
651 .iter()
652 .filter_map(|arg| match arg {
653 GenericArgument::Type(ty) => Some(ty),
654 _ => None,
655 })
656 .any(is_task_trait_object),
657 _ => false,
658 }
659 }
660 _ => false,
661 }
662}
663
664fn is_task_trait_object(ty: &Type) -> bool {
665 match ty {
666 Type::TraitObject(obj) => obj.bounds.iter().any(|bound| match bound {
667 syn::TypeParamBound::Trait(trait_bound) => trait_bound
668 .path
669 .segments
670 .last()
671 .map(|seg| seg.ident == "Task")
672 .unwrap_or(false),
673 _ => false,
674 }),
675 _ => false,
676 }
677}
678
679fn is_host_ref(ty: &Type) -> bool {
680 matches!(ty, Type::Reference(reference) if type_ends_with(&reference.elem, "Host"))
681}
682
683fn is_async_context_ref(ty: &Type) -> bool {
684 matches!(ty, Type::Reference(reference) if type_ends_with(&reference.elem, "TaskRuntimeContext"))
685}
686
687fn is_blocking_context_ref(ty: &Type) -> bool {
688 matches!(ty, Type::Reference(reference) if type_ends_with(&reference.elem, "BlockingTaskRuntimeContext"))
689}
690
691fn type_ends_with(ty: &Type, ident: &str) -> bool {
692 match ty {
693 Type::Path(TypePath { path, .. }) => path
694 .segments
695 .last()
696 .map(|segment| segment.ident == ident)
697 .unwrap_or(false),
698 _ => false,
699 }
700}
701
702fn require_tuple_wrapper(input: &DeriveInput, macro_name: &str) -> syn::Result<()> {
703 match &input.data {
704 syn::Data::Struct(data) => match &data.fields {
705 syn::Fields::Unnamed(fields) if !fields.unnamed.is_empty() => Ok(()),
706 _ => Err(syn::Error::new_spanned(
707 &input.ident,
708 format!("`{macro_name}` requires a tuple struct with the wrapped value in field 0"),
709 )),
710 },
711 _ => Err(syn::Error::new_spanned(
712 &input.ident,
713 format!("`{macro_name}` can only be derived for tuple structs"),
714 )),
715 }
716}