quantrs2_ml/tensorflow_compatibility/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::error::{MLError, Result};
6use crate::simulator_backends::{DynamicCircuit, Observable, SimulationResult, SimulatorBackend};
7use quantrs2_circuit::prelude::*;
8use quantrs2_core::prelude::*;
9use scirs2_core::ndarray::{s, Array1, Array2, Array3, Array4, ArrayD, Axis};
10use std::collections::HashMap;
11
12use super::types::{DataEncodingType, TFQCircuitFormat, TFQGate};
13
14/// TensorFlow Quantum layer trait
15pub trait TFQLayer: Send + Sync {
16    /// Forward pass
17    fn forward(&self, inputs: &ArrayD<f64>) -> Result<ArrayD<f64>>;
18    /// Backward pass
19    fn backward(&self, upstream_gradients: &ArrayD<f64>) -> Result<ArrayD<f64>>;
20    /// Get trainable parameters
21    fn get_parameters(&self) -> Vec<Array1<f64>>;
22    /// Set trainable parameters
23    fn set_parameters(&mut self, params: Vec<Array1<f64>>) -> Result<()>;
24    /// Layer name
25    fn name(&self) -> &str;
26}
27/// TensorFlow Quantum-style utilities
28pub mod tfq_utils {
29    use super::*;
30    /// Convert QuantRS2 circuit to TFQ-compatible format
31    pub fn circuit_to_tfq_format(circuit: &DynamicCircuit) -> Result<TFQCircuitFormat> {
32        let tfq_gates: Vec<TFQGate> = Vec::new();
33        Ok(TFQCircuitFormat {
34            gates: tfq_gates,
35            num_qubits: circuit.num_qubits(),
36        })
37    }
38    /// Create quantum data encoding circuit
39    pub fn create_data_encoding_circuit(
40        num_qubits: usize,
41        encoding_type: DataEncodingType,
42    ) -> Result<DynamicCircuit> {
43        let mut builder: Circuit<8> = CircuitBuilder::new();
44        match encoding_type {
45            DataEncodingType::Amplitude => {
46                for qubit in 0..num_qubits {
47                    builder.ry(qubit, 0.0)?;
48                }
49            }
50            DataEncodingType::Angle => {
51                for qubit in 0..num_qubits {
52                    builder.rz(qubit, 0.0)?;
53                }
54            }
55            DataEncodingType::Basis => {
56                for qubit in 0..num_qubits {
57                    builder.x(qubit)?;
58                }
59            }
60        }
61        let circuit = builder.build();
62        DynamicCircuit::from_circuit(circuit)
63    }
64    /// Create hardware-efficient ansatz
65    pub fn create_hardware_efficient_ansatz(
66        num_qubits: usize,
67        layers: usize,
68    ) -> Result<DynamicCircuit> {
69        let mut builder: Circuit<8> = CircuitBuilder::new();
70        for layer in 0..layers {
71            for qubit in 0..num_qubits {
72                builder.ry(qubit, 0.0)?;
73                builder.rz(qubit, 0.0)?;
74            }
75            for qubit in 0..num_qubits - 1 {
76                builder.cnot(qubit, qubit + 1)?;
77            }
78            if layer < layers - 1 && num_qubits > 2 {
79                builder.cnot(num_qubits - 1, 0)?;
80            }
81        }
82        let circuit = builder.build();
83        DynamicCircuit::from_circuit(circuit)
84    }
85    /// Batch quantum circuit execution
86    pub fn batch_execute_circuits(
87        circuits: &[DynamicCircuit],
88        parameters: &Array2<f64>,
89        observables: &[Observable],
90        backend: &dyn SimulatorBackend,
91    ) -> Result<Array2<f64>> {
92        let batch_size = circuits.len();
93        let num_observables = observables.len();
94        let mut results = Array2::zeros((batch_size, num_observables));
95        for (circuit_idx, circuit) in circuits.iter().enumerate() {
96            let params = parameters.row(circuit_idx % parameters.nrows());
97            let params_slice = params.as_slice().ok_or_else(|| {
98                MLError::InvalidConfiguration("Parameters must be contiguous in memory".to_string())
99            })?;
100            for (obs_idx, observable) in observables.iter().enumerate() {
101                let expectation = backend.expectation_value(circuit, params_slice, observable)?;
102                results[[circuit_idx, obs_idx]] = expectation;
103            }
104        }
105        Ok(results)
106    }
107}
108/// Differentiator trait for computing gradients of quantum circuits
109pub trait Differentiator: Send + Sync {
110    /// Compute gradients of expectation values with respect to parameters
111    fn differentiate(
112        &self,
113        circuit: &DynamicCircuit,
114        parameters: &[f64],
115        observable: &Observable,
116        backend: &dyn SimulatorBackend,
117    ) -> Result<Vec<f64>>;
118    /// Get the name of the differentiator
119    fn name(&self) -> &str;
120}
121/// Resolve symbols in a parameterized circuit
122///
123/// This creates a new DynamicCircuit with the symbol values bound.
124/// In TFQ, this is used to convert parameterized circuits to concrete circuits.
125pub fn resolve_symbols(
126    circuit: &DynamicCircuit,
127    symbols: &[String],
128    values: &[f64],
129) -> Result<DynamicCircuit> {
130    if symbols.len() != values.len() {
131        return Err(MLError::InvalidConfiguration(
132            "Number of symbols must match number of values".to_string(),
133        ));
134    }
135    let mut _symbol_map = HashMap::new();
136    for (sym, &val) in symbols.iter().zip(values.iter()) {
137        _symbol_map.insert(sym.clone(), val);
138    }
139    Ok(circuit.clone())
140}
141/// Convert tensor to circuits (TFQ-compatible utility)
142pub fn tensor_to_circuits(tensor: &Array1<String>) -> Result<Vec<DynamicCircuit>> {
143    tensor
144        .iter()
145        .map(|_| DynamicCircuit::from_circuit::<8>(Circuit::<8>::new()))
146        .collect()
147}
148/// Convert circuits to tensor (TFQ-compatible utility)
149pub fn circuits_to_tensor(circuits: &[DynamicCircuit]) -> Array1<String> {
150    Array1::from_vec(
151        circuits
152            .iter()
153            .map(|c| format!("circuit_{}_qubits", c.num_qubits()))
154            .collect(),
155    )
156}
157/// Cirq circuit converter module
158///
159/// Provides conversion from Cirq-style circuit representations to QuantRS2 circuits.
160/// Since Cirq is a Python library, this module provides Rust data structures that
161/// represent Cirq circuits and can be converted to QuantRS2 circuits.
162pub mod cirq_converter {
163    use super::*;
164    use quantrs2_circuit::prelude::*;
165    use std::collections::HashMap;
166    /// Cirq gate types
167    #[derive(Debug, Clone)]
168    pub enum CirqGate {
169        /// Pauli X gate
170        X { qubit: usize },
171        /// Pauli Y gate
172        Y { qubit: usize },
173        /// Pauli Z gate
174        Z { qubit: usize },
175        /// Hadamard gate
176        H { qubit: usize },
177        /// S gate (√Z)
178        S { qubit: usize },
179        /// T gate (√S)
180        T { qubit: usize },
181        /// CNOT gate
182        CNOT { control: usize, target: usize },
183        /// CZ gate
184        CZ { control: usize, target: usize },
185        /// SWAP gate
186        SWAP { qubit1: usize, qubit2: usize },
187        /// Rotation around X axis
188        Rx { qubit: usize, angle: f64 },
189        /// Rotation around Y axis
190        Ry { qubit: usize, angle: f64 },
191        /// Rotation around Z axis
192        Rz { qubit: usize, angle: f64 },
193        /// Arbitrary single-qubit rotation (U3)
194        U3 {
195            qubit: usize,
196            theta: f64,
197            phi: f64,
198            lambda: f64,
199        },
200        /// Parametric X rotation
201        XPowGate {
202            qubit: usize,
203            exponent: f64,
204            global_shift: f64,
205        },
206        /// Parametric Y rotation
207        YPowGate {
208            qubit: usize,
209            exponent: f64,
210            global_shift: f64,
211        },
212        /// Parametric Z rotation
213        ZPowGate {
214            qubit: usize,
215            exponent: f64,
216            global_shift: f64,
217        },
218        /// Measurement
219        Measure { qubits: Vec<usize> },
220    }
221    /// Cirq circuit representation
222    #[derive(Debug, Clone)]
223    pub struct CirqCircuit {
224        /// Number of qubits
225        pub num_qubits: usize,
226        /// Gates in the circuit
227        pub gates: Vec<CirqGate>,
228        /// Parameter symbols used in the circuit
229        pub param_symbols: HashMap<String, usize>,
230    }
231    impl CirqCircuit {
232        /// Create a new Cirq circuit
233        pub fn new(num_qubits: usize) -> Self {
234            Self {
235                num_qubits,
236                gates: Vec::new(),
237                param_symbols: HashMap::new(),
238            }
239        }
240        /// Add a gate to the circuit
241        pub fn add_gate(&mut self, gate: CirqGate) {
242            self.gates.push(gate);
243        }
244        /// Add a parameter symbol
245        pub fn add_param_symbol(&mut self, symbol: String, index: usize) {
246            self.param_symbols.insert(symbol, index);
247        }
248        /// Convert to QuantRS2 circuit (const generic version)
249        pub fn to_quantrs2_circuit<const N: usize>(&self) -> Result<Circuit<N>> {
250            if self.num_qubits != N {
251                return Err(MLError::ValidationError(format!(
252                    "Circuit has {} qubits but expected {}",
253                    self.num_qubits, N
254                )));
255            }
256            let mut builder = CircuitBuilder::new();
257            for gate in &self.gates {
258                match gate {
259                    CirqGate::X { qubit } => {
260                        builder.x(*qubit)?;
261                    }
262                    CirqGate::Y { qubit } => {
263                        builder.y(*qubit)?;
264                    }
265                    CirqGate::Z { qubit } => {
266                        builder.z(*qubit)?;
267                    }
268                    CirqGate::H { qubit } => {
269                        builder.h(*qubit)?;
270                    }
271                    CirqGate::S { qubit } => {
272                        builder.s(*qubit)?;
273                    }
274                    CirqGate::T { qubit } => {
275                        builder.t(*qubit)?;
276                    }
277                    CirqGate::CNOT { control, target } => {
278                        builder.cnot(*control, *target)?;
279                    }
280                    CirqGate::CZ { control, target } => {
281                        builder.cz(*control, *target)?;
282                    }
283                    CirqGate::SWAP { qubit1, qubit2 } => {
284                        builder.swap(*qubit1, *qubit2)?;
285                    }
286                    CirqGate::Rx { qubit, angle } => {
287                        builder.rx(*qubit, *angle)?;
288                    }
289                    CirqGate::Ry { qubit, angle } => {
290                        builder.ry(*qubit, *angle)?;
291                    }
292                    CirqGate::Rz { qubit, angle } => {
293                        builder.rz(*qubit, *angle)?;
294                    }
295                    CirqGate::U3 {
296                        qubit,
297                        theta,
298                        phi,
299                        lambda,
300                    } => {
301                        builder.u(*qubit, *theta, *phi, *lambda)?;
302                    }
303                    CirqGate::XPowGate {
304                        qubit,
305                        exponent,
306                        global_shift,
307                    } => {
308                        let angle = std::f64::consts::PI * exponent;
309                        builder.rx(*qubit, angle)?;
310                        let _ = global_shift;
311                    }
312                    CirqGate::YPowGate {
313                        qubit,
314                        exponent,
315                        global_shift,
316                    } => {
317                        let angle = std::f64::consts::PI * exponent;
318                        builder.ry(*qubit, angle)?;
319                        let _ = global_shift;
320                    }
321                    CirqGate::ZPowGate {
322                        qubit,
323                        exponent,
324                        global_shift,
325                    } => {
326                        let angle = std::f64::consts::PI * exponent;
327                        builder.rz(*qubit, angle)?;
328                        let _ = global_shift;
329                    }
330                    CirqGate::Measure { qubits: _ } => {}
331                }
332            }
333            Ok(builder.build())
334        }
335        /// Convert to dynamic circuit (runtime qubit count)
336        pub fn to_dynamic_circuit(&self) -> Result<DynamicCircuit> {
337            match self.num_qubits {
338                1 => DynamicCircuit::from_circuit(self.to_quantrs2_circuit::<1>()?),
339                2 => DynamicCircuit::from_circuit(self.to_quantrs2_circuit::<2>()?),
340                3 => DynamicCircuit::from_circuit(self.to_quantrs2_circuit::<3>()?),
341                4 => DynamicCircuit::from_circuit(self.to_quantrs2_circuit::<4>()?),
342                5 => DynamicCircuit::from_circuit(self.to_quantrs2_circuit::<5>()?),
343                6 => DynamicCircuit::from_circuit(self.to_quantrs2_circuit::<6>()?),
344                7 => DynamicCircuit::from_circuit(self.to_quantrs2_circuit::<7>()?),
345                8 => DynamicCircuit::from_circuit(self.to_quantrs2_circuit::<8>()?),
346                n => Err(MLError::ValidationError(format!(
347                    "Unsupported qubit count: {}. Supported: 1-8",
348                    n
349                ))),
350            }
351        }
352    }
353    /// Create a Bell state circuit (Cirq-style)
354    pub fn create_bell_circuit() -> CirqCircuit {
355        let mut circuit = CirqCircuit::new(2);
356        circuit.add_gate(CirqGate::H { qubit: 0 });
357        circuit.add_gate(CirqGate::CNOT {
358            control: 0,
359            target: 1,
360        });
361        circuit
362    }
363    /// Create a parametric circuit (Cirq-style)
364    pub fn create_parametric_circuit(num_qubits: usize, depth: usize) -> CirqCircuit {
365        let mut circuit = CirqCircuit::new(num_qubits);
366        for layer in 0..depth {
367            for qubit in 0..num_qubits {
368                let symbol = format!("theta_{}_{}", layer, qubit);
369                circuit.add_param_symbol(symbol.clone(), layer * num_qubits + qubit);
370                circuit.add_gate(CirqGate::Ry { qubit, angle: 0.5 });
371            }
372            for qubit in 0..num_qubits - 1 {
373                circuit.add_gate(CirqGate::CNOT {
374                    control: qubit,
375                    target: qubit + 1,
376                });
377            }
378        }
379        circuit
380    }
381    /// Convert Cirq PowGate to angle (helper)
382    pub fn pow_gate_to_angle(exponent: f64) -> f64 {
383        std::f64::consts::PI * exponent
384    }
385}