custos_macro/lib.rs
1use quote::{quote, ToTokens};
2use syn::{parse_macro_input, ItemImpl, ItemFn};
3
4#[proc_macro_attribute]
5/// Expands a `CPU` implementation to a `Stack` and `CPU` implementation.
6///
7/// # Example
8///
9/// ```ignore
10/// #[impl_stack]
11/// impl<T, D, S> ElementWise<T, D, S> for CPU
12/// where
13/// T: Number,
14/// D: MainMemory,
15/// S: Shape
16/// {
17/// fn add(&self, lhs: &Buffer<T, D, S>, rhs: &Buffer<T, D, S>) -> Buffer<T, CPU, S> {
18/// let mut out = self.retrieve(lhs.len, (lhs, rhs));
19/// cpu_element_wise(lhs, rhs, &mut out, |o, a, b| *o = a + b);
20/// out
21/// }
22///
23/// fn mul(&self, lhs: &Buffer<T, D, S>, rhs: &Buffer<T, D, S>) -> Buffer<T, CPU, S> {
24/// let mut out = self.retrieve(lhs.len, (lhs, rhs));
25/// cpu_element_wise(lhs, rhs, &mut out, |o, a, b| *o = a * b);
26/// out
27/// }
28/// }
29///
30/// '#[impl_stack]' expands the implementation above to the following 'Stack' implementation:
31///
32/// impl<T, D, S> ElementWise<T, D, S> for Stack
33/// where
34/// T: Number,
35/// D: MainMemory,
36/// S: Shape
37/// {
38/// fn add(&self, lhs: &Buffer<T, D, S>, rhs: &Buffer<T, D, S>) -> Buffer<T, Stack, S> {
39/// let mut out = self.retrieve(lhs.len, (lhs, rhs));
40/// cpu_element_wise(lhs, rhs, &mut out, |o, a, b| *o = a + b);
41/// out
42/// }
43///
44/// fn mul(&self, lhs: &Buffer<T, D, S>, rhs: &Buffer<T, D, S>) -> Buffer<T, Stack, S> {
45/// let mut out = self.retrieve(lhs.len, (lhs, rhs));
46/// cpu_element_wise(lhs, rhs, &mut out, |o, a, b| *o = a * b);
47/// out
48/// }
49/// }
50///
51/// // Now is it possible to execute this operations with a CPU and Stack device.
52///
53/// ```
54pub fn impl_stack(
55 _attr: proc_macro::TokenStream,
56 item: proc_macro::TokenStream,
57) -> proc_macro::TokenStream {
58 let input = parse_macro_input!(item as ItemImpl);
59 proc_macro::TokenStream::from(add_stack_impl_simpl(input))
60}
61
62const ERROR_MSG: &str = "Can't use #[impl_stack] on this implement block.";
63
64fn add_stack_impl_simpl(impl_block: ItemImpl) -> proc_macro2::TokenStream {
65 let stack_impl_block = impl_block
66 .to_token_stream()
67 .to_string()
68 .replace("CPU", "Stack");
69
70 let stack_impl_block: proc_macro2::TokenStream =
71 syn::parse_str(&stack_impl_block).expect(ERROR_MSG);
72
73 quote!(
74 #[cfg(feature = "cpu")]
75 #impl_block
76
77 #[cfg(feature = "stack")]
78 #stack_impl_block
79 )
80}
81
82#[proc_macro_attribute]
83pub fn stack_cpu_test(
84 _attr: proc_macro::TokenStream,
85 item: proc_macro::TokenStream,
86) -> proc_macro::TokenStream {
87 let input = parse_macro_input!(item as ItemFn);
88 proc_macro::TokenStream::from(add_stack_cpu_test(input))
89}
90
91const STACK_CPU_TEST_ERROR_MSG: &str = "Can't use #[stack_cpu_test] on this implement block.";
92
93fn add_stack_cpu_test(input: ItemFn) -> proc_macro2::TokenStream {
94 let stack_test_block = input
95 .to_token_stream()
96 .to_string()
97 .replace("cpu", "stack")
98 .replace("CPU :: new()", "custos::Stack");
99
100 let stack_test_block: proc_macro2::TokenStream =
101 syn::parse_str(&stack_test_block).expect(STACK_CPU_TEST_ERROR_MSG);
102
103 quote! {
104 #[cfg(feature = "cpu")]
105 #input
106
107 #[cfg(feature = "stack")]
108 #stack_test_block
109 }
110}
111
112/*
113
114fn add_stack_impl(impl_block: ItemImpl) -> proc_macro2::TokenStream {
115 let attrs = impl_block.attrs.iter().fold(quote!(), |mut acc, attr| {
116 acc.extend(attr.to_token_stream());
117 acc
118 });
119 let spawn_generics = impl_block.generics.params.to_token_stream();
120 let where_clause = impl_block.generics.where_clause.as_ref().unwrap();
121
122 if let Some(generic_type) = impl_block.generics.type_params().next() {
123 let generic_ident = &generic_type.ident;
124 /*if generic_type.ident != "T" {
125 panic!("{ERROR_MSG}");
126 //panic!("--> should use the datatype provided from ...? e.g. #[impl_stack(f32)]");
127 }*/
128
129 let impl_trait = &impl_block
130 .trait_
131 .as_ref()
132 .expect(ERROR_MSG)
133 .1
134 .to_token_stream()
135 .to_string();
136 let mut path_generics = impl_trait.split('<');
137
138 let trait_name = path_generics.next().expect(ERROR_MSG);
139 let generics_no_const = path_generics.next().expect(ERROR_MSG);
140 let trait_generics = format!(
141 "{}<{}, N >",
142 trait_name,
143 &generics_no_const[..generics_no_const.len() - 2]
144 );
145
146 let trait_path: Path = syn::parse_str(&trait_generics).expect(ERROR_MSG);
147
148 //let generics = remove_lit(generics);
149
150 let methods_updated = impl_block
151 .items
152 .clone()
153 .into_iter()
154 .flat_map(|item| match item {
155 syn::ImplItem::Method(method) => Some(method),
156 _ => None,
157 })
158 .fold(quote!(), |mut acc, mut meth| {
159 if let ReturnType::Type(_, output) = &mut meth.sig.output {
160 *output = insert_const_n_to_buf(output.to_token_stream());
161 }
162
163 meth.sig.inputs = meth
164 .sig
165 .inputs
166 .iter_mut()
167 .map(|input| {
168 match input.clone() {
169 // self
170 syn::FnArg::Receiver(_) => input.clone(),
171 // other args
172 syn::FnArg::Typed(typed) => {
173 insert_const_n_to_buf(typed.to_token_stream())
174 }
175 }
176 })
177 .collect();
178
179 acc.extend(meth.to_token_stream());
180 acc
181 });
182
183 //panic!("methods: {}", methods_updated.to_token_stream().to_string());
184
185 return quote! {
186 #impl_block
187
188 #[cfg(feature = "stack")]
189 #attrs
190 impl<#spawn_generics, const N: usize> #trait_path for custos::stack::Stack
191 #where_clause
192 custos::stack::Stack: custos::Alloc<#generic_ident, N>
193 {
194 #methods_updated
195 }
196 };
197 //panic!("x: {}", x.to_string());
198 }
199 panic!("{ERROR_MSG}")
200}
201
202fn insert_const_n_to_buf<R: syn::parse::Parse + Clone>(tokens: proc_macro2::TokenStream) -> R {
203 let tokens = tokens.to_string();
204 if !tokens.contains("Buffer") {
205 return syn::parse_str(&tokens).unwrap();
206 }
207 let mut tokens = tokens.replace("CPU", "Stack");
208
209 let idx = tokens.find('>').unwrap();
210 tokens.insert_str(idx - 1, ", N ");
211 syn::parse_str(&tokens).unwrap()
212}
213
214*/