qip_macros/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::{Delimiter, Spacing, TokenStream, TokenTree};
4use std::collections::HashSet;
5use std::iter::Peekable;
6use std::str::FromStr;
7
8/// Eats a ident <Group>
9fn parse_register_and_indices<It: Iterator<Item = TokenTree>>(
10    input_stream: &mut Peekable<It>,
11) -> (String, Option<TokenTree>) {
12    let register = if let Some(tokentree) = input_stream.next() {
13        match tokentree {
14            TokenTree::Ident(ident) => ident.to_string(),
15            _ => {
16                panic!("Expected register identifier, found {:?}", tokentree)
17            }
18        }
19    } else {
20        panic!("Expected register identifier, found nothing")
21    };
22
23    let indices = if let Some(tokentree) = input_stream.peek() {
24        match tokentree {
25            TokenTree::Punct(p) if p.as_char() == ';' || p.as_char() == ',' => None,
26            _ => input_stream.next(),
27        }
28    } else {
29        None
30    };
31
32    (register, indices)
33}
34
35/// Consumes a list of registers and punctuation up to the next line.
36fn parse_list_of_registers<It: Iterator<Item = TokenTree>>(
37    input_stream: &mut Peekable<It>,
38) -> (Vec<Vec<String>>, Vec<Vec<Option<TokenTree>>>) {
39    let mut register_groups: Vec<Vec<String>> = Vec::default();
40    let mut index_groups: Vec<Vec<Option<TokenTree>>> = Vec::default();
41
42    while let Some(p) = input_stream.peek() {
43        match p {
44            TokenTree::Group(_) => {
45                // Group of registers coming up
46                if let TokenTree::Group(group) = input_stream.next().unwrap() {
47                    let mut it = group.stream().into_iter().peekable();
48                    let (sub_register_groups, sub_index_groups) = parse_list_of_registers(&mut it);
49                    for v in &sub_register_groups {
50                        if v.len() != 1 {
51                            panic!("Register groups may not be nested");
52                        }
53                    }
54                    let sub_register_groups = sub_register_groups
55                        .into_iter()
56                        .map(|mut v| v.pop().unwrap())
57                        .collect();
58                    let sub_index_groups = sub_index_groups
59                        .into_iter()
60                        .map(|mut v| v.pop().unwrap())
61                        .collect();
62                    register_groups.push(sub_register_groups);
63                    index_groups.push(sub_index_groups);
64                }
65            }
66            TokenTree::Punct(punct) if punct.as_char() == ',' => {
67                input_stream.next();
68            }
69            TokenTree::Punct(punct) if punct.as_char() == ';' => {
70                // Done with this line
71                input_stream.next();
72                break;
73            }
74            TokenTree::Ident(_) => {
75                // A single register
76                let (r, indices) = parse_register_and_indices(input_stream);
77                register_groups.push(vec![r]);
78                index_groups.push(vec![indices]);
79            }
80            p => panic!(
81                "Expected group, identifier, comma, or semicolon, found {:?}",
82                p
83            ),
84        }
85    }
86
87    (register_groups, index_groups)
88}
89
90#[proc_macro]
91pub fn program(input_stream: TokenStream) -> TokenStream {
92    let mut input_stream = input_stream.into_iter().peekable();
93    let mut output_stream = TokenStream::new();
94
95    // First we take the builder expression.
96    let mut builder_stream = TokenStream::from_str("let _program_builder = ").unwrap();
97    for tokentree in input_stream.by_ref() {
98        let done = matches!(&tokentree, TokenTree::Punct(p) if p.as_char() == ';');
99        builder_stream.extend(Some(tokentree));
100        if done {
101            break;
102        }
103    }
104    output_stream.extend(builder_stream);
105
106    // Parse input registers
107    let mut input_registers = Vec::default();
108    for tokentree in input_stream.by_ref() {
109        match tokentree {
110            TokenTree::Punct(p) if p.as_char() == ';' => {
111                break;
112            }
113            TokenTree::Punct(p) if p.as_char() == ',' => {}
114            TokenTree::Ident(ident) => input_registers.push(ident.to_string()),
115            _ => panic!(
116                "Expecting a register ident, a comma, or a semicolon, found {:?}",
117                tokentree
118            ),
119        };
120    }
121
122    let original_list = input_registers.clone();
123    let original_size = original_list.len();
124    input_registers.dedup();
125    if original_size != input_registers.len() {
126        panic!(
127            "Input register list contained duplicates: {:?}",
128            original_list
129        );
130    }
131
132    for input_register in &input_registers {
133        output_stream.extend(TokenStream::from_str(&format!("let mut {} = _program_builder.split_all_register({}).into_iter().map(|r| Some(r)).collect::<Vec<_>>();", input_register, input_register)).unwrap())
134    }
135
136    // Now we parse lines in the program.
137    loop {
138        let mut control = false;
139        let mut control_bits = None;
140        let mut function = String::new();
141        let mut arguments = None;
142
143        // Check if first word is 'control'
144        if let Some(tokentree) = input_stream.next() {
145            match tokentree {
146                TokenTree::Ident(ident) if ident.to_string() == "control" => {
147                    control = true;
148                }
149                TokenTree::Ident(ident) => {
150                    function = ident.to_string();
151                    if let Some(TokenTree::Group(g)) = input_stream.peek() {
152                        if g.delimiter() == Delimiter::Parenthesis {
153                            if let Some(TokenTree::Group(g)) = input_stream.next() {
154                                arguments = Some(g.stream());
155                            }
156                        }
157                    }
158                }
159                _ => {
160                    panic!("Unexpected first token: {:?}", tokentree)
161                }
162            }
163        }
164
165        if control {
166            let mut found_bit_group = false;
167            // Now either we find control bits, or we find a function.
168            if let Some(tokentree) = input_stream.next() {
169                match tokentree {
170                    TokenTree::Ident(ident) => {
171                        function = ident.to_string();
172                        if let Some(TokenTree::Group(g)) = input_stream.peek() {
173                            if g.delimiter() == Delimiter::Parenthesis {
174                                if let Some(TokenTree::Group(g)) = input_stream.next() {
175                                    arguments = Some(g.stream());
176                                }
177                            }
178                        }
179                    }
180                    TokenTree::Group(group) => {
181                        found_bit_group = true;
182                        control_bits = Some(group.stream());
183                    }
184                    _ => {
185                        panic!("Unexpected token after `control`: {:?}", tokentree)
186                    }
187                }
188            }
189            if found_bit_group {
190                // Now it had better be a function.
191                if let Some(tokentree) = input_stream.next() {
192                    match tokentree {
193                        TokenTree::Ident(ident) => {
194                            function = ident.to_string();
195                            if let Some(TokenTree::Group(g)) = input_stream.peek() {
196                                if g.delimiter() == Delimiter::Parenthesis {
197                                    if let Some(TokenTree::Group(g)) = input_stream.next() {
198                                        arguments = Some(g.stream());
199                                    }
200                                }
201                            }
202                        }
203                        _ => {
204                            panic!("Unexpected token after `control(bits)`: {:?}", tokentree)
205                        }
206                    }
207                }
208            }
209        }
210
211        // Now parse each of the register[indices]
212        let (register_list, index_list) = parse_list_of_registers(&mut input_stream);
213
214        let mut line_stream = TokenStream::new();
215
216        // Pull relevant registers out.
217        for (ri, (rs, is)) in register_list.iter().zip(index_list.iter()).enumerate() {
218            let reg_name = format!("_program_register_{}", ri);
219
220            let full_string = Some("None.into_iter()".to_string()).into_iter().chain(rs.iter().zip(is).map(|(r, s)| {
221                if let Some(s) = s {
222                    format!("qip::macros::program::QubitIndices::from({}).into_iter().map(|i| {}[i].take().unwrap())", s, r)
223                } else {
224                    format!("(0..{}.len()).map(|i| {}[i].take().unwrap())", r, r)
225                }
226            }).map(|s| format!(".chain({})", s))).collect::<String>();
227
228            line_stream.extend(
229                TokenStream::from_str(&format!(
230                    "let {} = _program_builder.merge_registers({}).unwrap();",
231                    reg_name, full_string
232                ))
233                .unwrap(),
234            );
235        }
236        output_stream.extend(line_stream);
237
238        // Now use the registers to call the function.
239        let mut start = 0;
240        let mut builder_name = "_program_builder";
241        let mut has_control_bits = false;
242        if control {
243            start = 1;
244
245            if let Some(control_bits) = control_bits {
246                output_stream.extend(TokenStream::from_str("let _control_bitmask = "));
247                output_stream.extend(control_bits.clone());
248                output_stream.extend(TokenStream::from_str(";"));
249                output_stream.extend(TokenStream::from_str("let _program_register_0 = qip::macros::program::negate_bitmask(_program_builder, _program_register_0, _control_bitmask);"));
250                has_control_bits = true;
251            }
252
253            output_stream.extend(TokenStream::from_str("let mut _control_program_builder = _program_builder.condition_with(_program_register_0);"));
254            builder_name = "&mut _control_program_builder";
255        }
256
257        let args_string = if let Some(args) = arguments {
258            format!("{},", args)
259        } else {
260            "".to_string()
261        };
262
263        let subsection = &register_list[start..];
264        if subsection.len() == 1 {
265            let register_name = format!("_program_register_{} ", start);
266            let string = format!(
267                "let {} = {}({}, {} {})?;",
268                register_name, function, builder_name, args_string, register_name
269            );
270            output_stream.extend(TokenStream::from_str(&string).unwrap());
271        } else {
272            let register_names = (start..register_list.len() - 1)
273                .map(|i| format!("_program_register_{}, ", i))
274                .chain(Some(format!(
275                    "_program_register_{} ",
276                    register_list.len() - 1
277                )))
278                .collect::<String>();
279            let string = format!(
280                "let ({}) = {}({}, {} {})?;",
281                register_names, function, builder_name, args_string, register_names
282            );
283            output_stream.extend(TokenStream::from_str(&string).unwrap());
284        }
285
286        // Now use the registers to call the function.
287        if control {
288            output_stream.extend(TokenStream::from_str(
289                "let _program_register_0 = _control_program_builder.dissolve();",
290            ));
291            if has_control_bits {
292                output_stream.extend(TokenStream::from_str("let _program_register_0 = qip::macros::program::negate_bitmask(_program_builder, _program_register_0, _control_bitmask);"));
293            }
294        }
295
296        // Put registers back.
297        let mut replace_qudits_stream = TokenStream::new();
298        for (ri, (rs, is)) in register_list.iter().zip(index_list.iter()).enumerate() {
299            let reg_name = format!("_program_register_{}", ri);
300            replace_qudits_stream.extend(TokenStream::from_str(&format!("let mut {} = _program_builder.split_all_register({}).into_iter().map(|r| Some(r)).collect::<Vec<_>>(); let mut {}_index = 0;", reg_name, reg_name, reg_name)).unwrap());
301            for (r, s) in rs.iter().zip(is.iter()) {
302                let s = if let Some(s) = s {
303                    format!("qip::macros::program::QubitIndices::from({})", s)
304                } else {
305                    format!("0..{}.len()", r)
306                };
307
308                replace_qudits_stream.extend(
309                    TokenStream::from_str(&format!(
310                        "for i in {} {{ {}[i] = {}[{}_index].take(); {}_index += 1;  }}",
311                        s, r, reg_name, reg_name, reg_name
312                    ))
313                    .unwrap(),
314                );
315            }
316        }
317        output_stream.extend(replace_qudits_stream);
318
319        if input_stream.peek().is_none() {
320            break;
321        }
322    }
323
324    // Return the registers
325    for input_register in &input_registers {
326        output_stream.extend(TokenStream::from_str(&format!("let {} = _program_builder.merge_registers({}.into_iter().flat_map(|r| r)).unwrap();", input_register, input_register)).unwrap())
327    }
328
329    let mut tuple_stream = TokenStream::new();
330
331    if input_registers.len() == 1 {
332        output_stream.extend(TokenStream::from_str(&format!(
333            "Ok({})",
334            input_registers[0]
335        )));
336    } else {
337        for input_register in &input_registers {
338            tuple_stream.extend(Some(
339                TokenStream::from_str(&format!("{}, ", input_register)).unwrap(),
340            ))
341        }
342        output_stream.extend(TokenStream::from_str(&format!(
343            "Ok({})",
344            TokenTree::Group(proc_macro::Group::new(Delimiter::Parenthesis, tuple_stream))
345        )));
346    }
347
348    TokenStream::from(TokenTree::Group(proc_macro::Group::new(
349        proc_macro::Delimiter::Brace,
350        output_stream,
351    )))
352}
353
354fn parse_function_args(arg_stream: TokenStream, to: &mut Vec<String>) {
355    let mut arg_stream = arg_stream.into_iter().peekable();
356    while let Some(token) = arg_stream.next() {
357        match (token, arg_stream.peek()) {
358            (TokenTree::Ident(ident), Some(TokenTree::Punct(punct)))
359                if punct.as_char() == ':' && punct.spacing() == Spacing::Alone =>
360            {
361                to.push(ident.to_string());
362            }
363            _ => {}
364        }
365    }
366}
367
368#[proc_macro_attribute]
369pub fn invert(attr: TokenStream, input_stream: TokenStream) -> TokenStream {
370    // Output starts same as input.
371    let mut output_stream = input_stream.clone();
372
373    let mut attr = attr.into_iter().peekable();
374    let new_function_name = attr.next();
375    if let Some(TokenTree::Punct(_)) = attr.peek() {
376        attr.next();
377    }
378
379    let mut non_register_args = Vec::default();
380    while let Some(TokenTree::Ident(ident)) = attr.next() {
381        non_register_args.push(ident.to_string());
382        if let Some(TokenTree::Punct(_)) = attr.peek() {
383            attr.next();
384        }
385    }
386
387    let mut function_name = String::from("foo");
388    let new_function_name = new_function_name.map(|s| s.to_string());
389
390    let mut input_stream = input_stream.into_iter().peekable();
391    // We will draw from the stream until we find the opening parens for the function arguments or generics.
392    while let Some(token) = input_stream.next() {
393        if let TokenTree::Ident(ident) = &token {
394            match input_stream.peek() {
395                Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Parenthesis => {
396                    function_name = ident.to_string();
397
398                    let to_add =
399                        new_function_name.unwrap_or_else(|| format!("{}_inv", function_name));
400                    let to_add = TokenStream::from_str(&to_add).unwrap();
401                    output_stream.extend(to_add);
402
403                    break;
404                }
405                Some(TokenTree::Punct(punct)) if punct.as_char() == '<' => {
406                    function_name = ident.to_string();
407                    let to_add =
408                        new_function_name.unwrap_or_else(|| format!("{}_inv", function_name));
409                    let to_add = TokenStream::from_str(&to_add).unwrap();
410                    output_stream.extend(to_add);
411
412                    break;
413                }
414                _ => {
415                    let to_add = TokenStream::from(token);
416                    output_stream.extend(to_add);
417                }
418            }
419        } else {
420            let to_add = TokenStream::from(token);
421            output_stream.extend(to_add);
422        }
423    }
424
425    // Now get the group with the function arguments.
426    let mut function_args = vec![];
427    for token in input_stream.by_ref() {
428        let should_break = if let TokenTree::Group(group) = &token {
429            if group.delimiter() == Delimiter::Parenthesis {
430                parse_function_args(group.stream().clone(), &mut function_args);
431                true
432            } else {
433                false
434            }
435        } else {
436            false
437        };
438        let to_add = TokenStream::from(token);
439        output_stream.extend(to_add);
440        if should_break {
441            break;
442        }
443    }
444
445    // Now parse until curly braces, throw those out.
446    for token in input_stream {
447        match &token {
448            TokenTree::Group(group) if group.delimiter() == Delimiter::Brace => {
449                break;
450            }
451            _ => {
452                let to_add = TokenStream::from(token);
453                output_stream.extend(to_add);
454            }
455        }
456    }
457
458    let builder = function_args[0].clone();
459    let new_builder = format!("_{builder}_new");
460
461    let mut skip_args = HashSet::new();
462    skip_args.extend(non_register_args.into_iter());
463
464    let regs_only = function_args[1..]
465        .iter()
466        .filter_map(|s| {
467            if !skip_args.contains(s) {
468                Some(s.clone())
469            } else {
470                None
471            }
472        })
473        .collect::<Vec<_>>();
474
475    let regs_list = regs_only.join(",");
476
477    let regs_sizes = regs_only
478        .iter()
479        .map(|reg| format!("{reg}.n()"))
480        .collect::<Vec<String>>()
481        .join(",");
482
483    let make_new_regs = regs_only
484        .iter()
485        .map(|s| format!("let _{s}_new = {new_builder}.register({s}.n_nonzero());"))
486        .collect::<String>();
487
488    let new_regs_args = Some(format!("&mut {new_builder}"))
489        .into_iter()
490        .chain(function_args[1..].iter().map(|s| {
491            if !skip_args.contains(s) {
492                format!("_{s}_new")
493            } else {
494                s.clone()
495            }
496        }))
497        .collect::<Vec<String>>()
498        .join(",");
499
500    let pop_regs = regs_only
501        .iter()
502        .rev()
503        .map(|s| {
504            format!("let {s} = _selected_vec.pop().expect(&format!(\"Register {s} is missing!\"));")
505        })
506        .collect::<String>();
507
508    let to_add = TokenStream::from(TokenTree::Group(proc_macro::Group::new(Delimiter::Brace, TokenStream::from_str(&format!("
509        let _register_sizes = [{regs_sizes}];
510        let mut {new_builder} = {builder}.new_similar();
511        {make_new_regs}
512        {function_name}({new_regs_args})?;
513        let _subcircuit = {new_builder}.make_subcircuit()?;
514        let _combined_r = {builder}.merge_registers([{regs_list}]).expect(\"Must have some registers.\");
515        let _combined_r = {builder}.apply_inverted_subcircuit(_subcircuit, _combined_r)?;
516        let mut _selected_vec = {builder}.split_relative_index_groups(_combined_r, _register_sizes.into_iter().scan(0, |acc, n| {{
517            let range = *acc..*acc+n;
518            *acc += n;
519            Some(range)
520        }})).get_all_selected().expect(\"All registers should have been selected\");
521        {pop_regs}
522        Ok(({regs_list}))
523    ")).unwrap())));
524    output_stream.extend(to_add);
525
526    output_stream
527}