1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Attribute, Data, DeriveInput, Fields, Lit, Meta, parse_macro_input};
4
5#[proc_macro_derive(Task, attributes(task))]
23pub fn derive_task(input: TokenStream) -> TokenStream {
24 let input = parse_macro_input!(input as DeriveInput);
25 let name = &input.ident;
26
27 let (description, retry) = match parse_task_attributes(&input.attrs) {
28 Ok(attrs) => attrs,
29 Err(e) => panic!("Failed to parse task attributes: {}", e),
30 };
31
32 let expanded = quote! {
33 #[async_trait::async_trait]
34 impl axum_tasks::TaskHandler for #name {
35 async fn handle(&self, app_tasks: &axum_tasks::AppTasks, job_id: &str) -> axum_tasks::TaskResult {
36 let output = self.execute().await;
38
39 match output {
40 axum_tasks::TaskOutput::Success(data) => {
41 app_tasks.store_success(
43 job_id.to_string(),
44 data,
45 Some(std::time::Duration::from_secs(3600)) ).await;
47 axum_tasks::TaskResult::Success
48 }
49 axum_tasks::TaskOutput::RetryableError(error) => {
50 app_tasks.store_failure(
52 job_id.to_string(),
53 error.clone(),
54 Some(std::time::Duration::from_secs(3600))
55 ).await;
56 axum_tasks::TaskResult::RetryableError(error)
57 }
58 axum_tasks::TaskOutput::PermanentError(error) => {
59 app_tasks.store_failure(
61 job_id.to_string(),
62 error.clone(),
63 Some(std::time::Duration::from_secs(3600))
64 ).await;
65 axum_tasks::TaskResult::PermanentError(error)
66 }
67 }
68 }
69
70 fn description(&self) -> String {
71 #description.to_string()
72 }
73
74 fn is_retryable(&self, _error: &str) -> bool {
75 #retry
76 }
77 }
78
79 ::axum_tasks::inventory::submit! {
81 axum_tasks::TaskRegistration {
82 name: stringify!(#name),
83 handler: |task_data: &[u8], app_tasks: &axum_tasks::AppTasks, job_id: &str| {
84 let task_data = task_data.to_vec();
85 let app_tasks = app_tasks.clone();
86 let job_id = job_id.to_string();
87 Box::pin(async move {
88 let task: #name = serde_json::from_slice(&task_data)
89 .map_err(|e| axum_tasks::TaskResult::PermanentError(
90 format!("Deserialization failed: {}", e)
91 ))?;
92 Ok(task.handle(&app_tasks, &job_id).await)
93 })
94 }
95 }
96 }
97 };
98
99 TokenStream::from(expanded)
100}
101
102#[proc_macro_derive(HasTasks)]
106pub fn derive_has_tasks(input: TokenStream) -> TokenStream {
107 let input = parse_macro_input!(input as DeriveInput);
108 let name = &input.ident;
109
110 let tasks_field = find_tasks_field(&input.data);
111
112 let expanded = quote! {
113 impl axum_tasks::HasTasks for #name {
114 fn tasks(&self) -> &axum_tasks::AppTasks {
115 &self.#tasks_field
116 }
117
118 fn tasks_mut(&mut self) -> &mut axum_tasks::AppTasks {
119 &mut self.#tasks_field
120 }
121 }
122 };
123
124 TokenStream::from(expanded)
125}
126
127fn parse_task_attributes(attrs: &[Attribute]) -> Result<(String, bool), String> {
128 let mut description = "Processing task".to_string();
129 let mut retry = true;
130
131 for attr in attrs {
132 if attr.path().is_ident("task") {
133 match &attr.meta {
134 Meta::List(_meta_list) => {
135 attr.parse_nested_meta(|nested| {
136 if nested.path.is_ident("description") {
137 let value = nested.value()?;
138 if let Ok(lit_str) = value.parse::<Lit>() {
139 if let Lit::Str(s) = lit_str {
140 description = s.value();
141 }
142 }
143 } else if nested.path.is_ident("retry") {
144 let value = nested.value()?;
145 if let Ok(lit_bool) = value.parse::<Lit>() {
146 if let Lit::Bool(b) = lit_bool {
147 retry = b.value();
148 }
149 }
150 }
151 Ok(())
152 })
153 .map_err(|e| format!("Error parsing nested meta: {}", e))?;
154 }
155 Meta::Path(_) => {
156 }
158 Meta::NameValue(_) => {
159 return Err(
160 "task attribute should be a list: #[task(description = \"...\")]"
161 .to_string(),
162 );
163 }
164 }
165 }
166 }
167
168 Ok((description, retry))
169}
170
171fn find_tasks_field(data: &Data) -> proc_macro2::TokenStream {
172 match data {
173 Data::Struct(data_struct) => {
174 match &data_struct.fields {
175 Fields::Named(fields) => {
176 for field in &fields.named {
178 if let Some(ident) = &field.ident {
179 if ident == "tasks" {
180 return quote! { #ident };
181 }
182 }
183 }
184
185 for field in &fields.named {
186 if let Some(ident) = &field.ident {
187 let type_str = quote! { #(&field.ty) }.to_string();
188 if type_str.contains("AppTasks") {
189 return quote! { #ident };
190 }
191 }
192 }
193
194 panic!("No field named 'tasks' or of type 'AppTasks' found");
195 }
196 _ => panic!("HasTasks can only be derived for structs with named fields"),
197 }
198 }
199 _ => panic!("HasTasks can only be derived for structs"),
200 }
201}