sklears_semi_supervised/
lib.rs

1#![allow(dead_code)]
2#![allow(non_snake_case)]
3#![allow(missing_docs)]
4#![allow(deprecated)]
5#![allow(clippy::all)]
6#![allow(clippy::pedantic)]
7#![allow(clippy::nursery)]
8#![allow(unused_imports)]
9#![allow(unused_variables)]
10#![allow(unused_mut)]
11#![allow(unused_assignments)]
12#![allow(unused_doc_comments)]
13#![allow(unused_parens)]
14#![allow(unused_comparisons)]
15//! Semi-supervised learning algorithms
16//!
17//! This module provides semi-supervised learning algorithms that can utilize
18//! both labeled and unlabeled data for training.
19
20// #![warn(missing_docs)]
21
22mod active_learning;
23mod adversarial_graph_learning;
24mod approximate_graph_methods;
25mod batch_active_learning;
26mod bayesian_methods;
27mod co_training;
28mod composable_graph;
29mod contrastive_learning;
30mod convergence_tests;
31mod cross_modal_contrastive;
32mod deep_learning;
33mod democratic_co_learning;
34mod dynamic_graph_learning;
35mod entropy_methods;
36mod few_shot;
37mod graph;
38mod graph_learning;
39mod harmonic_functions;
40mod hierarchical_graph;
41mod information_theory;
42mod label_propagation;
43mod label_spreading;
44mod landmark_methods;
45mod local_global_consistency;
46mod manifold_regularization;
47mod mixture_discriminant_analysis;
48mod multi_armed_bandits;
49mod multi_view_graph;
50mod optimal_transport;
51pub mod parallel_graph;
52mod robust_graph_methods;
53mod self_training;
54mod self_training_classifier;
55mod semi_supervised_gmm;
56mod semi_supervised_naive_bayes;
57pub mod simd_distances;
58mod streaming_graph_learning;
59mod tri_training;
60
61pub use active_learning::*;
62pub use adversarial_graph_learning::*;
63pub use approximate_graph_methods::*;
64pub use batch_active_learning::*;
65pub use bayesian_methods::*;
66pub use co_training::*;
67pub use composable_graph::*;
68pub use contrastive_learning::*;
69pub use convergence_tests::*;
70pub use cross_modal_contrastive::*;
71pub use deep_learning::*;
72pub use democratic_co_learning::*;
73pub use dynamic_graph_learning::*;
74pub use entropy_methods::*;
75pub use few_shot::*;
76pub use graph::*;
77pub use graph_learning::*;
78pub use harmonic_functions::*;
79pub use hierarchical_graph::*;
80pub use information_theory::*;
81pub use label_propagation::*;
82pub use label_spreading::*;
83pub use landmark_methods::*;
84pub use local_global_consistency::*;
85pub use manifold_regularization::*;
86pub use mixture_discriminant_analysis::*;
87pub use multi_armed_bandits::*;
88pub use multi_view_graph::*;
89pub use optimal_transport::*;
90pub use robust_graph_methods::*;
91pub use self_training::*;
92pub use self_training_classifier::*;
93pub use semi_supervised_gmm::*;
94pub use semi_supervised_naive_bayes::*;
95pub use streaming_graph_learning::*;
96pub use tri_training::*;
97
98#[allow(non_snake_case)]
99#[cfg(test)]
100mod tests {
101    use super::*;
102    use scirs2_core::array;
103    use scirs2_core::ndarray_ext::Array2;
104    use sklears_core::traits::{Fit, Predict, PredictProba};
105
106    #[test]
107    fn test_label_propagation() {
108        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
109        let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
110
111        let lp = LabelPropagation::new()
112            .kernel("rbf".to_string())
113            .gamma(20.0);
114        let fitted = lp.fit(&X.view(), &y.view()).unwrap();
115
116        let predictions = fitted.predict(&X.view()).unwrap();
117        assert_eq!(predictions.len(), 4);
118
119        let probas = fitted.predict_proba(&X.view()).unwrap();
120        assert_eq!(probas.dim(), (4, 2));
121    }
122
123    #[test]
124    fn test_label_spreading() {
125        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
126        let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
127
128        let ls = LabelSpreading::new()
129            .kernel("rbf".to_string())
130            .gamma(20.0)
131            .alpha(0.2);
132        let fitted = ls.fit(&X.view(), &y.view()).unwrap();
133
134        let predictions = fitted.predict(&X.view()).unwrap();
135        assert_eq!(predictions.len(), 4);
136
137        let probas = fitted.predict_proba(&X.view()).unwrap();
138        assert_eq!(probas.dim(), (4, 2));
139    }
140
141    #[test]
142    fn test_self_training() {
143        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
144        let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
145
146        let stc = SelfTrainingClassifier::new().threshold(0.5).max_iter(5);
147        let fitted = stc.fit(&X.view(), &y.view()).unwrap();
148
149        let predictions = fitted.predict(&X.view()).unwrap();
150        assert_eq!(predictions.len(), 4);
151    }
152
153    // Note: Commented out test for private method
154    // #[test]
155    // fn test_affinity_matrix_rbf() {
156    //     let lp = LabelPropagation::new().kernel("rbf".to_string()).gamma(1.0);
157    //     let X = array![[1.0, 2.0], [3.0, 4.0]];
158    //
159    //     let W = lp.build_affinity_matrix(&X).unwrap();
160    //     assert_eq!(W.dim(), (2, 2));
161    //     assert_eq!(W[[0, 0]], 0.0); // Diagonal should be 0
162    //     assert_eq!(W[[1, 1]], 0.0);
163    //     assert!(W[[0, 1]] > 0.0); // Off-diagonal should be positive
164    //     assert!(W[[1, 0]] > 0.0);
165    // }
166
167    // Note: Commented out test for private method
168    // #[test]
169    // fn test_affinity_matrix_knn() {
170    //     let lp = LabelPropagation::new()
171    //         .kernel("knn".to_string())
172    //         .n_neighbors(1);
173    //     let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
174    //
175    //     let W = lp.build_affinity_matrix(&X).unwrap();
176    //     assert_eq!(W.dim(), (3, 3));
177    //
178    //     // Check that each row has exactly n_neighbors non-zero entries
179    //     for i in 0..3 {
180    //         let non_zero_count = W.row(i).iter().filter(|&&x| x > 0.0).count();
181    //         assert!(non_zero_count <= 2); // At most 2 (should be 1 for n_neighbors=1, but symmetric)
182    //     }
183    // }
184
185    #[test]
186    fn test_enhanced_self_training() {
187        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
188        let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
189
190        let est = EnhancedSelfTraining::new()
191            .threshold(0.6)
192            .confidence_method("entropy".to_string())
193            .max_iter(5);
194        let fitted = est.fit(&X.view(), &y.view()).unwrap();
195
196        let predictions = fitted.predict(&X.view()).unwrap();
197        assert_eq!(predictions.len(), 4);
198    }
199
200    #[test]
201    fn test_co_training() {
202        let X = array![
203            [1.0, 2.0, 3.0, 4.0],
204            [2.0, 3.0, 4.0, 5.0],
205            [3.0, 4.0, 5.0, 6.0],
206            [4.0, 5.0, 6.0, 7.0]
207        ];
208        let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
209
210        let ct = CoTraining::new()
211            .view1_features(vec![0, 1])
212            .view2_features(vec![2, 3])
213            .p(1)
214            .n(1)
215            .max_iter(5);
216        let fitted = ct.fit(&X.view(), &y.view()).unwrap();
217
218        let predictions = fitted.predict(&X.view()).unwrap();
219        assert_eq!(predictions.len(), 4);
220    }
221
222    #[test]
223    fn test_tri_training() {
224        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
225        let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
226
227        let tt = TriTraining::new().max_iter(5).theta(0.2);
228        let fitted = tt.fit(&X.view(), &y.view()).unwrap();
229
230        let predictions = fitted.predict(&X.view()).unwrap();
231        assert_eq!(predictions.len(), 4);
232    }
233
234    #[test]
235    fn test_knn_graph() {
236        let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
237
238        let W = knn_graph(&X, 1, "connectivity").unwrap();
239        assert_eq!(W.dim(), (3, 3));
240
241        // Check that each row has at most n_neighbors non-zero entries
242        for i in 0..3 {
243            let non_zero_count = W.row(i).iter().filter(|&&x| x > 0.0).count();
244            assert!(non_zero_count <= 1);
245        }
246    }
247
248    #[test]
249    fn test_epsilon_graph() {
250        let X = array![[1.0, 2.0], [1.1, 2.1], [5.0, 6.0]];
251
252        let W = epsilon_graph(&X, 1.0, "connectivity").unwrap();
253        assert_eq!(W.dim(), (3, 3));
254
255        // Points 0 and 1 should be connected (distance < 1.0)
256        assert!(W[[0, 1]] > 0.0 || W[[1, 0]] > 0.0);
257    }
258
259    #[test]
260    fn test_graph_laplacian() {
261        let W = array![[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 0.0]];
262
263        let L = graph_laplacian(&W, false).unwrap();
264        assert_eq!(L.dim(), (3, 3));
265
266        // Check Laplacian properties
267        assert_eq!(L[[0, 0]], 1.0); // degree of node 0
268        assert_eq!(L[[1, 1]], 2.0); // degree of node 1
269        assert_eq!(L[[0, 1]], -1.0); // -adjacency
270    }
271
272    #[test]
273    fn test_democratic_co_learning() {
274        let X = array![
275            [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
276            [2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
277            [3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
278            [4.0, 5.0, 6.0, 7.0, 8.0, 9.0]
279        ];
280        let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
281
282        let dcl = DemocraticCoLearning::new()
283            .views(vec![vec![0, 1], vec![2, 3], vec![4, 5]])
284            .k_add(1)
285            .min_agreement(2)
286            .max_iter(5);
287        let fitted = dcl.fit(&X.view(), &y.view()).unwrap();
288
289        let predictions = fitted.predict(&X.view()).unwrap();
290        assert_eq!(predictions.len(), 4);
291    }
292
293    #[test]
294    fn test_harmonic_functions() {
295        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
296        let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
297
298        let hf = HarmonicFunctions::new()
299            .kernel("rbf".to_string())
300            .gamma(20.0)
301            .max_iter(100);
302        let fitted = hf.fit(&X.view(), &y.view()).unwrap();
303
304        let predictions = fitted.predict(&X.view()).unwrap();
305        assert_eq!(predictions.len(), 4);
306
307        let probas = fitted.predict_proba(&X.view()).unwrap();
308        assert_eq!(probas.dim(), (4, 2));
309    }
310
311    #[test]
312    fn test_local_global_consistency() {
313        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
314        let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
315
316        let lgc = LocalGlobalConsistency::new()
317            .kernel("rbf".to_string())
318            .gamma(20.0)
319            .alpha(0.99)
320            .max_iter(100);
321        let fitted = lgc.fit(&X.view(), &y.view()).unwrap();
322
323        let predictions = fitted.predict(&X.view()).unwrap();
324        assert_eq!(predictions.len(), 4);
325
326        let probas = fitted.predict_proba(&X.view()).unwrap();
327        assert_eq!(probas.dim(), (4, 2));
328    }
329
330    #[test]
331    fn test_manifold_regularization() {
332        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
333        let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
334
335        let mr = ManifoldRegularization::new()
336            .lambda_a(0.01)
337            .lambda_i(0.1)
338            .kernel("rbf".to_string())
339            .gamma(1.0)
340            .max_iter(100);
341        let fitted = mr.fit(&X.view(), &y.view()).unwrap();
342
343        let predictions = fitted.predict(&X.view()).unwrap();
344        assert_eq!(predictions.len(), 4);
345    }
346
347    #[test]
348    fn test_semi_supervised_gmm() {
349        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
350        let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
351
352        let gmm = SemiSupervisedGMM::new()
353            .n_components(2)
354            .max_iter(50)
355            .labeled_weight(10.0);
356        let fitted = gmm.fit(&X.view(), &y.view()).unwrap();
357
358        let predictions = fitted.predict(&X.view()).unwrap();
359        assert_eq!(predictions.len(), 4);
360
361        let probas = fitted.predict_proba(&X.view()).unwrap();
362        assert_eq!(probas.dim(), (4, 2));
363    }
364
365    #[test]
366    fn test_multi_view_co_training() {
367        let X = array![
368            [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
369            [2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
370            [3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
371            [4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
372            [5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
373            [6.0, 7.0, 8.0, 9.0, 10.0, 11.0]
374        ];
375        let y = array![0, 1, -1, -1, -1, -1]; // -1 indicates unlabeled
376
377        let mvct = MultiViewCoTraining::new()
378            .views(vec![vec![0, 1], vec![2, 3], vec![4, 5]])
379            .k_add(1)
380            .confidence_threshold(0.5)
381            .max_iter(5);
382        let fitted = mvct.fit(&X.view(), &y.view()).unwrap();
383
384        let predictions = fitted.predict(&X.view()).unwrap();
385        assert_eq!(predictions.len(), 6);
386
387        // Check that labeled samples maintain their labels
388        assert_eq!(predictions[0], 0);
389        assert_eq!(predictions[1], 1);
390    }
391
392    #[test]
393    fn test_semi_supervised_naive_bayes() {
394        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
395        let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
396
397        let nb = SemiSupervisedNaiveBayes::new()
398            .alpha(1.0)
399            .max_iter(50)
400            .class_weight(1.0);
401        let fitted = nb.fit(&X.view(), &y.view()).unwrap();
402
403        let predictions = fitted.predict(&X.view()).unwrap();
404        assert_eq!(predictions.len(), 4);
405
406        let probas = fitted.predict_proba(&X.view()).unwrap();
407        assert_eq!(probas.dim(), (4, 2));
408
409        // Check that probabilities sum to 1
410        for i in 0..4 {
411            let sum: f64 = probas.row(i).sum();
412            assert!((sum - 1.0).abs() < 1e-10);
413        }
414    }
415
416    #[test]
417    fn test_random_walk_laplacian() {
418        let W = array![[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 0.0]];
419
420        let L_rw = random_walk_laplacian(&W).unwrap();
421        assert_eq!(L_rw.dim(), (3, 3));
422
423        // Check that L_rw has 1s on the diagonal
424        assert!((L_rw[[0, 0]] - 1.0).abs() < 1e-10);
425        assert!((L_rw[[1, 1]] - 1.0).abs() < 1e-10);
426        assert!((L_rw[[2, 2]] - 1.0).abs() < 1e-10);
427    }
428
429    #[test]
430    fn test_diffusion_matrix() {
431        let W = array![[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 0.0]];
432
433        let P = diffusion_matrix(&W, 2).unwrap();
434        assert_eq!(P.dim(), (3, 3));
435
436        // Check that probabilities are non-negative
437        for i in 0..3 {
438            for j in 0..3 {
439                assert!(P[[i, j]] >= 0.0);
440            }
441        }
442    }
443
444    #[test]
445    fn test_adaptive_knn_graph() {
446        let X = array![[1.0, 2.0], [1.1, 2.1], [5.0, 6.0]];
447
448        let W = adaptive_knn_graph(&X, "connectivity").unwrap();
449        assert_eq!(W.dim(), (3, 3));
450
451        // Check symmetry
452        assert_eq!(W[[0, 1]], W[[1, 0]]);
453        assert_eq!(W[[0, 2]], W[[2, 0]]);
454        assert_eq!(W[[1, 2]], W[[2, 1]]);
455    }
456
457    #[test]
458    fn test_sparsify_graph() {
459        let W = array![
460            [0.0, 0.8, 0.2, 0.1],
461            [0.8, 0.0, 0.9, 0.3],
462            [0.2, 0.9, 0.0, 0.7],
463            [0.1, 0.3, 0.7, 0.0]
464        ];
465
466        let W_sparse = sparsify_graph(&W, 0.5).unwrap();
467        assert_eq!(W_sparse.dim(), (4, 4));
468
469        // Count non-zero edges in original and sparse graphs
470        let original_edges = W.iter().filter(|&&x| x > 0.0).count();
471        let sparse_edges = W_sparse.iter().filter(|&&x| x > 0.0).count();
472
473        // Sparse graph should have fewer edges
474        assert!(sparse_edges <= original_edges);
475    }
476
477    #[test]
478    fn test_spectral_clustering() {
479        let W = array![
480            [0.0, 1.0, 0.1, 0.0, 0.0],
481            [1.0, 0.0, 0.2, 0.0, 0.0],
482            [0.1, 0.2, 0.0, 0.0, 0.0],
483            [0.0, 0.0, 0.0, 0.0, 1.0],
484            [0.0, 0.0, 0.0, 1.0, 0.0]
485        ];
486
487        let labels = spectral_clustering(&W, 2, true, Some(42)).unwrap();
488        assert_eq!(labels.len(), 5);
489
490        // Check that labels are valid
491        for &label in labels.iter() {
492            assert!(label >= 0 && label < 2);
493        }
494    }
495
496    #[test]
497    fn test_spectral_embedding() {
498        let W = array![[0.0, 1.0, 0.1], [1.0, 0.0, 0.2], [0.1, 0.2, 0.0]];
499
500        let embedding = spectral_embedding(&W, 2, true).unwrap();
501        assert_eq!(embedding.dim(), (3, 2));
502    }
503
504    // Robustness tests with label noise
505    #[test]
506    fn test_label_propagation_robustness() {
507        use scirs2_core::random::Random;
508
509        let mut rng = Random::seed(42);
510
511        // Generate synthetic dataset - create it manually to avoid distribution conflicts
512        let mut X = Array2::<f64>::zeros((50, 5));
513        for i in 0..50 {
514            for j in 0..5 {
515                X[(i, j)] = rng.random_range(-1.0..1.0);
516            }
517        }
518        let y_true = array![
519            0, 1, 0, 1, 0, 1, 0, 1, 0, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
520            -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
521            -1, -1, -1, -1
522        ];
523
524        // Test without noise
525        let lp = LabelPropagation::new()
526            .kernel("rbf".to_string())
527            .gamma(20.0);
528        let fitted_clean = lp.fit(&X.view(), &y_true.view()).unwrap();
529        let pred_clean = fitted_clean.predict(&X.view()).unwrap();
530
531        // Test with label noise (flip some labels)
532        let mut y_noisy = y_true.clone();
533        y_noisy[0] = 1; // Flip first label
534        y_noisy[2] = 1; // Flip third label
535
536        let lp_noisy = LabelPropagation::new()
537            .kernel("rbf".to_string())
538            .gamma(20.0);
539        let fitted_noisy = lp_noisy.fit(&X.view(), &y_noisy.view()).unwrap();
540        let pred_noisy = fitted_noisy.predict(&X.view()).unwrap();
541
542        // Calculate robustness (predictions should not change dramatically)
543        let different = pred_clean
544            .iter()
545            .zip(pred_noisy.iter())
546            .filter(|(a, b)| a != b)
547            .count();
548        let robustness = 1.0 - (different as f64 / pred_clean.len() as f64);
549
550        assert!(
551            robustness > 0.6,
552            "Label propagation should be somewhat robust to label noise"
553        );
554    }
555
556    #[test]
557    fn test_self_training_robustness() {
558        use scirs2_core::random::Random;
559
560        let mut rng = Random::seed(42);
561
562        // Generate synthetic dataset - create it manually to avoid distribution conflicts
563        let mut X = Array2::<f64>::zeros((30, 4));
564        for i in 0..30 {
565            for j in 0..4 {
566                X[(i, j)] = rng.random_range(-1.0..1.0);
567            }
568        }
569        let y_clean = array![
570            0, 1, 0, 1, 0, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
571            -1, -1, -1, -1, -1, -1, -1
572        ];
573
574        // Test without noise
575        let stc_clean = SelfTrainingClassifier::new().threshold(0.8).max_iter(10);
576        let fitted_clean = stc_clean.fit(&X.view(), &y_clean.view()).unwrap();
577        let pred_clean = fitted_clean.predict(&X.view()).unwrap();
578
579        // Test with noise
580        let mut y_noisy = y_clean.clone();
581        y_noisy[1] = 0; // Flip a label
582
583        let stc_noisy = SelfTrainingClassifier::new().threshold(0.8).max_iter(10);
584        let fitted_noisy = stc_noisy.fit(&X.view(), &y_noisy.view()).unwrap();
585        let pred_noisy = fitted_noisy.predict(&X.view()).unwrap();
586
587        // Check that algorithm still produces valid predictions
588        assert!(
589            pred_clean.iter().all(|&p| p >= 0 && p <= 1),
590            "Clean predictions should be valid"
591        );
592        assert!(
593            pred_noisy.iter().all(|&p| p >= 0 && p <= 1),
594            "Noisy predictions should be valid"
595        );
596    }
597
598    #[test]
599    fn test_co_training_robustness() {
600        use scirs2_core::random::Random;
601
602        let mut rng = Random::seed(42);
603
604        // Generate synthetic dataset - create it manually to avoid distribution conflicts
605        let mut X = Array2::<f64>::zeros((20, 6));
606        for i in 0..20 {
607            for j in 0..6 {
608                X[(i, j)] = rng.random_range(-1.0..1.0);
609            }
610        }
611        let y_clean =
612            array![0, 1, 0, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1];
613
614        // Test clean performance
615        let ct_clean = CoTraining::new()
616            .view1_features(vec![0, 1, 2])
617            .view2_features(vec![3, 4, 5])
618            .p(1)
619            .n(1)
620            .max_iter(5);
621        let fitted_clean = ct_clean.fit(&X.view(), &y_clean.view()).unwrap();
622        let pred_clean = fitted_clean.predict(&X.view()).unwrap();
623
624        // Test with noise
625        let mut y_noisy = y_clean.clone();
626        y_noisy[0] = 1; // Flip a label
627
628        let ct_noisy = CoTraining::new()
629            .view1_features(vec![0, 1, 2])
630            .view2_features(vec![3, 4, 5])
631            .p(1)
632            .n(1)
633            .max_iter(5);
634        let fitted_noisy = ct_noisy.fit(&X.view(), &y_noisy.view()).unwrap();
635        let pred_noisy = fitted_noisy.predict(&X.view()).unwrap();
636
637        // Check that predictions are valid
638        assert!(
639            pred_clean.iter().all(|&p| p >= 0 && p <= 1),
640            "Clean predictions should be valid"
641        );
642        assert!(
643            pred_noisy.iter().all(|&p| p >= 0 && p <= 1),
644            "Noisy predictions should be valid"
645        );
646    }
647
648    // Label efficiency tests
649    #[test]
650    fn test_label_efficiency_comparison() {
651        use scirs2_core::random::Random;
652        let mut rng = Random::seed(42);
653
654        // Generate synthetic dataset - create it manually to avoid distribution conflicts
655        let mut X = Array2::<f64>::zeros((40, 5));
656        for i in 0..40 {
657            for j in 0..5 {
658                X[(i, j)] = rng.random_range(-1.0..1.0);
659            }
660        }
661
662        // Test with different labeling ratios
663        let small_labeled = array![
664            0, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
665            -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1
666        ];
667
668        let large_labeled = array![
669            0, 1, 0, 1, 0, 1, 0, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
670            -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1
671        ];
672
673        // Test label propagation with different label amounts
674        let lp_small = LabelPropagation::new()
675            .kernel("rbf".to_string())
676            .gamma(20.0);
677        let fitted_small = lp_small.fit(&X.view(), &small_labeled.view()).unwrap();
678        let pred_small = fitted_small.predict(&X.view()).unwrap();
679
680        let lp_large = LabelPropagation::new()
681            .kernel("rbf".to_string())
682            .gamma(20.0);
683        let fitted_large = lp_large.fit(&X.view(), &large_labeled.view()).unwrap();
684        let pred_large = fitted_large.predict(&X.view()).unwrap();
685
686        // Both should produce valid predictions
687        assert!(
688            pred_small.iter().all(|&p| p >= 0 && p <= 1),
689            "Small labeled predictions should be valid"
690        );
691        assert!(
692            pred_large.iter().all(|&p| p >= 0 && p <= 1),
693            "Large labeled predictions should be valid"
694        );
695    }
696
697    // Convergence tests
698    #[test]
699    fn test_algorithm_convergence() {
700        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
701        let y = array![0, 1, -1, -1, -1];
702
703        // Test that harmonic functions converge
704        let hf = HarmonicFunctions::new()
705            .kernel("rbf".to_string())
706            .gamma(20.0)
707            .max_iter(100);
708        let fitted = hf.fit(&X.view(), &y.view()).unwrap();
709        let predictions = fitted.predict(&X.view()).unwrap();
710
711        // Check that predictions are stable and valid
712        assert!(
713            predictions.iter().all(|&p| p >= 0 && p <= 1),
714            "Predictions should be stable and valid"
715        );
716
717        // Test that local-global consistency converges
718        let lgc = LocalGlobalConsistency::new()
719            .kernel("rbf".to_string())
720            .gamma(20.0)
721            .alpha(0.99)
722            .max_iter(100);
723        let fitted_lgc = lgc.fit(&X.view(), &y.view()).unwrap();
724        let predictions_lgc = fitted_lgc.predict(&X.view()).unwrap();
725
726        assert!(
727            predictions_lgc.iter().all(|&p| p >= 0 && p <= 1),
728            "LGC predictions should be stable and valid"
729        );
730    }
731}