1use scirs2_core::ndarray::{Array1, Array2, Array3};
16use scirs2_core::random::{Rng, SeedableRng};
17use sklears_core::{error::SklearsError, traits::Estimator, traits::Fit, traits::Predict};
18use std::collections::HashMap;
19
20#[derive(Debug, Clone)]
22pub enum DomainStrategy {
23 ComputerVision(CVStrategy),
25 NLP(NLPStrategy),
27 TimeSeries(TimeSeriesStrategy),
29 Recommendation(RecStrategy),
31 AnomalyDetection(AnomalyStrategy),
33}
34
35#[derive(Debug, Clone)]
37pub enum CVStrategy {
38 PixelIntensity { statistic: PixelStatistic },
40 ColorHistogram {
42 bins: usize,
43 color_space: ColorSpace,
44 },
45 SpatialFrequency { method: FrequencyMethod },
47 Texture { method: TextureMethod },
49 EdgeDetection { threshold: f64 },
51 MostFrequentImageClass,
53 RandomImageClass,
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
59pub enum PixelStatistic {
60 Mean,
62 Median,
64 StandardDeviation,
66 Skewness,
68 Kurtosis,
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
74pub enum ColorSpace {
75 RGB,
77 HSV,
79 Grayscale,
81}
82
83#[derive(Debug, Clone, Copy)]
85pub enum FrequencyMethod {
86 DFT,
88 DCT,
90 Wavelet,
92}
93
94#[derive(Debug, Clone, Copy)]
96pub enum TextureMethod {
97 LocalBinaryPattern,
99 GrayLevelCooccurrence,
101 Gabor,
103}
104
105#[derive(Debug, Clone)]
107pub enum NLPStrategy {
108 WordFrequency { top_k: usize },
110 NGram { n: usize, top_k: usize },
112 DocumentLength,
114 VocabularyRichness,
116 SentimentPolarity,
118 MostFrequentTextClass,
120 TopicKeywords { num_topics: usize },
122}
123
124#[derive(Debug, Clone)]
126pub enum TimeSeriesStrategy {
127 SeasonalPattern { period: usize },
129 TrendAnalysis { window_size: usize },
131 CyclicalPattern { cycles: Vec<usize> },
133 Autocorrelation { max_lag: usize },
135 MovingAverage { windows: Vec<usize> },
137 RandomWalk { drift: f64 },
139}
140
141#[derive(Debug, Clone)]
143pub enum RecStrategy {
144 ItemPopularity,
146 UserAverage,
148 ItemAverage,
150 GlobalAverage,
152 RandomRating,
154 DemographicSimilarity,
156}
157
158#[derive(Debug, Clone)]
160pub enum AnomalyStrategy {
161 StatisticalThreshold {
163 method: ThresholdMethod,
164 contamination: f64,
165 },
166 IsolationBased { n_estimators: usize },
168 DistanceBased { k: usize },
170 DensityBased { min_samples: usize, eps: f64 },
172 AlwaysNormal,
174 RandomAnomaly { contamination: f64 },
176}
177
178#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
180pub enum ThresholdMethod {
181 ZScore,
183 ModifiedZScore,
185 IQR,
187 Percentile,
189}
190
191#[derive(Debug, Clone)]
193pub struct DomainClassifier {
194 strategy: DomainStrategy,
195 random_state: Option<u64>,
196}
197
198#[derive(Debug, Clone)]
200pub struct TrainedDomainClassifier {
201 strategy: DomainStrategy,
202 classes: Vec<i32>,
203 class_counts: HashMap<i32, usize>,
204 domain_features: DomainFeatures,
205 random_state: Option<u64>,
206}
207
208#[derive(Debug, Clone)]
210pub enum DomainFeatures {
211 ComputerVision(CVFeatures),
213 NLP(NLPFeatures),
215 TimeSeries(TSFeatures),
217 Recommendation(RecFeatures),
219 AnomalyDetection(AnomalyFeatures),
221}
222
223#[derive(Debug, Clone)]
225pub struct CVFeatures {
226 pub pixel_statistics: HashMap<PixelStatistic, f64>,
228 pub color_histograms: HashMap<ColorSpace, Vec<f64>>,
230 pub spatial_frequencies: Vec<f64>,
232 pub texture_features: Vec<f64>,
234 pub edge_features: Vec<f64>,
236}
237
238#[derive(Debug, Clone)]
240pub struct NLPFeatures {
241 pub word_frequencies: HashMap<String, usize>,
243 pub ngram_frequencies: HashMap<String, usize>,
245 pub document_lengths: Vec<usize>,
247 pub vocabulary_size: usize,
249 pub sentiment_scores: Vec<f64>,
251 pub topic_keywords: HashMap<usize, Vec<String>>,
253}
254
255#[derive(Debug, Clone)]
257pub struct TSFeatures {
258 pub seasonal_patterns: HashMap<usize, Vec<f64>>,
260 pub trend_coefficients: Vec<f64>,
262 pub cyclical_components: HashMap<usize, Vec<f64>>,
264 pub autocorrelations: Vec<f64>,
266 pub moving_averages: HashMap<usize, Vec<f64>>,
268}
269
270#[derive(Debug, Clone)]
272pub struct RecFeatures {
273 pub item_popularity: HashMap<usize, f64>,
275 pub user_averages: HashMap<usize, f64>,
277 pub item_averages: HashMap<usize, f64>,
279 pub global_average: f64,
281 pub rating_range: (f64, f64),
283}
284
285#[derive(Debug, Clone)]
287pub struct AnomalyFeatures {
288 pub statistical_thresholds: HashMap<ThresholdMethod, f64>,
290 pub isolation_scores: Vec<f64>,
292 pub distance_thresholds: Vec<f64>,
294 pub density_thresholds: Vec<f64>,
296 pub contamination_rate: f64,
298}
299
300impl DomainClassifier {
301 pub fn new(strategy: DomainStrategy) -> Self {
303 Self {
304 strategy,
305 random_state: None,
306 }
307 }
308
309 pub fn with_random_state(mut self, seed: u64) -> Self {
311 self.random_state = Some(seed);
312 self
313 }
314
315 pub fn computer_vision(strategy: CVStrategy) -> Self {
317 Self::new(DomainStrategy::ComputerVision(strategy))
318 }
319
320 pub fn nlp(strategy: NLPStrategy) -> Self {
322 Self::new(DomainStrategy::NLP(strategy))
323 }
324
325 pub fn time_series(strategy: TimeSeriesStrategy) -> Self {
327 Self::new(DomainStrategy::TimeSeries(strategy))
328 }
329
330 pub fn recommendation(strategy: RecStrategy) -> Self {
332 Self::new(DomainStrategy::Recommendation(strategy))
333 }
334
335 pub fn anomaly_detection(strategy: AnomalyStrategy) -> Self {
337 Self::new(DomainStrategy::AnomalyDetection(strategy))
338 }
339}
340
341impl Estimator for DomainClassifier {
342 type Config = DomainStrategy;
343 type Error = SklearsError;
344 type Float = f64;
345
346 fn config(&self) -> &Self::Config {
347 &self.strategy
348 }
349}
350
351impl Fit<Array2<f64>, Array1<i32>> for DomainClassifier {
352 type Fitted = TrainedDomainClassifier;
353
354 fn fit(self, x: &Array2<f64>, y: &Array1<i32>) -> Result<Self::Fitted, SklearsError> {
355 let mut class_counts = HashMap::new();
356 for &class in y.iter() {
357 *class_counts.entry(class).or_insert(0) += 1;
358 }
359
360 let mut classes: Vec<_> = class_counts.keys().cloned().collect();
361 classes.sort();
362
363 let domain_features = self.extract_domain_features(x, y)?;
364
365 Ok(TrainedDomainClassifier {
366 strategy: self.strategy,
367 classes,
368 class_counts,
369 domain_features,
370 random_state: self.random_state,
371 })
372 }
373}
374
375impl DomainClassifier {
376 fn extract_domain_features(
377 &self,
378 x: &Array2<f64>,
379 y: &Array1<i32>,
380 ) -> Result<DomainFeatures, SklearsError> {
381 match &self.strategy {
382 DomainStrategy::ComputerVision(cv_strategy) => {
383 let cv_features = self.extract_cv_features(x, y, cv_strategy)?;
384 Ok(DomainFeatures::ComputerVision(cv_features))
385 }
386 DomainStrategy::NLP(nlp_strategy) => {
387 let nlp_features = self.extract_nlp_features(x, y, nlp_strategy)?;
388 Ok(DomainFeatures::NLP(nlp_features))
389 }
390 DomainStrategy::TimeSeries(ts_strategy) => {
391 let ts_features = self.extract_ts_features(x, y, ts_strategy)?;
392 Ok(DomainFeatures::TimeSeries(ts_features))
393 }
394 DomainStrategy::Recommendation(rec_strategy) => {
395 let rec_features = self.extract_rec_features(x, y, rec_strategy)?;
396 Ok(DomainFeatures::Recommendation(rec_features))
397 }
398 DomainStrategy::AnomalyDetection(anomaly_strategy) => {
399 let anomaly_features = self.extract_anomaly_features(x, y, anomaly_strategy)?;
400 Ok(DomainFeatures::AnomalyDetection(anomaly_features))
401 }
402 }
403 }
404
405 fn extract_cv_features(
406 &self,
407 x: &Array2<f64>,
408 _y: &Array1<i32>,
409 strategy: &CVStrategy,
410 ) -> Result<CVFeatures, SklearsError> {
411 let mut pixel_statistics = HashMap::new();
412 let mut color_histograms = HashMap::new();
413 let spatial_frequencies = Vec::new();
414 let texture_features = Vec::new();
415 let edge_features = Vec::new();
416
417 match strategy {
418 CVStrategy::PixelIntensity { statistic } => {
419 let values = self.compute_pixel_statistic(x, *statistic)?;
420 pixel_statistics.insert(*statistic, values);
421 }
422 CVStrategy::ColorHistogram { bins, color_space } => {
423 let histogram = self.compute_color_histogram(x, *bins, *color_space)?;
424 color_histograms.insert(*color_space, histogram);
425 }
426 _ => {
427 pixel_statistics.insert(PixelStatistic::Mean, x.mean().unwrap_or(0.0));
429 }
430 }
431
432 Ok(CVFeatures {
433 pixel_statistics,
434 color_histograms,
435 spatial_frequencies,
436 texture_features,
437 edge_features,
438 })
439 }
440
441 fn extract_nlp_features(
442 &self,
443 x: &Array2<f64>,
444 _y: &Array1<i32>,
445 strategy: &NLPStrategy,
446 ) -> Result<NLPFeatures, SklearsError> {
447 let mut word_frequencies = HashMap::new();
448 let ngram_frequencies = HashMap::new();
449 let document_lengths = Vec::new();
450 let vocabulary_size = 0;
451 let sentiment_scores = Vec::new();
452 let topic_keywords = HashMap::new();
453
454 match strategy {
455 NLPStrategy::WordFrequency { top_k } => {
456 for i in 0..*top_k.min(&x.ncols()) {
458 let word = format!("word_{}", i);
459 let freq = x.column(i).sum() as usize;
460 word_frequencies.insert(word, freq);
461 }
462 }
463 NLPStrategy::DocumentLength => {
464 }
467 _ => {
468 }
470 }
471
472 Ok(NLPFeatures {
473 word_frequencies,
474 ngram_frequencies,
475 document_lengths,
476 vocabulary_size,
477 sentiment_scores,
478 topic_keywords,
479 })
480 }
481
482 fn extract_ts_features(
483 &self,
484 x: &Array2<f64>,
485 _y: &Array1<i32>,
486 strategy: &TimeSeriesStrategy,
487 ) -> Result<TSFeatures, SklearsError> {
488 let mut seasonal_patterns = HashMap::new();
489 let trend_coefficients = Vec::new();
490 let cyclical_components = HashMap::new();
491 let autocorrelations = Vec::new();
492 let mut moving_averages = HashMap::new();
493
494 match strategy {
495 TimeSeriesStrategy::SeasonalPattern { period } => {
496 if x.ncols() > 0 {
498 let series = x.column(0);
499 let pattern = self.compute_seasonal_pattern(&series, *period)?;
500 seasonal_patterns.insert(*period, pattern);
501 }
502 }
503 TimeSeriesStrategy::MovingAverage { windows } => {
504 if x.ncols() > 0 {
506 let series = x.column(0);
507 for &window in windows {
508 let ma = self.compute_moving_average(&series, window)?;
509 moving_averages.insert(window, ma);
510 }
511 }
512 }
513 _ => {
514 }
516 }
517
518 Ok(TSFeatures {
519 seasonal_patterns,
520 trend_coefficients,
521 cyclical_components,
522 autocorrelations,
523 moving_averages,
524 })
525 }
526
527 fn extract_rec_features(
528 &self,
529 x: &Array2<f64>,
530 y: &Array1<i32>,
531 _strategy: &RecStrategy,
532 ) -> Result<RecFeatures, SklearsError> {
533 let mut item_popularity = HashMap::new();
535 let mut user_averages = HashMap::new();
536 let mut item_averages = HashMap::new();
537 let global_average = y.iter().map(|&v| v as f64).sum::<f64>() / y.len() as f64;
538 let rating_range = {
539 let min_rating = y.iter().min().copied().unwrap_or(0) as f64;
540 let max_rating = y.iter().max().copied().unwrap_or(5) as f64;
541 (min_rating, max_rating)
542 };
543
544 for (i, &rating) in y.iter().enumerate() {
546 if x.ncols() >= 2 {
547 let user_id = x[[i, 0]] as usize;
548 let item_id = x[[i, 1]] as usize;
549 let rating_f64 = rating as f64;
550
551 *item_popularity.entry(item_id).or_insert(0.0) += 1.0;
553
554 let user_entry = user_averages.entry(user_id).or_insert((0.0, 0));
556 user_entry.0 += rating_f64;
557 user_entry.1 += 1;
558
559 let item_entry = item_averages.entry(item_id).or_insert((0.0, 0));
561 item_entry.0 += rating_f64;
562 item_entry.1 += 1;
563 }
564 }
565
566 let user_averages: HashMap<usize, f64> = user_averages
568 .into_iter()
569 .map(|(id, (sum, count))| (id, sum / count as f64))
570 .collect();
571
572 let item_averages: HashMap<usize, f64> = item_averages
573 .into_iter()
574 .map(|(id, (sum, count))| (id, sum / count as f64))
575 .collect();
576
577 Ok(RecFeatures {
578 item_popularity,
579 user_averages,
580 item_averages,
581 global_average,
582 rating_range,
583 })
584 }
585
586 fn extract_anomaly_features(
587 &self,
588 x: &Array2<f64>,
589 _y: &Array1<i32>,
590 strategy: &AnomalyStrategy,
591 ) -> Result<AnomalyFeatures, SklearsError> {
592 let mut statistical_thresholds = HashMap::new();
593 let isolation_scores = Vec::new();
594 let distance_thresholds = Vec::new();
595 let density_thresholds = Vec::new();
596
597 let contamination_rate = match strategy {
598 AnomalyStrategy::StatisticalThreshold { contamination, .. }
599 | AnomalyStrategy::RandomAnomaly { contamination } => *contamination,
600 _ => 0.1, };
602
603 if x.ncols() > 0 {
605 let feature = x.column(0);
606 let mean = feature.mean().unwrap_or(0.0);
607 let std = {
608 let variance = feature.iter().map(|&val| (val - mean).powi(2)).sum::<f64>()
609 / (feature.len() - 1) as f64;
610 variance.sqrt()
611 };
612
613 statistical_thresholds.insert(ThresholdMethod::ZScore, 2.0 * std);
615
616 let mut sorted_values = feature.to_vec();
618 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
619 let q1_idx = sorted_values.len() / 4;
620 let q3_idx = 3 * sorted_values.len() / 4;
621 let q1 = sorted_values[q1_idx];
622 let q3 = sorted_values[q3_idx];
623 let iqr = q3 - q1;
624 statistical_thresholds.insert(ThresholdMethod::IQR, 1.5 * iqr);
625 }
626
627 Ok(AnomalyFeatures {
628 statistical_thresholds,
629 isolation_scores,
630 distance_thresholds,
631 density_thresholds,
632 contamination_rate,
633 })
634 }
635
636 fn compute_pixel_statistic(
638 &self,
639 x: &Array2<f64>,
640 statistic: PixelStatistic,
641 ) -> Result<f64, SklearsError> {
642 match statistic {
643 PixelStatistic::Mean => Ok(x.mean().unwrap_or(0.0)),
644 PixelStatistic::Median => {
645 let mut values: Vec<f64> = x.iter().cloned().collect();
646 values.sort_by(|a, b| a.partial_cmp(b).unwrap());
647 let mid = values.len() / 2;
648 Ok(if values.len() % 2 == 0 {
649 (values[mid - 1] + values[mid]) / 2.0
650 } else {
651 values[mid]
652 })
653 }
654 PixelStatistic::StandardDeviation => {
655 let mean = x.mean().unwrap_or(0.0);
656 let variance =
657 x.iter().map(|&val| (val - mean).powi(2)).sum::<f64>() / x.len() as f64;
658 Ok(variance.sqrt())
659 }
660 _ => Ok(0.0), }
662 }
663
664 fn compute_color_histogram(
665 &self,
666 x: &Array2<f64>,
667 bins: usize,
668 _color_space: ColorSpace,
669 ) -> Result<Vec<f64>, SklearsError> {
670 let min_val = x.iter().fold(f64::INFINITY, |a, &b| a.min(b));
672 let max_val = x.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
673 let bin_width = (max_val - min_val) / bins as f64;
674
675 let mut histogram = vec![0.0; bins];
676 for &value in x.iter() {
677 let bin_idx = ((value - min_val) / bin_width).floor() as usize;
678 let bin_idx = bin_idx.min(bins - 1);
679 histogram[bin_idx] += 1.0;
680 }
681
682 let total: f64 = histogram.iter().sum();
684 if total > 0.0 {
685 for count in &mut histogram {
686 *count /= total;
687 }
688 }
689
690 Ok(histogram)
691 }
692
693 fn compute_seasonal_pattern(
694 &self,
695 series: &scirs2_core::ndarray::ArrayView1<f64>,
696 period: usize,
697 ) -> Result<Vec<f64>, SklearsError> {
698 let mut pattern = vec![0.0; period];
699 let mut counts = vec![0; period];
700
701 for (i, &value) in series.iter().enumerate() {
702 let seasonal_idx = i % period;
703 pattern[seasonal_idx] += value;
704 counts[seasonal_idx] += 1;
705 }
706
707 for (i, count) in counts.iter().enumerate() {
709 if *count > 0 {
710 pattern[i] /= *count as f64;
711 }
712 }
713
714 Ok(pattern)
715 }
716
717 fn compute_moving_average(
718 &self,
719 series: &scirs2_core::ndarray::ArrayView1<f64>,
720 window: usize,
721 ) -> Result<Vec<f64>, SklearsError> {
722 let mut moving_avg = Vec::new();
723
724 for i in 0..series.len() {
725 let start = if i >= window { i - window + 1 } else { 0 };
726 let end = i + 1;
727 let window_sum: f64 = series.slice(scirs2_core::ndarray::s![start..end]).sum();
728 let window_size = end - start;
729 moving_avg.push(window_sum / window_size as f64);
730 }
731
732 Ok(moving_avg)
733 }
734}
735
736impl Predict<Array2<f64>, Array1<i32>> for TrainedDomainClassifier {
737 fn predict(&self, x: &Array2<f64>) -> Result<Array1<i32>, SklearsError> {
738 let n_samples = x.nrows();
739 let mut predictions = Array1::zeros(n_samples);
740
741 match &self.strategy {
742 DomainStrategy::ComputerVision(cv_strategy) => {
743 self.predict_cv(x, cv_strategy, &mut predictions)?;
744 }
745 DomainStrategy::NLP(nlp_strategy) => {
746 self.predict_nlp(x, nlp_strategy, &mut predictions)?;
747 }
748 DomainStrategy::TimeSeries(ts_strategy) => {
749 self.predict_ts(x, ts_strategy, &mut predictions)?;
750 }
751 DomainStrategy::Recommendation(rec_strategy) => {
752 self.predict_rec(x, rec_strategy, &mut predictions)?;
753 }
754 DomainStrategy::AnomalyDetection(anomaly_strategy) => {
755 self.predict_anomaly(x, anomaly_strategy, &mut predictions)?;
756 }
757 }
758
759 Ok(predictions)
760 }
761}
762
763impl TrainedDomainClassifier {
764 fn predict_cv(
765 &self,
766 x: &Array2<f64>,
767 strategy: &CVStrategy,
768 predictions: &mut Array1<i32>,
769 ) -> Result<(), SklearsError> {
770 match strategy {
771 CVStrategy::MostFrequentImageClass => {
772 let most_frequent = self
773 .class_counts
774 .iter()
775 .max_by_key(|(_, &count)| count)
776 .map(|(&class, _)| class)
777 .unwrap_or(0);
778 predictions.fill(most_frequent);
779 }
780 CVStrategy::RandomImageClass => {
781 let mut rng = if let Some(seed) = self.random_state {
782 scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
783 } else {
784 scirs2_core::random::rngs::StdRng::seed_from_u64(0)
785 };
786
787 let total_count: usize = self.class_counts.values().sum();
788 for i in 0..predictions.len() {
789 let rand_val = rng.gen_range(0..total_count);
790 let mut cumsum = 0;
791 for (&class, &count) in &self.class_counts {
792 cumsum += count;
793 if rand_val < cumsum {
794 predictions[i] = class;
795 break;
796 }
797 }
798 }
799 }
800 CVStrategy::PixelIntensity { statistic } => {
801 if let DomainFeatures::ComputerVision(cv_features) = &self.domain_features {
803 if let Some(&threshold) = cv_features.pixel_statistics.get(statistic) {
804 for i in 0..predictions.len() {
805 let pixel_value = x.row(i).mean().unwrap_or(0.0);
806 predictions[i] = if pixel_value > threshold { 1 } else { 0 };
807 }
808 }
809 }
810 }
811 _ => {
812 let most_frequent = self
814 .class_counts
815 .iter()
816 .max_by_key(|(_, &count)| count)
817 .map(|(&class, _)| class)
818 .unwrap_or(0);
819 predictions.fill(most_frequent);
820 }
821 }
822 Ok(())
823 }
824
825 fn predict_nlp(
826 &self,
827 x: &Array2<f64>,
828 strategy: &NLPStrategy,
829 predictions: &mut Array1<i32>,
830 ) -> Result<(), SklearsError> {
831 match strategy {
832 NLPStrategy::MostFrequentTextClass => {
833 let most_frequent = self
834 .class_counts
835 .iter()
836 .max_by_key(|(_, &count)| count)
837 .map(|(&class, _)| class)
838 .unwrap_or(0);
839 predictions.fill(most_frequent);
840 }
841 NLPStrategy::DocumentLength => {
842 let median_length = {
844 let mut lengths: Vec<f64> = (0..x.nrows()).map(|i| x.row(i).sum()).collect();
845 lengths.sort_by(|a, b| a.partial_cmp(b).unwrap());
846 lengths[lengths.len() / 2]
847 };
848
849 for i in 0..predictions.len() {
850 let doc_length = x.row(i).sum();
851 predictions[i] = if doc_length > median_length { 1 } else { 0 };
852 }
853 }
854 _ => {
855 let most_frequent = self
857 .class_counts
858 .iter()
859 .max_by_key(|(_, &count)| count)
860 .map(|(&class, _)| class)
861 .unwrap_or(0);
862 predictions.fill(most_frequent);
863 }
864 }
865 Ok(())
866 }
867
868 fn predict_ts(
869 &self,
870 x: &Array2<f64>,
871 strategy: &TimeSeriesStrategy,
872 predictions: &mut Array1<i32>,
873 ) -> Result<(), SklearsError> {
874 match strategy {
875 TimeSeriesStrategy::SeasonalPattern { period } => {
876 if let DomainFeatures::TimeSeries(ts_features) = &self.domain_features {
878 if let Some(pattern) = ts_features.seasonal_patterns.get(period) {
879 for i in 0..predictions.len() {
880 let seasonal_idx = i % period;
881 let seasonal_value = pattern.get(seasonal_idx).unwrap_or(&0.0);
882 predictions[i] = if *seasonal_value > 0.5 { 1 } else { 0 };
883 }
884 }
885 }
886 }
887 TimeSeriesStrategy::RandomWalk { drift } => {
888 let mut current_value = 0.0;
889 let mut rng = if let Some(seed) = self.random_state {
890 scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
891 } else {
892 scirs2_core::random::rngs::StdRng::seed_from_u64(0)
893 };
894
895 for i in 0..predictions.len() {
896 current_value += drift + rng.gen_range(-0.1..0.1);
897 predictions[i] = if current_value > 0.0 { 1 } else { 0 };
898 }
899 }
900 _ => {
901 let most_frequent = self
903 .class_counts
904 .iter()
905 .max_by_key(|(_, &count)| count)
906 .map(|(&class, _)| class)
907 .unwrap_or(0);
908 predictions.fill(most_frequent);
909 }
910 }
911 Ok(())
912 }
913
914 fn predict_rec(
915 &self,
916 x: &Array2<f64>,
917 strategy: &RecStrategy,
918 predictions: &mut Array1<i32>,
919 ) -> Result<(), SklearsError> {
920 match strategy {
921 RecStrategy::GlobalAverage => {
922 if let DomainFeatures::Recommendation(rec_features) = &self.domain_features {
923 let threshold = rec_features.global_average;
924 for i in 0..predictions.len() {
925 let rating_proxy = if x.ncols() > 2 { x[[i, 2]] } else { threshold };
927 predictions[i] = if rating_proxy > threshold { 1 } else { 0 };
928 }
929 }
930 }
931 RecStrategy::ItemPopularity => {
932 if let DomainFeatures::Recommendation(rec_features) = &self.domain_features {
933 let median_popularity = {
934 let mut popularities: Vec<f64> =
935 rec_features.item_popularity.values().cloned().collect();
936 if popularities.is_empty() {
937 0.0
938 } else {
939 popularities.sort_by(|a, b| a.partial_cmp(b).unwrap());
940 popularities[popularities.len() / 2]
941 }
942 };
943
944 for i in 0..predictions.len() {
945 let item_id = if x.ncols() > 1 { x[[i, 1]] as usize } else { 0 };
946 let popularity = rec_features.item_popularity.get(&item_id).unwrap_or(&0.0);
947 predictions[i] = if *popularity > median_popularity {
948 1
949 } else {
950 0
951 };
952 }
953 }
954 }
955 _ => {
956 let most_frequent = self
958 .class_counts
959 .iter()
960 .max_by_key(|(_, &count)| count)
961 .map(|(&class, _)| class)
962 .unwrap_or(0);
963 predictions.fill(most_frequent);
964 }
965 }
966 Ok(())
967 }
968
969 fn predict_anomaly(
970 &self,
971 x: &Array2<f64>,
972 strategy: &AnomalyStrategy,
973 predictions: &mut Array1<i32>,
974 ) -> Result<(), SklearsError> {
975 match strategy {
976 AnomalyStrategy::AlwaysNormal => {
977 predictions.fill(0); }
979 AnomalyStrategy::RandomAnomaly { contamination } => {
980 let mut rng = if let Some(seed) = self.random_state {
981 scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
982 } else {
983 scirs2_core::random::rngs::StdRng::seed_from_u64(0)
984 };
985
986 for i in 0..predictions.len() {
987 predictions[i] = if rng.gen::<f64>() < *contamination {
988 1
989 } else {
990 0
991 };
992 }
993 }
994 AnomalyStrategy::StatisticalThreshold { method, .. } => {
995 if let DomainFeatures::AnomalyDetection(anomaly_features) = &self.domain_features {
996 if let Some(&threshold) = anomaly_features.statistical_thresholds.get(method) {
997 for i in 0..predictions.len() {
998 if x.ncols() > 0 {
999 let value = x[[i, 0]];
1000 let is_anomaly = match method {
1001 ThresholdMethod::ZScore | ThresholdMethod::ModifiedZScore => {
1002 value.abs() > threshold
1003 }
1004 ThresholdMethod::IQR => value > threshold,
1005 ThresholdMethod::Percentile => value > threshold,
1006 };
1007 predictions[i] = if is_anomaly { 1 } else { 0 };
1008 }
1009 }
1010 }
1011 }
1012 }
1013 _ => {
1014 predictions.fill(0);
1016 }
1017 }
1018 Ok(())
1019 }
1020}
1021
1022pub struct DomainPreprocessor;
1024
1025impl DomainPreprocessor {
1026 pub fn preprocess_images(images: &Array3<f64>) -> Result<Array2<f64>, SklearsError> {
1028 let (n_images, height, width) = images.dim();
1030 let features = images
1031 .clone()
1032 .into_shape((n_images, height * width))
1033 .map_err(|e| SklearsError::InvalidInput(format!("Shape error: {}", e)))?;
1034 Ok(features)
1035 }
1036
1037 pub fn preprocess_text(texts: &[String]) -> Result<Array2<f64>, SklearsError> {
1039 let n_texts = texts.len();
1041 let max_length = texts.iter().map(|s| s.len()).max().unwrap_or(0);
1042
1043 let mut features = Array2::zeros((n_texts, max_length));
1044 for (i, text) in texts.iter().enumerate() {
1045 for (j, byte) in text.bytes().enumerate() {
1046 if j < max_length {
1047 features[[i, j]] = byte as f64 / 255.0; }
1049 }
1050 }
1051
1052 Ok(features)
1053 }
1054
1055 pub fn preprocess_timeseries(
1057 series: &Array2<f64>,
1058 window_size: usize,
1059 ) -> Result<Array2<f64>, SklearsError> {
1060 let (n_series, length) = series.dim();
1061 if length < window_size {
1062 return Err(SklearsError::InvalidInput(
1063 "Time series length must be at least window size".to_string(),
1064 ));
1065 }
1066
1067 let n_windows = length - window_size + 1;
1068 let mut windowed = Array2::zeros((n_series * n_windows, window_size));
1069
1070 for i in 0..n_series {
1071 for j in 0..n_windows {
1072 let window = series.slice(scirs2_core::ndarray::s![i, j..j + window_size]);
1073 windowed
1074 .slice_mut(scirs2_core::ndarray::s![i * n_windows + j, ..])
1075 .assign(&window);
1076 }
1077 }
1078
1079 Ok(windowed)
1080 }
1081}
1082
1083#[allow(non_snake_case)]
1084#[cfg(test)]
1085mod tests {
1086 use super::*;
1087 use scirs2_core::ndarray::array;
1088
1089 #[test]
1090 fn test_cv_pixel_intensity_classifier() {
1091 let x = Array2::from_shape_vec(
1092 (4, 4),
1093 vec![
1094 0.1, 0.2, 0.3, 0.4, 0.8, 0.9, 0.7, 0.6, 0.2, 0.1, 0.4, 0.3, 0.9, 0.8, 0.6, 0.7,
1095 ],
1096 )
1097 .unwrap();
1098 let y = array![0, 1, 0, 1];
1099
1100 let classifier = DomainClassifier::computer_vision(CVStrategy::PixelIntensity {
1101 statistic: PixelStatistic::Mean,
1102 });
1103 let fitted = classifier.fit(&x, &y).unwrap();
1104 let predictions = fitted.predict(&x).unwrap();
1105
1106 assert_eq!(predictions.len(), 4);
1107 }
1108
1109 #[test]
1110 fn test_nlp_document_length_classifier() {
1111 let x = Array2::from_shape_vec(
1112 (4, 3),
1113 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5],
1114 )
1115 .unwrap();
1116 let y = array![0, 1, 0, 1];
1117
1118 let classifier = DomainClassifier::nlp(NLPStrategy::DocumentLength);
1119 let fitted = classifier.fit(&x, &y).unwrap();
1120 let predictions = fitted.predict(&x).unwrap();
1121
1122 assert_eq!(predictions.len(), 4);
1123 }
1124
1125 #[test]
1126 fn test_anomaly_detection_classifier() {
1127 let x = Array2::from_shape_vec(
1128 (4, 2),
1129 vec![
1130 1.0, 2.0, 3.0, 4.0, 100.0, 200.0, 2.0, 3.0,
1132 ],
1133 )
1134 .unwrap();
1135 let y = array![0, 0, 1, 0]; let classifier =
1138 DomainClassifier::anomaly_detection(AnomalyStrategy::StatisticalThreshold {
1139 method: ThresholdMethod::ZScore,
1140 contamination: 0.25,
1141 });
1142 let fitted = classifier.fit(&x, &y).unwrap();
1143 let predictions = fitted.predict(&x).unwrap();
1144
1145 assert_eq!(predictions.len(), 4);
1146 }
1147
1148 #[test]
1149 fn test_time_series_seasonal_classifier() {
1150 let x =
1151 Array2::from_shape_vec((8, 1), vec![1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]).unwrap();
1152 let y = array![0, 1, 1, 0, 0, 1, 1, 0];
1153
1154 let classifier =
1155 DomainClassifier::time_series(TimeSeriesStrategy::SeasonalPattern { period: 4 });
1156 let fitted = classifier.fit(&x, &y).unwrap();
1157 let predictions = fitted.predict(&x).unwrap();
1158
1159 assert_eq!(predictions.len(), 8);
1160 }
1161
1162 #[test]
1163 fn test_recommendation_classifier() {
1164 let x = Array2::from_shape_vec(
1165 (4, 3),
1166 vec![
1167 0.0, 0.0, 4.0, 0.0, 1.0, 5.0, 1.0, 0.0, 3.0, 1.0, 1.0, 2.0,
1169 ],
1170 )
1171 .unwrap();
1172 let y = array![1, 1, 0, 0]; let classifier = DomainClassifier::recommendation(RecStrategy::GlobalAverage);
1175 let fitted = classifier.fit(&x, &y).unwrap();
1176 let predictions = fitted.predict(&x).unwrap();
1177
1178 assert_eq!(predictions.len(), 4);
1179 }
1180
1181 #[test]
1182 fn test_domain_preprocessor() {
1183 let images = Array3::zeros((2, 4, 4)); let flattened = DomainPreprocessor::preprocess_images(&images).unwrap();
1186 assert_eq!(flattened.shape(), &[2, 16]);
1187
1188 let texts = vec!["hello".to_string(), "world".to_string()];
1190 let text_features = DomainPreprocessor::preprocess_text(&texts).unwrap();
1191 assert_eq!(text_features.shape(), &[2, 5]);
1192
1193 let series = Array2::from_shape_vec(
1195 (2, 6),
1196 vec![
1197 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1198 ],
1199 )
1200 .unwrap();
1201 let windowed = DomainPreprocessor::preprocess_timeseries(&series, 3).unwrap();
1202 assert_eq!(windowed.shape(), &[8, 3]); }
1204}