1use crate::prelude::SimulatorError;
8use half::f16;
9use scirs2_core::ndarray::Array1;
10use scirs2_core::{Complex32, Complex64};
11use std::fmt;
12
13use crate::error::Result;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum Precision {
18 Half,
20 Single,
22 Double,
24 Extended,
26}
27
28impl Precision {
29 pub fn bytes_per_complex(&self) -> usize {
31 match self {
32 Precision::Half => 4, Precision::Single => 8, Precision::Double => 16, Precision::Extended => 32, }
37 }
38
39 pub fn epsilon(&self) -> f64 {
41 match self {
42 Precision::Half => 0.001, Precision::Single => 1e-7, Precision::Double => 1e-15, Precision::Extended => 1e-30, }
47 }
48
49 pub fn from_tolerance(tolerance: f64) -> Self {
51 if tolerance >= 0.001 {
52 Precision::Half
53 } else if tolerance >= 1e-7 {
54 Precision::Single
55 } else if tolerance >= 1e-15 {
56 Precision::Double
57 } else {
58 Precision::Extended
59 }
60 }
61}
62
63pub trait ComplexAmplitude: Clone + Send + Sync {
65 fn to_complex64(&self) -> Complex64;
67
68 fn from_complex64(c: Complex64) -> Self;
70
71 fn norm_sqr(&self) -> f64;
73
74 fn scale(&mut self, factor: f64);
76}
77
78impl ComplexAmplitude for Complex64 {
79 fn to_complex64(&self) -> Complex64 {
80 *self
81 }
82
83 fn from_complex64(c: Complex64) -> Self {
84 c
85 }
86
87 fn norm_sqr(&self) -> f64 {
88 self.norm_sqr()
89 }
90
91 fn scale(&mut self, factor: f64) {
92 *self *= factor;
93 }
94}
95
96impl ComplexAmplitude for Complex32 {
97 fn to_complex64(&self) -> Complex64 {
98 Complex64::new(self.re as f64, self.im as f64)
99 }
100
101 fn from_complex64(c: Complex64) -> Self {
102 Complex32::new(c.re as f32, c.im as f32)
103 }
104
105 fn norm_sqr(&self) -> f64 {
106 (self.re * self.re + self.im * self.im) as f64
107 }
108
109 fn scale(&mut self, factor: f64) {
110 *self *= factor as f32;
111 }
112}
113
114#[derive(Debug, Clone, Copy)]
116pub struct ComplexF16 {
117 pub re: f16,
118 pub im: f16,
119}
120
121impl ComplexAmplitude for ComplexF16 {
122 fn to_complex64(&self) -> Complex64 {
123 Complex64::new(self.re.to_f64(), self.im.to_f64())
124 }
125
126 fn from_complex64(c: Complex64) -> Self {
127 ComplexF16 {
128 re: f16::from_f64(c.re),
129 im: f16::from_f64(c.im),
130 }
131 }
132
133 fn norm_sqr(&self) -> f64 {
134 let r = self.re.to_f64();
135 let i = self.im.to_f64();
136 r * r + i * i
137 }
138
139 fn scale(&mut self, factor: f64) {
140 self.re = f16::from_f64(self.re.to_f64() * factor);
141 self.im = f16::from_f64(self.im.to_f64() * factor);
142 }
143}
144
145pub enum AdaptiveStateVector {
147 Half(Array1<ComplexF16>),
148 Single(Array1<Complex32>),
149 Double(Array1<Complex64>),
150}
151
152impl AdaptiveStateVector {
153 pub fn new(num_qubits: usize, precision: Precision) -> Result<Self> {
155 let size = 1 << num_qubits;
156
157 if num_qubits > 30 {
158 return Err(SimulatorError::InvalidQubits(num_qubits));
159 }
160
161 match precision {
162 Precision::Half => {
163 let mut state = Array1::from_elem(
164 size,
165 ComplexF16 {
166 re: f16::from_f64(0.0),
167 im: f16::from_f64(0.0),
168 },
169 );
170 state[0] = ComplexF16 {
171 re: f16::from_f64(1.0),
172 im: f16::from_f64(0.0),
173 };
174 Ok(AdaptiveStateVector::Half(state))
175 }
176 Precision::Single => {
177 let mut state = Array1::zeros(size);
178 state[0] = Complex32::new(1.0, 0.0);
179 Ok(AdaptiveStateVector::Single(state))
180 }
181 Precision::Double => {
182 let mut state = Array1::zeros(size);
183 state[0] = Complex64::new(1.0, 0.0);
184 Ok(AdaptiveStateVector::Double(state))
185 }
186 Precision::Extended => Err(SimulatorError::InvalidConfiguration(
187 "Extended precision not yet supported".to_string(),
188 )),
189 }
190 }
191
192 pub fn precision(&self) -> Precision {
194 match self {
195 AdaptiveStateVector::Half(_) => Precision::Half,
196 AdaptiveStateVector::Single(_) => Precision::Single,
197 AdaptiveStateVector::Double(_) => Precision::Double,
198 }
199 }
200
201 pub fn num_qubits(&self) -> usize {
203 let size = match self {
204 AdaptiveStateVector::Half(v) => v.len(),
205 AdaptiveStateVector::Single(v) => v.len(),
206 AdaptiveStateVector::Double(v) => v.len(),
207 };
208 (size as f64).log2() as usize
209 }
210
211 pub fn to_complex64(&self) -> Array1<Complex64> {
213 match self {
214 AdaptiveStateVector::Half(v) => v.map(|c| c.to_complex64()),
215 AdaptiveStateVector::Single(v) => v.map(|c| c.to_complex64()),
216 AdaptiveStateVector::Double(v) => v.clone(),
217 }
218 }
219
220 pub fn from_complex64(&mut self, data: &Array1<Complex64>) -> Result<()> {
222 match self {
223 AdaptiveStateVector::Half(v) => {
224 if v.len() != data.len() {
225 return Err(SimulatorError::DimensionMismatch(format!(
226 "Size mismatch: {} vs {}",
227 v.len(),
228 data.len()
229 )));
230 }
231 for (i, &c) in data.iter().enumerate() {
232 v[i] = ComplexF16::from_complex64(c);
233 }
234 }
235 AdaptiveStateVector::Single(v) => {
236 if v.len() != data.len() {
237 return Err(SimulatorError::DimensionMismatch(format!(
238 "Size mismatch: {} vs {}",
239 v.len(),
240 data.len()
241 )));
242 }
243 for (i, &c) in data.iter().enumerate() {
244 v[i] = Complex32::from_complex64(c);
245 }
246 }
247 AdaptiveStateVector::Double(v) => {
248 v.assign(data);
249 }
250 }
251 Ok(())
252 }
253
254 pub fn needs_precision_upgrade(&self, threshold: f64) -> bool {
256 let min_amplitude = match self {
258 AdaptiveStateVector::Half(v) => v
259 .iter()
260 .map(|c| c.norm_sqr())
261 .filter(|&n| n > 0.0)
262 .fold(None, |acc, x| match acc {
263 None => Some(x),
264 Some(y) => Some(if x < y { x } else { y }),
265 }),
266 AdaptiveStateVector::Single(v) => v
267 .iter()
268 .map(|c| c.norm_sqr() as f64)
269 .filter(|&n| n > 0.0)
270 .fold(None, |acc, x| match acc {
271 None => Some(x),
272 Some(y) => Some(if x < y { x } else { y }),
273 }),
274 AdaptiveStateVector::Double(v) => v
275 .iter()
276 .map(|c| c.norm_sqr())
277 .filter(|&n| n > 0.0)
278 .fold(None, |acc, x| match acc {
279 None => Some(x),
280 Some(y) => Some(if x < y { x } else { y }),
281 }),
282 };
283
284 if let Some(min_amp) = min_amplitude {
285 min_amp < threshold * self.precision().epsilon()
286 } else {
287 false
288 }
289 }
290
291 pub fn upgrade_precision(&mut self) -> Result<()> {
293 let new_precision = match self.precision() {
294 Precision::Half => Precision::Single,
295 Precision::Single => Precision::Double,
296 Precision::Double => return Ok(()), Precision::Extended => unreachable!(),
298 };
299
300 let data = self.to_complex64();
301 *self = Self::new(self.num_qubits(), new_precision)?;
302 self.from_complex64(&data)?;
303
304 Ok(())
305 }
306
307 pub fn downgrade_precision(&mut self, tolerance: f64) -> Result<()> {
309 let new_precision = match self.precision() {
310 Precision::Half => return Ok(()), Precision::Single => Precision::Half,
312 Precision::Double => Precision::Single,
313 Precision::Extended => Precision::Double,
314 };
315
316 let data = self.to_complex64();
318 let test_vec = Self::new(self.num_qubits(), new_precision)?;
319
320 let mut max_error: f64 = 0.0;
322 match &test_vec {
323 AdaptiveStateVector::Half(_) => {
324 for &c in data.iter() {
325 let converted = ComplexF16::from_complex64(c).to_complex64();
326 let error = (c - converted).norm();
327 max_error = max_error.max(error);
328 }
329 }
330 AdaptiveStateVector::Single(_) => {
331 for &c in data.iter() {
332 let converted = Complex32::from_complex64(c).to_complex64();
333 let error = (c - converted).norm();
334 max_error = max_error.max(error);
335 }
336 }
337 _ => unreachable!(),
338 }
339
340 if max_error < tolerance {
341 *self = test_vec;
342 self.from_complex64(&data)?;
343 }
344
345 Ok(())
346 }
347
348 pub fn memory_usage(&self) -> usize {
350 let elements = match self {
351 AdaptiveStateVector::Half(v) => v.len(),
352 AdaptiveStateVector::Single(v) => v.len(),
353 AdaptiveStateVector::Double(v) => v.len(),
354 };
355 elements * self.precision().bytes_per_complex()
356 }
357}
358
359#[derive(Debug, Clone)]
361pub struct AdaptivePrecisionConfig {
362 pub initial_precision: Precision,
364 pub error_tolerance: f64,
366 pub check_interval: usize,
368 pub auto_upgrade: bool,
370 pub auto_downgrade: bool,
372 pub min_amplitude: f64,
374}
375
376impl Default for AdaptivePrecisionConfig {
377 fn default() -> Self {
378 Self {
379 initial_precision: Precision::Single,
380 error_tolerance: 1e-10,
381 check_interval: 100,
382 auto_upgrade: true,
383 auto_downgrade: true,
384 min_amplitude: 1e-12,
385 }
386 }
387}
388
389#[derive(Debug)]
391pub struct PrecisionTracker {
392 changes: Vec<(usize, Precision, Precision)>, gate_count: usize,
396 config: AdaptivePrecisionConfig,
398}
399
400impl PrecisionTracker {
401 pub fn new(config: AdaptivePrecisionConfig) -> Self {
403 Self {
404 changes: Vec::new(),
405 gate_count: 0,
406 config,
407 }
408 }
409
410 pub fn record_gate(&mut self) {
412 self.gate_count += 1;
413 }
414
415 pub fn should_check_precision(&self) -> bool {
417 self.gate_count % self.config.check_interval == 0
418 }
419
420 pub fn record_change(&mut self, from: Precision, to: Precision) {
422 self.changes.push((self.gate_count, from, to));
423 }
424
425 pub fn history(&self) -> &[(usize, Precision, Precision)] {
427 &self.changes
428 }
429
430 pub fn stats(&self) -> PrecisionStats {
432 let mut upgrades = 0;
433 let mut downgrades = 0;
434
435 for (_, from, to) in &self.changes {
436 match (from, to) {
437 (Precision::Half, Precision::Single)
438 | (Precision::Single, Precision::Double)
439 | (Precision::Double, Precision::Extended) => upgrades += 1,
440 _ => downgrades += 1,
441 }
442 }
443
444 PrecisionStats {
445 total_gates: self.gate_count,
446 precision_changes: self.changes.len(),
447 upgrades,
448 downgrades,
449 }
450 }
451}
452
453#[derive(Debug)]
455pub struct PrecisionStats {
456 pub total_gates: usize,
457 pub precision_changes: usize,
458 pub upgrades: usize,
459 pub downgrades: usize,
460}
461
462impl fmt::Display for PrecisionStats {
463 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
464 write!(
465 f,
466 "Precision Stats: {} gates, {} changes ({} upgrades, {} downgrades)",
467 self.total_gates, self.precision_changes, self.upgrades, self.downgrades
468 )
469 }
470}
471
472pub fn benchmark_precisions(num_qubits: usize) -> Result<()> {
474 println!("\nPrecision Benchmark for {} qubits:", num_qubits);
475 println!("{:-<60}", "");
476
477 for precision in [Precision::Half, Precision::Single, Precision::Double] {
478 let state = AdaptiveStateVector::new(num_qubits, precision)?;
479 let memory = state.memory_usage();
480 let memory_mb = memory as f64 / (1024.0 * 1024.0);
481
482 println!(
483 "{:?} precision: {:.2} MB ({} bytes per amplitude)",
484 precision,
485 memory_mb,
486 precision.bytes_per_complex()
487 );
488 }
489
490 Ok(())
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496
497 #[test]
498 fn test_precision_levels() {
499 assert_eq!(Precision::Half.bytes_per_complex(), 4);
500 assert_eq!(Precision::Single.bytes_per_complex(), 8);
501 assert_eq!(Precision::Double.bytes_per_complex(), 16);
502 }
503
504 #[test]
505 fn test_precision_from_tolerance() {
506 assert_eq!(Precision::from_tolerance(0.01), Precision::Half);
507 assert_eq!(Precision::from_tolerance(1e-8), Precision::Double); assert_eq!(Precision::from_tolerance(1e-16), Precision::Extended); }
510
511 #[test]
512 fn test_complex_f16() {
513 let c = ComplexF16 {
514 re: f16::from_f64(0.5),
515 im: f16::from_f64(0.5),
516 };
517
518 let c64 = c.to_complex64();
519 assert!((c64.re - 0.5).abs() < 0.01);
520 assert!((c64.im - 0.5).abs() < 0.01);
521 }
522
523 #[test]
524 fn test_adaptive_state_vector() {
525 let mut state = AdaptiveStateVector::new(2, Precision::Single).unwrap();
526 assert_eq!(state.precision(), Precision::Single);
527 assert_eq!(state.num_qubits(), 2);
528
529 let c64 = state.to_complex64();
531 assert_eq!(c64.len(), 4);
532 assert_eq!(c64[0], Complex64::new(1.0, 0.0));
533 }
534
535 #[test]
536 fn test_precision_upgrade() {
537 let mut state = AdaptiveStateVector::new(2, Precision::Half).unwrap();
538 state.upgrade_precision().unwrap();
539 assert_eq!(state.precision(), Precision::Single);
540 }
541
542 #[test]
543 fn test_precision_tracker() {
544 let config = AdaptivePrecisionConfig::default();
545 let mut tracker = PrecisionTracker::new(config);
546
547 for _ in 0..100 {
549 tracker.record_gate();
550 }
551
552 assert!(tracker.should_check_precision());
553
554 tracker.record_change(Precision::Single, Precision::Double);
555 let stats = tracker.stats();
556 assert_eq!(stats.upgrades, 1);
557 assert_eq!(stats.downgrades, 0);
558 }
559
560 #[test]
561 fn test_memory_usage() {
562 let state = AdaptiveStateVector::new(10, Precision::Half).unwrap();
563 let memory = state.memory_usage();
564 assert_eq!(memory, 1024 * 4); }
566}