1#![allow(dead_code)]
7
8use crate::sampler::{SampleResult, Sampler, SamplerError, SamplerResult};
9use scirs2_core::ndarray::{Array, Array1, Array2, IxDyn};
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::{Distribution, RandNormal, Rng, SeedableRng};
12use scirs2_core::Complex64;
13use std::collections::HashMap;
14
15type Normal<T> = RandNormal<T>;
16use std::f64::consts::PI;
17
18#[derive(Clone)]
20pub struct CIMSimulator {
21 pub n_spins: usize,
23 pump_parameter: f64,
25 detuning: f64,
27 dt: f64,
29 evolution_time: f64,
31 noise_strength: f64,
33 coupling_scale: f64,
35 seed: Option<u64>,
37 use_feedback: bool,
39 feedback_delay: f64,
41}
42
43impl CIMSimulator {
44 pub const fn new(n_spins: usize) -> Self {
46 Self {
47 n_spins,
48 pump_parameter: 1.0,
49 detuning: 0.0,
50 dt: 0.01,
51 evolution_time: 10.0,
52 noise_strength: 0.1,
53 coupling_scale: 1.0,
54 seed: None,
55 use_feedback: true,
56 feedback_delay: 0.1,
57 }
58 }
59
60 pub const fn with_pump_parameter(mut self, pump: f64) -> Self {
62 self.pump_parameter = pump;
63 self
64 }
65
66 pub const fn with_detuning(mut self, detuning: f64) -> Self {
68 self.detuning = detuning;
69 self
70 }
71
72 pub const fn with_time_step(mut self, dt: f64) -> Self {
74 self.dt = dt;
75 self
76 }
77
78 pub const fn with_evolution_time(mut self, time: f64) -> Self {
80 self.evolution_time = time;
81 self
82 }
83
84 pub const fn with_noise_strength(mut self, noise: f64) -> Self {
86 self.noise_strength = noise;
87 self
88 }
89
90 pub const fn with_coupling_scale(mut self, scale: f64) -> Self {
92 self.coupling_scale = scale;
93 self
94 }
95
96 pub const fn with_seed(mut self, seed: u64) -> Self {
98 self.seed = Some(seed);
99 self
100 }
101
102 pub const fn with_feedback(mut self, use_feedback: bool) -> Self {
104 self.use_feedback = use_feedback;
105 self
106 }
107
108 fn simulate_cim(
110 &self,
111 coupling_matrix: &Array2<f64>,
112 local_fields: &Array1<f64>,
113 rng: &mut StdRng,
114 ) -> Result<Vec<f64>, String> {
115 let n = self.n_spins;
116 let steps = (self.evolution_time / self.dt) as usize;
117
118 let mut amplitudes: Vec<Complex64> = (0..n)
120 .map(|_| {
121 let r = rng.random_range(0.0..0.1);
122 let theta = rng.random_range(0.0..2.0 * PI);
123 Complex64::new(r * theta.cos(), r * theta.sin())
124 })
125 .collect();
126
127 let standard_normal = Normal::<f64>::new(0.0_f64, 1.0_f64)
129 .expect("Normal distribution with mean=0, std=1 is always valid");
130
131 for step in 0..steps {
133 let mut new_amplitudes = amplitudes.clone();
134
135 let noise_scale = self.noise_strength;
137
138 for i in 0..n {
139 let mut coupling_term = Complex64::new(0.0, 0.0);
141 for j in 0..n {
142 if i != j {
143 let coupling = coupling_matrix[[i, j]] * self.coupling_scale;
144
145 if self.use_feedback {
146 let delayed_step =
148 (step as f64 - self.feedback_delay / self.dt).max(0.0) as usize;
149 let delayed_amp = if delayed_step < step {
150 amplitudes[j]
151 } else {
152 amplitudes[j]
153 };
154 coupling_term += coupling * delayed_amp.re;
155 } else {
156 coupling_term += coupling * amplitudes[j];
158 }
159 }
160 }
161
162 coupling_term += local_fields[i];
164
165 let nonlinear_term = amplitudes[i] * amplitudes[i].norm_sqr();
167 let pump_term = self.pump_parameter;
168 let detuning_term = Complex64::new(0.0, -self.detuning) * amplitudes[i];
169
170 let deterministic = (pump_term - 1.0) * amplitudes[i] - nonlinear_term
172 + detuning_term
173 + coupling_term;
174
175 let noise_re = standard_normal.sample(rng) * noise_scale;
178 let noise_im = standard_normal.sample(rng) * noise_scale;
179 let noise = Complex64::new(noise_re, noise_im);
180
181 new_amplitudes[i] =
184 amplitudes[i] + self.dt * deterministic + (self.dt.sqrt()) * noise;
185 }
186
187 amplitudes = new_amplitudes;
188
189 if step % 100 == 0 {
191 self.apply_constraints(&mut amplitudes);
192 }
193 }
194
195 let spins: Vec<f64> = amplitudes.iter().map(|amp| amp.re.signum()).collect();
197
198 Ok(spins)
199 }
200
201 fn apply_constraints(&self, amplitudes: &mut Vec<Complex64>) {
203 let max_amplitude = 2.0;
205 for amp in amplitudes.iter_mut() {
206 if amp.norm() > max_amplitude {
207 *amp = *amp / amp.norm() * max_amplitude;
208 }
209 }
210 }
211
212 fn qubo_to_ising(&self, qubo_matrix: &Array2<f64>) -> (Array2<f64>, Array1<f64>, f64) {
214 let n = qubo_matrix.shape()[0];
215 let mut j_matrix = Array2::zeros((n, n));
216 let mut h_vector = Array1::zeros(n);
217 let mut offset = 0.0;
218
219 for i in 0..n {
221 for j in 0..n {
222 if i == j {
223 h_vector[i] += qubo_matrix[[i, i]];
224 offset += qubo_matrix[[i, i]] / 2.0;
225 } else if i < j {
226 j_matrix[[i, j]] = qubo_matrix[[i, j]] / 4.0;
227 j_matrix[[j, i]] = qubo_matrix[[i, j]] / 4.0;
228 h_vector[i] += qubo_matrix[[i, j]] / 2.0;
229 h_vector[j] += qubo_matrix[[i, j]] / 2.0;
230 offset += qubo_matrix[[i, j]] / 4.0;
231 }
232 }
233 }
234
235 (j_matrix, h_vector, offset)
236 }
237
238 fn spins_to_binary(&self, spins: &[f64]) -> Vec<bool> {
240 spins.iter().map(|&s| s > 0.0).collect()
241 }
242
243 fn calculate_ising_energy(
245 &self,
246 spins: &[f64],
247 j_matrix: &Array2<f64>,
248 h_vector: &Array1<f64>,
249 ) -> f64 {
250 let n = spins.len();
251 let mut energy = 0.0;
252
253 for i in 0..n {
255 for j in i + 1..n {
256 energy += j_matrix[[i, j]] * spins[i] * spins[j];
257 }
258 }
259
260 for i in 0..n {
262 energy += h_vector[i] * spins[i];
263 }
264
265 energy
266 }
267}
268
269impl Sampler for CIMSimulator {
270 fn run_qubo(
271 &self,
272 qubo: &(Array2<f64>, HashMap<String, usize>),
273 shots: usize,
274 ) -> SamplerResult<Vec<SampleResult>> {
275 let (qubo_matrix, var_map) = qubo;
276 let n = qubo_matrix.shape()[0];
277
278 if n != self.n_spins {
279 return Err(SamplerError::InvalidParameter(format!(
280 "CIM configured for {} spins but QUBO has {} variables",
281 self.n_spins, n
282 )));
283 }
284
285 let (j_matrix, h_vector, offset) = self.qubo_to_ising(qubo_matrix);
287
288 let mut rng = match self.seed {
290 Some(seed) => StdRng::seed_from_u64(seed),
291 None => StdRng::seed_from_u64(42), };
293
294 let mut results = Vec::new();
295 let mut solution_counts: HashMap<Vec<bool>, (f64, usize)> = HashMap::new();
296
297 for _ in 0..shots {
299 let spins = self.simulate_cim(&j_matrix, &h_vector, &mut rng)?;
301
302 let binary = self.spins_to_binary(&spins);
304
305 let ising_energy = self.calculate_ising_energy(&spins, &j_matrix, &h_vector);
307 let qubo_energy = ising_energy + offset;
308
309 let entry = solution_counts
311 .entry(binary.clone())
312 .or_insert((qubo_energy, 0));
313 entry.1 += 1;
314 }
315
316 for (binary, (energy, count)) in solution_counts {
318 let assignments: HashMap<String, bool> = var_map
319 .iter()
320 .map(|(var, &idx)| (var.clone(), binary[idx]))
321 .collect();
322
323 results.push(SampleResult {
324 assignments,
325 energy,
326 occurrences: count,
327 });
328 }
329
330 results.sort_by(|a, b| {
332 a.energy
333 .partial_cmp(&b.energy)
334 .unwrap_or(std::cmp::Ordering::Equal)
335 });
336
337 Ok(results)
338 }
339
340 fn run_hobo(
341 &self,
342 hobo: &(Array<f64, IxDyn>, HashMap<String, usize>),
343 shots: usize,
344 ) -> SamplerResult<Vec<SampleResult>> {
345 let (tensor, var_map) = hobo;
346
347 let (qubo, ext_var_map) = crate::sampler::energy::hobo_to_qubo(tensor, var_map)
349 .map_err(SamplerError::InvalidModel)?;
350
351 let n_qubo = qubo.shape()[0];
355 let mut tmp_cim = self.clone();
356 tmp_cim.n_spins = n_qubo;
357
358 let mut results = tmp_cim.run_qubo(&(qubo, ext_var_map), shots)?;
360
361 for result in &mut results {
363 result.assignments.retain(|k, _| !k.starts_with("_aux_"));
364 }
365
366 Ok(results)
367 }
368}
369
370pub struct AdvancedCIM {
372 pub base_cim: CIMSimulator,
374 pulse_shape: PulseShape,
376 error_correction: ErrorCorrectionScheme,
378 pub bifurcation_control: BifurcationControl,
380 pub num_rounds: usize,
382}
383
384#[derive(Debug, Clone)]
385pub enum PulseShape {
386 Gaussian { width: f64, amplitude: f64 },
388 Sech { width: f64, amplitude: f64 },
390 Custom { name: String, parameters: Vec<f64> },
392}
393
394#[derive(Debug, Clone)]
395pub enum ErrorCorrectionScheme {
396 None,
398 MajorityVoting { window_size: usize },
400 ParityCheck { check_matrix: Array2<bool> },
402 Stabilizer { generators: Vec<Vec<bool>> },
404}
405
406#[derive(Debug, Clone)]
407pub struct BifurcationControl {
408 pub initial_param: f64,
410 pub final_param: f64,
412 ramp_time: f64,
414 ramp_type: RampType,
416}
417
418#[derive(Debug, Clone)]
419pub enum RampType {
420 Linear,
421 Exponential,
422 Sigmoid,
423 Adaptive,
424}
425
426impl AdvancedCIM {
427 pub const fn new(n_spins: usize) -> Self {
429 Self {
430 base_cim: CIMSimulator::new(n_spins),
431 pulse_shape: PulseShape::Gaussian {
432 width: 1.0,
433 amplitude: 1.0,
434 },
435 error_correction: ErrorCorrectionScheme::None,
436 bifurcation_control: BifurcationControl {
437 initial_param: 0.0,
438 final_param: 2.0,
439 ramp_time: 5.0,
440 ramp_type: RampType::Linear,
441 },
442 num_rounds: 1,
443 }
444 }
445
446 pub fn with_pulse_shape(mut self, shape: PulseShape) -> Self {
448 self.pulse_shape = shape;
449 self
450 }
451
452 pub fn with_error_correction(mut self, scheme: ErrorCorrectionScheme) -> Self {
454 self.error_correction = scheme;
455 self
456 }
457
458 pub const fn with_bifurcation_control(mut self, control: BifurcationControl) -> Self {
460 self.bifurcation_control = control;
461 self
462 }
463
464 pub const fn with_num_rounds(mut self, rounds: usize) -> Self {
466 self.num_rounds = rounds;
467 self
468 }
469
470 fn apply_pulse_shaping(&self, t: f64) -> f64 {
472 match &self.pulse_shape {
473 PulseShape::Gaussian { width, amplitude } => {
474 let sigma = width;
475 amplitude * (-t * t / (2.0 * sigma * sigma)).exp()
476 }
477 PulseShape::Sech { width, amplitude } => amplitude / (t / width).cosh(),
478 PulseShape::Custom { .. } => {
479 1.0
481 }
482 }
483 }
484
485 fn apply_error_correction(&self, spins: &mut Vec<f64>, history: &[Vec<f64>]) {
487 match &self.error_correction {
488 ErrorCorrectionScheme::None => {}
489 ErrorCorrectionScheme::MajorityVoting { window_size } => {
490 if history.len() >= *window_size {
491 for i in 0..spins.len() {
492 let mut sum = 0.0;
493 for h in history.iter().rev().take(*window_size) {
494 sum += h[i];
495 }
496 spins[i] = if sum > 0.0 { 1.0 } else { -1.0 };
497 }
498 }
499 }
500 ErrorCorrectionScheme::ParityCheck { check_matrix } => {
501 let n = spins.len();
503 let m = check_matrix.shape()[0];
504
505 for i in 0..m {
506 let mut parity = 0;
507 for j in 0..n {
508 if check_matrix[[i, j]] && spins[j] > 0.0 {
509 parity ^= 1;
510 }
511 }
512 if parity != 0 {
514 for j in 0..n {
517 if check_matrix[[i, j]] {
518 spins[j] *= -1.0;
519 break;
520 }
521 }
522 }
523 }
524 }
525 ErrorCorrectionScheme::Stabilizer { .. } => {
526 }
528 }
529 }
530
531 fn compute_bifurcation_param(&self, t: f64) -> f64 {
533 let progress = (t / self.bifurcation_control.ramp_time).min(1.0);
534 let initial = self.bifurcation_control.initial_param;
535 let final_param = self.bifurcation_control.final_param;
536
537 match self.bifurcation_control.ramp_type {
538 RampType::Linear => (final_param - initial).mul_add(progress, initial),
539 RampType::Exponential => {
540 (final_param - initial).mul_add(1.0 - (-5.0 * progress).exp(), initial)
541 }
542 RampType::Sigmoid => {
543 let x = 10.0 * (progress - 0.5);
544 let sigmoid = 1.0 / (1.0 + (-x).exp());
545 (final_param - initial).mul_add(sigmoid, initial)
546 }
547 RampType::Adaptive => {
548 (final_param - initial).mul_add(progress.powi(2), initial)
550 }
551 }
552 }
553}
554
555impl Sampler for AdvancedCIM {
556 fn run_qubo(
557 &self,
558 qubo: &(Array2<f64>, HashMap<String, usize>),
559 shots: usize,
560 ) -> SamplerResult<Vec<SampleResult>> {
561 let mut all_results = Vec::new();
562 let shots_per_round = shots / self.num_rounds.max(1);
563
564 for round in 0..self.num_rounds {
565 let t = round as f64 * self.base_cim.evolution_time / self.num_rounds as f64;
567 let pump = self.compute_bifurcation_param(t);
568
569 let mut round_cim = self.base_cim.clone();
570 round_cim.pump_parameter = pump * self.apply_pulse_shaping(t);
571
572 let round_results = round_cim.run_qubo(qubo, shots_per_round)?;
574 all_results.extend(round_results);
575 }
576
577 let mut aggregated: HashMap<Vec<bool>, (f64, usize)> = HashMap::new();
579
580 for result in all_results {
581 let state: Vec<bool> = qubo.1.keys().map(|var| result.assignments[var]).collect();
582
583 let entry = aggregated.entry(state).or_insert((result.energy, 0));
584 entry.1 += result.occurrences;
585 }
586
587 let mut final_results: Vec<SampleResult> = aggregated
588 .into_iter()
589 .map(|(state, (energy, count))| {
590 let assignments: HashMap<String, bool> = qubo
591 .1
592 .iter()
593 .zip(state.iter())
594 .map(|((var, _), &val)| (var.clone(), val))
595 .collect();
596
597 SampleResult {
598 assignments,
599 energy,
600 occurrences: count,
601 }
602 })
603 .collect();
604
605 final_results.sort_by(|a, b| {
606 a.energy
607 .partial_cmp(&b.energy)
608 .unwrap_or(std::cmp::Ordering::Equal)
609 });
610
611 Ok(final_results)
612 }
613
614 fn run_hobo(
615 &self,
616 hobo: &(Array<f64, IxDyn>, HashMap<String, usize>),
617 shots: usize,
618 ) -> SamplerResult<Vec<SampleResult>> {
619 self.base_cim.run_hobo(hobo, shots)
620 }
621}
622
623pub struct NetworkedCIM {
625 pub modules: Vec<CIMSimulator>,
627 topology: NetworkTopology,
629 sync_scheme: SynchronizationScheme,
631 comm_delay: f64,
633}
634
635#[derive(Debug, Clone)]
636pub enum NetworkTopology {
637 FullyConnected,
639 Ring,
641 Grid2D { rows: usize, cols: usize },
643 Hierarchical { levels: usize },
645 Custom { adjacency: Array2<bool> },
647}
648
649#[derive(Debug, Clone)]
650pub enum SynchronizationScheme {
651 Synchronous,
653 Asynchronous,
655 BlockSynchronous { block_size: usize },
657 EventDriven { threshold: f64 },
659}
660
661impl NetworkedCIM {
662 pub fn new(num_modules: usize, spins_per_module: usize, topology: NetworkTopology) -> Self {
664 let modules = (0..num_modules)
665 .map(|_| CIMSimulator::new(spins_per_module))
666 .collect();
667
668 Self {
669 modules,
670 topology,
671 sync_scheme: SynchronizationScheme::Synchronous,
672 comm_delay: 0.0,
673 }
674 }
675
676 pub const fn with_sync_scheme(mut self, scheme: SynchronizationScheme) -> Self {
678 self.sync_scheme = scheme;
679 self
680 }
681
682 pub const fn with_comm_delay(mut self, delay: f64) -> Self {
684 self.comm_delay = delay;
685 self
686 }
687
688 pub fn get_neighbors(&self, module_idx: usize) -> Vec<usize> {
690 match &self.topology {
691 NetworkTopology::FullyConnected => (0..self.modules.len())
692 .filter(|&i| i != module_idx)
693 .collect(),
694 NetworkTopology::Ring => {
695 let n = self.modules.len();
696 vec![(module_idx + n - 1) % n, (module_idx + 1) % n]
697 }
698 NetworkTopology::Grid2D { rows, cols } => {
699 let row = module_idx / cols;
700 let col = module_idx % cols;
701 let mut neighbors = Vec::new();
702
703 if row > 0 {
704 neighbors.push((row - 1) * cols + col);
705 }
706 if row < rows - 1 {
707 neighbors.push((row + 1) * cols + col);
708 }
709 if col > 0 {
710 neighbors.push(row * cols + (col - 1));
711 }
712 if col < cols - 1 {
713 neighbors.push(row * cols + (col + 1));
714 }
715
716 neighbors
717 }
718 _ => Vec::new(),
719 }
720 }
721}
722
723#[cfg(test)]
724mod tests {
725 use super::*;
726
727 #[test]
728 fn test_cim_simulator() {
729 let cim = CIMSimulator::new(4)
730 .with_pump_parameter(1.5)
731 .with_evolution_time(5.0)
732 .with_seed(42);
733
734 let mut qubo_matrix = Array2::zeros((4, 4));
736 qubo_matrix[[0, 1]] = -1.0;
737 qubo_matrix[[1, 0]] = -1.0;
738
739 let mut var_map = HashMap::new();
740 for i in 0..4 {
741 var_map.insert(format!("x{i}"), i);
742 }
743
744 let results = cim
745 .run_qubo(&(qubo_matrix, var_map), 10)
746 .expect("CIM run_qubo should succeed for valid QUBO input");
747 assert!(!results.is_empty());
748 }
749
750 #[test]
751 fn test_advanced_cim() {
752 let cim = AdvancedCIM::new(3)
753 .with_pulse_shape(PulseShape::Gaussian {
754 width: 1.0,
755 amplitude: 1.5,
756 })
757 .with_num_rounds(2);
758
759 assert_eq!(cim.num_rounds, 2);
760 }
761
762 #[test]
763 fn test_networked_cim() {
764 let net_cim = NetworkedCIM::new(4, 2, NetworkTopology::Ring)
765 .with_sync_scheme(SynchronizationScheme::Synchronous);
766
767 assert_eq!(net_cim.modules.len(), 4);
768 assert_eq!(net_cim.get_neighbors(0), vec![3, 1]);
769 }
770}