1use proc_macro::TokenStream;
2use proc_macro2::{Delimiter, Group, Punct, Spacing, Span, TokenStream as TokenStream2, TokenTree};
3use quote::quote;
4use syn::{Error, Result};
5
6#[derive(Debug, Clone)]
7enum TensorTokens {
8 OpenBracket,
9 ClosedBracket,
10 Comma,
11 Number(Vec<TokenTree>),
12}
13
14#[derive(Debug, Clone)]
15enum RecursiveTensor {
16 Scalar(Vec<TokenTree>),
17 SubTensor(Vec<RecursiveTensor>),
18}
19
20fn j_processing32(input: TokenStream2) -> TokenStream2 {
36 let tokens: Vec<TokenTree> = input.into_iter().collect();
37 let mut new_stream = Vec::<TokenTree>::new();
38 let mut stream_accumulator: Vec<TokenTree>;
39
40 let mut token: &TokenTree;
41 let mut lit_str: String;
42 let mut pass: bool;
43 for index in 0..tokens.len() {
44 stream_accumulator = Vec::<TokenTree>::new();
45 pass = false;
46 token = &tokens[index];
47
48 match &token {
50 TokenTree::Literal(literal) => {
51 lit_str = literal.to_string();
52 if let Some(num_part) = lit_str.strip_suffix('j') {
53 if let Ok(number_val) = num_part.parse::<f32>() {
54 stream_accumulator.extend(quote! {#number_val * c32::new(0.0, 1.0)});
55 pass = true;
56 } else {
57 panic!("Not a valid number")
58 }
59 } else if let Some(num_part) = lit_str.strip_prefix('j') {
60 if let Ok(number_val) = num_part.parse::<f32>() {
61 stream_accumulator.extend(quote! {#number_val * c32::new(0.0, 1.0)});
62 pass = true;
63 } else {
64 panic!("Not a valid number")
65 }
66 }
67 }
68
69 TokenTree::Ident(identifier) => {
70 if identifier.to_string().as_str() == "j" {
71 stream_accumulator.extend(quote! {c32::new(0.0, 1.0)});
72 pass = true;
73 }
74 }
75
76 TokenTree::Group(group) => {
77 let processed_inner_stream = j_processing32(group.stream());
78 let new_group = Group::new(group.delimiter(), processed_inner_stream);
79 new_stream.push(TokenTree::Group(new_group));
80 continue;
81 }
82
83 _ => (),
84 }
85
86 if !pass {
87 new_stream.push(token.clone());
88 continue;
89 }
90
91 pass = false;
93 if index > 0 {
94 if let TokenTree::Punct(punct) = &tokens[index - 1] {
95 match punct.as_char() {
96 ']' | ',' | '+' | '-' | '*' | '/' => pass = true,
97 _ => (),
98 }
99 }
100 if !pass {
101 new_stream.push(TokenTree::Punct(Punct::new('*', Spacing::Alone)));
102 }
103 }
104
105 new_stream.extend(stream_accumulator);
106
107 pass = false;
108 if index < tokens.len() - 1 {
109 if let TokenTree::Punct(punct) = &tokens[index + 1] {
110 match punct.as_char() {
111 ']' | ',' | '+' | '-' | '*' | '/' => pass = true,
112 _ => (),
113 }
114 }
115 if !pass {
116 new_stream.push(TokenTree::Punct(Punct::new('*', Spacing::Alone)));
117 }
118 }
119 }
120 TokenStream2::from_iter(new_stream)
121}
122
123fn j_processing64(input: TokenStream2) -> TokenStream2 {
139 let tokens: Vec<TokenTree> = input.into_iter().collect();
140 let mut new_stream = Vec::<TokenTree>::new();
141 let mut stream_accumulator: Vec<TokenTree>;
142
143 let mut token: &TokenTree;
144 let mut lit_str: String;
145 let mut pass: bool;
146 for index in 0..tokens.len() {
147 stream_accumulator = Vec::<TokenTree>::new();
148 pass = false;
149 token = &tokens[index];
150
151 match &token {
153 TokenTree::Literal(literal) => {
154 lit_str = literal.to_string();
155 if let Some(num_part) = lit_str.strip_suffix('j') {
156 if let Ok(number_val) = num_part.parse::<f64>() {
157 stream_accumulator.extend(quote! {#number_val * c64::new(0.0, 1.0)});
158 pass = true;
159 } else {
160 panic!("Not a valid number")
161 }
162 } else if let Some(num_part) = lit_str.strip_prefix('j') {
163 if let Ok(number_val) = num_part.parse::<f64>() {
164 stream_accumulator.extend(quote! {#number_val * c64::new(0.0, 1.0)});
165 pass = true;
166 } else {
167 panic!("Not a valid number")
168 }
169 }
170 }
171
172 TokenTree::Ident(identifier) => {
173 if identifier.to_string().as_str() == "j" {
174 stream_accumulator.extend(quote! {c64::new(0.0, 1.0)});
175 pass = true;
176 }
177 }
178
179 TokenTree::Group(group) => {
180 let processed_inner_stream = j_processing64(group.stream());
181 let new_group = Group::new(group.delimiter(), processed_inner_stream);
182 new_stream.push(TokenTree::Group(new_group));
183 continue;
184 }
185
186 _ => (),
187 }
188
189 if !pass {
190 new_stream.push(token.clone());
191 continue;
192 }
193
194 pass = false;
196 if index > 0 {
197 if let TokenTree::Punct(punct) = &tokens[index - 1] {
198 match punct.as_char() {
199 ']' | ',' | '+' | '-' | '*' | '/' => pass = true,
200 _ => (),
201 }
202 }
203 if !pass {
204 new_stream.push(TokenTree::Punct(Punct::new('*', Spacing::Alone)));
205 }
206 }
207
208 new_stream.extend(stream_accumulator);
209
210 pass = false;
211 if index < tokens.len() - 1 {
212 if let TokenTree::Punct(punct) = &tokens[index + 1] {
213 match punct.as_char() {
214 ']' | ',' | '+' | '-' | '*' | '/' => pass = true,
215 _ => (),
216 }
217 }
218 if !pass {
219 new_stream.push(TokenTree::Punct(Punct::new('*', Spacing::Alone)));
220 }
221 }
222 }
223 TokenStream2::from_iter(new_stream)
224}
225
226fn tensor_lexer(token_stream: TokenStream2) -> Result<Vec<TensorTokens>> {
237 let mut processed_tokens = Vec::new();
238 let token_iterator = token_stream.into_iter();
239
240 let mut number_token_accumulator = Vec::new();
241
242 for token in token_iterator {
243 match &token {
244 TokenTree::Literal(_literal) => {
245 number_token_accumulator.push(token);
246 }
247
248 TokenTree::Ident(_identifier) => {
249 number_token_accumulator.push(token);
250 }
251
252 TokenTree::Punct(punctuation) => match punctuation.as_char() {
253 ',' => {
254 if !number_token_accumulator.is_empty() {
255 processed_tokens.push(TensorTokens::Number(number_token_accumulator));
256 number_token_accumulator = Vec::new();
257 }
258
259 processed_tokens.push(TensorTokens::Comma)
260 }
261 _ => number_token_accumulator.push(token),
262 },
263
264 TokenTree::Group(group) => match group.delimiter() {
265 Delimiter::Bracket => {
266 if !number_token_accumulator.is_empty() {
267 processed_tokens.push(TensorTokens::Number(number_token_accumulator));
268 number_token_accumulator = Vec::new();
269 }
270
271 processed_tokens.push(TensorTokens::OpenBracket);
272 processed_tokens.extend(tensor_lexer(group.stream())?);
273 processed_tokens.push(TensorTokens::ClosedBracket);
274 }
275 _ => number_token_accumulator.push(token),
276 },
277 }
278 }
279 if !number_token_accumulator.is_empty() {
280 processed_tokens.push(TensorTokens::Number(number_token_accumulator));
281 }
282 Ok(processed_tokens)
283}
284
285fn tensor_parser(tokens: &[TensorTokens]) -> Result<(RecursiveTensor, usize)> {
302 let mut index = 0;
303
304 if let Some(TensorTokens::OpenBracket) = tokens.get(index) {
306 index += 1;
307
308 let mut children = Vec::new();
310 loop {
311 if index > tokens.len() {
312 return Err(Error::new(Span::call_site(), "Missing closing bracket"));
313 }
314
315 if let Some(TensorTokens::ClosedBracket) = tokens.get(index) {
316 index += 1;
317 return Ok((RecursiveTensor::SubTensor(children), index));
318 }
319
320 let (child, delta_index) = tensor_parser(&tokens[index..])?;
321 children.push(child);
322 index += delta_index;
323
324 if let Some(TensorTokens::Comma) = tokens.get(index) {
325 index += 1;
326 }
327 }
328 } else if let Some(TensorTokens::Number(token_tree_vec)) = tokens.get(index) {
329 Ok((RecursiveTensor::Scalar(token_tree_vec.clone()), 1))
330 } else {
331 Err(Error::new(Span::call_site(), "Not a valid tensor"))
332 }
333}
334
335fn flatten_tensor_data(input: &RecursiveTensor) -> (Vec<TokenStream2>, Vec<usize>) {
347 match input {
348 RecursiveTensor::Scalar(token_vec) => {
349 let stream = TokenStream2::from_iter(token_vec.clone());
350 (vec![stream.clone()], vec![])
351 }
352
353 RecursiveTensor::SubTensor(subtensors) => {
354 let mut flat_data = Vec::new();
356 for subtensor in subtensors {
357 let (mut subtensor_data, _) = flatten_tensor_data(subtensor);
358 flat_data.append(&mut subtensor_data);
359 }
360
361 let (_, sub_shape) = flatten_tensor_data(&subtensors[0]);
363 let mut final_shape = vec![subtensors.len()];
364 final_shape.extend(sub_shape);
365
366 (flat_data, final_shape)
367 }
368 }
369}
370
371#[proc_macro]
388pub fn complex_tensor64(input: TokenStream) -> TokenStream {
389 let input_processed = j_processing64(input.into());
390
391 let tokens = match tensor_lexer(input_processed) {
392 Ok(inner_val) => inner_val,
393 Err(error) => return error.to_compile_error().into(),
394 };
395
396 let (recursive_tensor, _) = match tensor_parser(&tokens) {
397 Ok(inner_val) => inner_val,
398 Err(error) => return error.to_compile_error().into(),
399 };
400
401 let (flat_data, shape) = flatten_tensor_data(&recursive_tensor);
402
403 let quoted_shape = quote! {[#(#shape),*]};
404
405 let d = shape.len();
406 quote! {{
407 let data_vec: Vec<c64> = vec![#(#flat_data),*];
408 Tensor::<c64, #d>::from_slice(&data_vec, #quoted_shape)
409 }}
410 .into()
411}
412
413#[proc_macro]
430pub fn complex_tensor32(input: TokenStream) -> TokenStream {
431 let input_processed = j_processing32(input.into());
432
433 let tokens = match tensor_lexer(input_processed) {
434 Ok(inner_val) => inner_val,
435 Err(error) => return error.to_compile_error().into(),
436 };
437
438 let (recursive_tensor, _) = match tensor_parser(&tokens) {
439 Ok(inner_val) => inner_val,
440 Err(error) => return error.to_compile_error().into(),
441 };
442
443 let (flat_data, shape) = flatten_tensor_data(&recursive_tensor);
444
445 let quoted_shape = quote! {[#(#shape),*]};
446
447 let d = shape.len();
448 quote! {{
449 let data_vec: Vec<c32> = vec![#(#flat_data),*];
450 Tensor::<c32, #d>::from_slice(&data_vec, #quoted_shape)
451 }}
452 .into()
453}