1use crate::prelude::SimulatorError;
8use half::f16;
9use scirs2_core::ndarray::Array1;
10use scirs2_core::{Complex32, Complex64};
11use std::fmt;
12
13use crate::error::Result;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum Precision {
18 Half,
20 Single,
22 Double,
24 Extended,
26}
27
28impl Precision {
29 pub const fn bytes_per_complex(&self) -> usize {
31 match self {
32 Self::Half => 4, Self::Single => 8, Self::Double => 16, Self::Extended => 32, }
37 }
38
39 pub const fn epsilon(&self) -> f64 {
41 match self {
42 Self::Half => 0.001, Self::Single => 1e-7, Self::Double => 1e-15, Self::Extended => 1e-30, }
47 }
48
49 pub fn from_tolerance(tolerance: f64) -> Self {
51 if tolerance >= 0.001 {
52 Self::Half
53 } else if tolerance >= 1e-7 {
54 Self::Single
55 } else if tolerance >= 1e-15 {
56 Self::Double
57 } else {
58 Self::Extended
59 }
60 }
61}
62
63pub trait ComplexAmplitude: Clone + Send + Sync {
65 fn to_complex64(&self) -> Complex64;
67
68 fn from_complex64(c: Complex64) -> Self;
70
71 fn norm_sqr(&self) -> f64;
73
74 fn scale(&mut self, factor: f64);
76}
77
78impl ComplexAmplitude for Complex64 {
79 fn to_complex64(&self) -> Complex64 {
80 *self
81 }
82
83 fn from_complex64(c: Complex64) -> Self {
84 c
85 }
86
87 fn norm_sqr(&self) -> f64 {
88 self.norm_sqr()
89 }
90
91 fn scale(&mut self, factor: f64) {
92 *self *= factor;
93 }
94}
95
96impl ComplexAmplitude for Complex32 {
97 fn to_complex64(&self) -> Complex64 {
98 Complex64::new(self.re as f64, self.im as f64)
99 }
100
101 fn from_complex64(c: Complex64) -> Self {
102 Self::new(c.re as f32, c.im as f32)
103 }
104
105 fn norm_sqr(&self) -> f64 {
106 self.re.mul_add(self.re, self.im * self.im) as f64
107 }
108
109 fn scale(&mut self, factor: f64) {
110 *self *= factor as f32;
111 }
112}
113
114#[derive(Debug, Clone, Copy)]
116pub struct ComplexF16 {
117 pub re: f16,
118 pub im: f16,
119}
120
121impl ComplexAmplitude for ComplexF16 {
122 fn to_complex64(&self) -> Complex64 {
123 Complex64::new(self.re.to_f64(), self.im.to_f64())
124 }
125
126 fn from_complex64(c: Complex64) -> Self {
127 Self {
128 re: f16::from_f64(c.re),
129 im: f16::from_f64(c.im),
130 }
131 }
132
133 fn norm_sqr(&self) -> f64 {
134 let r = self.re.to_f64();
135 let i = self.im.to_f64();
136 r.mul_add(r, i * i)
137 }
138
139 fn scale(&mut self, factor: f64) {
140 self.re = f16::from_f64(self.re.to_f64() * factor);
141 self.im = f16::from_f64(self.im.to_f64() * factor);
142 }
143}
144
145pub enum AdaptiveStateVector {
147 Half(Array1<ComplexF16>),
148 Single(Array1<Complex32>),
149 Double(Array1<Complex64>),
150}
151
152impl AdaptiveStateVector {
153 pub fn new(num_qubits: usize, precision: Precision) -> Result<Self> {
155 let size = 1 << num_qubits;
156
157 if num_qubits > 30 {
158 return Err(SimulatorError::InvalidQubits(num_qubits));
159 }
160
161 match precision {
162 Precision::Half => {
163 let mut state = Array1::from_elem(
164 size,
165 ComplexF16 {
166 re: f16::from_f64(0.0),
167 im: f16::from_f64(0.0),
168 },
169 );
170 state[0] = ComplexF16 {
171 re: f16::from_f64(1.0),
172 im: f16::from_f64(0.0),
173 };
174 Ok(Self::Half(state))
175 }
176 Precision::Single => {
177 let mut state = Array1::zeros(size);
178 state[0] = Complex32::new(1.0, 0.0);
179 Ok(Self::Single(state))
180 }
181 Precision::Double => {
182 let mut state = Array1::zeros(size);
183 state[0] = Complex64::new(1.0, 0.0);
184 Ok(Self::Double(state))
185 }
186 Precision::Extended => Err(SimulatorError::InvalidConfiguration(
187 "Extended precision not yet supported".to_string(),
188 )),
189 }
190 }
191
192 pub const fn precision(&self) -> Precision {
194 match self {
195 Self::Half(_) => Precision::Half,
196 Self::Single(_) => Precision::Single,
197 Self::Double(_) => Precision::Double,
198 }
199 }
200
201 pub fn num_qubits(&self) -> usize {
203 let size = match self {
204 Self::Half(v) => v.len(),
205 Self::Single(v) => v.len(),
206 Self::Double(v) => v.len(),
207 };
208 (size as f64).log2() as usize
209 }
210
211 pub fn to_complex64(&self) -> Array1<Complex64> {
213 match self {
214 Self::Half(v) => v.map(|c| c.to_complex64()),
215 Self::Single(v) => v.map(|c| c.to_complex64()),
216 Self::Double(v) => v.clone(),
217 }
218 }
219
220 pub fn from_complex64(&mut self, data: &Array1<Complex64>) -> Result<()> {
222 match self {
223 Self::Half(v) => {
224 if v.len() != data.len() {
225 return Err(SimulatorError::DimensionMismatch(format!(
226 "Size mismatch: {} vs {}",
227 v.len(),
228 data.len()
229 )));
230 }
231 for (i, &c) in data.iter().enumerate() {
232 v[i] = ComplexF16::from_complex64(c);
233 }
234 }
235 Self::Single(v) => {
236 if v.len() != data.len() {
237 return Err(SimulatorError::DimensionMismatch(format!(
238 "Size mismatch: {} vs {}",
239 v.len(),
240 data.len()
241 )));
242 }
243 for (i, &c) in data.iter().enumerate() {
244 v[i] = Complex32::from_complex64(c);
245 }
246 }
247 Self::Double(v) => {
248 v.assign(data);
249 }
250 }
251 Ok(())
252 }
253
254 pub fn needs_precision_upgrade(&self, threshold: f64) -> bool {
256 let min_amplitude = match self {
258 Self::Half(v) => {
259 v.iter()
260 .map(|c| c.norm_sqr())
261 .filter(|&n| n > 0.0)
262 .fold(None, |acc, x| match acc {
263 None => Some(x),
264 Some(y) => Some(if x < y { x } else { y }),
265 })
266 }
267 Self::Single(v) => v
268 .iter()
269 .map(|c| c.norm_sqr() as f64)
270 .filter(|&n| n > 0.0)
271 .fold(None, |acc, x| match acc {
272 None => Some(x),
273 Some(y) => Some(if x < y { x } else { y }),
274 }),
275 Self::Double(v) => {
276 v.iter()
277 .map(|c| c.norm_sqr())
278 .filter(|&n| n > 0.0)
279 .fold(None, |acc, x| match acc {
280 None => Some(x),
281 Some(y) => Some(if x < y { x } else { y }),
282 })
283 }
284 };
285
286 if let Some(min_amp) = min_amplitude {
287 min_amp < threshold * self.precision().epsilon()
288 } else {
289 false
290 }
291 }
292
293 pub fn upgrade_precision(&mut self) -> Result<()> {
295 let new_precision = match self.precision() {
296 Precision::Half => Precision::Single,
297 Precision::Single => Precision::Double,
298 Precision::Double => return Ok(()), Precision::Extended => unreachable!(),
300 };
301
302 let data = self.to_complex64();
303 *self = Self::new(self.num_qubits(), new_precision)?;
304 self.from_complex64(&data)?;
305
306 Ok(())
307 }
308
309 pub fn downgrade_precision(&mut self, tolerance: f64) -> Result<()> {
311 let new_precision = match self.precision() {
312 Precision::Half => return Ok(()), Precision::Single => Precision::Half,
314 Precision::Double => Precision::Single,
315 Precision::Extended => Precision::Double,
316 };
317
318 let data = self.to_complex64();
320 let test_vec = Self::new(self.num_qubits(), new_precision)?;
321
322 let mut max_error: f64 = 0.0;
324 match &test_vec {
325 Self::Half(_) => {
326 for &c in &data {
327 let converted = ComplexF16::from_complex64(c).to_complex64();
328 let error = (c - converted).norm();
329 max_error = max_error.max(error);
330 }
331 }
332 Self::Single(_) => {
333 for &c in &data {
334 let converted = Complex32::from_complex64(c).to_complex64();
335 let error = (c - converted).norm();
336 max_error = max_error.max(error);
337 }
338 }
339 _ => unreachable!(),
340 }
341
342 if max_error < tolerance {
343 *self = test_vec;
344 self.from_complex64(&data)?;
345 }
346
347 Ok(())
348 }
349
350 pub fn memory_usage(&self) -> usize {
352 let elements = match self {
353 Self::Half(v) => v.len(),
354 Self::Single(v) => v.len(),
355 Self::Double(v) => v.len(),
356 };
357 elements * self.precision().bytes_per_complex()
358 }
359}
360
361#[derive(Debug, Clone)]
363pub struct AdaptivePrecisionConfig {
364 pub initial_precision: Precision,
366 pub error_tolerance: f64,
368 pub check_interval: usize,
370 pub auto_upgrade: bool,
372 pub auto_downgrade: bool,
374 pub min_amplitude: f64,
376}
377
378impl Default for AdaptivePrecisionConfig {
379 fn default() -> Self {
380 Self {
381 initial_precision: Precision::Single,
382 error_tolerance: 1e-10,
383 check_interval: 100,
384 auto_upgrade: true,
385 auto_downgrade: true,
386 min_amplitude: 1e-12,
387 }
388 }
389}
390
391#[derive(Debug)]
393pub struct PrecisionTracker {
394 changes: Vec<(usize, Precision, Precision)>, gate_count: usize,
398 config: AdaptivePrecisionConfig,
400}
401
402impl PrecisionTracker {
403 pub const fn new(config: AdaptivePrecisionConfig) -> Self {
405 Self {
406 changes: Vec::new(),
407 gate_count: 0,
408 config,
409 }
410 }
411
412 pub const fn record_gate(&mut self) {
414 self.gate_count += 1;
415 }
416
417 pub const fn should_check_precision(&self) -> bool {
419 self.gate_count % self.config.check_interval == 0
420 }
421
422 pub fn record_change(&mut self, from: Precision, to: Precision) {
424 self.changes.push((self.gate_count, from, to));
425 }
426
427 pub fn history(&self) -> &[(usize, Precision, Precision)] {
429 &self.changes
430 }
431
432 pub fn stats(&self) -> PrecisionStats {
434 let mut upgrades = 0;
435 let mut downgrades = 0;
436
437 for (_, from, to) in &self.changes {
438 match (from, to) {
439 (Precision::Half, Precision::Single)
440 | (Precision::Single, Precision::Double)
441 | (Precision::Double, Precision::Extended) => upgrades += 1,
442 _ => downgrades += 1,
443 }
444 }
445
446 PrecisionStats {
447 total_gates: self.gate_count,
448 precision_changes: self.changes.len(),
449 upgrades,
450 downgrades,
451 }
452 }
453}
454
455#[derive(Debug)]
457pub struct PrecisionStats {
458 pub total_gates: usize,
459 pub precision_changes: usize,
460 pub upgrades: usize,
461 pub downgrades: usize,
462}
463
464impl fmt::Display for PrecisionStats {
465 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
466 write!(
467 f,
468 "Precision Stats: {} gates, {} changes ({} upgrades, {} downgrades)",
469 self.total_gates, self.precision_changes, self.upgrades, self.downgrades
470 )
471 }
472}
473
474pub fn benchmark_precisions(num_qubits: usize) -> Result<()> {
476 println!("\nPrecision Benchmark for {num_qubits} qubits:");
477 println!("{:-<60}", "");
478
479 for precision in [Precision::Half, Precision::Single, Precision::Double] {
480 let state = AdaptiveStateVector::new(num_qubits, precision)?;
481 let memory = state.memory_usage();
482 let memory_mb = memory as f64 / (1024.0 * 1024.0);
483
484 println!(
485 "{:?} precision: {:.2} MB ({} bytes per amplitude)",
486 precision,
487 memory_mb,
488 precision.bytes_per_complex()
489 );
490 }
491
492 Ok(())
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498
499 #[test]
500 fn test_precision_levels() {
501 assert_eq!(Precision::Half.bytes_per_complex(), 4);
502 assert_eq!(Precision::Single.bytes_per_complex(), 8);
503 assert_eq!(Precision::Double.bytes_per_complex(), 16);
504 }
505
506 #[test]
507 fn test_precision_from_tolerance() {
508 assert_eq!(Precision::from_tolerance(0.01), Precision::Half);
509 assert_eq!(Precision::from_tolerance(1e-8), Precision::Double); assert_eq!(Precision::from_tolerance(1e-16), Precision::Extended); }
512
513 #[test]
514 fn test_complex_f16() {
515 let c = ComplexF16 {
516 re: f16::from_f64(0.5),
517 im: f16::from_f64(0.5),
518 };
519
520 let c64 = c.to_complex64();
521 assert!((c64.re - 0.5).abs() < 0.01);
522 assert!((c64.im - 0.5).abs() < 0.01);
523 }
524
525 #[test]
526 fn test_adaptive_state_vector() {
527 let mut state = AdaptiveStateVector::new(2, Precision::Single).unwrap();
528 assert_eq!(state.precision(), Precision::Single);
529 assert_eq!(state.num_qubits(), 2);
530
531 let c64 = state.to_complex64();
533 assert_eq!(c64.len(), 4);
534 assert_eq!(c64[0], Complex64::new(1.0, 0.0));
535 }
536
537 #[test]
538 fn test_precision_upgrade() {
539 let mut state = AdaptiveStateVector::new(2, Precision::Half).unwrap();
540 state.upgrade_precision().unwrap();
541 assert_eq!(state.precision(), Precision::Single);
542 }
543
544 #[test]
545 fn test_precision_tracker() {
546 let config = AdaptivePrecisionConfig::default();
547 let mut tracker = PrecisionTracker::new(config);
548
549 for _ in 0..100 {
551 tracker.record_gate();
552 }
553
554 assert!(tracker.should_check_precision());
555
556 tracker.record_change(Precision::Single, Precision::Double);
557 let stats = tracker.stats();
558 assert_eq!(stats.upgrades, 1);
559 assert_eq!(stats.downgrades, 0);
560 }
561
562 #[test]
563 fn test_memory_usage() {
564 let state = AdaptiveStateVector::new(10, Precision::Half).unwrap();
565 let memory = state.memory_usage();
566 assert_eq!(memory, 1024 * 4); }
568}