custos_macro/
lib.rs

1use quote::{quote, ToTokens};
2use syn::{parse_macro_input, ItemImpl, ItemFn};
3
4#[proc_macro_attribute]
5/// Expands a `CPU` implementation to a `Stack` and `CPU` implementation.
6///
7/// # Example
8///
9/// ```ignore
10/// #[impl_stack]
11/// impl<T, D, S> ElementWise<T, D, S> for CPU
12/// where
13///     T: Number,
14///     D: MainMemory,
15///     S: Shape
16/// {
17///     fn add(&self, lhs: &Buffer<T, D, S>, rhs: &Buffer<T, D, S>) -> Buffer<T, CPU, S> {
18///         let mut out = self.retrieve(lhs.len, (lhs, rhs));
19///         cpu_element_wise(lhs, rhs, &mut out, |o, a, b| *o = a + b);
20///         out
21///     }
22/// 
23///     fn mul(&self, lhs: &Buffer<T, D, S>, rhs: &Buffer<T, D, S>) -> Buffer<T, CPU, S> {
24///         let mut out = self.retrieve(lhs.len, (lhs, rhs));
25///         cpu_element_wise(lhs, rhs, &mut out, |o, a, b| *o = a * b);
26///         out
27///     }
28/// }
29/// 
30/// '#[impl_stack]' expands the implementation above to the following 'Stack' implementation:
31/// 
32/// impl<T, D, S> ElementWise<T, D, S> for Stack
33/// where
34///     T: Number,
35///     D: MainMemory,
36///     S: Shape
37/// {
38///     fn add(&self, lhs: &Buffer<T, D, S>, rhs: &Buffer<T, D, S>) -> Buffer<T, Stack, S> {
39///         let mut out = self.retrieve(lhs.len, (lhs, rhs));
40///         cpu_element_wise(lhs, rhs, &mut out, |o, a, b| *o = a + b);
41///         out
42///     }
43/// 
44///     fn mul(&self, lhs: &Buffer<T, D, S>, rhs: &Buffer<T, D, S>) -> Buffer<T, Stack, S> {
45///         let mut out = self.retrieve(lhs.len, (lhs, rhs));
46///         cpu_element_wise(lhs, rhs, &mut out, |o, a, b| *o = a * b);
47///         out
48///     }
49/// }
50/// 
51/// // Now is it possible to execute this operations with a CPU and Stack device.
52///
53/// ```
54pub fn impl_stack(
55    _attr: proc_macro::TokenStream,
56    item: proc_macro::TokenStream,
57) -> proc_macro::TokenStream {
58    let input = parse_macro_input!(item as ItemImpl);
59    proc_macro::TokenStream::from(add_stack_impl_simpl(input))
60}
61
62const ERROR_MSG: &str = "Can't use #[impl_stack] on this implement block.";
63
64fn add_stack_impl_simpl(impl_block: ItemImpl) -> proc_macro2::TokenStream {
65    let stack_impl_block = impl_block
66        .to_token_stream()
67        .to_string()
68        .replace("CPU", "Stack");
69
70    let stack_impl_block: proc_macro2::TokenStream =
71        syn::parse_str(&stack_impl_block).expect(ERROR_MSG);
72
73    quote!(
74        #[cfg(feature = "cpu")]
75        #impl_block
76
77        #[cfg(feature = "stack")]
78        #stack_impl_block
79    )
80}
81
82#[proc_macro_attribute]
83pub fn stack_cpu_test(
84    _attr: proc_macro::TokenStream,
85    item: proc_macro::TokenStream,
86) -> proc_macro::TokenStream {
87    let input = parse_macro_input!(item as ItemFn);
88    proc_macro::TokenStream::from(add_stack_cpu_test(input))
89}
90
91const STACK_CPU_TEST_ERROR_MSG: &str = "Can't use #[stack_cpu_test] on this implement block.";
92
93fn add_stack_cpu_test(input: ItemFn) -> proc_macro2::TokenStream {
94    let stack_test_block = input
95        .to_token_stream()
96        .to_string()
97        .replace("cpu", "stack")
98        .replace("CPU :: new()", "custos::Stack");
99
100    let stack_test_block: proc_macro2::TokenStream =
101        syn::parse_str(&stack_test_block).expect(STACK_CPU_TEST_ERROR_MSG);
102
103    quote! {
104        #[cfg(feature = "cpu")]
105        #input
106
107        #[cfg(feature = "stack")]
108        #stack_test_block
109    }
110}
111
112/*
113
114fn add_stack_impl(impl_block: ItemImpl) -> proc_macro2::TokenStream {
115    let attrs = impl_block.attrs.iter().fold(quote!(), |mut acc, attr| {
116        acc.extend(attr.to_token_stream());
117        acc
118    });
119    let spawn_generics = impl_block.generics.params.to_token_stream();
120    let where_clause = impl_block.generics.where_clause.as_ref().unwrap();
121
122    if let Some(generic_type) = impl_block.generics.type_params().next() {
123        let generic_ident = &generic_type.ident;
124        /*if generic_type.ident != "T" {
125            panic!("{ERROR_MSG}");
126            //panic!("--> should use the datatype provided from ...? e.g. #[impl_stack(f32)]");
127        }*/
128
129        let impl_trait = &impl_block
130            .trait_
131            .as_ref()
132            .expect(ERROR_MSG)
133            .1
134            .to_token_stream()
135            .to_string();
136        let mut path_generics = impl_trait.split('<');
137
138        let trait_name = path_generics.next().expect(ERROR_MSG);
139        let generics_no_const = path_generics.next().expect(ERROR_MSG);
140        let trait_generics = format!(
141            "{}<{}, N >",
142            trait_name,
143            &generics_no_const[..generics_no_const.len() - 2]
144        );
145
146        let trait_path: Path = syn::parse_str(&trait_generics).expect(ERROR_MSG);
147
148        //let generics = remove_lit(generics);
149
150        let methods_updated = impl_block
151            .items
152            .clone()
153            .into_iter()
154            .flat_map(|item| match item {
155                syn::ImplItem::Method(method) => Some(method),
156                _ => None,
157            })
158            .fold(quote!(), |mut acc, mut meth| {
159                if let ReturnType::Type(_, output) = &mut meth.sig.output {
160                    *output = insert_const_n_to_buf(output.to_token_stream());
161                }
162
163                meth.sig.inputs = meth
164                    .sig
165                    .inputs
166                    .iter_mut()
167                    .map(|input| {
168                        match input.clone() {
169                            // self
170                            syn::FnArg::Receiver(_) => input.clone(),
171                            // other args
172                            syn::FnArg::Typed(typed) => {
173                                insert_const_n_to_buf(typed.to_token_stream())
174                            }
175                        }
176                    })
177                    .collect();
178
179                acc.extend(meth.to_token_stream());
180                acc
181            });
182
183        //panic!("methods: {}", methods_updated.to_token_stream().to_string());
184
185        return quote! {
186            #impl_block
187
188            #[cfg(feature = "stack")]
189            #attrs
190            impl<#spawn_generics, const N: usize> #trait_path for custos::stack::Stack
191            #where_clause
192            custos::stack::Stack: custos::Alloc<#generic_ident, N>
193            {
194                #methods_updated
195            }
196        };
197        //panic!("x: {}", x.to_string());
198    }
199    panic!("{ERROR_MSG}")
200}
201
202fn insert_const_n_to_buf<R: syn::parse::Parse + Clone>(tokens: proc_macro2::TokenStream) -> R {
203    let tokens = tokens.to_string();
204    if !tokens.contains("Buffer") {
205        return syn::parse_str(&tokens).unwrap();
206    }
207    let mut tokens = tokens.replace("CPU", "Stack");
208
209    let idx = tokens.find('>').unwrap();
210    tokens.insert_str(idx - 1, ", N ");
211    syn::parse_str(&tokens).unwrap()
212}
213
214*/