use crate::builder::Circuit;
use crate::classical::{ClassicalCondition, ClassicalRegister};
use quantrs2_core::{
error::{QuantRS2Error, QuantRS2Result},
gate::GateOp,
qubit::QubitId,
};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct Measurement {
pub qubit: QubitId,
pub target_bit: usize,
pub label: Option<String>,
}
impl Measurement {
#[must_use]
pub const fn new(qubit: QubitId, target_bit: usize) -> Self {
Self {
qubit,
target_bit,
label: None,
}
}
#[must_use]
pub fn with_label(mut self, label: String) -> Self {
self.label = Some(label);
self
}
}
#[derive(Debug, Clone)]
pub struct FeedForward {
pub condition: ClassicalCondition,
pub gate: Box<dyn GateOp>,
pub else_gate: Option<Box<dyn GateOp>>,
}
impl FeedForward {
#[must_use]
pub fn new(condition: ClassicalCondition, gate: Box<dyn GateOp>) -> Self {
Self {
condition,
gate,
else_gate: None,
}
}
#[must_use]
pub fn with_else(mut self, else_gate: Box<dyn GateOp>) -> Self {
self.else_gate = Some(else_gate);
self
}
}
#[derive(Debug, Clone)]
pub enum CircuitOp {
Gate(Box<dyn GateOp>),
Measure(Measurement),
FeedForward(FeedForward),
Barrier(Vec<QubitId>),
Reset(QubitId),
}
pub struct MeasurementCircuit<const N: usize> {
operations: Vec<CircuitOp>,
classical_registers: HashMap<String, ClassicalRegister>,
measurement_count: usize,
current_bit: usize,
}
impl<const N: usize> Default for MeasurementCircuit<N> {
fn default() -> Self {
Self::new()
}
}
impl<const N: usize> MeasurementCircuit<N> {
#[must_use]
pub fn new() -> Self {
let mut classical_registers = HashMap::new();
classical_registers.insert(
"default".to_string(),
ClassicalRegister::new("default".to_string(), N),
);
Self {
operations: Vec::new(),
classical_registers,
measurement_count: 0,
current_bit: 0,
}
}
pub fn add_gate(&mut self, gate: Box<dyn GateOp>) -> QuantRS2Result<()> {
for qubit in gate.qubits() {
if qubit.id() >= N as u32 {
return Err(QuantRS2Error::InvalidQubitId(qubit.id()));
}
}
self.operations.push(CircuitOp::Gate(gate));
Ok(())
}
pub fn measure(&mut self, qubit: QubitId) -> QuantRS2Result<usize> {
if qubit.id() >= N as u32 {
return Err(QuantRS2Error::InvalidQubitId(qubit.id()));
}
if self.current_bit >= N {
return Err(QuantRS2Error::InvalidInput(
"Not enough classical bits for measurement".to_string(),
));
}
let target_bit = self.current_bit;
self.current_bit += 1;
let measurement =
Measurement::new(qubit, target_bit).with_label(format!("m{}", self.measurement_count));
self.operations.push(CircuitOp::Measure(measurement));
self.measurement_count += 1;
Ok(target_bit)
}
pub fn add_conditional(
&mut self,
condition: ClassicalCondition,
gate: Box<dyn GateOp>,
) -> QuantRS2Result<()> {
for qubit in gate.qubits() {
if qubit.id() >= N as u32 {
return Err(QuantRS2Error::InvalidQubitId(qubit.id()));
}
}
let feed_forward = FeedForward::new(condition, gate);
self.operations.push(CircuitOp::FeedForward(feed_forward));
Ok(())
}
pub fn add_if_else(
&mut self,
condition: ClassicalCondition,
if_gate: Box<dyn GateOp>,
else_gate: Box<dyn GateOp>,
) -> QuantRS2Result<()> {
for qubit in if_gate.qubits().iter().chain(else_gate.qubits().iter()) {
if qubit.id() >= N as u32 {
return Err(QuantRS2Error::InvalidQubitId(qubit.id()));
}
}
let feed_forward = FeedForward::new(condition, if_gate).with_else(else_gate);
self.operations.push(CircuitOp::FeedForward(feed_forward));
Ok(())
}
pub fn barrier(&mut self, qubits: Vec<QubitId>) -> QuantRS2Result<()> {
for qubit in &qubits {
if qubit.id() >= N as u32 {
return Err(QuantRS2Error::InvalidQubitId(qubit.id()));
}
}
self.operations.push(CircuitOp::Barrier(qubits));
Ok(())
}
pub fn reset(&mut self, qubit: QubitId) -> QuantRS2Result<()> {
if qubit.id() >= N as u32 {
return Err(QuantRS2Error::InvalidQubitId(qubit.id()));
}
self.operations.push(CircuitOp::Reset(qubit));
Ok(())
}
#[must_use]
pub fn num_operations(&self) -> usize {
self.operations.len()
}
#[must_use]
pub const fn num_measurements(&self) -> usize {
self.measurement_count
}
#[must_use]
pub fn operations(&self) -> &[CircuitOp] {
&self.operations
}
pub fn to_circuit(&self) -> QuantRS2Result<Circuit<N>> {
let mut circuit = Circuit::<N>::new();
for op in &self.operations {
match op {
CircuitOp::Gate(_)
| CircuitOp::Measure(_)
| CircuitOp::FeedForward(_)
| CircuitOp::Barrier(_)
| CircuitOp::Reset(_) => {
}
}
}
Ok(circuit)
}
#[must_use]
pub fn analyze_dependencies(&self) -> MeasurementDependencies {
let mut deps = MeasurementDependencies::new();
let mut measurement_map = HashMap::new();
for (i, op) in self.operations.iter().enumerate() {
if let CircuitOp::Measure(m) = op {
measurement_map.insert(m.target_bit, i);
deps.measurements.push((i, m.clone()));
}
}
for (i, op) in self.operations.iter().enumerate() {
if let CircuitOp::FeedForward(_ff) = op {
if !measurement_map.is_empty() {
let last_measurement = measurement_map.len() - 1;
deps.feed_forward_deps.push((last_measurement, i));
}
}
}
deps
}
}
#[derive(Debug)]
pub struct MeasurementDependencies {
pub measurements: Vec<(usize, Measurement)>,
pub feed_forward_deps: Vec<(usize, usize)>,
}
impl MeasurementDependencies {
const fn new() -> Self {
Self {
measurements: Vec::new(),
feed_forward_deps: Vec::new(),
}
}
#[must_use]
pub fn has_feed_forward(&self) -> bool {
!self.feed_forward_deps.is_empty()
}
#[must_use]
pub fn num_measurements(&self) -> usize {
self.measurements.len()
}
}
pub struct MeasurementCircuitBuilder<const N: usize> {
circuit: MeasurementCircuit<N>,
}
impl<const N: usize> Default for MeasurementCircuitBuilder<N> {
fn default() -> Self {
Self::new()
}
}
impl<const N: usize> MeasurementCircuitBuilder<N> {
#[must_use]
pub fn new() -> Self {
Self {
circuit: MeasurementCircuit::new(),
}
}
pub fn gate(mut self, gate: Box<dyn GateOp>) -> QuantRS2Result<Self> {
self.circuit.add_gate(gate)?;
Ok(self)
}
pub fn measure(mut self, qubit: QubitId) -> QuantRS2Result<(Self, usize)> {
let bit = self.circuit.measure(qubit)?;
Ok((self, bit))
}
pub fn when(
mut self,
condition: ClassicalCondition,
gate: Box<dyn GateOp>,
) -> QuantRS2Result<Self> {
self.circuit.add_conditional(condition, gate)?;
Ok(self)
}
pub fn if_else(
mut self,
condition: ClassicalCondition,
if_gate: Box<dyn GateOp>,
else_gate: Box<dyn GateOp>,
) -> QuantRS2Result<Self> {
self.circuit.add_if_else(condition, if_gate, else_gate)?;
Ok(self)
}
pub fn barrier(mut self, qubits: Vec<QubitId>) -> QuantRS2Result<Self> {
self.circuit.barrier(qubits)?;
Ok(self)
}
pub fn reset(mut self, qubit: QubitId) -> QuantRS2Result<Self> {
self.circuit.reset(qubit)?;
Ok(self)
}
#[must_use]
pub fn build(self) -> MeasurementCircuit<N> {
self.circuit
}
}
#[cfg(test)]
mod tests {
use super::*;
use quantrs2_core::gate::single::{Hadamard, PauliX};
#[test]
fn test_measurement_circuit() {
let mut circuit = MeasurementCircuit::<3>::new();
circuit
.add_gate(Box::new(Hadamard { target: QubitId(0) }))
.expect("Failed to add Hadamard gate");
let bit0 = circuit
.measure(QubitId(0))
.expect("Failed to measure qubit 0");
assert_eq!(bit0, 0);
let condition = ClassicalCondition::equals(
crate::classical::ClassicalValue::Integer(bit0 as u64),
crate::classical::ClassicalValue::Integer(1),
);
circuit
.add_conditional(condition, Box::new(PauliX { target: QubitId(1) }))
.expect("Failed to add conditional X gate");
assert_eq!(circuit.num_operations(), 3);
assert_eq!(circuit.num_measurements(), 1);
}
#[test]
fn test_feed_forward() {
let mut circuit = MeasurementCircuit::<2>::new();
circuit
.add_gate(Box::new(Hadamard { target: QubitId(0) }))
.expect("Failed to add Hadamard gate");
circuit
.add_gate(Box::new(quantrs2_core::gate::multi::CNOT {
control: QubitId(0),
target: QubitId(1),
}))
.expect("Failed to add CNOT gate");
let bit = circuit
.measure(QubitId(0))
.expect("Failed to measure qubit 0");
let condition = ClassicalCondition::equals(
crate::classical::ClassicalValue::Integer(bit as u64),
crate::classical::ClassicalValue::Integer(1),
);
circuit
.add_conditional(condition, Box::new(PauliX { target: QubitId(1) }))
.expect("Failed to add conditional X gate");
let deps = circuit.analyze_dependencies();
assert_eq!(deps.num_measurements(), 1);
assert!(deps.has_feed_forward());
}
#[test]
fn test_builder_pattern() {
let (builder, bit) = MeasurementCircuitBuilder::<2>::new()
.gate(Box::new(Hadamard { target: QubitId(0) }))
.expect("Failed to add gate")
.measure(QubitId(0))
.expect("Failed to measure qubit");
let circuit = builder
.when(
ClassicalCondition::equals(
crate::classical::ClassicalValue::Integer(bit as u64),
crate::classical::ClassicalValue::Integer(1),
),
Box::new(PauliX { target: QubitId(1) }),
)
.expect("Failed to add conditional gate")
.build();
assert_eq!(circuit.num_operations(), 3);
}
}