Skip to main content

quantrs2_core/
register.rs

1//! Quantum register types for holding the full state of N qubits.
2//!
3//! A [`Register<N>`] stores complex amplitude vectors of dimension 2^N and
4//! supports measurement sampling and amplitude manipulation.
5
6use scirs2_core::Complex64;
7use std::marker::PhantomData;
8
9use crate::error::{QuantRS2Error, QuantRS2Result};
10use crate::qubit::QubitId;
11
12/// A quantum register that holds the full quantum state of N qubits.
13///
14/// The state is stored as a 2^N complex amplitude vector where index `i`
15/// corresponds to the computational basis state |i⟩.
16///
17/// # Examples
18///
19/// ```rust
20/// use quantrs2_core::register::Register;
21///
22/// // Start in |00⟩ state
23/// let reg: Register<2> = Register::new();
24/// assert_eq!(reg.num_qubits(), 2);
25/// // Ground state has amplitude 1 at index 0
26/// assert!((reg.amplitudes()[0].re - 1.0).abs() < 1e-10);
27/// ```
28#[derive(Debug, Clone)]
29pub struct Register<const N: usize> {
30    /// Complex amplitudes for each basis state
31    ///
32    /// The index corresponds to the integer representation of a basis state.
33    /// For example, for 2 qubits, amplitudes[0] = |00⟩, amplitudes[1] = |01⟩,
34    /// amplitudes[2] = |10⟩, amplitudes[3] = |11⟩
35    amplitudes: Vec<Complex64>,
36
37    /// Marker to enforce the const generic parameter
38    _phantom: PhantomData<[(); N]>,
39}
40
41impl<const N: usize> Register<N> {
42    /// Create a new register with N qubits in the |0...0⟩ state
43    pub fn new() -> Self {
44        let dim = 1 << N;
45        let mut amplitudes = vec![Complex64::new(0.0, 0.0); dim];
46        amplitudes[0] = Complex64::new(1.0, 0.0);
47
48        Self {
49            amplitudes,
50            _phantom: PhantomData,
51        }
52    }
53
54    /// Create a register with custom initial amplitudes
55    ///
56    /// # Errors
57    ///
58    /// Returns an error if the provided amplitudes vector doesn't have
59    /// the correct dimension (2^N) or if the vector isn't properly normalized.
60    pub fn with_amplitudes(amplitudes: Vec<Complex64>) -> QuantRS2Result<Self> {
61        let expected_dim = 1 << N;
62        if amplitudes.len() != expected_dim {
63            return Err(QuantRS2Error::CircuitValidationFailed(format!(
64                "Amplitudes vector has incorrect dimension. Expected {}, got {}",
65                expected_dim,
66                amplitudes.len()
67            )));
68        }
69
70        // Check if the state is properly normalized (within a small epsilon)
71        let norm_squared: f64 = amplitudes.iter().map(|a| a.norm_sqr()).sum();
72
73        if (norm_squared - 1.0).abs() > 1e-10 {
74            return Err(QuantRS2Error::CircuitValidationFailed(format!(
75                "Amplitudes vector is not properly normalized. Norm^2 = {norm_squared}"
76            )));
77        }
78
79        Ok(Self {
80            amplitudes,
81            _phantom: PhantomData,
82        })
83    }
84
85    /// Get the number of qubits in this register
86    #[inline]
87    pub const fn num_qubits(&self) -> usize {
88        N
89    }
90
91    /// Get the dimension of the state space (2^N)
92    #[inline]
93    pub const fn dimension(&self) -> usize {
94        1 << N
95    }
96
97    /// Get access to the raw amplitudes vector
98    pub fn amplitudes(&self) -> &[Complex64] {
99        &self.amplitudes
100    }
101
102    /// Get mutable access to the raw amplitudes vector
103    pub fn amplitudes_mut(&mut self) -> &mut [Complex64] {
104        &mut self.amplitudes
105    }
106
107    /// Get the amplitude for a specific basis state
108    ///
109    /// The bits parameter must be a slice of length N, where each element
110    /// is either 0 or 1 representing the computational basis state.
111    ///
112    /// # Errors
113    ///
114    /// Returns an error if the bits slice has incorrect length or contains
115    /// values other than 0 or 1.
116    pub fn amplitude(&self, bits: &[u8]) -> QuantRS2Result<Complex64> {
117        if bits.len() != N {
118            return Err(QuantRS2Error::CircuitValidationFailed(format!(
119                "Bits slice has incorrect length. Expected {}, got {}",
120                N,
121                bits.len()
122            )));
123        }
124
125        for &bit in bits {
126            if bit > 1 {
127                return Err(QuantRS2Error::CircuitValidationFailed(format!(
128                    "Invalid bit value {bit}. Must be 0 or 1"
129                )));
130            }
131        }
132
133        let index = bits
134            .iter()
135            .fold(0usize, |acc, &bit| (acc << 1) | bit as usize);
136
137        Ok(self.amplitudes[index])
138    }
139
140    /// Calculate the probability of measuring a specific basis state
141    ///
142    /// The bits parameter must be a slice of length N, where each element
143    /// is either 0 or 1 representing the computational basis state.
144    ///
145    /// # Errors
146    ///
147    /// Returns an error if the bits slice has incorrect length or contains
148    /// values other than 0 or 1.
149    pub fn probability(&self, bits: &[u8]) -> QuantRS2Result<f64> {
150        let amplitude = self.amplitude(bits)?;
151        Ok(amplitude.norm_sqr())
152    }
153
154    /// Calculate the probabilities of measuring each basis state
155    pub fn probabilities(&self) -> Vec<f64> {
156        self.amplitudes.iter().map(|a| a.norm_sqr()).collect()
157    }
158
159    /// Calculate the expectation value of a single-qubit Pauli operator
160    pub fn expectation_z(&self, qubit: impl Into<QubitId>) -> QuantRS2Result<f64> {
161        let qubit_id = qubit.into();
162        let q_idx = qubit_id.id() as usize;
163
164        if q_idx >= N {
165            return Err(QuantRS2Error::InvalidQubitId(qubit_id.id()));
166        }
167
168        let dim = 1 << N;
169        let mut result = 0.0;
170
171        for i in 0..dim {
172            // Check if the qubit is 0 or 1 in this basis state
173            let bit_val = (i >> q_idx) & 1;
174
175            // For Z measurement, +1 if bit is 0, -1 if bit is 1
176            let z_val = if bit_val == 0 { 1.0 } else { -1.0 };
177
178            // Add contribution to expectation value
179            result += z_val * self.amplitudes[i].norm_sqr();
180        }
181
182        Ok(result)
183    }
184}
185
186impl<const N: usize> Default for Register<N> {
187    fn default() -> Self {
188        Self::new()
189    }
190}