quantrs2_tytan/
quantum_inspired_ml.rs

1//! Quantum-inspired machine learning algorithms.
2//!
3//! This module provides quantum-inspired ML algorithms that leverage
4//! quantum optimization principles for classical machine learning tasks.
5
6#![allow(dead_code)]
7
8use crate::sampler::Sampler;
9use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::prelude::*;
12use std::collections::HashMap;
13
14/// Quantum-Inspired Support Vector Machine
15pub struct QuantumSVM {
16    /// Kernel type
17    kernel: KernelType,
18    /// Regularization parameter
19    c: f64,
20    /// Kernel parameters
21    kernel_params: KernelParams,
22    /// Support vectors
23    support_vectors: Option<Array2<f64>>,
24    /// Alphas (Lagrange multipliers)
25    alphas: Option<Array1<f64>>,
26    /// Bias term
27    bias: Option<f64>,
28    /// Labels of support vectors
29    sv_labels: Option<Array1<f64>>,
30}
31
32#[derive(Debug, Clone)]
33pub enum KernelType {
34    /// Linear kernel: K(x, y) = x^T y
35    Linear,
36    /// RBF kernel: K(x, y) = exp(-gamma ||x - y||^2)
37    RBF { gamma: f64 },
38    /// Polynomial kernel: K(x, y) = (x^T y + c)^d
39    Polynomial { degree: usize, coef0: f64 },
40    /// Quantum kernel: K(x, y) = |<φ(x)|φ(y)>|^2
41    Quantum { feature_map: FeatureMap },
42}
43
44#[derive(Debug, Clone)]
45pub struct KernelParams {
46    /// Cache size for kernel matrix
47    cache_size: usize,
48    /// Tolerance for convergence
49    tolerance: f64,
50    /// Maximum iterations
51    max_iter: usize,
52}
53
54#[derive(Debug, Clone)]
55pub enum FeatureMap {
56    /// Pauli-Z feature map
57    PauliZ { depth: usize },
58    /// Pauli-ZZ feature map
59    PauliZZ { depth: usize, entanglement: String },
60    /// Custom feature map
61    Custom { name: String },
62}
63
64impl QuantumSVM {
65    /// Create new Quantum SVM
66    pub const fn new(kernel: KernelType, c: f64) -> Self {
67        Self {
68            kernel,
69            c,
70            kernel_params: KernelParams {
71                cache_size: 200,
72                tolerance: 1e-3,
73                max_iter: 1000,
74            },
75            support_vectors: None,
76            alphas: None,
77            bias: None,
78            sv_labels: None,
79        }
80    }
81
82    /// Train the SVM using quantum optimization
83    pub fn fit(
84        &mut self,
85        x: &Array2<f64>,
86        y: &Array1<f64>,
87        sampler: &dyn Sampler,
88    ) -> Result<(), String> {
89        let n_samples = x.shape()[0];
90
91        // Compute kernel matrix
92        let k_matrix = self.compute_kernel_matrix(x)?;
93
94        // Formulate as QUBO for alpha optimization
95        let (qubo, var_map) = self.create_svm_qubo(&k_matrix, y)?;
96
97        // Solve using quantum sampler
98        let results = sampler
99            .run_qubo(&(qubo, var_map.clone()), 100)
100            .map_err(|e| format!("Sampling error: {e:?}"))?;
101
102        if let Some(best) = results.first() {
103            // Extract alphas from solution
104            let alphas = self.decode_alphas(&best.assignments, &var_map, n_samples);
105
106            // Identify support vectors
107            let sv_indices: Vec<usize> = alphas
108                .iter()
109                .enumerate()
110                .filter(|(_, &alpha)| alpha > 1e-5)
111                .map(|(i, _)| i)
112                .collect();
113
114            if sv_indices.is_empty() {
115                return Err("No support vectors found".to_string());
116            }
117
118            // Store support vectors and alphas
119            let mut support_vectors = Array2::zeros((sv_indices.len(), x.shape()[1]));
120            let mut sv_alphas = Array1::zeros(sv_indices.len());
121            let mut sv_labels = Array1::zeros(sv_indices.len());
122
123            for (i, &idx) in sv_indices.iter().enumerate() {
124                support_vectors.row_mut(i).assign(&x.row(idx));
125                sv_alphas[i] = alphas[idx];
126                sv_labels[i] = y[idx];
127            }
128
129            self.support_vectors = Some(support_vectors);
130            self.alphas = Some(sv_alphas);
131            self.sv_labels = Some(sv_labels);
132
133            // Calculate bias
134            self.bias = Some(self.calculate_bias(x, y, &alphas)?);
135        }
136
137        Ok(())
138    }
139
140    /// Predict labels for new data
141    pub fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>, String> {
142        let support_vectors = self.support_vectors.as_ref().ok_or("Model not trained")?;
143        let alphas = self.alphas.as_ref().ok_or("Model not trained")?;
144        let sv_labels = self.sv_labels.as_ref().ok_or("Model not trained")?;
145        let bias = self.bias.ok_or("Model not trained")?;
146
147        let n_samples = x.shape()[0];
148        let mut predictions = Array1::zeros(n_samples);
149
150        for i in 0..n_samples {
151            let mut decision = bias;
152
153            for j in 0..support_vectors.shape()[0] {
154                let kernel_val = self.kernel_function(&x.row(i), &support_vectors.row(j))?;
155                decision += alphas[j] * sv_labels[j] * kernel_val;
156            }
157
158            predictions[i] = if decision >= 0.0 { 1.0 } else { -1.0 };
159        }
160
161        Ok(predictions)
162    }
163
164    /// Compute kernel matrix
165    fn compute_kernel_matrix(&self, x: &Array2<f64>) -> Result<Array2<f64>, String> {
166        let n = x.shape()[0];
167        let mut k_matrix = Array2::zeros((n, n));
168
169        for i in 0..n {
170            for j in i..n {
171                let k_val = self.kernel_function(&x.row(i), &x.row(j))?;
172                k_matrix[[i, j]] = k_val;
173                k_matrix[[j, i]] = k_val;
174            }
175        }
176
177        Ok(k_matrix)
178    }
179
180    /// Kernel function evaluation
181    fn kernel_function(&self, x: &ArrayView1<f64>, y: &ArrayView1<f64>) -> Result<f64, String> {
182        match &self.kernel {
183            KernelType::Linear => Ok(x.dot(y)),
184            KernelType::RBF { gamma } => {
185                let diff = x - y;
186                Ok((-gamma * diff.dot(&diff)).exp())
187            }
188            KernelType::Polynomial { degree, coef0 } => Ok((x.dot(y) + coef0).powi(*degree as i32)),
189            KernelType::Quantum { feature_map } => {
190                // Simulate quantum kernel
191                self.quantum_kernel(x, y, feature_map)
192            }
193        }
194    }
195
196    /// Quantum kernel computation
197    fn quantum_kernel(
198        &self,
199        x: &ArrayView1<f64>,
200        y: &ArrayView1<f64>,
201        feature_map: &FeatureMap,
202    ) -> Result<f64, String> {
203        // Simplified quantum kernel simulation
204        match feature_map {
205            FeatureMap::PauliZ { depth } => {
206                // Simulate Pauli-Z feature map
207                let mut kernel = 1.0;
208                for _ in 0..*depth {
209                    let phase_x: f64 = x.iter().sum();
210                    let phase_y: f64 = y.iter().sum();
211                    kernel *= (phase_x - phase_y).cos();
212                }
213                Ok(kernel * kernel) // |<φ(x)|φ(y)>|^2
214            }
215            FeatureMap::PauliZZ { depth, .. } => {
216                // Simulate Pauli-ZZ feature map with entanglement
217                let mut kernel = 1.0;
218                for d in 0..*depth {
219                    for i in 0..x.len() - 1 {
220                        let phase = (x[i] - y[i]) * (x[i + 1] - y[i + 1]);
221                        kernel *= (phase * (d + 1) as f64).cos();
222                    }
223                }
224                Ok(kernel * kernel)
225            }
226            FeatureMap::Custom { .. } => {
227                // Placeholder for custom feature maps
228                Ok(x.dot(y))
229            }
230        }
231    }
232
233    /// Create QUBO for SVM optimization
234    fn create_svm_qubo(
235        &self,
236        k_matrix: &Array2<f64>,
237        y: &Array1<f64>,
238    ) -> Result<(Array2<f64>, HashMap<String, usize>), String> {
239        let n = k_matrix.shape()[0];
240        let n_bits = 5; // Bits per alpha variable
241        let total_vars = n * n_bits;
242
243        let mut qubo = Array2::zeros((total_vars, total_vars));
244        let mut var_map = HashMap::new();
245
246        // Create variable mapping
247        for i in 0..n {
248            for b in 0..n_bits {
249                let var_name = format!("alpha_{i}_{b}");
250                var_map.insert(var_name, i * n_bits + b);
251            }
252        }
253
254        // Objective: maximize sum(alpha_i) - 0.5 * sum(alpha_i * alpha_j * y_i * y_j * K_ij)
255        // Convert to minimization and binary variables
256
257        // Linear terms (maximize sum becomes minimize negative sum)
258        for i in 0..n {
259            for b in 0..n_bits {
260                let idx = i * n_bits + b;
261                let weight = -(1 << b) as f64 / (1 << n_bits) as f64;
262                qubo[[idx, idx]] += weight;
263            }
264        }
265
266        // Quadratic terms
267        for i in 0..n {
268            for j in 0..n {
269                let coef = 0.5 * y[i] * y[j] * k_matrix[[i, j]];
270
271                for bi in 0..n_bits {
272                    for bj in 0..n_bits {
273                        let idx_i = i * n_bits + bi;
274                        let idx_j = j * n_bits + bj;
275
276                        let weight = coef * (1 << bi) as f64 * (1 << bj) as f64
277                            / ((1 << n_bits) * (1 << n_bits)) as f64;
278
279                        if idx_i == idx_j {
280                            qubo[[idx_i, idx_j]] += weight;
281                        } else {
282                            qubo[[idx_i, idx_j]] += weight / 2.0;
283                            qubo[[idx_j, idx_i]] += weight / 2.0;
284                        }
285                    }
286                }
287            }
288        }
289
290        // Constraints: 0 <= alpha_i <= C
291        let penalty = 100.0 * self.c;
292        for i in 0..n {
293            // Add penalty for exceeding C
294            let alpha_max = (1 << n_bits) - 1;
295            if alpha_max as f64 > self.c {
296                // Add quadratic penalty
297                for b1 in 0..n_bits {
298                    for b2 in b1..n_bits {
299                        if (1 << b1) + (1 << b2) > self.c as usize {
300                            let idx1 = i * n_bits + b1;
301                            let idx2 = i * n_bits + b2;
302
303                            if idx1 == idx2 {
304                                qubo[[idx1, idx1]] += penalty;
305                            } else {
306                                qubo[[idx1, idx2]] += penalty;
307                                qubo[[idx2, idx1]] += penalty;
308                            }
309                        }
310                    }
311                }
312            }
313        }
314
315        Ok((qubo, var_map))
316    }
317
318    /// Decode alpha values from binary solution
319    fn decode_alphas(
320        &self,
321        assignments: &HashMap<String, bool>,
322        var_map: &HashMap<String, usize>,
323        n_samples: usize,
324    ) -> Array1<f64> {
325        let n_bits = 5;
326        let mut alphas = Array1::zeros(n_samples);
327
328        for i in 0..n_samples {
329            let mut alpha = 0.0;
330            for b in 0..n_bits {
331                let var_name = format!("alpha_{i}_{b}");
332                if let Some(&_var_idx) = var_map.get(&var_name) {
333                    if assignments.get(&var_name).copied().unwrap_or(false) {
334                        alpha += (1 << b) as f64 / (1 << n_bits) as f64 * self.c;
335                    }
336                }
337            }
338            alphas[i] = alpha;
339        }
340
341        alphas
342    }
343
344    /// Calculate bias term
345    fn calculate_bias(
346        &self,
347        x: &Array2<f64>,
348        y: &Array1<f64>,
349        alphas: &Array1<f64>,
350    ) -> Result<f64, String> {
351        // Use first support vector to calculate bias
352        for i in 0..x.shape()[0] {
353            if alphas[i] > 1e-5 && alphas[i] < self.c - 1e-5 {
354                let mut sum = 0.0;
355                for j in 0..x.shape()[0] {
356                    if alphas[j] > 1e-5 {
357                        let k_val = self.kernel_function(&x.row(i), &x.row(j))?;
358                        sum += alphas[j] * y[j] * k_val;
359                    }
360                }
361                return Ok(y[i] - sum);
362            }
363        }
364
365        Ok(0.0)
366    }
367}
368
369/// Quantum Boltzmann Machine for generative modeling
370pub struct QuantumBoltzmannMachine {
371    /// Number of visible units
372    n_visible: usize,
373    /// Number of hidden units
374    n_hidden: usize,
375    /// Weights between visible and hidden
376    weights: Array2<f64>,
377    /// Visible bias
378    visible_bias: Array1<f64>,
379    /// Hidden bias
380    hidden_bias: Array1<f64>,
381    /// Learning rate
382    learning_rate: f64,
383    /// Temperature parameter
384    temperature: f64,
385}
386
387impl QuantumBoltzmannMachine {
388    /// Create new QBM
389    pub fn new(n_visible: usize, n_hidden: usize) -> Self {
390        let mut rng = thread_rng();
391
392        Self {
393            n_visible,
394            n_hidden,
395            weights: {
396                let mut weights = Array2::zeros((n_visible, n_hidden));
397                for element in &mut weights {
398                    *element = rng.gen_range(-0.01..0.01);
399                }
400                weights
401            },
402            visible_bias: Array1::zeros(n_visible),
403            hidden_bias: Array1::zeros(n_hidden),
404            learning_rate: 0.01,
405            temperature: 1.0,
406        }
407    }
408
409    /// Train using quantum sampling
410    pub fn train(
411        &mut self,
412        data: &Array2<f64>,
413        sampler: &dyn Sampler,
414        epochs: usize,
415    ) -> Result<Vec<f64>, String> {
416        let mut losses = Vec::new();
417        let batch_size = data.shape()[0];
418
419        for epoch in 0..epochs {
420            #[allow(unused_assignments)]
421            let mut epoch_loss = 0.0;
422
423            // Positive phase - from data
424            let pos_hidden = self.sample_hidden_given_visible(&data.view(), sampler)?;
425            let pos_associations = data.t().dot(&pos_hidden);
426
427            // Negative phase - from model
428            let neg_visible = self.sample_visible_given_hidden(&pos_hidden.view(), sampler)?;
429            let neg_hidden = self.sample_hidden_given_visible(&neg_visible.view(), sampler)?;
430            let neg_associations = neg_visible.t().dot(&neg_hidden);
431
432            // Update weights
433            self.weights +=
434                &((pos_associations - neg_associations) * self.learning_rate / batch_size as f64);
435
436            // Update biases
437            let pos_v_mean = data
438                .mean_axis(Axis(0))
439                .ok_or_else(|| "Empty data batch: cannot compute visible mean".to_string())?;
440            let neg_v_mean = neg_visible
441                .mean_axis(Axis(0))
442                .ok_or_else(|| "Empty negative visible batch: cannot compute mean".to_string())?;
443            self.visible_bias += &((pos_v_mean - neg_v_mean) * self.learning_rate);
444
445            let pos_h_mean = pos_hidden
446                .mean_axis(Axis(0))
447                .ok_or_else(|| "Empty positive hidden batch: cannot compute mean".to_string())?;
448            let neg_h_mean = neg_hidden
449                .mean_axis(Axis(0))
450                .ok_or_else(|| "Empty negative hidden batch: cannot compute mean".to_string())?;
451            self.hidden_bias += &((pos_h_mean - neg_h_mean) * self.learning_rate);
452
453            // Calculate reconstruction error
454            let reconstruction_error =
455                ((data - &neg_visible).mapv(|x| x * x)).sum() / batch_size as f64;
456            epoch_loss = reconstruction_error;
457
458            losses.push(epoch_loss);
459
460            if epoch % 10 == 0 {
461                println!("Epoch {epoch}: Loss = {epoch_loss:.4}");
462            }
463        }
464
465        Ok(losses)
466    }
467
468    /// Sample hidden given visible using quantum sampler
469    fn sample_hidden_given_visible(
470        &self,
471        visible: &ArrayView2<f64>,
472        sampler: &dyn Sampler,
473    ) -> Result<Array2<f64>, String> {
474        let batch_size = visible.shape()[0];
475        let mut hidden = Array2::zeros((batch_size, self.n_hidden));
476
477        // Create QUBO for each sample
478        for i in 0..batch_size {
479            let v = visible.row(i);
480
481            // Energy function: -sum(b_j * h_j) - sum(v_i * W_ij * h_j)
482            let mut qubo = Array2::zeros((self.n_hidden, self.n_hidden));
483            let mut var_map = HashMap::new();
484
485            for j in 0..self.n_hidden {
486                var_map.insert(format!("h_{j}"), j);
487
488                // Linear term
489                let linear = self.hidden_bias[j] + v.dot(&self.weights.column(j));
490                qubo[[j, j]] = -linear / self.temperature;
491            }
492
493            // Sample using quantum sampler
494            let results = sampler
495                .run_qubo(&(qubo, var_map), 1)
496                .map_err(|e| format!("Sampling error: {e:?}"))?;
497
498            if let Some(result) = results.first() {
499                for j in 0..self.n_hidden {
500                    let var_name = format!("h_{j}");
501                    hidden[[i, j]] = if result.assignments.get(&var_name).copied().unwrap_or(false)
502                    {
503                        1.0
504                    } else {
505                        0.0
506                    };
507                }
508            }
509        }
510
511        Ok(hidden)
512    }
513
514    /// Sample visible given hidden using quantum sampler
515    fn sample_visible_given_hidden(
516        &self,
517        hidden: &ArrayView2<f64>,
518        sampler: &dyn Sampler,
519    ) -> Result<Array2<f64>, String> {
520        let batch_size = hidden.shape()[0];
521        let mut visible = Array2::zeros((batch_size, self.n_visible));
522
523        // Similar to sample_hidden_given_visible but reversed
524        for i in 0..batch_size {
525            let h = hidden.row(i);
526
527            let mut qubo = Array2::zeros((self.n_visible, self.n_visible));
528            let mut var_map = HashMap::new();
529
530            for j in 0..self.n_visible {
531                var_map.insert(format!("v_{j}"), j);
532
533                // Linear term
534                let linear = self.visible_bias[j] + self.weights.row(j).dot(&h);
535                qubo[[j, j]] = -linear / self.temperature;
536            }
537
538            let results = sampler
539                .run_qubo(&(qubo, var_map), 1)
540                .map_err(|e| format!("Sampling error: {e:?}"))?;
541
542            if let Some(result) = results.first() {
543                for j in 0..self.n_visible {
544                    let var_name = format!("v_{j}");
545                    visible[[i, j]] = if result.assignments.get(&var_name).copied().unwrap_or(false)
546                    {
547                        1.0
548                    } else {
549                        0.0
550                    };
551                }
552            }
553        }
554
555        Ok(visible)
556    }
557
558    /// Generate new samples
559    pub fn generate(&self, n_samples: usize, sampler: &dyn Sampler) -> Result<Array2<f64>, String> {
560        // Start with random hidden state
561        let mut rng = thread_rng();
562        let mut hidden = {
563            let mut hidden = Array2::zeros((n_samples, self.n_hidden));
564            for element in &mut hidden {
565                *element = if rng.gen::<bool>() { 1.0 } else { 0.0 };
566            }
567            hidden
568        };
569
570        // Gibbs sampling
571        for _ in 0..10 {
572            let visible = self.sample_visible_given_hidden(&hidden.view(), sampler)?;
573            hidden = self.sample_hidden_given_visible(&visible.view(), sampler)?;
574        }
575
576        // Final visible sample
577        self.sample_visible_given_hidden(&hidden.view(), sampler)
578    }
579}
580
581/// Quantum-inspired clustering using quantum optimization
582pub struct QuantumClustering {
583    /// Number of clusters
584    n_clusters: usize,
585    /// Distance metric
586    distance_metric: DistanceMetric,
587    /// Regularization for balanced clusters
588    balance_weight: f64,
589}
590
591#[derive(Debug, Clone)]
592pub enum DistanceMetric {
593    Euclidean,
594    Manhattan,
595    Cosine,
596    Quantum,
597}
598
599impl QuantumClustering {
600    /// Create new quantum clustering
601    pub const fn new(n_clusters: usize) -> Self {
602        Self {
603            n_clusters,
604            distance_metric: DistanceMetric::Euclidean,
605            balance_weight: 0.1,
606        }
607    }
608
609    /// Set distance metric
610    pub const fn with_distance_metric(mut self, metric: DistanceMetric) -> Self {
611        self.distance_metric = metric;
612        self
613    }
614
615    /// Perform clustering using quantum optimization
616    pub fn fit_predict(
617        &self,
618        data: &Array2<f64>,
619        sampler: &dyn Sampler,
620    ) -> Result<Array1<usize>, String> {
621        let n_samples = data.shape()[0];
622
623        // Compute distance matrix
624        let distances = self.compute_distance_matrix(data)?;
625
626        // Create QUBO for clustering
627        let (qubo, var_map) = self.create_clustering_qubo(&distances)?;
628
629        // Solve using quantum sampler
630        let results = sampler
631            .run_qubo(&(qubo, var_map.clone()), 100)
632            .map_err(|e| format!("Sampling error: {e:?}"))?;
633
634        if let Some(best) = results.first() {
635            // Decode cluster assignments
636            let assignments = self.decode_clusters(&best.assignments, &var_map, n_samples);
637            Ok(assignments)
638        } else {
639            Err("No solution found".to_string())
640        }
641    }
642
643    /// Compute distance matrix
644    fn compute_distance_matrix(&self, data: &Array2<f64>) -> Result<Array2<f64>, String> {
645        let n = data.shape()[0];
646        let mut distances = Array2::zeros((n, n));
647
648        for i in 0..n {
649            for j in i + 1..n {
650                let dist = match &self.distance_metric {
651                    DistanceMetric::Euclidean => {
652                        let diff = &data.row(i) - &data.row(j);
653                        diff.dot(&diff).sqrt()
654                    }
655                    DistanceMetric::Manhattan => {
656                        (&data.row(i) - &data.row(j)).mapv(|x| x.abs()).sum()
657                    }
658                    DistanceMetric::Cosine => {
659                        let dot = data.row(i).dot(&data.row(j));
660                        let norm_i = data.row(i).dot(&data.row(i)).sqrt();
661                        let norm_j = data.row(j).dot(&data.row(j)).sqrt();
662                        1.0 - dot / (norm_i * norm_j)
663                    }
664                    DistanceMetric::Quantum => {
665                        // Quantum-inspired distance
666                        self.quantum_distance(&data.row(i), &data.row(j))
667                    }
668                };
669
670                distances[[i, j]] = dist;
671                distances[[j, i]] = dist;
672            }
673        }
674
675        Ok(distances)
676    }
677
678    /// Quantum-inspired distance metric
679    fn quantum_distance(&self, x: &ArrayView1<f64>, y: &ArrayView1<f64>) -> f64 {
680        // Based on quantum state fidelity
681        let inner_product = x.dot(y);
682        let norm_x = x.dot(x).sqrt();
683        let norm_y = y.dot(y).sqrt();
684
685        let fidelity = (inner_product / (norm_x * norm_y)).abs();
686        fidelity.mul_add(-fidelity, 1.0).sqrt()
687    }
688
689    /// Create QUBO for clustering
690    fn create_clustering_qubo(
691        &self,
692        distances: &Array2<f64>,
693    ) -> Result<(Array2<f64>, HashMap<String, usize>), String> {
694        let n_samples = distances.shape()[0];
695        let n_vars = n_samples * self.n_clusters;
696
697        let mut qubo = Array2::zeros((n_vars, n_vars));
698        let mut var_map = HashMap::new();
699
700        // Variable mapping: x[i,k] = 1 if sample i is in cluster k
701        for i in 0..n_samples {
702            for k in 0..self.n_clusters {
703                let var_name = format!("x_{i}_{k}");
704                var_map.insert(var_name, i * self.n_clusters + k);
705            }
706        }
707
708        // Objective: minimize sum of intra-cluster distances
709        for i in 0..n_samples {
710            for j in i + 1..n_samples {
711                for k in 0..self.n_clusters {
712                    let idx_ik = i * self.n_clusters + k;
713                    let idx_jk = j * self.n_clusters + k;
714
715                    qubo[[idx_ik, idx_jk]] += distances[[i, j]];
716                    qubo[[idx_jk, idx_ik]] += distances[[i, j]];
717                }
718            }
719        }
720
721        // Constraint: each sample in exactly one cluster
722        let penalty = distances.sum() * 10.0;
723        for i in 0..n_samples {
724            // One-hot constraint
725            for k1 in 0..self.n_clusters {
726                let idx1 = i * self.n_clusters + k1;
727
728                // Linear penalty
729                qubo[[idx1, idx1]] -= penalty;
730
731                // Quadratic penalty
732                for k2 in k1 + 1..self.n_clusters {
733                    let idx2 = i * self.n_clusters + k2;
734                    qubo[[idx1, idx2]] += penalty;
735                    qubo[[idx2, idx1]] += penalty;
736                }
737            }
738        }
739
740        // Balance term: encourage equal-sized clusters
741        if self.balance_weight > 0.0 {
742            let target_size = n_samples as f64 / self.n_clusters as f64;
743
744            for k in 0..self.n_clusters {
745                // Penalize deviation from target size
746                for i in 0..n_samples {
747                    for j in i + 1..n_samples {
748                        let idx_ik = i * self.n_clusters + k;
749                        let idx_jk = j * self.n_clusters + k;
750
751                        let weight = self.balance_weight / (target_size * target_size);
752                        qubo[[idx_ik, idx_jk]] += weight;
753                        qubo[[idx_jk, idx_ik]] += weight;
754                    }
755                }
756            }
757        }
758
759        Ok((qubo, var_map))
760    }
761
762    /// Decode cluster assignments
763    fn decode_clusters(
764        &self,
765        assignments: &HashMap<String, bool>,
766        _var_map: &HashMap<String, usize>,
767        n_samples: usize,
768    ) -> Array1<usize> {
769        let mut clusters = Array1::zeros(n_samples);
770
771        for i in 0..n_samples {
772            for k in 0..self.n_clusters {
773                let var_name = format!("x_{i}_{k}");
774                if assignments.get(&var_name).copied().unwrap_or(false) {
775                    clusters[i] = k;
776                    break;
777                }
778            }
779        }
780
781        clusters
782    }
783}
784
785#[cfg(test)]
786mod tests {
787    use super::*;
788    use crate::sampler::SASampler;
789    use quantrs2_anneal::simulator::AnnealingParams;
790    use scirs2_core::ndarray::array;
791
792    #[test]
793    fn test_quantum_svm() {
794        // Simple linearly separable data
795        let mut x = array![[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0],];
796        let mut y = array![-1.0, -1.0, 1.0, 1.0];
797
798        let mut svm = QuantumSVM::new(KernelType::Linear, 1.0);
799
800        // Create fast annealing parameters for testing
801        let mut params = AnnealingParams::new();
802        params.timeout = Some(10.0); // 10 second timeout
803        params.num_sweeps = 100; // Reduce from default 1000
804        params.num_repetitions = 2; // Reduce from default 10
805
806        let sampler = SASampler::with_params(Some(42), params);
807
808        svm.fit(&x, &y, &sampler)
809            .expect("SVM training should succeed on linearly separable data");
810
811        let mut predictions = svm
812            .predict(&x)
813            .expect("SVM prediction should succeed after training");
814
815        // Check that it learned something reasonable
816        assert!(svm.support_vectors.is_some());
817        assert_eq!(predictions.len(), 4);
818    }
819
820    #[test]
821    #[ignore]
822    fn test_quantum_clustering() {
823        let data = array![[0.0, 0.0], [0.1, 0.1], [5.0, 5.0], [5.1, 5.1],];
824
825        let clustering = QuantumClustering::new(2).with_distance_metric(DistanceMetric::Euclidean);
826
827        let sampler = SASampler::new(Some(42));
828        let labels = clustering
829            .fit_predict(&data, &sampler)
830            .expect("Clustering should succeed on simple test data");
831
832        // Check that similar points are in same cluster
833        assert_eq!(labels[0], labels[1]);
834        assert_eq!(labels[2], labels[3]);
835        assert_ne!(labels[0], labels[2]);
836    }
837}