quantrs2_ml/tensorflow_compatibility/
functions.rs1use crate::error::{MLError, Result};
6use crate::simulator_backends::{DynamicCircuit, Observable, SimulationResult, SimulatorBackend};
7use quantrs2_circuit::prelude::*;
8use quantrs2_core::prelude::*;
9use scirs2_core::ndarray::{s, Array1, Array2, Array3, Array4, ArrayD, Axis};
10use std::collections::HashMap;
11
12use super::types::{DataEncodingType, TFQCircuitFormat, TFQGate};
13
14pub trait TFQLayer: Send + Sync {
16 fn forward(&self, inputs: &ArrayD<f64>) -> Result<ArrayD<f64>>;
18 fn backward(&self, upstream_gradients: &ArrayD<f64>) -> Result<ArrayD<f64>>;
20 fn get_parameters(&self) -> Vec<Array1<f64>>;
22 fn set_parameters(&mut self, params: Vec<Array1<f64>>) -> Result<()>;
24 fn name(&self) -> &str;
26}
27pub mod tfq_utils {
29 use super::*;
30 pub fn circuit_to_tfq_format(circuit: &DynamicCircuit) -> Result<TFQCircuitFormat> {
32 let tfq_gates: Vec<TFQGate> = Vec::new();
33 Ok(TFQCircuitFormat {
34 gates: tfq_gates,
35 num_qubits: circuit.num_qubits(),
36 })
37 }
38 pub fn create_data_encoding_circuit(
40 num_qubits: usize,
41 encoding_type: DataEncodingType,
42 ) -> Result<DynamicCircuit> {
43 let mut builder: Circuit<8> = CircuitBuilder::new();
44 match encoding_type {
45 DataEncodingType::Amplitude => {
46 for qubit in 0..num_qubits {
47 builder.ry(qubit, 0.0)?;
48 }
49 }
50 DataEncodingType::Angle => {
51 for qubit in 0..num_qubits {
52 builder.rz(qubit, 0.0)?;
53 }
54 }
55 DataEncodingType::Basis => {
56 for qubit in 0..num_qubits {
57 builder.x(qubit)?;
58 }
59 }
60 }
61 let circuit = builder.build();
62 DynamicCircuit::from_circuit(circuit)
63 }
64 pub fn create_hardware_efficient_ansatz(
66 num_qubits: usize,
67 layers: usize,
68 ) -> Result<DynamicCircuit> {
69 let mut builder: Circuit<8> = CircuitBuilder::new();
70 for layer in 0..layers {
71 for qubit in 0..num_qubits {
72 builder.ry(qubit, 0.0)?;
73 builder.rz(qubit, 0.0)?;
74 }
75 for qubit in 0..num_qubits - 1 {
76 builder.cnot(qubit, qubit + 1)?;
77 }
78 if layer < layers - 1 && num_qubits > 2 {
79 builder.cnot(num_qubits - 1, 0)?;
80 }
81 }
82 let circuit = builder.build();
83 DynamicCircuit::from_circuit(circuit)
84 }
85 pub fn batch_execute_circuits(
87 circuits: &[DynamicCircuit],
88 parameters: &Array2<f64>,
89 observables: &[Observable],
90 backend: &dyn SimulatorBackend,
91 ) -> Result<Array2<f64>> {
92 let batch_size = circuits.len();
93 let num_observables = observables.len();
94 let mut results = Array2::zeros((batch_size, num_observables));
95 for (circuit_idx, circuit) in circuits.iter().enumerate() {
96 let params = parameters.row(circuit_idx % parameters.nrows());
97 let params_slice = params.as_slice().ok_or_else(|| {
98 MLError::InvalidConfiguration("Parameters must be contiguous in memory".to_string())
99 })?;
100 for (obs_idx, observable) in observables.iter().enumerate() {
101 let expectation = backend.expectation_value(circuit, params_slice, observable)?;
102 results[[circuit_idx, obs_idx]] = expectation;
103 }
104 }
105 Ok(results)
106 }
107}
108pub trait Differentiator: Send + Sync {
110 fn differentiate(
112 &self,
113 circuit: &DynamicCircuit,
114 parameters: &[f64],
115 observable: &Observable,
116 backend: &dyn SimulatorBackend,
117 ) -> Result<Vec<f64>>;
118 fn name(&self) -> &str;
120}
121pub fn resolve_symbols(
126 circuit: &DynamicCircuit,
127 symbols: &[String],
128 values: &[f64],
129) -> Result<DynamicCircuit> {
130 if symbols.len() != values.len() {
131 return Err(MLError::InvalidConfiguration(
132 "Number of symbols must match number of values".to_string(),
133 ));
134 }
135 let mut _symbol_map = HashMap::new();
136 for (sym, &val) in symbols.iter().zip(values.iter()) {
137 _symbol_map.insert(sym.clone(), val);
138 }
139 Ok(circuit.clone())
140}
141pub fn tensor_to_circuits(tensor: &Array1<String>) -> Result<Vec<DynamicCircuit>> {
143 tensor
144 .iter()
145 .map(|_| DynamicCircuit::from_circuit::<8>(Circuit::<8>::new()))
146 .collect()
147}
148pub fn circuits_to_tensor(circuits: &[DynamicCircuit]) -> Array1<String> {
150 Array1::from_vec(
151 circuits
152 .iter()
153 .map(|c| format!("circuit_{}_qubits", c.num_qubits()))
154 .collect(),
155 )
156}
157pub mod cirq_converter {
163 use super::*;
164 use quantrs2_circuit::prelude::*;
165 use std::collections::HashMap;
166 #[derive(Debug, Clone)]
168 pub enum CirqGate {
169 X { qubit: usize },
171 Y { qubit: usize },
173 Z { qubit: usize },
175 H { qubit: usize },
177 S { qubit: usize },
179 T { qubit: usize },
181 CNOT { control: usize, target: usize },
183 CZ { control: usize, target: usize },
185 SWAP { qubit1: usize, qubit2: usize },
187 Rx { qubit: usize, angle: f64 },
189 Ry { qubit: usize, angle: f64 },
191 Rz { qubit: usize, angle: f64 },
193 U3 {
195 qubit: usize,
196 theta: f64,
197 phi: f64,
198 lambda: f64,
199 },
200 XPowGate {
202 qubit: usize,
203 exponent: f64,
204 global_shift: f64,
205 },
206 YPowGate {
208 qubit: usize,
209 exponent: f64,
210 global_shift: f64,
211 },
212 ZPowGate {
214 qubit: usize,
215 exponent: f64,
216 global_shift: f64,
217 },
218 Measure { qubits: Vec<usize> },
220 }
221 #[derive(Debug, Clone)]
223 pub struct CirqCircuit {
224 pub num_qubits: usize,
226 pub gates: Vec<CirqGate>,
228 pub param_symbols: HashMap<String, usize>,
230 }
231 impl CirqCircuit {
232 pub fn new(num_qubits: usize) -> Self {
234 Self {
235 num_qubits,
236 gates: Vec::new(),
237 param_symbols: HashMap::new(),
238 }
239 }
240 pub fn add_gate(&mut self, gate: CirqGate) {
242 self.gates.push(gate);
243 }
244 pub fn add_param_symbol(&mut self, symbol: String, index: usize) {
246 self.param_symbols.insert(symbol, index);
247 }
248 pub fn to_quantrs2_circuit<const N: usize>(&self) -> Result<Circuit<N>> {
250 if self.num_qubits != N {
251 return Err(MLError::ValidationError(format!(
252 "Circuit has {} qubits but expected {}",
253 self.num_qubits, N
254 )));
255 }
256 let mut builder = CircuitBuilder::new();
257 for gate in &self.gates {
258 match gate {
259 CirqGate::X { qubit } => {
260 builder.x(*qubit)?;
261 }
262 CirqGate::Y { qubit } => {
263 builder.y(*qubit)?;
264 }
265 CirqGate::Z { qubit } => {
266 builder.z(*qubit)?;
267 }
268 CirqGate::H { qubit } => {
269 builder.h(*qubit)?;
270 }
271 CirqGate::S { qubit } => {
272 builder.s(*qubit)?;
273 }
274 CirqGate::T { qubit } => {
275 builder.t(*qubit)?;
276 }
277 CirqGate::CNOT { control, target } => {
278 builder.cnot(*control, *target)?;
279 }
280 CirqGate::CZ { control, target } => {
281 builder.cz(*control, *target)?;
282 }
283 CirqGate::SWAP { qubit1, qubit2 } => {
284 builder.swap(*qubit1, *qubit2)?;
285 }
286 CirqGate::Rx { qubit, angle } => {
287 builder.rx(*qubit, *angle)?;
288 }
289 CirqGate::Ry { qubit, angle } => {
290 builder.ry(*qubit, *angle)?;
291 }
292 CirqGate::Rz { qubit, angle } => {
293 builder.rz(*qubit, *angle)?;
294 }
295 CirqGate::U3 {
296 qubit,
297 theta,
298 phi,
299 lambda,
300 } => {
301 builder.u(*qubit, *theta, *phi, *lambda)?;
302 }
303 CirqGate::XPowGate {
304 qubit,
305 exponent,
306 global_shift,
307 } => {
308 let angle = std::f64::consts::PI * exponent;
309 builder.rx(*qubit, angle)?;
310 let _ = global_shift;
311 }
312 CirqGate::YPowGate {
313 qubit,
314 exponent,
315 global_shift,
316 } => {
317 let angle = std::f64::consts::PI * exponent;
318 builder.ry(*qubit, angle)?;
319 let _ = global_shift;
320 }
321 CirqGate::ZPowGate {
322 qubit,
323 exponent,
324 global_shift,
325 } => {
326 let angle = std::f64::consts::PI * exponent;
327 builder.rz(*qubit, angle)?;
328 let _ = global_shift;
329 }
330 CirqGate::Measure { qubits: _ } => {}
331 }
332 }
333 Ok(builder.build())
334 }
335 pub fn to_dynamic_circuit(&self) -> Result<DynamicCircuit> {
337 match self.num_qubits {
338 1 => DynamicCircuit::from_circuit(self.to_quantrs2_circuit::<1>()?),
339 2 => DynamicCircuit::from_circuit(self.to_quantrs2_circuit::<2>()?),
340 3 => DynamicCircuit::from_circuit(self.to_quantrs2_circuit::<3>()?),
341 4 => DynamicCircuit::from_circuit(self.to_quantrs2_circuit::<4>()?),
342 5 => DynamicCircuit::from_circuit(self.to_quantrs2_circuit::<5>()?),
343 6 => DynamicCircuit::from_circuit(self.to_quantrs2_circuit::<6>()?),
344 7 => DynamicCircuit::from_circuit(self.to_quantrs2_circuit::<7>()?),
345 8 => DynamicCircuit::from_circuit(self.to_quantrs2_circuit::<8>()?),
346 n => Err(MLError::ValidationError(format!(
347 "Unsupported qubit count: {}. Supported: 1-8",
348 n
349 ))),
350 }
351 }
352 }
353 pub fn create_bell_circuit() -> CirqCircuit {
355 let mut circuit = CirqCircuit::new(2);
356 circuit.add_gate(CirqGate::H { qubit: 0 });
357 circuit.add_gate(CirqGate::CNOT {
358 control: 0,
359 target: 1,
360 });
361 circuit
362 }
363 pub fn create_parametric_circuit(num_qubits: usize, depth: usize) -> CirqCircuit {
365 let mut circuit = CirqCircuit::new(num_qubits);
366 for layer in 0..depth {
367 for qubit in 0..num_qubits {
368 let symbol = format!("theta_{}_{}", layer, qubit);
369 circuit.add_param_symbol(symbol.clone(), layer * num_qubits + qubit);
370 circuit.add_gate(CirqGate::Ry { qubit, angle: 0.5 });
371 }
372 for qubit in 0..num_qubits - 1 {
373 circuit.add_gate(CirqGate::CNOT {
374 control: qubit,
375 target: qubit + 1,
376 });
377 }
378 }
379 circuit
380 }
381 pub fn pow_gate_to_angle(exponent: f64) -> f64 {
383 std::f64::consts::PI * exponent
384 }
385}