1extern crate proc_macro;
2
3use proc_macro::{Delimiter, Spacing, TokenStream, TokenTree};
4use std::collections::HashSet;
5use std::iter::Peekable;
6use std::str::FromStr;
7
8fn parse_register_and_indices<It: Iterator<Item = TokenTree>>(
10 input_stream: &mut Peekable<It>,
11) -> (String, Option<TokenTree>) {
12 let register = if let Some(tokentree) = input_stream.next() {
13 match tokentree {
14 TokenTree::Ident(ident) => ident.to_string(),
15 _ => {
16 panic!("Expected register identifier, found {:?}", tokentree)
17 }
18 }
19 } else {
20 panic!("Expected register identifier, found nothing")
21 };
22
23 let indices = if let Some(tokentree) = input_stream.peek() {
24 match tokentree {
25 TokenTree::Punct(p) if p.as_char() == ';' || p.as_char() == ',' => None,
26 _ => input_stream.next(),
27 }
28 } else {
29 None
30 };
31
32 (register, indices)
33}
34
35fn parse_list_of_registers<It: Iterator<Item = TokenTree>>(
37 input_stream: &mut Peekable<It>,
38) -> (Vec<Vec<String>>, Vec<Vec<Option<TokenTree>>>) {
39 let mut register_groups: Vec<Vec<String>> = Vec::default();
40 let mut index_groups: Vec<Vec<Option<TokenTree>>> = Vec::default();
41
42 while let Some(p) = input_stream.peek() {
43 match p {
44 TokenTree::Group(_) => {
45 if let TokenTree::Group(group) = input_stream.next().unwrap() {
47 let mut it = group.stream().into_iter().peekable();
48 let (sub_register_groups, sub_index_groups) = parse_list_of_registers(&mut it);
49 for v in &sub_register_groups {
50 if v.len() != 1 {
51 panic!("Register groups may not be nested");
52 }
53 }
54 let sub_register_groups = sub_register_groups
55 .into_iter()
56 .map(|mut v| v.pop().unwrap())
57 .collect();
58 let sub_index_groups = sub_index_groups
59 .into_iter()
60 .map(|mut v| v.pop().unwrap())
61 .collect();
62 register_groups.push(sub_register_groups);
63 index_groups.push(sub_index_groups);
64 }
65 }
66 TokenTree::Punct(punct) if punct.as_char() == ',' => {
67 input_stream.next();
68 }
69 TokenTree::Punct(punct) if punct.as_char() == ';' => {
70 input_stream.next();
72 break;
73 }
74 TokenTree::Ident(_) => {
75 let (r, indices) = parse_register_and_indices(input_stream);
77 register_groups.push(vec![r]);
78 index_groups.push(vec![indices]);
79 }
80 p => panic!(
81 "Expected group, identifier, comma, or semicolon, found {:?}",
82 p
83 ),
84 }
85 }
86
87 (register_groups, index_groups)
88}
89
90#[proc_macro]
91pub fn program(input_stream: TokenStream) -> TokenStream {
92 let mut input_stream = input_stream.into_iter().peekable();
93 let mut output_stream = TokenStream::new();
94
95 let mut builder_stream = TokenStream::from_str("let _program_builder = ").unwrap();
97 for tokentree in input_stream.by_ref() {
98 let done = matches!(&tokentree, TokenTree::Punct(p) if p.as_char() == ';');
99 builder_stream.extend(Some(tokentree));
100 if done {
101 break;
102 }
103 }
104 output_stream.extend(builder_stream);
105
106 let mut input_registers = Vec::default();
108 for tokentree in input_stream.by_ref() {
109 match tokentree {
110 TokenTree::Punct(p) if p.as_char() == ';' => {
111 break;
112 }
113 TokenTree::Punct(p) if p.as_char() == ',' => {}
114 TokenTree::Ident(ident) => input_registers.push(ident.to_string()),
115 _ => panic!(
116 "Expecting a register ident, a comma, or a semicolon, found {:?}",
117 tokentree
118 ),
119 };
120 }
121
122 let original_list = input_registers.clone();
123 let original_size = original_list.len();
124 input_registers.dedup();
125 if original_size != input_registers.len() {
126 panic!(
127 "Input register list contained duplicates: {:?}",
128 original_list
129 );
130 }
131
132 for input_register in &input_registers {
133 output_stream.extend(TokenStream::from_str(&format!("let mut {} = _program_builder.split_all_register({}).into_iter().map(|r| Some(r)).collect::<Vec<_>>();", input_register, input_register)).unwrap())
134 }
135
136 loop {
138 let mut control = false;
139 let mut control_bits = None;
140 let mut function = String::new();
141 let mut arguments = None;
142
143 if let Some(tokentree) = input_stream.next() {
145 match tokentree {
146 TokenTree::Ident(ident) if ident.to_string() == "control" => {
147 control = true;
148 }
149 TokenTree::Ident(ident) => {
150 function = ident.to_string();
151 if let Some(TokenTree::Group(g)) = input_stream.peek() {
152 if g.delimiter() == Delimiter::Parenthesis {
153 if let Some(TokenTree::Group(g)) = input_stream.next() {
154 arguments = Some(g.stream());
155 }
156 }
157 }
158 }
159 _ => {
160 panic!("Unexpected first token: {:?}", tokentree)
161 }
162 }
163 }
164
165 if control {
166 let mut found_bit_group = false;
167 if let Some(tokentree) = input_stream.next() {
169 match tokentree {
170 TokenTree::Ident(ident) => {
171 function = ident.to_string();
172 if let Some(TokenTree::Group(g)) = input_stream.peek() {
173 if g.delimiter() == Delimiter::Parenthesis {
174 if let Some(TokenTree::Group(g)) = input_stream.next() {
175 arguments = Some(g.stream());
176 }
177 }
178 }
179 }
180 TokenTree::Group(group) => {
181 found_bit_group = true;
182 control_bits = Some(group.stream());
183 }
184 _ => {
185 panic!("Unexpected token after `control`: {:?}", tokentree)
186 }
187 }
188 }
189 if found_bit_group {
190 if let Some(tokentree) = input_stream.next() {
192 match tokentree {
193 TokenTree::Ident(ident) => {
194 function = ident.to_string();
195 if let Some(TokenTree::Group(g)) = input_stream.peek() {
196 if g.delimiter() == Delimiter::Parenthesis {
197 if let Some(TokenTree::Group(g)) = input_stream.next() {
198 arguments = Some(g.stream());
199 }
200 }
201 }
202 }
203 _ => {
204 panic!("Unexpected token after `control(bits)`: {:?}", tokentree)
205 }
206 }
207 }
208 }
209 }
210
211 let (register_list, index_list) = parse_list_of_registers(&mut input_stream);
213
214 let mut line_stream = TokenStream::new();
215
216 for (ri, (rs, is)) in register_list.iter().zip(index_list.iter()).enumerate() {
218 let reg_name = format!("_program_register_{}", ri);
219
220 let full_string = Some("None.into_iter()".to_string()).into_iter().chain(rs.iter().zip(is).map(|(r, s)| {
221 if let Some(s) = s {
222 format!("qip::macros::program::QubitIndices::from({}).into_iter().map(|i| {}[i].take().unwrap())", s, r)
223 } else {
224 format!("(0..{}.len()).map(|i| {}[i].take().unwrap())", r, r)
225 }
226 }).map(|s| format!(".chain({})", s))).collect::<String>();
227
228 line_stream.extend(
229 TokenStream::from_str(&format!(
230 "let {} = _program_builder.merge_registers({}).unwrap();",
231 reg_name, full_string
232 ))
233 .unwrap(),
234 );
235 }
236 output_stream.extend(line_stream);
237
238 let mut start = 0;
240 let mut builder_name = "_program_builder";
241 let mut has_control_bits = false;
242 if control {
243 start = 1;
244
245 if let Some(control_bits) = control_bits {
246 output_stream.extend(TokenStream::from_str("let _control_bitmask = "));
247 output_stream.extend(control_bits.clone());
248 output_stream.extend(TokenStream::from_str(";"));
249 output_stream.extend(TokenStream::from_str("let _program_register_0 = qip::macros::program::negate_bitmask(_program_builder, _program_register_0, _control_bitmask);"));
250 has_control_bits = true;
251 }
252
253 output_stream.extend(TokenStream::from_str("let mut _control_program_builder = _program_builder.condition_with(_program_register_0);"));
254 builder_name = "&mut _control_program_builder";
255 }
256
257 let args_string = if let Some(args) = arguments {
258 format!("{},", args)
259 } else {
260 "".to_string()
261 };
262
263 let subsection = ®ister_list[start..];
264 if subsection.len() == 1 {
265 let register_name = format!("_program_register_{} ", start);
266 let string = format!(
267 "let {} = {}({}, {} {})?;",
268 register_name, function, builder_name, args_string, register_name
269 );
270 output_stream.extend(TokenStream::from_str(&string).unwrap());
271 } else {
272 let register_names = (start..register_list.len() - 1)
273 .map(|i| format!("_program_register_{}, ", i))
274 .chain(Some(format!(
275 "_program_register_{} ",
276 register_list.len() - 1
277 )))
278 .collect::<String>();
279 let string = format!(
280 "let ({}) = {}({}, {} {})?;",
281 register_names, function, builder_name, args_string, register_names
282 );
283 output_stream.extend(TokenStream::from_str(&string).unwrap());
284 }
285
286 if control {
288 output_stream.extend(TokenStream::from_str(
289 "let _program_register_0 = _control_program_builder.dissolve();",
290 ));
291 if has_control_bits {
292 output_stream.extend(TokenStream::from_str("let _program_register_0 = qip::macros::program::negate_bitmask(_program_builder, _program_register_0, _control_bitmask);"));
293 }
294 }
295
296 let mut replace_qudits_stream = TokenStream::new();
298 for (ri, (rs, is)) in register_list.iter().zip(index_list.iter()).enumerate() {
299 let reg_name = format!("_program_register_{}", ri);
300 replace_qudits_stream.extend(TokenStream::from_str(&format!("let mut {} = _program_builder.split_all_register({}).into_iter().map(|r| Some(r)).collect::<Vec<_>>(); let mut {}_index = 0;", reg_name, reg_name, reg_name)).unwrap());
301 for (r, s) in rs.iter().zip(is.iter()) {
302 let s = if let Some(s) = s {
303 format!("qip::macros::program::QubitIndices::from({})", s)
304 } else {
305 format!("0..{}.len()", r)
306 };
307
308 replace_qudits_stream.extend(
309 TokenStream::from_str(&format!(
310 "for i in {} {{ {}[i] = {}[{}_index].take(); {}_index += 1; }}",
311 s, r, reg_name, reg_name, reg_name
312 ))
313 .unwrap(),
314 );
315 }
316 }
317 output_stream.extend(replace_qudits_stream);
318
319 if input_stream.peek().is_none() {
320 break;
321 }
322 }
323
324 for input_register in &input_registers {
326 output_stream.extend(TokenStream::from_str(&format!("let {} = _program_builder.merge_registers({}.into_iter().flat_map(|r| r)).unwrap();", input_register, input_register)).unwrap())
327 }
328
329 let mut tuple_stream = TokenStream::new();
330
331 if input_registers.len() == 1 {
332 output_stream.extend(TokenStream::from_str(&format!(
333 "Ok({})",
334 input_registers[0]
335 )));
336 } else {
337 for input_register in &input_registers {
338 tuple_stream.extend(Some(
339 TokenStream::from_str(&format!("{}, ", input_register)).unwrap(),
340 ))
341 }
342 output_stream.extend(TokenStream::from_str(&format!(
343 "Ok({})",
344 TokenTree::Group(proc_macro::Group::new(Delimiter::Parenthesis, tuple_stream))
345 )));
346 }
347
348 TokenStream::from(TokenTree::Group(proc_macro::Group::new(
349 proc_macro::Delimiter::Brace,
350 output_stream,
351 )))
352}
353
354fn parse_function_args(arg_stream: TokenStream, to: &mut Vec<String>) {
355 let mut arg_stream = arg_stream.into_iter().peekable();
356 while let Some(token) = arg_stream.next() {
357 match (token, arg_stream.peek()) {
358 (TokenTree::Ident(ident), Some(TokenTree::Punct(punct)))
359 if punct.as_char() == ':' && punct.spacing() == Spacing::Alone =>
360 {
361 to.push(ident.to_string());
362 }
363 _ => {}
364 }
365 }
366}
367
368#[proc_macro_attribute]
369pub fn invert(attr: TokenStream, input_stream: TokenStream) -> TokenStream {
370 let mut output_stream = input_stream.clone();
372
373 let mut attr = attr.into_iter().peekable();
374 let new_function_name = attr.next();
375 if let Some(TokenTree::Punct(_)) = attr.peek() {
376 attr.next();
377 }
378
379 let mut non_register_args = Vec::default();
380 while let Some(TokenTree::Ident(ident)) = attr.next() {
381 non_register_args.push(ident.to_string());
382 if let Some(TokenTree::Punct(_)) = attr.peek() {
383 attr.next();
384 }
385 }
386
387 let mut function_name = String::from("foo");
388 let new_function_name = new_function_name.map(|s| s.to_string());
389
390 let mut input_stream = input_stream.into_iter().peekable();
391 while let Some(token) = input_stream.next() {
393 if let TokenTree::Ident(ident) = &token {
394 match input_stream.peek() {
395 Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Parenthesis => {
396 function_name = ident.to_string();
397
398 let to_add =
399 new_function_name.unwrap_or_else(|| format!("{}_inv", function_name));
400 let to_add = TokenStream::from_str(&to_add).unwrap();
401 output_stream.extend(to_add);
402
403 break;
404 }
405 Some(TokenTree::Punct(punct)) if punct.as_char() == '<' => {
406 function_name = ident.to_string();
407 let to_add =
408 new_function_name.unwrap_or_else(|| format!("{}_inv", function_name));
409 let to_add = TokenStream::from_str(&to_add).unwrap();
410 output_stream.extend(to_add);
411
412 break;
413 }
414 _ => {
415 let to_add = TokenStream::from(token);
416 output_stream.extend(to_add);
417 }
418 }
419 } else {
420 let to_add = TokenStream::from(token);
421 output_stream.extend(to_add);
422 }
423 }
424
425 let mut function_args = vec![];
427 for token in input_stream.by_ref() {
428 let should_break = if let TokenTree::Group(group) = &token {
429 if group.delimiter() == Delimiter::Parenthesis {
430 parse_function_args(group.stream().clone(), &mut function_args);
431 true
432 } else {
433 false
434 }
435 } else {
436 false
437 };
438 let to_add = TokenStream::from(token);
439 output_stream.extend(to_add);
440 if should_break {
441 break;
442 }
443 }
444
445 for token in input_stream {
447 match &token {
448 TokenTree::Group(group) if group.delimiter() == Delimiter::Brace => {
449 break;
450 }
451 _ => {
452 let to_add = TokenStream::from(token);
453 output_stream.extend(to_add);
454 }
455 }
456 }
457
458 let builder = function_args[0].clone();
459 let new_builder = format!("_{builder}_new");
460
461 let mut skip_args = HashSet::new();
462 skip_args.extend(non_register_args.into_iter());
463
464 let regs_only = function_args[1..]
465 .iter()
466 .filter_map(|s| {
467 if !skip_args.contains(s) {
468 Some(s.clone())
469 } else {
470 None
471 }
472 })
473 .collect::<Vec<_>>();
474
475 let regs_list = regs_only.join(",");
476
477 let regs_sizes = regs_only
478 .iter()
479 .map(|reg| format!("{reg}.n()"))
480 .collect::<Vec<String>>()
481 .join(",");
482
483 let make_new_regs = regs_only
484 .iter()
485 .map(|s| format!("let _{s}_new = {new_builder}.register({s}.n_nonzero());"))
486 .collect::<String>();
487
488 let new_regs_args = Some(format!("&mut {new_builder}"))
489 .into_iter()
490 .chain(function_args[1..].iter().map(|s| {
491 if !skip_args.contains(s) {
492 format!("_{s}_new")
493 } else {
494 s.clone()
495 }
496 }))
497 .collect::<Vec<String>>()
498 .join(",");
499
500 let pop_regs = regs_only
501 .iter()
502 .rev()
503 .map(|s| {
504 format!("let {s} = _selected_vec.pop().expect(&format!(\"Register {s} is missing!\"));")
505 })
506 .collect::<String>();
507
508 let to_add = TokenStream::from(TokenTree::Group(proc_macro::Group::new(Delimiter::Brace, TokenStream::from_str(&format!("
509 let _register_sizes = [{regs_sizes}];
510 let mut {new_builder} = {builder}.new_similar();
511 {make_new_regs}
512 {function_name}({new_regs_args})?;
513 let _subcircuit = {new_builder}.make_subcircuit()?;
514 let _combined_r = {builder}.merge_registers([{regs_list}]).expect(\"Must have some registers.\");
515 let _combined_r = {builder}.apply_inverted_subcircuit(_subcircuit, _combined_r)?;
516 let mut _selected_vec = {builder}.split_relative_index_groups(_combined_r, _register_sizes.into_iter().scan(0, |acc, n| {{
517 let range = *acc..*acc+n;
518 *acc += n;
519 Some(range)
520 }})).get_all_selected().expect(\"All registers should have been selected\");
521 {pop_regs}
522 Ok(({regs_list}))
523 ")).unwrap())));
524 output_stream.extend(to_add);
525
526 output_stream
527}