1use crate::prelude::SimulatorError;
8use half::f16;
9use scirs2_core::ndarray::Array1;
10use scirs2_core::{Complex32, Complex64};
11use std::fmt;
12
13use crate::error::Result;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum Precision {
18 Half,
20 Single,
22 Double,
24 Extended,
26}
27
28impl Precision {
29 #[must_use]
31 pub const fn bytes_per_complex(&self) -> usize {
32 match self {
33 Self::Half => 4, Self::Single => 8, Self::Double => 16, Self::Extended => 32, }
38 }
39
40 #[must_use]
42 pub const fn epsilon(&self) -> f64 {
43 match self {
44 Self::Half => 0.001, Self::Single => 1e-7, Self::Double => 1e-15, Self::Extended => 1e-30, }
49 }
50
51 #[must_use]
53 pub fn from_tolerance(tolerance: f64) -> Self {
54 if tolerance >= 0.001 {
55 Self::Half
56 } else if tolerance >= 1e-7 {
57 Self::Single
58 } else if tolerance >= 1e-15 {
59 Self::Double
60 } else {
61 Self::Extended
62 }
63 }
64}
65
66pub trait ComplexAmplitude: Clone + Send + Sync {
68 fn to_complex64(&self) -> Complex64;
70
71 fn from_complex64(c: Complex64) -> Self;
73
74 fn norm_sqr(&self) -> f64;
76
77 fn scale(&mut self, factor: f64);
79}
80
81impl ComplexAmplitude for Complex64 {
82 fn to_complex64(&self) -> Complex64 {
83 *self
84 }
85
86 fn from_complex64(c: Complex64) -> Self {
87 c
88 }
89
90 fn norm_sqr(&self) -> f64 {
91 self.norm_sqr()
92 }
93
94 fn scale(&mut self, factor: f64) {
95 *self *= factor;
96 }
97}
98
99impl ComplexAmplitude for Complex32 {
100 fn to_complex64(&self) -> Complex64 {
101 Complex64::new(f64::from(self.re), f64::from(self.im))
102 }
103
104 fn from_complex64(c: Complex64) -> Self {
105 Self::new(c.re as f32, c.im as f32)
106 }
107
108 fn norm_sqr(&self) -> f64 {
109 f64::from(self.re.mul_add(self.re, self.im * self.im))
110 }
111
112 fn scale(&mut self, factor: f64) {
113 *self *= factor as f32;
114 }
115}
116
117#[derive(Debug, Clone, Copy)]
119pub struct ComplexF16 {
120 pub re: f16,
121 pub im: f16,
122}
123
124impl ComplexAmplitude for ComplexF16 {
125 fn to_complex64(&self) -> Complex64 {
126 Complex64::new(self.re.to_f64(), self.im.to_f64())
127 }
128
129 fn from_complex64(c: Complex64) -> Self {
130 Self {
131 re: f16::from_f64(c.re),
132 im: f16::from_f64(c.im),
133 }
134 }
135
136 fn norm_sqr(&self) -> f64 {
137 let r = self.re.to_f64();
138 let i = self.im.to_f64();
139 r.mul_add(r, i * i)
140 }
141
142 fn scale(&mut self, factor: f64) {
143 self.re = f16::from_f64(self.re.to_f64() * factor);
144 self.im = f16::from_f64(self.im.to_f64() * factor);
145 }
146}
147
148pub enum AdaptiveStateVector {
150 Half(Array1<ComplexF16>),
151 Single(Array1<Complex32>),
152 Double(Array1<Complex64>),
153}
154
155impl AdaptiveStateVector {
156 pub fn new(num_qubits: usize, precision: Precision) -> Result<Self> {
158 let size = 1 << num_qubits;
159
160 if num_qubits > 30 {
161 return Err(SimulatorError::InvalidQubits(num_qubits));
162 }
163
164 match precision {
165 Precision::Half => {
166 let mut state = Array1::from_elem(
167 size,
168 ComplexF16 {
169 re: f16::from_f64(0.0),
170 im: f16::from_f64(0.0),
171 },
172 );
173 state[0] = ComplexF16 {
174 re: f16::from_f64(1.0),
175 im: f16::from_f64(0.0),
176 };
177 Ok(Self::Half(state))
178 }
179 Precision::Single => {
180 let mut state = Array1::zeros(size);
181 state[0] = Complex32::new(1.0, 0.0);
182 Ok(Self::Single(state))
183 }
184 Precision::Double => {
185 let mut state = Array1::zeros(size);
186 state[0] = Complex64::new(1.0, 0.0);
187 Ok(Self::Double(state))
188 }
189 Precision::Extended => Err(SimulatorError::InvalidConfiguration(
190 "Extended precision not yet supported".to_string(),
191 )),
192 }
193 }
194
195 #[must_use]
197 pub const fn precision(&self) -> Precision {
198 match self {
199 Self::Half(_) => Precision::Half,
200 Self::Single(_) => Precision::Single,
201 Self::Double(_) => Precision::Double,
202 }
203 }
204
205 #[must_use]
207 pub fn num_qubits(&self) -> usize {
208 let size = match self {
209 Self::Half(v) => v.len(),
210 Self::Single(v) => v.len(),
211 Self::Double(v) => v.len(),
212 };
213 (size as f64).log2() as usize
214 }
215
216 #[must_use]
218 pub fn to_complex64(&self) -> Array1<Complex64> {
219 match self {
220 Self::Half(v) => v.map(ComplexAmplitude::to_complex64),
221 Self::Single(v) => v.map(ComplexAmplitude::to_complex64),
222 Self::Double(v) => v.clone(),
223 }
224 }
225
226 pub fn from_complex64(&mut self, data: &Array1<Complex64>) -> Result<()> {
228 match self {
229 Self::Half(v) => {
230 if v.len() != data.len() {
231 return Err(SimulatorError::DimensionMismatch(format!(
232 "Size mismatch: {} vs {}",
233 v.len(),
234 data.len()
235 )));
236 }
237 for (i, &c) in data.iter().enumerate() {
238 v[i] = ComplexF16::from_complex64(c);
239 }
240 }
241 Self::Single(v) => {
242 if v.len() != data.len() {
243 return Err(SimulatorError::DimensionMismatch(format!(
244 "Size mismatch: {} vs {}",
245 v.len(),
246 data.len()
247 )));
248 }
249 for (i, &c) in data.iter().enumerate() {
250 v[i] = Complex32::from_complex64(c);
251 }
252 }
253 Self::Double(v) => {
254 v.assign(data);
255 }
256 }
257 Ok(())
258 }
259
260 #[must_use]
262 pub fn needs_precision_upgrade(&self, threshold: f64) -> bool {
263 let min_amplitude = match self {
265 Self::Half(v) => v
266 .iter()
267 .map(ComplexAmplitude::norm_sqr)
268 .filter(|&n| n > 0.0)
269 .fold(None, |acc, x| match acc {
270 None => Some(x),
271 Some(y) => Some(if x < y { x } else { y }),
272 }),
273 Self::Single(v) => v
274 .iter()
275 .map(|c| f64::from(c.norm_sqr()))
276 .filter(|&n| n > 0.0)
277 .fold(None, |acc, x| match acc {
278 None => Some(x),
279 Some(y) => Some(if x < y { x } else { y }),
280 }),
281 Self::Double(v) => v
282 .iter()
283 .map(scirs2_core::Complex::norm_sqr)
284 .filter(|&n| n > 0.0)
285 .fold(None, |acc, x| match acc {
286 None => Some(x),
287 Some(y) => Some(if x < y { x } else { y }),
288 }),
289 };
290
291 if let Some(min_amp) = min_amplitude {
292 min_amp < threshold * self.precision().epsilon()
293 } else {
294 false
295 }
296 }
297
298 pub fn upgrade_precision(&mut self) -> Result<()> {
300 let new_precision = match self.precision() {
301 Precision::Half => Precision::Single,
302 Precision::Single => Precision::Double,
303 Precision::Double => return Ok(()), Precision::Extended => unreachable!(),
305 };
306
307 let data = self.to_complex64();
308 *self = Self::new(self.num_qubits(), new_precision)?;
309 self.from_complex64(&data)?;
310
311 Ok(())
312 }
313
314 pub fn downgrade_precision(&mut self, tolerance: f64) -> Result<()> {
316 let new_precision = match self.precision() {
317 Precision::Half => return Ok(()), Precision::Single => Precision::Half,
319 Precision::Double => Precision::Single,
320 Precision::Extended => Precision::Double,
321 };
322
323 let data = self.to_complex64();
325 let test_vec = Self::new(self.num_qubits(), new_precision)?;
326
327 let mut max_error: f64 = 0.0;
329 match &test_vec {
330 Self::Half(_) => {
331 for &c in &data {
332 let converted = ComplexF16::from_complex64(c).to_complex64();
333 let error = (c - converted).norm();
334 max_error = max_error.max(error);
335 }
336 }
337 Self::Single(_) => {
338 for &c in &data {
339 let converted = Complex32::from_complex64(c).to_complex64();
340 let error = (c - converted).norm();
341 max_error = max_error.max(error);
342 }
343 }
344 Self::Double(_) => unreachable!(),
345 }
346
347 if max_error < tolerance {
348 *self = test_vec;
349 self.from_complex64(&data)?;
350 }
351
352 Ok(())
353 }
354
355 #[must_use]
357 pub fn memory_usage(&self) -> usize {
358 let elements = match self {
359 Self::Half(v) => v.len(),
360 Self::Single(v) => v.len(),
361 Self::Double(v) => v.len(),
362 };
363 elements * self.precision().bytes_per_complex()
364 }
365}
366
367#[derive(Debug, Clone)]
369pub struct AdaptivePrecisionConfig {
370 pub initial_precision: Precision,
372 pub error_tolerance: f64,
374 pub check_interval: usize,
376 pub auto_upgrade: bool,
378 pub auto_downgrade: bool,
380 pub min_amplitude: f64,
382}
383
384impl Default for AdaptivePrecisionConfig {
385 fn default() -> Self {
386 Self {
387 initial_precision: Precision::Single,
388 error_tolerance: 1e-10,
389 check_interval: 100,
390 auto_upgrade: true,
391 auto_downgrade: true,
392 min_amplitude: 1e-12,
393 }
394 }
395}
396
397#[derive(Debug)]
399pub struct PrecisionTracker {
400 changes: Vec<(usize, Precision, Precision)>, gate_count: usize,
404 config: AdaptivePrecisionConfig,
406}
407
408impl PrecisionTracker {
409 #[must_use]
411 pub const fn new(config: AdaptivePrecisionConfig) -> Self {
412 Self {
413 changes: Vec::new(),
414 gate_count: 0,
415 config,
416 }
417 }
418
419 pub const fn record_gate(&mut self) {
421 self.gate_count += 1;
422 }
423
424 #[must_use]
426 pub const fn should_check_precision(&self) -> bool {
427 self.gate_count % self.config.check_interval == 0
428 }
429
430 pub fn record_change(&mut self, from: Precision, to: Precision) {
432 self.changes.push((self.gate_count, from, to));
433 }
434
435 #[must_use]
437 pub fn history(&self) -> &[(usize, Precision, Precision)] {
438 &self.changes
439 }
440
441 #[must_use]
443 pub fn stats(&self) -> PrecisionStats {
444 let mut upgrades = 0;
445 let mut downgrades = 0;
446
447 for (_, from, to) in &self.changes {
448 match (from, to) {
449 (Precision::Half, Precision::Single)
450 | (Precision::Single, Precision::Double)
451 | (Precision::Double, Precision::Extended) => upgrades += 1,
452 _ => downgrades += 1,
453 }
454 }
455
456 PrecisionStats {
457 total_gates: self.gate_count,
458 precision_changes: self.changes.len(),
459 upgrades,
460 downgrades,
461 }
462 }
463}
464
465#[derive(Debug)]
467pub struct PrecisionStats {
468 pub total_gates: usize,
469 pub precision_changes: usize,
470 pub upgrades: usize,
471 pub downgrades: usize,
472}
473
474impl fmt::Display for PrecisionStats {
475 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
476 write!(
477 f,
478 "Precision Stats: {} gates, {} changes ({} upgrades, {} downgrades)",
479 self.total_gates, self.precision_changes, self.upgrades, self.downgrades
480 )
481 }
482}
483
484pub fn benchmark_precisions(num_qubits: usize) -> Result<()> {
486 println!("\nPrecision Benchmark for {num_qubits} qubits:");
487 println!("{:-<60}", "");
488
489 for precision in [Precision::Half, Precision::Single, Precision::Double] {
490 let state = AdaptiveStateVector::new(num_qubits, precision)?;
491 let memory = state.memory_usage();
492 let memory_mb = memory as f64 / (1024.0 * 1024.0);
493
494 println!(
495 "{:?} precision: {:.2} MB ({} bytes per amplitude)",
496 precision,
497 memory_mb,
498 precision.bytes_per_complex()
499 );
500 }
501
502 Ok(())
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508
509 #[test]
510 fn test_precision_levels() {
511 assert_eq!(Precision::Half.bytes_per_complex(), 4);
512 assert_eq!(Precision::Single.bytes_per_complex(), 8);
513 assert_eq!(Precision::Double.bytes_per_complex(), 16);
514 }
515
516 #[test]
517 fn test_precision_from_tolerance() {
518 assert_eq!(Precision::from_tolerance(0.01), Precision::Half);
519 assert_eq!(Precision::from_tolerance(1e-8), Precision::Double); assert_eq!(Precision::from_tolerance(1e-16), Precision::Extended); }
522
523 #[test]
524 fn test_complex_f16() {
525 let c = ComplexF16 {
526 re: f16::from_f64(0.5),
527 im: f16::from_f64(0.5),
528 };
529
530 let c64 = c.to_complex64();
531 assert!((c64.re - 0.5).abs() < 0.01);
532 assert!((c64.im - 0.5).abs() < 0.01);
533 }
534
535 #[test]
536 fn test_adaptive_state_vector() {
537 let mut state = AdaptiveStateVector::new(2, Precision::Single)
538 .expect("Failed to create adaptive state vector");
539 assert_eq!(state.precision(), Precision::Single);
540 assert_eq!(state.num_qubits(), 2);
541
542 let c64 = state.to_complex64();
544 assert_eq!(c64.len(), 4);
545 assert_eq!(c64[0], Complex64::new(1.0, 0.0));
546 }
547
548 #[test]
549 fn test_precision_upgrade() {
550 let mut state = AdaptiveStateVector::new(2, Precision::Half)
551 .expect("Failed to create half precision state");
552 state
553 .upgrade_precision()
554 .expect("Failed to upgrade precision");
555 assert_eq!(state.precision(), Precision::Single);
556 }
557
558 #[test]
559 fn test_precision_tracker() {
560 let config = AdaptivePrecisionConfig::default();
561 let mut tracker = PrecisionTracker::new(config);
562
563 for _ in 0..100 {
565 tracker.record_gate();
566 }
567
568 assert!(tracker.should_check_precision());
569
570 tracker.record_change(Precision::Single, Precision::Double);
571 let stats = tracker.stats();
572 assert_eq!(stats.upgrades, 1);
573 assert_eq!(stats.downgrades, 0);
574 }
575
576 #[test]
577 fn test_memory_usage() {
578 let state = AdaptiveStateVector::new(10, Precision::Half)
579 .expect("Failed to create state for memory test");
580 let memory = state.memory_usage();
581 assert_eq!(memory, 1024 * 4); }
583}