quantrs2_circuit/
classical.rs

1//! Classical control flow support for quantum circuits
2//!
3//! This module provides support for classical registers, measurements,
4//! and conditional execution of quantum gates based on classical values.
5
6use quantrs2_core::{
7    error::{QuantRS2Error, QuantRS2Result},
8    gate::GateOp,
9    qubit::QubitId,
10};
11use std::collections::HashMap;
12use std::fmt;
13
14/// A classical register that can store measurement results
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub struct ClassicalRegister {
17    /// Name of the register
18    pub name: String,
19    /// Number of bits in the register
20    pub size: usize,
21}
22
23impl ClassicalRegister {
24    /// Create a new classical register
25    pub fn new(name: impl Into<String>, size: usize) -> Self {
26        Self {
27            name: name.into(),
28            size,
29        }
30    }
31}
32
33/// A classical bit reference within a register
34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
35pub struct ClassicalBit {
36    /// The register containing this bit
37    pub register: String,
38    /// Index within the register
39    pub index: usize,
40}
41
42/// Classical values that can be used in conditions
43#[derive(Debug, Clone, PartialEq)]
44pub enum ClassicalValue {
45    /// A single bit value
46    Bit(bool),
47    /// A multi-bit integer value
48    Integer(u64),
49    /// A reference to a classical register
50    Register(String),
51}
52
53/// Comparison operators for classical conditions
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub enum ComparisonOp {
56    /// Equal to
57    Equal,
58    /// Not equal to
59    NotEqual,
60    /// Less than
61    Less,
62    /// Less than or equal
63    LessEqual,
64    /// Greater than
65    Greater,
66    /// Greater than or equal
67    GreaterEqual,
68}
69
70/// A condition that gates execution based on classical values
71#[derive(Debug, Clone)]
72pub struct ClassicalCondition {
73    /// Left-hand side of the comparison
74    pub lhs: ClassicalValue,
75    /// Comparison operator
76    pub op: ComparisonOp,
77    /// Right-hand side of the comparison
78    pub rhs: ClassicalValue,
79}
80
81impl ClassicalCondition {
82    /// Create a new equality condition
83    pub fn equals(lhs: ClassicalValue, rhs: ClassicalValue) -> Self {
84        Self {
85            lhs,
86            op: ComparisonOp::Equal,
87            rhs,
88        }
89    }
90
91    /// Check if a register equals a specific value
92    pub fn register_equals(register: &str, value: u64) -> Self {
93        Self {
94            lhs: ClassicalValue::Register(register.to_string()),
95            op: ComparisonOp::Equal,
96            rhs: ClassicalValue::Integer(value),
97        }
98    }
99}
100
101/// A measurement operation that stores the result in a classical register
102#[derive(Debug, Clone)]
103pub struct MeasureOp {
104    /// Qubit to measure
105    pub qubit: QubitId,
106    /// Classical bit to store the result
107    pub cbit: ClassicalBit,
108}
109
110impl MeasureOp {
111    /// Create a new measurement operation
112    pub fn new(qubit: QubitId, register: &str, bit_index: usize) -> Self {
113        Self {
114            qubit,
115            cbit: ClassicalBit {
116                register: register.to_string(),
117                index: bit_index,
118            },
119        }
120    }
121}
122
123/// A gate that executes conditionally based on classical values
124pub struct ConditionalGate {
125    /// The condition to check
126    pub condition: ClassicalCondition,
127    /// The gate to execute if the condition is true
128    pub gate: Box<dyn GateOp>,
129}
130
131impl fmt::Debug for ConditionalGate {
132    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133        f.debug_struct("ConditionalGate")
134            .field("condition", &self.condition)
135            .field("gate", &self.gate.name())
136            .finish()
137    }
138}
139
140/// Classical control flow operations
141#[derive(Debug)]
142pub enum ClassicalOp {
143    /// Measure a qubit into a classical bit
144    Measure(MeasureOp),
145    /// Reset a classical register
146    Reset(String),
147    /// Conditional gate execution
148    Conditional(ConditionalGate),
149    /// Classical computation (e.g., XOR, AND)
150    Compute {
151        /// Output register
152        output: String,
153        /// Operation type
154        op: String,
155        /// Input registers
156        inputs: Vec<String>,
157    },
158}
159
160/// A circuit with classical control flow support
161pub struct ClassicalCircuit<const N: usize> {
162    /// Classical registers
163    pub classical_registers: HashMap<String, ClassicalRegister>,
164    /// Operations (both quantum and classical)
165    pub operations: Vec<CircuitOp>,
166}
167
168/// Operations that can appear in a classical circuit
169pub enum CircuitOp {
170    /// A quantum gate operation
171    Quantum(Box<dyn GateOp>),
172    /// A classical operation
173    Classical(ClassicalOp),
174}
175
176impl<const N: usize> ClassicalCircuit<N> {
177    /// Create a new circuit with classical control
178    pub fn new() -> Self {
179        Self {
180            classical_registers: HashMap::new(),
181            operations: Vec::new(),
182        }
183    }
184
185    /// Add a classical register
186    pub fn add_classical_register(&mut self, name: &str, size: usize) -> QuantRS2Result<()> {
187        if self.classical_registers.contains_key(name) {
188            return Err(QuantRS2Error::InvalidInput(format!(
189                "Classical register '{}' already exists",
190                name
191            )));
192        }
193
194        self.classical_registers
195            .insert(name.to_string(), ClassicalRegister::new(name, size));
196        Ok(())
197    }
198
199    /// Add a quantum gate
200    pub fn add_gate<G: GateOp + 'static>(&mut self, gate: G) -> QuantRS2Result<()> {
201        // Validate qubits
202        for qubit in gate.qubits() {
203            if qubit.id() as usize >= N {
204                return Err(QuantRS2Error::InvalidQubitId(qubit.id()));
205            }
206        }
207
208        self.operations.push(CircuitOp::Quantum(Box::new(gate)));
209        Ok(())
210    }
211
212    /// Add a measurement
213    pub fn measure(&mut self, qubit: QubitId, register: &str, bit: usize) -> QuantRS2Result<()> {
214        // Validate qubit
215        if qubit.id() as usize >= N {
216            return Err(QuantRS2Error::InvalidQubitId(qubit.id()));
217        }
218
219        // Validate classical register
220        let creg = self.classical_registers.get(register).ok_or_else(|| {
221            QuantRS2Error::InvalidInput(format!("Classical register '{}' not found", register))
222        })?;
223
224        if bit >= creg.size {
225            return Err(QuantRS2Error::InvalidInput(format!(
226                "Bit index {} out of range for register '{}' (size: {})",
227                bit, register, creg.size
228            )));
229        }
230
231        self.operations
232            .push(CircuitOp::Classical(ClassicalOp::Measure(MeasureOp::new(
233                qubit, register, bit,
234            ))));
235        Ok(())
236    }
237
238    /// Add a conditional gate
239    pub fn add_conditional<G: GateOp + 'static>(
240        &mut self,
241        condition: ClassicalCondition,
242        gate: G,
243    ) -> QuantRS2Result<()> {
244        // Validate gate qubits
245        for qubit in gate.qubits() {
246            if qubit.id() as usize >= N {
247                return Err(QuantRS2Error::InvalidQubitId(qubit.id()));
248            }
249        }
250
251        // TODO: Validate condition references valid registers
252
253        self.operations
254            .push(CircuitOp::Classical(ClassicalOp::Conditional(
255                ConditionalGate {
256                    condition,
257                    gate: Box::new(gate),
258                },
259            )));
260        Ok(())
261    }
262
263    /// Reset a classical register to zero
264    pub fn reset_classical(&mut self, register: &str) -> QuantRS2Result<()> {
265        if !self.classical_registers.contains_key(register) {
266            return Err(QuantRS2Error::InvalidInput(format!(
267                "Classical register '{}' not found",
268                register
269            )));
270        }
271
272        self.operations
273            .push(CircuitOp::Classical(ClassicalOp::Reset(
274                register.to_string(),
275            )));
276        Ok(())
277    }
278
279    /// Get the number of operations
280    pub fn num_operations(&self) -> usize {
281        self.operations.len()
282    }
283}
284
285/// Builder pattern for classical circuits
286pub struct ClassicalCircuitBuilder<const N: usize> {
287    circuit: ClassicalCircuit<N>,
288}
289
290impl<const N: usize> ClassicalCircuitBuilder<N> {
291    /// Create a new builder
292    pub fn new() -> Self {
293        Self {
294            circuit: ClassicalCircuit::new(),
295        }
296    }
297
298    /// Add a classical register
299    pub fn classical_register(mut self, name: &str, size: usize) -> QuantRS2Result<Self> {
300        self.circuit.add_classical_register(name, size)?;
301        Ok(self)
302    }
303
304    /// Add a quantum gate
305    pub fn gate<G: GateOp + 'static>(mut self, gate: G) -> QuantRS2Result<Self> {
306        self.circuit.add_gate(gate)?;
307        Ok(self)
308    }
309
310    /// Add a measurement
311    pub fn measure(mut self, qubit: QubitId, register: &str, bit: usize) -> QuantRS2Result<Self> {
312        self.circuit.measure(qubit, register, bit)?;
313        Ok(self)
314    }
315
316    /// Add a conditional gate
317    pub fn conditional<G: GateOp + 'static>(
318        mut self,
319        condition: ClassicalCondition,
320        gate: G,
321    ) -> QuantRS2Result<Self> {
322        self.circuit.add_conditional(condition, gate)?;
323        Ok(self)
324    }
325
326    /// Build the circuit
327    pub fn build(self) -> ClassicalCircuit<N> {
328        self.circuit
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335    use quantrs2_core::gate::single::PauliX;
336
337    #[test]
338    fn test_classical_register() {
339        let reg = ClassicalRegister::new("c", 3);
340        assert_eq!(reg.name, "c");
341        assert_eq!(reg.size, 3);
342    }
343
344    #[test]
345    fn test_classical_condition() {
346        let cond = ClassicalCondition::register_equals("c", 1);
347        assert_eq!(cond.op, ComparisonOp::Equal);
348
349        match &cond.lhs {
350            ClassicalValue::Register(name) => assert_eq!(name, "c"),
351            _ => panic!("Expected Register variant"),
352        }
353
354        match &cond.rhs {
355            ClassicalValue::Integer(val) => assert_eq!(*val, 1),
356            _ => panic!("Expected Integer variant"),
357        }
358    }
359
360    #[test]
361    fn test_classical_circuit_builder() {
362        let circuit = ClassicalCircuitBuilder::<2>::new()
363            .classical_register("c", 2)
364            .unwrap()
365            .gate(PauliX { target: QubitId(0) })
366            .unwrap()
367            .measure(QubitId(0), "c", 0)
368            .unwrap()
369            .conditional(
370                ClassicalCondition::register_equals("c", 1),
371                PauliX { target: QubitId(1) },
372            )
373            .unwrap()
374            .build();
375
376        assert_eq!(circuit.classical_registers.len(), 1);
377        assert_eq!(circuit.num_operations(), 3);
378    }
379}