qudit_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Delimiter, Group, Punct, Spacing, Span, TokenStream as TokenStream2, TokenTree};
3use quote::quote;
4use syn::{Error, Result};
5
6#[derive(Debug, Clone)]
7enum TensorTokens {
8    OpenBracket,
9    ClosedBracket,
10    Comma,
11    Number(Vec<TokenTree>),
12}
13
14#[derive(Debug, Clone)]
15enum RecursiveTensor {
16    Scalar(Vec<TokenTree>),
17    SubTensor(Vec<RecursiveTensor>),
18}
19
20/// Replaces `j` with `c32::new(0.0, 1.0)` in the input token stream.
21/// Also makes implicit multiplication explicit. (e.g. `4j` becomes `4.0 * j`)
22///
23/// # Arguments
24///
25/// * `input` - A tokenstream containing the input tokens.
26///
27/// # Returns
28///
29/// * A tokenstream with the processed tokens.
30///
31/// # Panics
32///
33/// * If a literal with suffix or prefix `j` is not a valid number.
34///
35fn j_processing32(input: TokenStream2) -> TokenStream2 {
36    let tokens: Vec<TokenTree> = input.into_iter().collect();
37    let mut new_stream = Vec::<TokenTree>::new();
38    let mut stream_accumulator: Vec<TokenTree>;
39
40    let mut token: &TokenTree;
41    let mut lit_str: String;
42    let mut pass: bool;
43    for index in 0..tokens.len() {
44        stream_accumulator = Vec::<TokenTree>::new();
45        pass = false;
46        token = &tokens[index];
47
48        // Replaces `j` with c64::new(0.0, 1.0).
49        match &token {
50            TokenTree::Literal(literal) => {
51                lit_str = literal.to_string();
52                if let Some(num_part) = lit_str.strip_suffix('j') {
53                    if let Ok(number_val) = num_part.parse::<f32>() {
54                        stream_accumulator.extend(quote! {#number_val * c32::new(0.0, 1.0)});
55                        pass = true;
56                    } else {
57                        panic!("Not a valid number")
58                    }
59                } else if let Some(num_part) = lit_str.strip_prefix('j') {
60                    if let Ok(number_val) = num_part.parse::<f32>() {
61                        stream_accumulator.extend(quote! {#number_val * c32::new(0.0, 1.0)});
62                        pass = true;
63                    } else {
64                        panic!("Not a valid number")
65                    }
66                }
67            }
68
69            TokenTree::Ident(identifier) => {
70                if identifier.to_string().as_str() == "j" {
71                    stream_accumulator.extend(quote! {c32::new(0.0, 1.0)});
72                    pass = true;
73                }
74            }
75
76            TokenTree::Group(group) => {
77                let processed_inner_stream = j_processing32(group.stream());
78                let new_group = Group::new(group.delimiter(), processed_inner_stream);
79                new_stream.push(TokenTree::Group(new_group));
80                continue;
81            }
82
83            _ => (),
84        }
85
86        if !pass {
87            new_stream.push(token.clone());
88            continue;
89        }
90
91        // Makes implicit multiplication explicit
92        pass = false;
93        if index > 0 {
94            if let TokenTree::Punct(punct) = &tokens[index - 1] {
95                match punct.as_char() {
96                    ']' | ',' | '+' | '-' | '*' | '/' => pass = true,
97                    _ => (),
98                }
99            }
100            if !pass {
101                new_stream.push(TokenTree::Punct(Punct::new('*', Spacing::Alone)));
102            }
103        }
104
105        new_stream.extend(stream_accumulator);
106
107        pass = false;
108        if index < tokens.len() - 1 {
109            if let TokenTree::Punct(punct) = &tokens[index + 1] {
110                match punct.as_char() {
111                    ']' | ',' | '+' | '-' | '*' | '/' => pass = true,
112                    _ => (),
113                }
114            }
115            if !pass {
116                new_stream.push(TokenTree::Punct(Punct::new('*', Spacing::Alone)));
117            }
118        }
119    }
120    TokenStream2::from_iter(new_stream)
121}
122
123/// Replaces `j` with `c64::new(0.0, 1.0)` in the input token stream.
124/// Also makes implicit multiplication explicit. (e.g. `4j` becomes `4.0 * j`)
125///
126/// # Arguments
127///
128/// * `input` - A tokenstream containing the input tokens.
129///
130/// # Returns
131///
132/// * A tokenstream with the processed tokens.
133///
134/// # Panics
135///
136/// * If a literal with suffix or prefix `j` is not a valid number.
137///
138fn j_processing64(input: TokenStream2) -> TokenStream2 {
139    let tokens: Vec<TokenTree> = input.into_iter().collect();
140    let mut new_stream = Vec::<TokenTree>::new();
141    let mut stream_accumulator: Vec<TokenTree>;
142
143    let mut token: &TokenTree;
144    let mut lit_str: String;
145    let mut pass: bool;
146    for index in 0..tokens.len() {
147        stream_accumulator = Vec::<TokenTree>::new();
148        pass = false;
149        token = &tokens[index];
150
151        // Replaces `j` with c64::new(0.0, 1.0).
152        match &token {
153            TokenTree::Literal(literal) => {
154                lit_str = literal.to_string();
155                if let Some(num_part) = lit_str.strip_suffix('j') {
156                    if let Ok(number_val) = num_part.parse::<f64>() {
157                        stream_accumulator.extend(quote! {#number_val * c64::new(0.0, 1.0)});
158                        pass = true;
159                    } else {
160                        panic!("Not a valid number")
161                    }
162                } else if let Some(num_part) = lit_str.strip_prefix('j') {
163                    if let Ok(number_val) = num_part.parse::<f64>() {
164                        stream_accumulator.extend(quote! {#number_val * c64::new(0.0, 1.0)});
165                        pass = true;
166                    } else {
167                        panic!("Not a valid number")
168                    }
169                }
170            }
171
172            TokenTree::Ident(identifier) => {
173                if identifier.to_string().as_str() == "j" {
174                    stream_accumulator.extend(quote! {c64::new(0.0, 1.0)});
175                    pass = true;
176                }
177            }
178
179            TokenTree::Group(group) => {
180                let processed_inner_stream = j_processing64(group.stream());
181                let new_group = Group::new(group.delimiter(), processed_inner_stream);
182                new_stream.push(TokenTree::Group(new_group));
183                continue;
184            }
185
186            _ => (),
187        }
188
189        if !pass {
190            new_stream.push(token.clone());
191            continue;
192        }
193
194        // Makes implicit multiplication explicit
195        pass = false;
196        if index > 0 {
197            if let TokenTree::Punct(punct) = &tokens[index - 1] {
198                match punct.as_char() {
199                    ']' | ',' | '+' | '-' | '*' | '/' => pass = true,
200                    _ => (),
201                }
202            }
203            if !pass {
204                new_stream.push(TokenTree::Punct(Punct::new('*', Spacing::Alone)));
205            }
206        }
207
208        new_stream.extend(stream_accumulator);
209
210        pass = false;
211        if index < tokens.len() - 1 {
212            if let TokenTree::Punct(punct) = &tokens[index + 1] {
213                match punct.as_char() {
214                    ']' | ',' | '+' | '-' | '*' | '/' => pass = true,
215                    _ => (),
216                }
217            }
218            if !pass {
219                new_stream.push(TokenTree::Punct(Punct::new('*', Spacing::Alone)));
220            }
221        }
222    }
223    TokenStream2::from_iter(new_stream)
224}
225
226/// Categorizes the tokens in the input token stream to aid in parsing.
227///
228/// # Arguments
229///
230/// * `token_stream` - A tokenstream containing the input tokens.
231///
232/// # Returns
233///
234/// * A tokenstream with the processed tokens.
235///
236fn tensor_lexer(token_stream: TokenStream2) -> Result<Vec<TensorTokens>> {
237    let mut processed_tokens = Vec::new();
238    let token_iterator = token_stream.into_iter();
239
240    let mut number_token_accumulator = Vec::new();
241
242    for token in token_iterator {
243        match &token {
244            TokenTree::Literal(_literal) => {
245                number_token_accumulator.push(token);
246            }
247
248            TokenTree::Ident(_identifier) => {
249                number_token_accumulator.push(token);
250            }
251
252            TokenTree::Punct(punctuation) => match punctuation.as_char() {
253                ',' => {
254                    if !number_token_accumulator.is_empty() {
255                        processed_tokens.push(TensorTokens::Number(number_token_accumulator));
256                        number_token_accumulator = Vec::new();
257                    }
258
259                    processed_tokens.push(TensorTokens::Comma)
260                }
261                _ => number_token_accumulator.push(token),
262            },
263
264            TokenTree::Group(group) => match group.delimiter() {
265                Delimiter::Bracket => {
266                    if !number_token_accumulator.is_empty() {
267                        processed_tokens.push(TensorTokens::Number(number_token_accumulator));
268                        number_token_accumulator = Vec::new();
269                    }
270
271                    processed_tokens.push(TensorTokens::OpenBracket);
272                    processed_tokens.extend(tensor_lexer(group.stream())?);
273                    processed_tokens.push(TensorTokens::ClosedBracket);
274                }
275                _ => number_token_accumulator.push(token),
276            },
277        }
278    }
279    if !number_token_accumulator.is_empty() {
280        processed_tokens.push(TensorTokens::Number(number_token_accumulator));
281    }
282    Ok(processed_tokens)
283}
284
285/// Organizes a series of custom tokens into a recursive tensor structure.
286///
287/// # Arguments
288///
289/// * `tokens` - A slice of `TensorTokens`, expected from `tensor_lexer`.
290///
291/// # Returns
292///
293/// * A recursive tensor storing the user's tokens.
294/// * The number of tokens consumed from the input slice.
295///
296/// # Panics
297///
298/// * If there is a missing closing bracket.
299/// * If the tensor does not start with an opening bracket or is not a scalar.
300///
301fn tensor_parser(tokens: &[TensorTokens]) -> Result<(RecursiveTensor, usize)> {
302    let mut index = 0;
303
304    // A tensor starts with [ or is a scalar.
305    if let Some(TensorTokens::OpenBracket) = tokens.get(index) {
306        index += 1;
307
308        // Each recursive step is adding one dimension.
309        let mut children = Vec::new();
310        loop {
311            if index > tokens.len() {
312                return Err(Error::new(Span::call_site(), "Missing closing bracket"));
313            }
314
315            if let Some(TensorTokens::ClosedBracket) = tokens.get(index) {
316                index += 1;
317                return Ok((RecursiveTensor::SubTensor(children), index));
318            }
319
320            let (child, delta_index) = tensor_parser(&tokens[index..])?;
321            children.push(child);
322            index += delta_index;
323
324            if let Some(TensorTokens::Comma) = tokens.get(index) {
325                index += 1;
326            }
327        }
328    } else if let Some(TensorTokens::Number(token_tree_vec)) = tokens.get(index) {
329        Ok((RecursiveTensor::Scalar(token_tree_vec.clone()), 1))
330    } else {
331        Err(Error::new(Span::call_site(), "Not a valid tensor"))
332    }
333}
334
335/// Flattens the recursive tensor structure into a single vector of tokens and calculates its shape.
336///
337/// # Arguments
338///
339/// * `input` - A reference to the recursive tensor structure.
340///
341/// # Returns
342///
343/// * A vector containing all elements of the input tensor.
344/// * The shape of the input tensor.
345///
346fn flatten_tensor_data(input: &RecursiveTensor) -> (Vec<TokenStream2>, Vec<usize>) {
347    match input {
348        RecursiveTensor::Scalar(token_vec) => {
349            let stream = TokenStream2::from_iter(token_vec.clone());
350            (vec![stream.clone()], vec![])
351        }
352
353        RecursiveTensor::SubTensor(subtensors) => {
354            // Flattened data calculation
355            let mut flat_data = Vec::new();
356            for subtensor in subtensors {
357                let (mut subtensor_data, _) = flatten_tensor_data(subtensor);
358                flat_data.append(&mut subtensor_data);
359            }
360
361            // Shape calculation
362            let (_, sub_shape) = flatten_tensor_data(&subtensors[0]);
363            let mut final_shape = vec![subtensors.len()];
364            final_shape.extend(sub_shape);
365
366            (flat_data, final_shape)
367        }
368    }
369}
370
371/// Creates a 64-bit complex tensor from nested brackets. Complex numbers
372/// can be created using `j`. (e.g. `4j`, `my_function()j`, or `some_variable * j`)
373///
374/// # Arguments
375///
376/// * `input` - The user's desired tensor written in simplified language.
377///
378/// # Returns
379///
380/// * A 64-bit complex tensor implementing the user's data.
381///
382/// # Panics
383///
384/// * If there is a missing closing bracket.
385/// * If the tensor does not start with an opening bracket or is not a scalar.
386/// * If a literal with suffix or prefix `j` is not a valid number.
387#[proc_macro]
388pub fn complex_tensor64(input: TokenStream) -> TokenStream {
389    let input_processed = j_processing64(input.into());
390
391    let tokens = match tensor_lexer(input_processed) {
392        Ok(inner_val) => inner_val,
393        Err(error) => return error.to_compile_error().into(),
394    };
395
396    let (recursive_tensor, _) = match tensor_parser(&tokens) {
397        Ok(inner_val) => inner_val,
398        Err(error) => return error.to_compile_error().into(),
399    };
400
401    let (flat_data, shape) = flatten_tensor_data(&recursive_tensor);
402
403    let quoted_shape = quote! {[#(#shape),*]};
404
405    let d = shape.len();
406    quote! {{
407            let data_vec: Vec<c64> = vec![#(#flat_data),*];
408            Tensor::<c64, #d>::from_slice(&data_vec, #quoted_shape)
409    }}
410    .into()
411}
412
413/// Creates a 32-bit complex tensor from nested brackets. Complex numbers
414/// can be created using `j`. (e.g. `4j`, `my_function()j`, or `some_variable * j`)
415///
416/// # Arguments
417///
418/// * `input` - The user's desired tensor written in simplified language.
419///
420/// # Returns
421///
422/// * A 32-bit complex tensor implementing the user's data.
423///
424/// # Panics
425///
426/// * If there is a missing closing bracket.
427/// * If the tensor does not start with an opening bracket or is not a scalar.
428/// * If a literal with suffix or prefix `j` is not a valid number.
429#[proc_macro]
430pub fn complex_tensor32(input: TokenStream) -> TokenStream {
431    let input_processed = j_processing32(input.into());
432
433    let tokens = match tensor_lexer(input_processed) {
434        Ok(inner_val) => inner_val,
435        Err(error) => return error.to_compile_error().into(),
436    };
437
438    let (recursive_tensor, _) = match tensor_parser(&tokens) {
439        Ok(inner_val) => inner_val,
440        Err(error) => return error.to_compile_error().into(),
441    };
442
443    let (flat_data, shape) = flatten_tensor_data(&recursive_tensor);
444
445    let quoted_shape = quote! {[#(#shape),*]};
446
447    let d = shape.len();
448    quote! {{
449            let data_vec: Vec<c32> = vec![#(#flat_data),*];
450            Tensor::<c32, #d>::from_slice(&data_vec, #quoted_shape)
451    }}
452    .into()
453}