qsdr_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use std::str::FromStr;
4use syn::{
5    parse_macro_input, Data, DeriveInput, Expr, Field, Fields, GenericParam, Lit, Meta, Path,
6    PathArguments, TypeParam,
7};
8
9#[proc_macro_derive(Block, attributes(port, work, qsdr_crate))]
10pub fn block_derive(input: TokenStream) -> TokenStream {
11    let ast = parse_macro_input!(input as DeriveInput);
12    //dbg!(&ast);
13    let qsdr = qsdr_crate(&ast);
14    let vis = &ast.vis;
15
16    let work = work_type(&ast);
17
18    let block_ident = &ast.ident;
19    let block_generics = struct_generic_types(&ast);
20    let block_where = &ast.generics.where_clause;
21
22    let Data::Struct(data) = &ast.data else {
23        panic!("derive(Block) only works for struct");
24    };
25    let Fields::Named(fields) = &data.fields else {
26        panic!("struct fields should be be named fields");
27    };
28    let ports = fields
29        .named
30        .iter()
31        .filter(|field| field_is_port(field))
32        .collect::<Vec<_>>();
33
34    let work_impl = match work {
35        WorkType::WorkInPlace => {
36            check_required_ports(&ports, &["input", "output"], "WorkInPlace");
37            quote! {
38                async fn block_work(&mut self, channels: &mut Self::Channels) -> Result<#qsdr::BlockWorkStatus> {
39                    use #qsdr::{Receiver, Sender};
40                    let Some(mut item) = channels.input.recv().await else {
41                        return Ok(#qsdr::BlockWorkStatus::Done);
42                    };
43                    let status = self.work_in_place(&mut item).await?;
44                    if status.produces_output() {
45                        channels.output.send(item);
46                    }
47                    Ok(status.into())
48                }
49            }
50        }
51        WorkType::WorkSink => {
52            check_required_ports(&ports, &["input"], "WorkSink");
53            quote! {
54                async fn block_work(&mut self, channels: &mut Self::Channels) -> Result<#qsdr::BlockWorkStatus> {
55                    use #qsdr::{Receiver, RefReceiver, Sender};
56                    use ::std::borrow::Borrow;
57                    let Some(item) = channels.input.ref_recv().await else {
58                         return Ok(#qsdr::BlockWorkStatus::Done);
59                    };
60                    self.work_sink(item.borrow()).await
61                }
62            }
63        }
64        WorkType::WorkWithRef => {
65            check_required_ports(&ports, &["input", "source", "output"], "WorkWithRef");
66            quote! {
67                async fn block_work(&mut self, channels: &mut Self::Channels) -> Result<#qsdr::BlockWorkStatus> {
68                    use #qsdr::{Receiver, RefReceiver, Sender};
69                    use ::std::borrow::Borrow;
70                    let Some(mut output_item) = channels.source.recv().await else {
71                        return Ok(#qsdr::BlockWorkStatus::Done);
72                    };
73                    let Some(input_item) = channels.input.ref_recv().await else {
74                         return Ok(#qsdr::BlockWorkStatus::Done);
75                    };
76                    let status = self.work_with_ref(input_item.borrow(), &mut output_item).await?;
77                    // drop the input item reference, which potentially causes it to be returned
78                    drop(input_item);
79                    if status.produces_output() {
80                        channels.output.send(output_item);
81                    }
82                    Ok(status.into())
83                }
84            }
85        }
86        WorkType::WorkCustom => {
87            quote! {
88                fn block_work(&mut self, channels: &mut Self::Channels)
89                              -> impl ::std::future::Future<Output = Result<#qsdr::BlockWorkStatus>> {
90                    #qsdr::WorkCustom::work_custom(self, channels)
91                }
92            }
93        }
94    };
95
96    let mut channels = Vec::new();
97    let mut channel_idents = Vec::new();
98    let mut seeds = Vec::new();
99    let mut seeds_defaults = Vec::new();
100    let mut port_ids = Vec::new();
101    for (port_id, port) in ports.iter().enumerate() {
102        let ident = port.ident.as_ref().expect("port should have ident");
103        channel_idents.push(ident);
104        let ty = &port.ty;
105        channels.push(quote! {
106            #ident: <#ty as #qsdr::__private::Port>::Channel
107        });
108        seeds.push(quote! {
109            #ident: ::std::cell::RefCell<<#ty as #qsdr::__private::Port>::Seed>
110        });
111        seeds_defaults.push(quote! {
112            #ident: ::std::cell::RefCell::new(Default::default())
113        });
114        let port_id = u32::try_from(port_id).unwrap();
115        port_ids.push(quote! {
116            #vis fn #ident(&self) -> #qsdr::ports::Endpoint<'_, #ty> {
117                // Use this to remove a "field is never read" warning. With
118                // this, the warning will typically show iff this function is
119                // never called.
120                let _ = &self.as_ref().#ident;
121                let port = #qsdr::__private::PortId::from(#port_id);
122                let seed = self.seeds.#ident.borrow_mut();
123                #qsdr::ports::Endpoint::new(self.flowgraph_id, self.node_id, port, seed)
124            }
125        });
126    }
127
128    let block_channels_ident = format_ident!("__{block_ident}BlockChannels");
129    let block_seeds_ident = format_ident!("__{block_ident}BlockSeeds");
130    let block_generic_types = block_generics.iter().map(|ty| &ty.ident);
131    let block_generic_types = quote! {
132        #(#block_generic_types),*
133    };
134    let block_generic_list = quote! {
135        #(#block_generics),*
136    };
137
138    let block_channels = quote! {
139        #qsdr::__private::pin_project_lite::pin_project! {
140            #vis struct #block_channels_ident<#block_generic_list>
141            #block_where
142        {
143            #(
144                #[pin]
145                #channels
146            ),*,
147            __qsdr__phantom: ::std::marker::PhantomData<(#block_generic_types)>,
148        }
149        }
150
151        impl<#block_generic_list> TryFrom<#block_seeds_ident<#block_generic_types>>
152            for #block_channels_ident<#block_generic_types>
153            #block_where
154        {
155            type Error = anyhow::Error;
156
157            fn try_from(value: #block_seeds_ident<#block_generic_types>) -> anyhow::Result<Self> {
158                Ok(Self {
159                    #(#channel_idents: value.#channel_idents.into_inner().try_into()?),*,
160                    __qsdr__phantom: ::std::marker::PhantomData,
161                })
162            }
163        }
164
165        impl<#block_generic_list> ::std::fmt::Debug for #block_channels_ident<#block_generic_types>
166            #block_where
167        {
168            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> Result<(), ::std::fmt::Error> {
169                f.debug_struct("BlockChannels")
170                    #(
171                        .field(stringify!(#channel_idents), &std::any::type_name_of_val(&self.#channel_idents))
172                    )*
173                    .field("__qsdr__phantom", &self.__qsdr__phantom)
174                    .finish()
175            }
176        }
177    };
178
179    let block_seeds = quote! {
180        #vis struct #block_seeds_ident<#block_generic_list>
181            #block_where
182        {
183            #(#seeds),*,
184            __qsdr__phantom: ::std::marker::PhantomData<(#block_generic_types)>,
185        }
186
187        impl<#block_generic_list> Default for #block_seeds_ident<#block_generic_types>
188            #block_where
189        {
190            fn default() -> Self {
191                Self {
192                    #(#channel_idents: Default::default()),*,
193                    __qsdr__phantom: Default::default(),
194                }
195            }
196        }
197
198        impl<#block_generic_list> ::std::fmt::Debug for #block_seeds_ident<#block_generic_types>
199            #block_where
200        {
201            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> Result<(), ::std::fmt::Error> {
202                f.debug_struct("BlockSeeds")
203                    #(
204                        .field(stringify!(#channel_idents), &std::any::type_name_of_val(&self.#channel_idents))
205                    )*
206                    .field("__qsdr__phantom", &self.__qsdr__phantom)
207                    .finish()
208            }
209        }
210    };
211
212    let flowgraph_node_ident = format_ident!("__{block_ident}FlowgraphNode");
213    let flowgraph_node = quote! {
214        #[derive(Debug)]
215        #vis struct #flowgraph_node_ident<#block_generic_list>
216            #block_where
217        {
218            flowgraph_id: #qsdr::__private::FlowgraphId,
219            node_id: #qsdr::__private::NodeId,
220            block: #block_ident<#block_generic_types>,
221            seeds: #block_seeds_ident<#block_generic_types>,
222        }
223
224        impl<#block_generic_list> #qsdr::__private::FlowgraphNode for #flowgraph_node_ident<#block_generic_types>
225            #block_where
226        {
227            type B = #block_ident<#block_generic_types>;
228
229            fn flowgraph_id(&self) -> #qsdr::__private::FlowgraphId {
230                self.flowgraph_id
231            }
232
233            fn node_id(&self) -> #qsdr::__private::NodeId {
234                self.node_id
235            }
236
237            fn wrap_block(flowgraph_id: #qsdr::__private::FlowgraphId,
238                          node_id: #qsdr::__private::NodeId, block: Self::B) -> Self {
239                Self { flowgraph_id, node_id, block, seeds: Default::default() }
240            }
241
242            fn try_into_object(self, _fg: &mut #qsdr::ValidatedFlowgraph) ->
243                Result<#qsdr::BlockObject<#block_ident<#block_generic_types>>, anyhow::Error> {
244                    Ok(#qsdr::BlockObject::new(self.block, self.seeds.try_into()?))
245                }
246        }
247
248        impl<#block_generic_list> ::std::convert::AsRef<#block_ident<#block_generic_types>>
249            for #flowgraph_node_ident<#block_generic_types>
250            #block_where
251        {
252            fn as_ref(&self) -> &#block_ident<#block_generic_types> {
253                &self.block
254            }
255        }
256
257        impl<#block_generic_list> ::std::convert::AsMut<#block_ident<#block_generic_types>>
258            for #flowgraph_node_ident<#block_generic_types>
259            #block_where
260        {
261            fn as_mut(&mut self) -> &mut #block_ident<#block_generic_types> {
262                &mut self.block
263            }
264        }
265    };
266
267    let block_impl = quote! {
268        impl<#block_generic_list> #qsdr::Block for #block_ident<#block_generic_types>
269            #block_where
270        {
271            type Channels = #block_channels_ident<#block_generic_types>;
272
273            type Seeds = #block_seeds_ident<#block_generic_types>;
274
275            type Node = #flowgraph_node_ident<#block_generic_types>;
276
277            #work_impl
278        }
279    };
280
281    let ports_impl = quote! {
282        impl<#block_generic_list> #flowgraph_node_ident<#block_generic_types>
283            #block_where
284        {
285            #(#port_ids)*
286        }
287    };
288
289    let gen = quote! {
290        const _: () =  {
291            #block_channels
292            #block_seeds
293            #flowgraph_node
294            #block_impl
295            #ports_impl
296        };
297    };
298    //println!("{}", pretty_print(&gen));
299    gen.into()
300}
301
302// https://stackoverflow.com/a/74360109
303#[allow(dead_code)]
304fn pretty_print(ts: &proc_macro2::TokenStream) -> String {
305    let file = syn::parse_file(&ts.to_string()).unwrap();
306    prettyplease::unparse(&file)
307}
308
309#[derive(Debug, Copy, Clone, Eq, PartialEq)]
310#[allow(clippy::enum_variant_names)]
311enum WorkType {
312    WorkInPlace,
313    WorkSink,
314    WorkWithRef,
315    WorkCustom,
316}
317
318impl FromStr for WorkType {
319    type Err = String;
320    fn from_str(s: &str) -> Result<WorkType, String> {
321        Ok(match s {
322            "WorkInPlace" => WorkType::WorkInPlace,
323            "WorkSink" => WorkType::WorkSink,
324            "WorkWithRef" => WorkType::WorkWithRef,
325            "WorkCustom" => WorkType::WorkCustom,
326            _ => return Err(format!("invalid work type: {s}")),
327        })
328    }
329}
330
331fn qsdr_crate(ast: &DeriveInput) -> proc_macro2::TokenStream {
332    let qsdr_crate_attrs = ast
333        .attrs
334        .iter()
335        .filter_map(|attr| {
336            let Meta::NameValue(name_value) = &attr.meta else {
337                return None;
338            };
339            let segments = &name_value.path.segments;
340            if segments.len() != 1 {
341                return None;
342            }
343            let segment = segments.first().unwrap();
344            if segment.ident == "qsdr_crate" && matches!(segment.arguments, PathArguments::None) {
345                let Expr::Lit(lit) = &name_value.value else {
346                    panic!("qsdr_crate value is not a literal");
347                };
348                let Lit::Str(s) = &lit.lit else {
349                    panic!("qsdr_crate value is not a string literal");
350                };
351                Some(s.parse().unwrap())
352            } else {
353                None
354            }
355        })
356        .collect::<Vec<_>>();
357    if qsdr_crate_attrs.is_empty() {
358        return "::qsdr".parse().unwrap();
359    }
360    if qsdr_crate_attrs.len() > 1 {
361        panic!("qsdr_crate attribute present multiple times");
362    }
363    qsdr_crate_attrs.into_iter().next().unwrap()
364}
365
366fn work_type(ast: &DeriveInput) -> WorkType {
367    let work_attrs = ast
368        .attrs
369        .iter()
370        .filter_map(|attr| {
371            let Meta::List(list) = &attr.meta else {
372                return None;
373            };
374            let segments = &list.path.segments;
375            if segments.len() != 1 {
376                return None;
377            }
378            let segment = segments.first().unwrap();
379            if segment.ident == "work" && matches!(segment.arguments, PathArguments::None) {
380                Some(&list.tokens)
381            } else {
382                None
383            }
384        })
385        .collect::<Vec<_>>();
386    if work_attrs.is_empty() {
387        panic!("work attribute missing");
388    }
389    if work_attrs.len() > 1 {
390        panic!("work attribute present multiple times");
391    }
392    let attr = work_attrs[0].clone().into_iter().collect::<Vec<_>>();
393    if attr.len() != 1 {
394        panic!("work attribute does not have a single argument");
395    }
396    let proc_macro2::TokenTree::Ident(ident) = &attr[0] else {
397        panic!("work attribute is not an ident");
398    };
399    match ident.to_string().parse() {
400        Ok(w) => w,
401        Err(err) => panic!("{}", err),
402    }
403}
404
405fn struct_generic_types(ast: &DeriveInput) -> Vec<TypeParam> {
406    ast.generics
407        .params
408        .iter()
409        .filter_map(|param| {
410            if let GenericParam::Type(ty) = param {
411                let mut ty = ty.clone();
412                // remove any possible default, since it interferes with code
413                // generation
414                ty.default = None;
415                Some(ty)
416            } else {
417                None
418            }
419        })
420        .collect()
421}
422
423fn field_is_port(field: &Field) -> bool {
424    field.attrs.iter().any(|attr| match &attr.meta {
425        Meta::Path(Path { segments, .. }) => {
426            if segments.len() != 1 {
427                return false;
428            }
429            let segment = segments.first().unwrap();
430            segment.ident == "port" && matches!(segment.arguments, PathArguments::None)
431        }
432        _ => false,
433    })
434}
435
436fn has_port_with_name(ports: &[&Field], name: &str) -> bool {
437    ports.iter().any(|field| {
438        if let Some(ident) = &field.ident {
439            ident == name
440        } else {
441            false
442        }
443    })
444}
445
446fn check_required_ports(ports: &[&Field], required: &[&str], work_name: &str) {
447    for req in required {
448        if !has_port_with_name(ports, req) {
449            panic!("{} requires a port called {}", work_name, req);
450        }
451    }
452}