axum_tasks_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Attribute, Data, DeriveInput, Fields, Lit, Meta, parse_macro_input};
4
5/// Derive macro that automatically implements TaskHandler
6///
7/// Usage:
8/// ```rust
9/// #[derive(Task)]
10/// #[task(description = "Analyzing telemetry data", retry = true)]
11/// pub struct TelemetryAnalysis {
12///     pub session_id: String,
13///     pub file_path: String,
14/// }
15///
16/// impl TelemetryAnalysis {
17///     async fn execute(&self) -> TaskResult {
18///         // Your implementation here
19///     }
20/// }
21/// ```
22#[proc_macro_derive(Task, attributes(task))]
23pub fn derive_task(input: TokenStream) -> TokenStream {
24    let input = parse_macro_input!(input as DeriveInput);
25    let name = &input.ident;
26
27    let (description, retry) = match parse_task_attributes(&input.attrs) {
28        Ok(attrs) => attrs,
29        Err(e) => panic!("Failed to parse task attributes: {}", e),
30    };
31
32    let expanded = quote! {
33        #[async_trait::async_trait]
34        impl axum_tasks::TaskHandler for #name {
35            async fn handle(&self, app_tasks: &axum_tasks::AppTasks, job_id: &str) -> axum_tasks::TaskResult {
36                // Call execute and handle the output
37                let output = self.execute().await;
38
39                match output {
40                    axum_tasks::TaskOutput::Success(data) => {
41                        // Store the successful result
42                        app_tasks.store_success(
43                            job_id.to_string(),
44                            data,
45                            Some(std::time::Duration::from_secs(3600))  // 1 hour TTL
46                        ).await;
47                        axum_tasks::TaskResult::Success
48                    }
49                    axum_tasks::TaskOutput::RetryableError(error) => {
50                        // Store the failure
51                        app_tasks.store_failure(
52                            job_id.to_string(),
53                            error.clone(),
54                            Some(std::time::Duration::from_secs(3600))
55                        ).await;
56                        axum_tasks::TaskResult::RetryableError(error)
57                    }
58                    axum_tasks::TaskOutput::PermanentError(error) => {
59                        // Store the failure
60                        app_tasks.store_failure(
61                            job_id.to_string(),
62                            error.clone(),
63                            Some(std::time::Duration::from_secs(3600))
64                        ).await;
65                        axum_tasks::TaskResult::PermanentError(error)
66                    }
67                }
68            }
69
70            fn description(&self) -> String {
71                #description.to_string()
72            }
73
74            fn is_retryable(&self, _error: &str) -> bool {
75                #retry
76            }
77        }
78
79        // Auto-register with task registry
80        ::axum_tasks::inventory::submit! {
81            axum_tasks::TaskRegistration {
82                name: stringify!(#name),
83                handler: |task_data: &[u8], app_tasks: &axum_tasks::AppTasks, job_id: &str| {
84                    let task_data = task_data.to_vec();
85                    let app_tasks = app_tasks.clone();
86                    let job_id = job_id.to_string();
87                    Box::pin(async move {
88                        let task: #name = serde_json::from_slice(&task_data)
89                            .map_err(|e| axum_tasks::TaskResult::PermanentError(
90                                format!("Deserialization failed: {}", e)
91                            ))?;
92                        Ok(task.handle(&app_tasks, &job_id).await)
93                    })
94                }
95            }
96        }
97    };
98
99    TokenStream::from(expanded)
100}
101
102/// Derive macro that automatically implements HasTasks trait
103///
104/// Looks for a field named `tasks` or of type `AppTasks`
105#[proc_macro_derive(HasTasks)]
106pub fn derive_has_tasks(input: TokenStream) -> TokenStream {
107    let input = parse_macro_input!(input as DeriveInput);
108    let name = &input.ident;
109
110    let tasks_field = find_tasks_field(&input.data);
111
112    let expanded = quote! {
113        impl axum_tasks::HasTasks for #name {
114            fn tasks(&self) -> &axum_tasks::AppTasks {
115                &self.#tasks_field
116            }
117
118            fn tasks_mut(&mut self) -> &mut axum_tasks::AppTasks {
119                &mut self.#tasks_field
120            }
121        }
122    };
123
124    TokenStream::from(expanded)
125}
126
127fn parse_task_attributes(attrs: &[Attribute]) -> Result<(String, bool), String> {
128    let mut description = "Processing task".to_string();
129    let mut retry = true;
130
131    for attr in attrs {
132        if attr.path().is_ident("task") {
133            match &attr.meta {
134                Meta::List(_meta_list) => {
135                    attr.parse_nested_meta(|nested| {
136                        if nested.path.is_ident("description") {
137                            let value = nested.value()?;
138                            if let Ok(lit_str) = value.parse::<Lit>() {
139                                if let Lit::Str(s) = lit_str {
140                                    description = s.value();
141                                }
142                            }
143                        } else if nested.path.is_ident("retry") {
144                            let value = nested.value()?;
145                            if let Ok(lit_bool) = value.parse::<Lit>() {
146                                if let Lit::Bool(b) = lit_bool {
147                                    retry = b.value();
148                                }
149                            }
150                        }
151                        Ok(())
152                    })
153                    .map_err(|e| format!("Error parsing nested meta: {}", e))?;
154                }
155                Meta::Path(_) => {
156                    // #[task] with no arguments - use defaults
157                }
158                Meta::NameValue(_) => {
159                    return Err(
160                        "task attribute should be a list: #[task(description = \"...\")]"
161                            .to_string(),
162                    );
163                }
164            }
165        }
166    }
167
168    Ok((description, retry))
169}
170
171fn find_tasks_field(data: &Data) -> proc_macro2::TokenStream {
172    match data {
173        Data::Struct(data_struct) => {
174            match &data_struct.fields {
175                Fields::Named(fields) => {
176                    // TODO: type checks
177                    for field in &fields.named {
178                        if let Some(ident) = &field.ident {
179                            if ident == "tasks" {
180                                return quote! { #ident };
181                            }
182                        }
183                    }
184
185                    for field in &fields.named {
186                        if let Some(ident) = &field.ident {
187                            let type_str = quote! { #(&field.ty) }.to_string();
188                            if type_str.contains("AppTasks") {
189                                return quote! { #ident };
190                            }
191                        }
192                    }
193
194                    panic!("No field named 'tasks' or of type 'AppTasks' found");
195                }
196                _ => panic!("HasTasks can only be derived for structs with named fields"),
197            }
198        }
199        _ => panic!("HasTasks can only be derived for structs"),
200    }
201}