quantrs2_core/
register.rs

1use scirs2_core::Complex64;
2use std::marker::PhantomData;
3
4use crate::error::{QuantRS2Error, QuantRS2Result};
5use crate::qubit::QubitId;
6
7/// A quantum register that holds the state of N qubits
8#[derive(Debug, Clone)]
9pub struct Register<const N: usize> {
10    /// Complex amplitudes for each basis state
11    ///
12    /// The index corresponds to the integer representation of a basis state.
13    /// For example, for 2 qubits, amplitudes[0] = |00⟩, amplitudes[1] = |01⟩,
14    /// amplitudes[2] = |10⟩, amplitudes[3] = |11⟩
15    amplitudes: Vec<Complex64>,
16
17    /// Marker to enforce the const generic parameter
18    _phantom: PhantomData<[(); N]>,
19}
20
21impl<const N: usize> Register<N> {
22    /// Create a new register with N qubits in the |0...0⟩ state
23    pub fn new() -> Self {
24        let dim = 1 << N;
25        let mut amplitudes = vec![Complex64::new(0.0, 0.0); dim];
26        amplitudes[0] = Complex64::new(1.0, 0.0);
27
28        Self {
29            amplitudes,
30            _phantom: PhantomData,
31        }
32    }
33
34    /// Create a register with custom initial amplitudes
35    ///
36    /// # Errors
37    ///
38    /// Returns an error if the provided amplitudes vector doesn't have
39    /// the correct dimension (2^N) or if the vector isn't properly normalized.
40    pub fn with_amplitudes(amplitudes: Vec<Complex64>) -> QuantRS2Result<Self> {
41        let expected_dim = 1 << N;
42        if amplitudes.len() != expected_dim {
43            return Err(QuantRS2Error::CircuitValidationFailed(format!(
44                "Amplitudes vector has incorrect dimension. Expected {}, got {}",
45                expected_dim,
46                amplitudes.len()
47            )));
48        }
49
50        // Check if the state is properly normalized (within a small epsilon)
51        let norm_squared: f64 = amplitudes.iter().map(|a| a.norm_sqr()).sum();
52
53        if (norm_squared - 1.0).abs() > 1e-10 {
54            return Err(QuantRS2Error::CircuitValidationFailed(format!(
55                "Amplitudes vector is not properly normalized. Norm^2 = {norm_squared}"
56            )));
57        }
58
59        Ok(Self {
60            amplitudes,
61            _phantom: PhantomData,
62        })
63    }
64
65    /// Get the number of qubits in this register
66    #[inline]
67    pub const fn num_qubits(&self) -> usize {
68        N
69    }
70
71    /// Get the dimension of the state space (2^N)
72    #[inline]
73    pub const fn dimension(&self) -> usize {
74        1 << N
75    }
76
77    /// Get access to the raw amplitudes vector
78    pub fn amplitudes(&self) -> &[Complex64] {
79        &self.amplitudes
80    }
81
82    /// Get mutable access to the raw amplitudes vector
83    pub fn amplitudes_mut(&mut self) -> &mut [Complex64] {
84        &mut self.amplitudes
85    }
86
87    /// Get the amplitude for a specific basis state
88    ///
89    /// The bits parameter must be a slice of length N, where each element
90    /// is either 0 or 1 representing the computational basis state.
91    ///
92    /// # Errors
93    ///
94    /// Returns an error if the bits slice has incorrect length or contains
95    /// values other than 0 or 1.
96    pub fn amplitude(&self, bits: &[u8]) -> QuantRS2Result<Complex64> {
97        if bits.len() != N {
98            return Err(QuantRS2Error::CircuitValidationFailed(format!(
99                "Bits slice has incorrect length. Expected {}, got {}",
100                N,
101                bits.len()
102            )));
103        }
104
105        for &bit in bits {
106            if bit > 1 {
107                return Err(QuantRS2Error::CircuitValidationFailed(format!(
108                    "Invalid bit value {bit}. Must be 0 or 1"
109                )));
110            }
111        }
112
113        let index = bits
114            .iter()
115            .fold(0usize, |acc, &bit| (acc << 1) | bit as usize);
116
117        Ok(self.amplitudes[index])
118    }
119
120    /// Calculate the probability of measuring a specific basis state
121    ///
122    /// The bits parameter must be a slice of length N, where each element
123    /// is either 0 or 1 representing the computational basis state.
124    ///
125    /// # Errors
126    ///
127    /// Returns an error if the bits slice has incorrect length or contains
128    /// values other than 0 or 1.
129    pub fn probability(&self, bits: &[u8]) -> QuantRS2Result<f64> {
130        let amplitude = self.amplitude(bits)?;
131        Ok(amplitude.norm_sqr())
132    }
133
134    /// Calculate the probabilities of measuring each basis state
135    pub fn probabilities(&self) -> Vec<f64> {
136        self.amplitudes.iter().map(|a| a.norm_sqr()).collect()
137    }
138
139    /// Calculate the expectation value of a single-qubit Pauli operator
140    pub fn expectation_z(&self, qubit: impl Into<QubitId>) -> QuantRS2Result<f64> {
141        let qubit_id = qubit.into();
142        let q_idx = qubit_id.id() as usize;
143
144        if q_idx >= N {
145            return Err(QuantRS2Error::InvalidQubitId(qubit_id.id()));
146        }
147
148        let dim = 1 << N;
149        let mut result = 0.0;
150
151        for i in 0..dim {
152            // Check if the qubit is 0 or 1 in this basis state
153            let bit_val = (i >> q_idx) & 1;
154
155            // For Z measurement, +1 if bit is 0, -1 if bit is 1
156            let z_val = if bit_val == 0 { 1.0 } else { -1.0 };
157
158            // Add contribution to expectation value
159            result += z_val * self.amplitudes[i].norm_sqr();
160        }
161
162        Ok(result)
163    }
164}
165
166impl<const N: usize> Default for Register<N> {
167    fn default() -> Self {
168        Self::new()
169    }
170}