quantrs2_sim/mixed_precision_impl/
state_vector.rs1use scirs2_core::ndarray::Array1;
8use scirs2_core::{Complex32, Complex64};
9
10use super::config::QuantumPrecision;
11use crate::error::{Result, SimulatorError};
12
13pub enum MixedPrecisionStateVector {
15 Half(Array1<Complex32>),
17 BFloat16(Array1<Complex32>),
19 TF32(Array1<Complex32>),
21 Single(Array1<Complex32>),
23 Double(Array1<Complex64>),
25 Adaptive {
27 primary: Box<Self>,
28 secondary: Option<Box<Self>>,
29 precision_map: Vec<QuantumPrecision>,
30 },
31}
32
33impl MixedPrecisionStateVector {
34 #[must_use]
36 pub fn new(size: usize, precision: QuantumPrecision) -> Self {
37 match precision {
38 QuantumPrecision::Half => Self::Half(Array1::zeros(size)),
39 QuantumPrecision::BFloat16 => Self::BFloat16(Array1::zeros(size)),
40 QuantumPrecision::TF32 => Self::TF32(Array1::zeros(size)),
41 QuantumPrecision::Single => Self::Single(Array1::zeros(size)),
42 QuantumPrecision::Double => Self::Double(Array1::zeros(size)),
43 QuantumPrecision::Adaptive => {
44 let primary = Box::new(Self::Single(Array1::zeros(size)));
46 Self::Adaptive {
47 primary,
48 secondary: None,
49 precision_map: vec![QuantumPrecision::Single; size],
50 }
51 }
52 }
53 }
54
55 #[must_use]
57 pub fn computational_basis(num_qubits: usize, precision: QuantumPrecision) -> Self {
58 let size = 1 << num_qubits;
59 let mut state = Self::new(size, precision);
60
61 match &mut state {
63 Self::Half(ref mut arr)
64 | Self::BFloat16(ref mut arr)
65 | Self::TF32(ref mut arr)
66 | Self::Single(ref mut arr) => arr[0] = Complex32::new(1.0, 0.0),
67 Self::Double(ref mut arr) => arr[0] = Complex64::new(1.0, 0.0),
68 Self::Adaptive {
69 ref mut primary, ..
70 } => {
71 **primary = Self::computational_basis(num_qubits, QuantumPrecision::Single);
72 }
73 }
74
75 state
76 }
77
78 #[must_use]
80 pub fn len(&self) -> usize {
81 match self {
82 Self::Half(arr) | Self::BFloat16(arr) | Self::TF32(arr) | Self::Single(arr) => {
83 arr.len()
84 }
85 Self::Double(arr) => arr.len(),
86 Self::Adaptive { primary, .. } => primary.len(),
87 }
88 }
89
90 #[must_use]
92 pub fn is_empty(&self) -> bool {
93 self.len() == 0
94 }
95
96 #[must_use]
98 pub const fn precision(&self) -> QuantumPrecision {
99 match self {
100 Self::Half(_) => QuantumPrecision::Half,
101 Self::BFloat16(_) => QuantumPrecision::BFloat16,
102 Self::TF32(_) => QuantumPrecision::TF32,
103 Self::Single(_) => QuantumPrecision::Single,
104 Self::Double(_) => QuantumPrecision::Double,
105 Self::Adaptive { .. } => QuantumPrecision::Adaptive,
106 }
107 }
108
109 pub fn to_precision(&self, target_precision: QuantumPrecision) -> Result<Self> {
111 if self.precision() == target_precision {
112 return Ok(self.clone());
113 }
114
115 let size = self.len();
116 let mut result = Self::new(size, target_precision);
117
118 match (self, &mut result) {
119 (Self::Single(src), Self::Double(dst)) => {
120 for (i, &val) in src.iter().enumerate() {
121 dst[i] = Complex64::new(f64::from(val.re), f64::from(val.im));
122 }
123 }
124 (Self::Double(src), Self::Single(dst)) => {
125 for (i, &val) in src.iter().enumerate() {
126 dst[i] = Complex32::new(val.re as f32, val.im as f32);
127 }
128 }
129 (Self::Half(src), Self::Single(dst)) => {
130 dst.clone_from(src);
131 }
132 (Self::Single(src), Self::Half(dst)) => {
133 dst.clone_from(src);
134 }
135 (Self::Half(src), Self::Double(dst)) => {
136 for (i, &val) in src.iter().enumerate() {
137 dst[i] = Complex64::new(f64::from(val.re), f64::from(val.im));
138 }
139 }
140 (Self::Double(src), Self::Half(dst)) => {
141 for (i, &val) in src.iter().enumerate() {
142 dst[i] = Complex32::new(val.re as f32, val.im as f32);
143 }
144 }
145 _ => {
146 return Err(SimulatorError::UnsupportedOperation(
147 "Complex precision conversion not supported".to_string(),
148 ));
149 }
150 }
151
152 Ok(result)
153 }
154
155 pub fn normalize(&mut self) -> Result<()> {
157 let norm = self.norm();
158 if norm == 0.0 {
159 return Err(SimulatorError::InvalidInput(
160 "Cannot normalize zero vector".to_string(),
161 ));
162 }
163
164 match self {
165 Self::Half(arr) | Self::BFloat16(arr) | Self::TF32(arr) | Self::Single(arr) => {
166 let norm_f32 = norm as f32;
167 for val in arr.iter_mut() {
168 *val /= norm_f32;
169 }
170 }
171 Self::Double(arr) => {
172 for val in arr.iter_mut() {
173 *val /= norm;
174 }
175 }
176 Self::Adaptive {
177 ref mut primary, ..
178 } => {
179 primary.normalize()?;
180 }
181 }
182
183 Ok(())
184 }
185
186 #[must_use]
188 pub fn norm(&self) -> f64 {
189 match self {
190 Self::Half(arr) | Self::BFloat16(arr) | Self::TF32(arr) | Self::Single(arr) => arr
191 .iter()
192 .map(|x| f64::from(x.norm_sqr()))
193 .sum::<f64>()
194 .sqrt(),
195 Self::Double(arr) => arr
196 .iter()
197 .map(scirs2_core::Complex::norm_sqr)
198 .sum::<f64>()
199 .sqrt(),
200 Self::Adaptive { primary, .. } => primary.norm(),
201 }
202 }
203
204 pub fn probability(&self, index: usize) -> Result<f64> {
206 if index >= self.len() {
207 return Err(SimulatorError::InvalidInput(format!(
208 "Index {} out of bounds for state vector of length {}",
209 index,
210 self.len()
211 )));
212 }
213
214 let prob = match self {
215 Self::Half(arr) | Self::BFloat16(arr) | Self::TF32(arr) | Self::Single(arr) => {
216 f64::from(arr[index].norm_sqr())
217 }
218 Self::Double(arr) => arr[index].norm_sqr(),
219 Self::Adaptive { primary, .. } => primary.probability(index)?,
220 };
221
222 Ok(prob)
223 }
224
225 pub fn amplitude(&self, index: usize) -> Result<Complex64> {
227 if index >= self.len() {
228 return Err(SimulatorError::InvalidInput(format!(
229 "Index {} out of bounds for state vector of length {}",
230 index,
231 self.len()
232 )));
233 }
234
235 let amplitude = match self {
236 Self::Half(arr) | Self::BFloat16(arr) | Self::TF32(arr) | Self::Single(arr) => {
237 let val = arr[index];
238 Complex64::new(f64::from(val.re), f64::from(val.im))
239 }
240 Self::Double(arr) => arr[index],
241 Self::Adaptive { primary, .. } => primary.amplitude(index)?,
242 };
243
244 Ok(amplitude)
245 }
246
247 pub fn set_amplitude(&mut self, index: usize, amplitude: Complex64) -> Result<()> {
249 if index >= self.len() {
250 return Err(SimulatorError::InvalidInput(format!(
251 "Index {} out of bounds for state vector of length {}",
252 index,
253 self.len()
254 )));
255 }
256
257 match self {
258 Self::Half(arr) | Self::BFloat16(arr) | Self::TF32(arr) | Self::Single(arr) => {
259 arr[index] = Complex32::new(amplitude.re as f32, amplitude.im as f32);
260 }
261 Self::Double(arr) => {
262 arr[index] = amplitude;
263 }
264 Self::Adaptive {
265 ref mut primary, ..
266 } => {
267 primary.set_amplitude(index, amplitude)?;
268 }
269 }
270
271 Ok(())
272 }
273
274 pub fn fidelity(&self, other: &Self) -> Result<f64> {
276 if self.len() != other.len() {
277 return Err(SimulatorError::InvalidInput(
278 "State vectors must have the same length for fidelity calculation".to_string(),
279 ));
280 }
281
282 let mut inner_product = Complex64::new(0.0, 0.0);
283
284 for i in 0..self.len() {
285 let amp1 = self.amplitude(i)?;
286 let amp2 = other.amplitude(i)?;
287 inner_product += amp1.conj() * amp2;
288 }
289
290 Ok(inner_product.norm_sqr())
291 }
292
293 pub fn clone_to_precision(&self, precision: QuantumPrecision) -> Result<Self> {
295 self.to_precision(precision)
296 }
297
298 #[must_use]
300 pub fn memory_usage(&self) -> usize {
301 match self {
302 Self::Half(arr) | Self::BFloat16(arr) | Self::TF32(arr) | Self::Single(arr) => {
303 arr.len() * std::mem::size_of::<Complex32>()
304 }
305 Self::Double(arr) => arr.len() * std::mem::size_of::<Complex64>(),
306 Self::Adaptive {
307 primary, secondary, ..
308 } => {
309 let mut usage = primary.memory_usage();
310 if let Some(sec) = secondary {
311 usage += sec.memory_usage();
312 }
313 usage += std::mem::size_of::<QuantumPrecision>() * primary.len(); usage
315 }
316 }
317 }
318
319 #[must_use]
321 pub fn is_normalized(&self, tolerance: f64) -> bool {
322 (self.norm() - 1.0).abs() < tolerance
323 }
324
325 #[must_use]
327 pub fn num_qubits(&self) -> usize {
328 (self.len() as f64).log2() as usize
329 }
330}
331
332impl Clone for MixedPrecisionStateVector {
333 fn clone(&self) -> Self {
334 match self {
335 Self::Half(arr) => Self::Half(arr.clone()),
336 Self::BFloat16(arr) => Self::BFloat16(arr.clone()),
337 Self::TF32(arr) => Self::TF32(arr.clone()),
338 Self::Single(arr) => Self::Single(arr.clone()),
339 Self::Double(arr) => Self::Double(arr.clone()),
340 Self::Adaptive {
341 primary,
342 secondary,
343 precision_map,
344 } => Self::Adaptive {
345 primary: primary.clone(),
346 secondary: secondary.clone(),
347 precision_map: precision_map.clone(),
348 },
349 }
350 }
351}
352
353impl std::fmt::Debug for MixedPrecisionStateVector {
354 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355 match self {
356 Self::Half(arr) => write!(f, "Half({} elements)", arr.len()),
357 Self::BFloat16(arr) => write!(f, "BFloat16({} elements)", arr.len()),
358 Self::TF32(arr) => write!(f, "TF32({} elements)", arr.len()),
359 Self::Single(arr) => write!(f, "Single({} elements)", arr.len()),
360 Self::Double(arr) => write!(f, "Double({} elements)", arr.len()),
361 Self::Adaptive {
362 primary, secondary, ..
363 } => {
364 write!(
365 f,
366 "Adaptive(primary: {primary:?}, secondary: {secondary:?})"
367 )
368 }
369 }
370 }
371}