roqoqo_derive/
involve_qubits.rs

1// Copyright © 2021-2024 HQS Quantum Simulations GmbH. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
4// in compliance with the License. You may obtain a copy of the License at
5//
6//     http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software distributed under the
9// License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
10// express or implied. See the License for the specific language governing permissions and
11// limitations under the License.
12
13use crate::{extract_fields_with_types, extract_variants_with_types};
14use proc_macro2::TokenStream;
15use quote::quote;
16use syn::{Data, DataEnum, DataStruct, DeriveInput, Ident};
17
18/// Dispatch to derive InvolveQubits for enums and structs
19pub fn dispatch_struct_enum(input: DeriveInput) -> TokenStream {
20    let ident = input.ident;
21    match input.data {
22        Data::Struct(ds) => involve_qubits_struct(ds, ident),
23        Data::Enum(de) => involve_qubits_enum(de, ident),
24        _ => panic!("InvolveQubits can only be derived on structs and enums"),
25    }
26}
27
28/// Create the TokenStream of the InvolveQubits trait for enums
29/// This derive delegates the invocations of the involved_qubits function to all possible variants via match arms
30fn involve_qubits_enum(de: DataEnum, ident: Ident) -> TokenStream {
31    let variants_with_type = extract_variants_with_types(de).into_iter();
32    let match_quotes = variants_with_type.clone().map(|(vident, _, _)| {
33        quote! {
34            &#ident::#vident(ref inner) => {InvolveQubits::involved_qubits(&(*inner))},
35        }
36    });
37
38    let match_quotes_classical = variants_with_type.map(|(vident, _, _)| {
39        quote! {
40            &#ident::#vident(ref inner) => {InvolveQubits::involved_classical(&(*inner))},
41        }
42    });
43    quote! {
44        #[automatically_derived]
45        /// Implements [InvolveQubits] trait for the qubits involved in this Operation.
46        impl InvolveQubits for #ident{
47            fn involved_qubits(&self) -> InvolvedQubits {
48                match self{
49                    #(#match_quotes)*
50                    _ => panic!("Unexpectedly cannot match variant")
51                }
52            }
53
54            fn involved_classical(&self) -> InvolvedClassical {
55                match self{
56                    #(#match_quotes_classical)*
57                    _ => panic!("Unexpectedly cannot match variant")
58                }
59            }
60        }
61    }
62}
63
64/// Generate the TokenStream of the implementation of InvolvedQubits for structs
65fn involve_qubits_struct(ds: DataStruct, ident: Ident) -> TokenStream {
66    // We only allow structs with named fields
67    // Extract named fields with match and panic if fields are not named
68    let fields_with_type = extract_fields_with_types(ds).into_iter();
69
70    // Bool values that show if there is a qubit field, control field etc. in the struct
71    let mut qubit: bool = false;
72    let mut control: bool = false;
73    let mut control_0: bool = false;
74    let mut control_1: bool = false;
75    let mut control_2: bool = false;
76    let mut target: bool = false;
77    let mut qubits: bool = false;
78
79    // Iterating over the fields in the struct and setting the bool values to true if the field is
80    // in the struct
81    for (fid, type_string, _) in fields_with_type {
82        // Matching the name to see if we need to set one to true
83        match fid.clone().to_string().as_str() {
84            "qubit" => {
85                if type_string == Some("usize".to_string()) {
86                    qubit = true
87                } else {
88                    panic!("Field  qubit must have type usize")
89                }
90            }
91            "target" => {
92                if type_string == Some("usize".to_string()) {
93                    target = true;
94                } else {
95                    panic!("Field target must have type usize")
96                }
97            }
98            "control" => {
99                if type_string == Some("usize".to_string()) {
100                    control = true;
101                } else {
102                    panic!("Field control must have type usize")
103                }
104            }
105            "control_0" => {
106                if type_string == Some("usize".to_string()) {
107                    control_0 = true;
108                } else {
109                    panic!("Field control_0 must have type usize")
110                }
111            }
112            "control_1" => {
113                if type_string == Some("usize".to_string()) {
114                    control_1 = true;
115                } else {
116                    panic!("Field control_1 must have type usize")
117                }
118            }
119            "control_2" => {
120                if type_string == Some("usize".to_string()) {
121                    control_2 = true;
122                } else {
123                    panic!("Field control_2 must have type usize")
124                }
125            }
126            "qubits" => {
127                qubits = true;
128            }
129            _ => {}
130        };
131    }
132    if qubit {
133        if control || target || qubits {
134            panic!("When deriving InvolveQubits, qubit field is not compatible with control, target or qubits fields");
135        };
136        // Creating a function that puts exactly one qubit `qubit` into the InvolvedQubits HashSet
137        quote! {
138            /// Implements [InvolveQubits] trait for the qubits involved in this Operation.
139            #[automatically_derived]
140            impl InvolveQubits for #ident{
141                /// Returns a list of all involved qubits.
142                fn involved_qubits(&self ) -> InvolvedQubits {
143                    let mut new_hash_set: std::collections::HashSet<usize> = std::collections::HashSet::new();
144                    new_hash_set.insert(self.qubit);
145                    InvolvedQubits::Set(new_hash_set)
146                }
147            }
148        }
149    } else if control_2 {
150        if !(control_0 && control_1 && target) {
151            panic!("When deriving InvolveQubits for a four-qubit operation control_0, control_1, control_2 and target have to be present");
152        };
153        // Creating a function that puts qubits `control_0`, `control_1` `control_2` and `target` into the InvolvedQubits HashSet
154        quote! {
155            /// Implements [InvolveQubits] trait for the qubits involved in this Operation.
156            #[automatically_derived]
157            impl InvolveQubits for #ident{
158                /// Returns a list of all involed qubits.
159                fn involved_qubits(&self ) -> InvolvedQubits {
160                    let mut new_hash_set: std::collections::HashSet<usize> = std::collections::HashSet::new();
161                    new_hash_set.insert(self.control_0);
162                    new_hash_set.insert(self.control_1);
163                    new_hash_set.insert(self.control_2);
164                    new_hash_set.insert(self.target);
165                    InvolvedQubits::Set(new_hash_set)
166                }
167            }
168        }
169    } else if control_0 || control_1 {
170        if control {
171            panic!("When deriving InvolveQubits for a three-qubit operation, control field is not compatible with control_0 and control_1 fields");
172        }
173        if !(control_0 && control_1 && target) {
174            panic!("When deriving InvolveQubits for a three-qubit operation control_0, control_1 and target have to be present");
175        };
176        if qubits {
177            panic!("When deriving InvolveQubits for a three-qubit operation, control_0 and control_1 fields are not compatible with qubits fields");
178        };
179        // Creating a function that puts qubits `control_0`, `control_1` and `target` into the InvolvedQubits HashSet
180        quote! {
181            /// Implements [InvolveQubits] trait for the qubits involved in this Operation.
182            #[automatically_derived]
183            impl InvolveQubits for #ident{
184                /// Returns a list of all involed qubits.
185                fn involved_qubits(&self ) -> InvolvedQubits {
186                    let mut new_hash_set: std::collections::HashSet<usize> = std::collections::HashSet::new();
187                    new_hash_set.insert(self.control_0);
188                    new_hash_set.insert(self.control_1);
189                    new_hash_set.insert(self.target);
190                    InvolvedQubits::Set(new_hash_set)
191                }
192            }
193        }
194    } else if target || control {
195        if !(control && target) {
196            panic!("When deriving InvolveQubits control and target fields have to both be present");
197        };
198        if qubits {
199            panic!("When deriving InvolveQubits, control and target fields are not compatible with qubits fields");
200        };
201        // Creating a function that puts qubits `control` and `target` into the InvolvedQubits HashSet
202        quote! {
203            /// Implements [InvolveQubits] trait for the qubits involved in this Operation.
204            #[automatically_derived]
205            impl InvolveQubits for #ident{
206                /// Returns a list of all involved qubits.
207                fn involved_qubits(&self ) -> InvolvedQubits {
208                    let mut new_hash_set: std::collections::HashSet<usize> = std::collections::HashSet::new();
209                    new_hash_set.insert(self.control);
210                    new_hash_set.insert(self.target);
211                    InvolvedQubits::Set(new_hash_set)
212                }
213            }
214        }
215    } else if qubits {
216        // Creating a function that puts all qubits in the vector `qubits` into the InvolvedQubits HashSet
217        quote! {
218            /// Implements [InvolveQubits] trait for the qubits involved in this Operation.
219            #[automatically_derived]
220            impl InvolveQubits for #ident{
221                /// Returns a list of all involved qubits.
222                fn involved_qubits(&self ) -> InvolvedQubits {
223                    let mut new_hash_set: std::collections::HashSet<usize> = std::collections::HashSet::new();
224                    for qubit in self.qubits.iter(){
225                        new_hash_set.insert(*qubit);
226                    }
227                    InvolvedQubits::Set(new_hash_set)
228                }
229            }
230        }
231    } else {
232        panic!("To derive InvolveQubits qubit or control or target or qubits fields need to be present in struct")
233    }
234}