quantrs2_core/
register.rs

1use num_complex::Complex64;
2use std::marker::PhantomData;
3
4use crate::error::{QuantRS2Error, QuantRS2Result};
5use crate::qubit::QubitId;
6
7/// A quantum register that holds the state of N qubits
8#[derive(Debug, Clone)]
9pub struct Register<const N: usize> {
10    /// Complex amplitudes for each basis state
11    ///
12    /// The index corresponds to the integer representation of a basis state.
13    /// For example, for 2 qubits, amplitudes[0] = |00⟩, amplitudes[1] = |01⟩,
14    /// amplitudes[2] = |10⟩, amplitudes[3] = |11⟩
15    amplitudes: Vec<Complex64>,
16
17    /// Marker to enforce the const generic parameter
18    _phantom: PhantomData<[(); N]>,
19}
20
21impl<const N: usize> Register<N> {
22    /// Create a new register with N qubits in the |0...0⟩ state
23    pub fn new() -> Self {
24        let dim = 1 << N;
25        let mut amplitudes = vec![Complex64::new(0.0, 0.0); dim];
26        amplitudes[0] = Complex64::new(1.0, 0.0);
27
28        Self {
29            amplitudes,
30            _phantom: PhantomData,
31        }
32    }
33
34    /// Create a register with custom initial amplitudes
35    ///
36    /// # Errors
37    ///
38    /// Returns an error if the provided amplitudes vector doesn't have
39    /// the correct dimension (2^N) or if the vector isn't properly normalized.
40    pub fn with_amplitudes(amplitudes: Vec<Complex64>) -> QuantRS2Result<Self> {
41        let expected_dim = 1 << N;
42        if amplitudes.len() != expected_dim {
43            return Err(QuantRS2Error::CircuitValidationFailed(format!(
44                "Amplitudes vector has incorrect dimension. Expected {}, got {}",
45                expected_dim,
46                amplitudes.len()
47            )));
48        }
49
50        // Check if the state is properly normalized (within a small epsilon)
51        let norm_squared: f64 = amplitudes.iter().map(|a| a.norm_sqr()).sum();
52
53        if (norm_squared - 1.0).abs() > 1e-10 {
54            return Err(QuantRS2Error::CircuitValidationFailed(format!(
55                "Amplitudes vector is not properly normalized. Norm^2 = {}",
56                norm_squared
57            )));
58        }
59
60        Ok(Self {
61            amplitudes,
62            _phantom: PhantomData,
63        })
64    }
65
66    /// Get the number of qubits in this register
67    #[inline]
68    pub const fn num_qubits(&self) -> usize {
69        N
70    }
71
72    /// Get the dimension of the state space (2^N)
73    #[inline]
74    pub const fn dimension(&self) -> usize {
75        1 << N
76    }
77
78    /// Get access to the raw amplitudes vector
79    pub fn amplitudes(&self) -> &[Complex64] {
80        &self.amplitudes
81    }
82
83    /// Get mutable access to the raw amplitudes vector
84    pub fn amplitudes_mut(&mut self) -> &mut [Complex64] {
85        &mut self.amplitudes
86    }
87
88    /// Get the amplitude for a specific basis state
89    ///
90    /// The bits parameter must be a slice of length N, where each element
91    /// is either 0 or 1 representing the computational basis state.
92    ///
93    /// # Errors
94    ///
95    /// Returns an error if the bits slice has incorrect length or contains
96    /// values other than 0 or 1.
97    pub fn amplitude(&self, bits: &[u8]) -> QuantRS2Result<Complex64> {
98        if bits.len() != N {
99            return Err(QuantRS2Error::CircuitValidationFailed(format!(
100                "Bits slice has incorrect length. Expected {}, got {}",
101                N,
102                bits.len()
103            )));
104        }
105
106        for &bit in bits {
107            if bit > 1 {
108                return Err(QuantRS2Error::CircuitValidationFailed(format!(
109                    "Invalid bit value {}. Must be 0 or 1",
110                    bit
111                )));
112            }
113        }
114
115        let index = bits
116            .iter()
117            .fold(0usize, |acc, &bit| (acc << 1) | bit as usize);
118
119        Ok(self.amplitudes[index])
120    }
121
122    /// Calculate the probability of measuring a specific basis state
123    ///
124    /// The bits parameter must be a slice of length N, where each element
125    /// is either 0 or 1 representing the computational basis state.
126    ///
127    /// # Errors
128    ///
129    /// Returns an error if the bits slice has incorrect length or contains
130    /// values other than 0 or 1.
131    pub fn probability(&self, bits: &[u8]) -> QuantRS2Result<f64> {
132        let amplitude = self.amplitude(bits)?;
133        Ok(amplitude.norm_sqr())
134    }
135
136    /// Calculate the probabilities of measuring each basis state
137    pub fn probabilities(&self) -> Vec<f64> {
138        self.amplitudes.iter().map(|a| a.norm_sqr()).collect()
139    }
140
141    /// Calculate the expectation value of a single-qubit Pauli operator
142    pub fn expectation_z(&self, qubit: impl Into<QubitId>) -> QuantRS2Result<f64> {
143        let qubit_id = qubit.into();
144        let q_idx = qubit_id.id() as usize;
145
146        if q_idx >= N {
147            return Err(QuantRS2Error::InvalidQubitId(qubit_id.id()));
148        }
149
150        let dim = 1 << N;
151        let mut result = 0.0;
152
153        for i in 0..dim {
154            // Check if the qubit is 0 or 1 in this basis state
155            let bit_val = (i >> q_idx) & 1;
156
157            // For Z measurement, +1 if bit is 0, -1 if bit is 1
158            let z_val = if bit_val == 0 { 1.0 } else { -1.0 };
159
160            // Add contribution to expectation value
161            result += z_val * self.amplitudes[i].norm_sqr();
162        }
163
164        Ok(result)
165    }
166}
167
168impl<const N: usize> Default for Register<N> {
169    fn default() -> Self {
170        Self::new()
171    }
172}