1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use std::str::FromStr;
4use syn::{
5 parse_macro_input, Data, DeriveInput, Expr, Field, Fields, GenericParam, Lit, Meta, Path,
6 PathArguments, TypeParam,
7};
8
9#[proc_macro_derive(Block, attributes(port, work, qsdr_crate))]
10pub fn block_derive(input: TokenStream) -> TokenStream {
11 let ast = parse_macro_input!(input as DeriveInput);
12 let qsdr = qsdr_crate(&ast);
14 let vis = &ast.vis;
15
16 let work = work_type(&ast);
17
18 let block_ident = &ast.ident;
19 let block_generics = struct_generic_types(&ast);
20 let block_where = &ast.generics.where_clause;
21
22 let Data::Struct(data) = &ast.data else {
23 panic!("derive(Block) only works for struct");
24 };
25 let Fields::Named(fields) = &data.fields else {
26 panic!("struct fields should be be named fields");
27 };
28 let ports = fields
29 .named
30 .iter()
31 .filter(|field| field_is_port(field))
32 .collect::<Vec<_>>();
33
34 let work_impl = match work {
35 WorkType::WorkInPlace => {
36 check_required_ports(&ports, &["input", "output"], "WorkInPlace");
37 quote! {
38 async fn block_work(&mut self, channels: &mut Self::Channels) -> Result<#qsdr::BlockWorkStatus> {
39 use #qsdr::{Receiver, Sender};
40 let Some(mut item) = channels.input.recv().await else {
41 return Ok(#qsdr::BlockWorkStatus::Done);
42 };
43 let status = self.work_in_place(&mut item).await?;
44 if status.produces_output() {
45 channels.output.send(item);
46 }
47 Ok(status.into())
48 }
49 }
50 }
51 WorkType::WorkSink => {
52 check_required_ports(&ports, &["input"], "WorkSink");
53 quote! {
54 async fn block_work(&mut self, channels: &mut Self::Channels) -> Result<#qsdr::BlockWorkStatus> {
55 use #qsdr::{Receiver, RefReceiver, Sender};
56 use ::std::borrow::Borrow;
57 let Some(item) = channels.input.ref_recv().await else {
58 return Ok(#qsdr::BlockWorkStatus::Done);
59 };
60 self.work_sink(item.borrow()).await
61 }
62 }
63 }
64 WorkType::WorkWithRef => {
65 check_required_ports(&ports, &["input", "source", "output"], "WorkWithRef");
66 quote! {
67 async fn block_work(&mut self, channels: &mut Self::Channels) -> Result<#qsdr::BlockWorkStatus> {
68 use #qsdr::{Receiver, RefReceiver, Sender};
69 use ::std::borrow::Borrow;
70 let Some(mut output_item) = channels.source.recv().await else {
71 return Ok(#qsdr::BlockWorkStatus::Done);
72 };
73 let Some(input_item) = channels.input.ref_recv().await else {
74 return Ok(#qsdr::BlockWorkStatus::Done);
75 };
76 let status = self.work_with_ref(input_item.borrow(), &mut output_item).await?;
77 drop(input_item);
79 if status.produces_output() {
80 channels.output.send(output_item);
81 }
82 Ok(status.into())
83 }
84 }
85 }
86 WorkType::WorkCustom => {
87 quote! {
88 fn block_work(&mut self, channels: &mut Self::Channels)
89 -> impl ::std::future::Future<Output = Result<#qsdr::BlockWorkStatus>> {
90 #qsdr::WorkCustom::work_custom(self, channels)
91 }
92 }
93 }
94 };
95
96 let mut channels = Vec::new();
97 let mut channel_idents = Vec::new();
98 let mut seeds = Vec::new();
99 let mut seeds_defaults = Vec::new();
100 let mut port_ids = Vec::new();
101 for (port_id, port) in ports.iter().enumerate() {
102 let ident = port.ident.as_ref().expect("port should have ident");
103 channel_idents.push(ident);
104 let ty = &port.ty;
105 channels.push(quote! {
106 #ident: <#ty as #qsdr::__private::Port>::Channel
107 });
108 seeds.push(quote! {
109 #ident: ::std::cell::RefCell<<#ty as #qsdr::__private::Port>::Seed>
110 });
111 seeds_defaults.push(quote! {
112 #ident: ::std::cell::RefCell::new(Default::default())
113 });
114 let port_id = u32::try_from(port_id).unwrap();
115 port_ids.push(quote! {
116 #vis fn #ident(&self) -> #qsdr::ports::Endpoint<'_, #ty> {
117 let _ = &self.as_ref().#ident;
121 let port = #qsdr::__private::PortId::from(#port_id);
122 let seed = self.seeds.#ident.borrow_mut();
123 #qsdr::ports::Endpoint::new(self.flowgraph_id, self.node_id, port, seed)
124 }
125 });
126 }
127
128 let block_channels_ident = format_ident!("__{block_ident}BlockChannels");
129 let block_seeds_ident = format_ident!("__{block_ident}BlockSeeds");
130 let block_generic_types = block_generics.iter().map(|ty| &ty.ident);
131 let block_generic_types = quote! {
132 #(#block_generic_types),*
133 };
134 let block_generic_list = quote! {
135 #(#block_generics),*
136 };
137
138 let block_channels = quote! {
139 #qsdr::__private::pin_project_lite::pin_project! {
140 #vis struct #block_channels_ident<#block_generic_list>
141 #block_where
142 {
143 #(
144 #[pin]
145 #channels
146 ),*,
147 __qsdr__phantom: ::std::marker::PhantomData<(#block_generic_types)>,
148 }
149 }
150
151 impl<#block_generic_list> TryFrom<#block_seeds_ident<#block_generic_types>>
152 for #block_channels_ident<#block_generic_types>
153 #block_where
154 {
155 type Error = anyhow::Error;
156
157 fn try_from(value: #block_seeds_ident<#block_generic_types>) -> anyhow::Result<Self> {
158 Ok(Self {
159 #(#channel_idents: value.#channel_idents.into_inner().try_into()?),*,
160 __qsdr__phantom: ::std::marker::PhantomData,
161 })
162 }
163 }
164
165 impl<#block_generic_list> ::std::fmt::Debug for #block_channels_ident<#block_generic_types>
166 #block_where
167 {
168 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> Result<(), ::std::fmt::Error> {
169 f.debug_struct("BlockChannels")
170 #(
171 .field(stringify!(#channel_idents), &std::any::type_name_of_val(&self.#channel_idents))
172 )*
173 .field("__qsdr__phantom", &self.__qsdr__phantom)
174 .finish()
175 }
176 }
177 };
178
179 let block_seeds = quote! {
180 #vis struct #block_seeds_ident<#block_generic_list>
181 #block_where
182 {
183 #(#seeds),*,
184 __qsdr__phantom: ::std::marker::PhantomData<(#block_generic_types)>,
185 }
186
187 impl<#block_generic_list> Default for #block_seeds_ident<#block_generic_types>
188 #block_where
189 {
190 fn default() -> Self {
191 Self {
192 #(#channel_idents: Default::default()),*,
193 __qsdr__phantom: Default::default(),
194 }
195 }
196 }
197
198 impl<#block_generic_list> ::std::fmt::Debug for #block_seeds_ident<#block_generic_types>
199 #block_where
200 {
201 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> Result<(), ::std::fmt::Error> {
202 f.debug_struct("BlockSeeds")
203 #(
204 .field(stringify!(#channel_idents), &std::any::type_name_of_val(&self.#channel_idents))
205 )*
206 .field("__qsdr__phantom", &self.__qsdr__phantom)
207 .finish()
208 }
209 }
210 };
211
212 let flowgraph_node_ident = format_ident!("__{block_ident}FlowgraphNode");
213 let flowgraph_node = quote! {
214 #[derive(Debug)]
215 #vis struct #flowgraph_node_ident<#block_generic_list>
216 #block_where
217 {
218 flowgraph_id: #qsdr::__private::FlowgraphId,
219 node_id: #qsdr::__private::NodeId,
220 block: #block_ident<#block_generic_types>,
221 seeds: #block_seeds_ident<#block_generic_types>,
222 }
223
224 impl<#block_generic_list> #qsdr::__private::FlowgraphNode for #flowgraph_node_ident<#block_generic_types>
225 #block_where
226 {
227 type B = #block_ident<#block_generic_types>;
228
229 fn flowgraph_id(&self) -> #qsdr::__private::FlowgraphId {
230 self.flowgraph_id
231 }
232
233 fn node_id(&self) -> #qsdr::__private::NodeId {
234 self.node_id
235 }
236
237 fn wrap_block(flowgraph_id: #qsdr::__private::FlowgraphId,
238 node_id: #qsdr::__private::NodeId, block: Self::B) -> Self {
239 Self { flowgraph_id, node_id, block, seeds: Default::default() }
240 }
241
242 fn try_into_object(self, _fg: &mut #qsdr::ValidatedFlowgraph) ->
243 Result<#qsdr::BlockObject<#block_ident<#block_generic_types>>, anyhow::Error> {
244 Ok(#qsdr::BlockObject::new(self.block, self.seeds.try_into()?))
245 }
246 }
247
248 impl<#block_generic_list> ::std::convert::AsRef<#block_ident<#block_generic_types>>
249 for #flowgraph_node_ident<#block_generic_types>
250 #block_where
251 {
252 fn as_ref(&self) -> &#block_ident<#block_generic_types> {
253 &self.block
254 }
255 }
256
257 impl<#block_generic_list> ::std::convert::AsMut<#block_ident<#block_generic_types>>
258 for #flowgraph_node_ident<#block_generic_types>
259 #block_where
260 {
261 fn as_mut(&mut self) -> &mut #block_ident<#block_generic_types> {
262 &mut self.block
263 }
264 }
265 };
266
267 let block_impl = quote! {
268 impl<#block_generic_list> #qsdr::Block for #block_ident<#block_generic_types>
269 #block_where
270 {
271 type Channels = #block_channels_ident<#block_generic_types>;
272
273 type Seeds = #block_seeds_ident<#block_generic_types>;
274
275 type Node = #flowgraph_node_ident<#block_generic_types>;
276
277 #work_impl
278 }
279 };
280
281 let ports_impl = quote! {
282 impl<#block_generic_list> #flowgraph_node_ident<#block_generic_types>
283 #block_where
284 {
285 #(#port_ids)*
286 }
287 };
288
289 let gen = quote! {
290 const _: () = {
291 #block_channels
292 #block_seeds
293 #flowgraph_node
294 #block_impl
295 #ports_impl
296 };
297 };
298 gen.into()
300}
301
302#[allow(dead_code)]
304fn pretty_print(ts: &proc_macro2::TokenStream) -> String {
305 let file = syn::parse_file(&ts.to_string()).unwrap();
306 prettyplease::unparse(&file)
307}
308
309#[derive(Debug, Copy, Clone, Eq, PartialEq)]
310#[allow(clippy::enum_variant_names)]
311enum WorkType {
312 WorkInPlace,
313 WorkSink,
314 WorkWithRef,
315 WorkCustom,
316}
317
318impl FromStr for WorkType {
319 type Err = String;
320 fn from_str(s: &str) -> Result<WorkType, String> {
321 Ok(match s {
322 "WorkInPlace" => WorkType::WorkInPlace,
323 "WorkSink" => WorkType::WorkSink,
324 "WorkWithRef" => WorkType::WorkWithRef,
325 "WorkCustom" => WorkType::WorkCustom,
326 _ => return Err(format!("invalid work type: {s}")),
327 })
328 }
329}
330
331fn qsdr_crate(ast: &DeriveInput) -> proc_macro2::TokenStream {
332 let qsdr_crate_attrs = ast
333 .attrs
334 .iter()
335 .filter_map(|attr| {
336 let Meta::NameValue(name_value) = &attr.meta else {
337 return None;
338 };
339 let segments = &name_value.path.segments;
340 if segments.len() != 1 {
341 return None;
342 }
343 let segment = segments.first().unwrap();
344 if segment.ident == "qsdr_crate" && matches!(segment.arguments, PathArguments::None) {
345 let Expr::Lit(lit) = &name_value.value else {
346 panic!("qsdr_crate value is not a literal");
347 };
348 let Lit::Str(s) = &lit.lit else {
349 panic!("qsdr_crate value is not a string literal");
350 };
351 Some(s.parse().unwrap())
352 } else {
353 None
354 }
355 })
356 .collect::<Vec<_>>();
357 if qsdr_crate_attrs.is_empty() {
358 return "::qsdr".parse().unwrap();
359 }
360 if qsdr_crate_attrs.len() > 1 {
361 panic!("qsdr_crate attribute present multiple times");
362 }
363 qsdr_crate_attrs.into_iter().next().unwrap()
364}
365
366fn work_type(ast: &DeriveInput) -> WorkType {
367 let work_attrs = ast
368 .attrs
369 .iter()
370 .filter_map(|attr| {
371 let Meta::List(list) = &attr.meta else {
372 return None;
373 };
374 let segments = &list.path.segments;
375 if segments.len() != 1 {
376 return None;
377 }
378 let segment = segments.first().unwrap();
379 if segment.ident == "work" && matches!(segment.arguments, PathArguments::None) {
380 Some(&list.tokens)
381 } else {
382 None
383 }
384 })
385 .collect::<Vec<_>>();
386 if work_attrs.is_empty() {
387 panic!("work attribute missing");
388 }
389 if work_attrs.len() > 1 {
390 panic!("work attribute present multiple times");
391 }
392 let attr = work_attrs[0].clone().into_iter().collect::<Vec<_>>();
393 if attr.len() != 1 {
394 panic!("work attribute does not have a single argument");
395 }
396 let proc_macro2::TokenTree::Ident(ident) = &attr[0] else {
397 panic!("work attribute is not an ident");
398 };
399 match ident.to_string().parse() {
400 Ok(w) => w,
401 Err(err) => panic!("{}", err),
402 }
403}
404
405fn struct_generic_types(ast: &DeriveInput) -> Vec<TypeParam> {
406 ast.generics
407 .params
408 .iter()
409 .filter_map(|param| {
410 if let GenericParam::Type(ty) = param {
411 let mut ty = ty.clone();
412 ty.default = None;
415 Some(ty)
416 } else {
417 None
418 }
419 })
420 .collect()
421}
422
423fn field_is_port(field: &Field) -> bool {
424 field.attrs.iter().any(|attr| match &attr.meta {
425 Meta::Path(Path { segments, .. }) => {
426 if segments.len() != 1 {
427 return false;
428 }
429 let segment = segments.first().unwrap();
430 segment.ident == "port" && matches!(segment.arguments, PathArguments::None)
431 }
432 _ => false,
433 })
434}
435
436fn has_port_with_name(ports: &[&Field], name: &str) -> bool {
437 ports.iter().any(|field| {
438 if let Some(ident) = &field.ident {
439 ident == name
440 } else {
441 false
442 }
443 })
444}
445
446fn check_required_ports(ports: &[&Field], required: &[&str], work_name: &str) {
447 for req in required {
448 if !has_port_with_name(ports, req) {
449 panic!("{} requires a port called {}", work_name, req);
450 }
451 }
452}