quantrs2_sim/mixed_precision_impl/
state_vector.rs1use scirs2_core::ndarray::Array1;
8use scirs2_core::{Complex32, Complex64};
9
10use super::config::QuantumPrecision;
11use crate::error::{Result, SimulatorError};
12
13pub enum MixedPrecisionStateVector {
15 Half(Array1<Complex32>),
17 Single(Array1<Complex32>),
19 Double(Array1<Complex64>),
21 Adaptive {
23 primary: Box<Self>,
24 secondary: Option<Box<Self>>,
25 precision_map: Vec<QuantumPrecision>,
26 },
27}
28
29impl MixedPrecisionStateVector {
30 #[must_use]
32 pub fn new(size: usize, precision: QuantumPrecision) -> Self {
33 match precision {
34 QuantumPrecision::Half => Self::Half(Array1::zeros(size)),
35 QuantumPrecision::Single => Self::Single(Array1::zeros(size)),
36 QuantumPrecision::Double => Self::Double(Array1::zeros(size)),
37 QuantumPrecision::Adaptive => {
38 let primary = Box::new(Self::Single(Array1::zeros(size)));
40 Self::Adaptive {
41 primary,
42 secondary: None,
43 precision_map: vec![QuantumPrecision::Single; size],
44 }
45 }
46 }
47 }
48
49 #[must_use]
51 pub fn computational_basis(num_qubits: usize, precision: QuantumPrecision) -> Self {
52 let size = 1 << num_qubits;
53 let mut state = Self::new(size, precision);
54
55 match &mut state {
57 Self::Half(ref mut arr) => arr[0] = Complex32::new(1.0, 0.0),
58 Self::Single(ref mut arr) => arr[0] = Complex32::new(1.0, 0.0),
59 Self::Double(ref mut arr) => arr[0] = Complex64::new(1.0, 0.0),
60 Self::Adaptive {
61 ref mut primary, ..
62 } => {
63 **primary = Self::computational_basis(num_qubits, QuantumPrecision::Single);
64 }
65 }
66
67 state
68 }
69
70 #[must_use]
72 pub fn len(&self) -> usize {
73 match self {
74 Self::Half(arr) => arr.len(),
75 Self::Single(arr) => arr.len(),
76 Self::Double(arr) => arr.len(),
77 Self::Adaptive { primary, .. } => primary.len(),
78 }
79 }
80
81 #[must_use]
83 pub fn is_empty(&self) -> bool {
84 self.len() == 0
85 }
86
87 #[must_use]
89 pub const fn precision(&self) -> QuantumPrecision {
90 match self {
91 Self::Half(_) => QuantumPrecision::Half,
92 Self::Single(_) => QuantumPrecision::Single,
93 Self::Double(_) => QuantumPrecision::Double,
94 Self::Adaptive { .. } => QuantumPrecision::Adaptive,
95 }
96 }
97
98 pub fn to_precision(&self, target_precision: QuantumPrecision) -> Result<Self> {
100 if self.precision() == target_precision {
101 return Ok(self.clone());
102 }
103
104 let size = self.len();
105 let mut result = Self::new(size, target_precision);
106
107 match (self, &mut result) {
108 (Self::Single(src), Self::Double(dst)) => {
109 for (i, &val) in src.iter().enumerate() {
110 dst[i] = Complex64::new(f64::from(val.re), f64::from(val.im));
111 }
112 }
113 (Self::Double(src), Self::Single(dst)) => {
114 for (i, &val) in src.iter().enumerate() {
115 dst[i] = Complex32::new(val.re as f32, val.im as f32);
116 }
117 }
118 (Self::Half(src), Self::Single(dst)) => {
119 dst.clone_from(src);
120 }
121 (Self::Single(src), Self::Half(dst)) => {
122 dst.clone_from(src);
123 }
124 (Self::Half(src), Self::Double(dst)) => {
125 for (i, &val) in src.iter().enumerate() {
126 dst[i] = Complex64::new(f64::from(val.re), f64::from(val.im));
127 }
128 }
129 (Self::Double(src), Self::Half(dst)) => {
130 for (i, &val) in src.iter().enumerate() {
131 dst[i] = Complex32::new(val.re as f32, val.im as f32);
132 }
133 }
134 _ => {
135 return Err(SimulatorError::UnsupportedOperation(
136 "Complex precision conversion not supported".to_string(),
137 ));
138 }
139 }
140
141 Ok(result)
142 }
143
144 pub fn normalize(&mut self) -> Result<()> {
146 let norm = self.norm();
147 if norm == 0.0 {
148 return Err(SimulatorError::InvalidInput(
149 "Cannot normalize zero vector".to_string(),
150 ));
151 }
152
153 match self {
154 Self::Half(arr) => {
155 let norm_f32 = norm as f32;
156 for val in arr.iter_mut() {
157 *val /= norm_f32;
158 }
159 }
160 Self::Single(arr) => {
161 let norm_f32 = norm as f32;
162 for val in arr.iter_mut() {
163 *val /= norm_f32;
164 }
165 }
166 Self::Double(arr) => {
167 for val in arr.iter_mut() {
168 *val /= norm;
169 }
170 }
171 Self::Adaptive {
172 ref mut primary, ..
173 } => {
174 primary.normalize()?;
175 }
176 }
177
178 Ok(())
179 }
180
181 #[must_use]
183 pub fn norm(&self) -> f64 {
184 match self {
185 Self::Half(arr) => arr
186 .iter()
187 .map(|x| f64::from(x.norm_sqr()))
188 .sum::<f64>()
189 .sqrt(),
190 Self::Single(arr) => arr
191 .iter()
192 .map(|x| f64::from(x.norm_sqr()))
193 .sum::<f64>()
194 .sqrt(),
195 Self::Double(arr) => arr
196 .iter()
197 .map(scirs2_core::Complex::norm_sqr)
198 .sum::<f64>()
199 .sqrt(),
200 Self::Adaptive { primary, .. } => primary.norm(),
201 }
202 }
203
204 pub fn probability(&self, index: usize) -> Result<f64> {
206 if index >= self.len() {
207 return Err(SimulatorError::InvalidInput(format!(
208 "Index {} out of bounds for state vector of length {}",
209 index,
210 self.len()
211 )));
212 }
213
214 let prob = match self {
215 Self::Half(arr) => f64::from(arr[index].norm_sqr()),
216 Self::Single(arr) => f64::from(arr[index].norm_sqr()),
217 Self::Double(arr) => arr[index].norm_sqr(),
218 Self::Adaptive { primary, .. } => primary.probability(index)?,
219 };
220
221 Ok(prob)
222 }
223
224 pub fn amplitude(&self, index: usize) -> Result<Complex64> {
226 if index >= self.len() {
227 return Err(SimulatorError::InvalidInput(format!(
228 "Index {} out of bounds for state vector of length {}",
229 index,
230 self.len()
231 )));
232 }
233
234 let amplitude = match self {
235 Self::Half(arr) => {
236 let val = arr[index];
237 Complex64::new(f64::from(val.re), f64::from(val.im))
238 }
239 Self::Single(arr) => {
240 let val = arr[index];
241 Complex64::new(f64::from(val.re), f64::from(val.im))
242 }
243 Self::Double(arr) => arr[index],
244 Self::Adaptive { primary, .. } => primary.amplitude(index)?,
245 };
246
247 Ok(amplitude)
248 }
249
250 pub fn set_amplitude(&mut self, index: usize, amplitude: Complex64) -> Result<()> {
252 if index >= self.len() {
253 return Err(SimulatorError::InvalidInput(format!(
254 "Index {} out of bounds for state vector of length {}",
255 index,
256 self.len()
257 )));
258 }
259
260 match self {
261 Self::Half(arr) => {
262 arr[index] = Complex32::new(amplitude.re as f32, amplitude.im as f32);
263 }
264 Self::Single(arr) => {
265 arr[index] = Complex32::new(amplitude.re as f32, amplitude.im as f32);
266 }
267 Self::Double(arr) => {
268 arr[index] = amplitude;
269 }
270 Self::Adaptive {
271 ref mut primary, ..
272 } => {
273 primary.set_amplitude(index, amplitude)?;
274 }
275 }
276
277 Ok(())
278 }
279
280 pub fn fidelity(&self, other: &Self) -> Result<f64> {
282 if self.len() != other.len() {
283 return Err(SimulatorError::InvalidInput(
284 "State vectors must have the same length for fidelity calculation".to_string(),
285 ));
286 }
287
288 let mut inner_product = Complex64::new(0.0, 0.0);
289
290 for i in 0..self.len() {
291 let amp1 = self.amplitude(i)?;
292 let amp2 = other.amplitude(i)?;
293 inner_product += amp1.conj() * amp2;
294 }
295
296 Ok(inner_product.norm_sqr())
297 }
298
299 pub fn clone_to_precision(&self, precision: QuantumPrecision) -> Result<Self> {
301 self.to_precision(precision)
302 }
303
304 #[must_use]
306 pub fn memory_usage(&self) -> usize {
307 match self {
308 Self::Half(arr) => arr.len() * std::mem::size_of::<Complex32>(),
309 Self::Single(arr) => arr.len() * std::mem::size_of::<Complex32>(),
310 Self::Double(arr) => arr.len() * std::mem::size_of::<Complex64>(),
311 Self::Adaptive {
312 primary, secondary, ..
313 } => {
314 let mut usage = primary.memory_usage();
315 if let Some(sec) = secondary {
316 usage += sec.memory_usage();
317 }
318 usage += std::mem::size_of::<QuantumPrecision>() * primary.len(); usage
320 }
321 }
322 }
323
324 #[must_use]
326 pub fn is_normalized(&self, tolerance: f64) -> bool {
327 (self.norm() - 1.0).abs() < tolerance
328 }
329
330 #[must_use]
332 pub fn num_qubits(&self) -> usize {
333 (self.len() as f64).log2() as usize
334 }
335}
336
337impl Clone for MixedPrecisionStateVector {
338 fn clone(&self) -> Self {
339 match self {
340 Self::Half(arr) => Self::Half(arr.clone()),
341 Self::Single(arr) => Self::Single(arr.clone()),
342 Self::Double(arr) => Self::Double(arr.clone()),
343 Self::Adaptive {
344 primary,
345 secondary,
346 precision_map,
347 } => Self::Adaptive {
348 primary: primary.clone(),
349 secondary: secondary.clone(),
350 precision_map: precision_map.clone(),
351 },
352 }
353 }
354}
355
356impl std::fmt::Debug for MixedPrecisionStateVector {
357 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358 match self {
359 Self::Half(arr) => write!(f, "Half({} elements)", arr.len()),
360 Self::Single(arr) => write!(f, "Single({} elements)", arr.len()),
361 Self::Double(arr) => write!(f, "Double({} elements)", arr.len()),
362 Self::Adaptive {
363 primary, secondary, ..
364 } => {
365 write!(
366 f,
367 "Adaptive(primary: {primary:?}, secondary: {secondary:?})"
368 )
369 }
370 }
371 }
372}