1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
//! Core traits for the high-level ML Pipeline API.
//!
//! This module defines the two foundational traits used throughout the
//! SciRS2 ML pipeline framework:
//!
//! - [`FeatureTransformer`]: stateful transformations (fit + transform)
//! - [`ModelPredictor`]: inference interface (predict / predict_batch)
//!
//! Downstream crates implement these traits to plug into [`super::Pipeline`]
//! and [`super::MLPipelineGeneric`].
use ;
use ;
use fmt;
use PipelineError;
/// Trait for stateful feature transformations.
///
/// A `FeatureTransformer` learns statistics from training data during [`fit`]
/// and applies those statistics during [`transform`]. Implementations must
/// handle zero-variance columns, empty inputs, and other degenerate cases
/// gracefully by returning [`PipelineError`] instead of panicking.
///
/// [`fit`]: FeatureTransformer::fit
/// [`transform`]: FeatureTransformer::transform
///
/// # Implementation Contract
///
/// 1. `fit` **must** be called before `transform`.
/// 2. `fit_transform` is equivalent to calling `fit` followed by `transform`
/// on the same data; the default implementation enforces this.
/// 3. After `fit`, subsequent `transform` calls with different data of the
/// same column count must succeed.
/// 4. No method may use `unwrap()` or `expect()` in production paths.
///
/// # Example
///
/// ```rust
/// # #[cfg(feature = "ml_pipeline")]
/// # {
/// use scirs2_core::ml_pipeline::{FeatureTransformer, StandardScaler, PipelineError};
/// use ndarray::Array2;
///
/// let data = Array2::from_shape_vec((3, 2), vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("should succeed");
/// let mut scaler = StandardScaler::new();
/// scaler.fit(&data).expect("should succeed");
/// let transformed = scaler.transform(&data).expect("should succeed");
/// assert_eq!(transformed.shape(), &[3, 2]);
/// # }
/// ```
/// Trait for ML model predictors.
///
/// A `ModelPredictor` encapsulates the inference step of an ML pipeline.
/// It receives a feature matrix (already transformed by the upstream
/// [`FeatureTransformer`] steps) and returns a 1-D prediction array.
///
/// # Implementation Contract
///
/// 1. `predict` must return a 1-D array of length `data.nrows()`.
/// 2. `predict_batch` is semantically identical to `predict`; it exists as a
/// separate method to allow implementations to apply different optimisations
/// for batched inference (e.g., running inference in chunks).
/// 3. No panics — all errors must be returned as [`PipelineError`].
///
/// # Example
///
/// ```rust
/// # #[cfg(feature = "ml_pipeline")]
/// # {
/// use scirs2_core::ml_pipeline::{ModelPredictor, PipelineError};
/// use ndarray::{Array1, Array2};
///
/// /// A trivial predictor that returns the mean of each row.
/// struct RowMeanPredictor;
///
/// impl ModelPredictor<f64> for RowMeanPredictor {
/// fn predict(&self, data: &Array2<f64>) -> Result<Array1<f64>, PipelineError> {
/// let n = data.nrows();
/// if n == 0 {
/// return Err(PipelineError::EmptyInput("predict".to_string()));
/// }
/// let preds: Vec<f64> = (0..n)
/// .map(|i| data.row(i).sum() / data.ncols() as f64)
/// .collect();
/// Ok(Array1::from_vec(preds))
/// }
///
/// fn predict_batch(&self, data: &Array2<f64>) -> Result<Array1<f64>, PipelineError> {
/// self.predict(data)
/// }
///
/// fn name(&self) -> &str { "RowMeanPredictor" }
/// }
///
/// let data = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("should succeed");
/// let predictor = RowMeanPredictor;
/// let preds = predictor.predict(&data).expect("should succeed");
/// assert_eq!(preds.len(), 2);
/// assert!((preds[0] - 2.0).abs() < 1e-10);
/// assert!((preds[1] - 5.0).abs() < 1e-10);
/// # }
/// ```