light_curve_feature/features/
otsu_split.rs1use crate::evaluator::*;
2use crate::time_series::DataSample;
3use conv::prelude::*;
4use ndarray::{Array1, ArrayView1, Axis, Zip, s};
5use ndarray_stats::QuantileExt;
6
7macro_const! {
8 const DOC: &'static str = r#"
9Otsu threshholding algorithm
10
11Difference of subset means, standard deviation of the lower subset, standard deviation of the upper
12subset and lower-to-all observation count ratio for two subsets of magnitudes obtained by Otsu's
13method split. Otsu's method is used to perform automatic thresholding. The algorithm returns a
14single threshold that separate values into two classes. This threshold is determined by minimizing
15intra-class intensity variance, or equivalently, by maximizing inter-class variance.
16The algorithm returns the minimum threshold which corresponds to the absolute maximum of the inter-class variance.
17
18- Depends on: **magnitude**
19- Minimum number of observations: **2**
20- Number of features: **4**
21
22Otsu, Nobuyuki 1979. [DOI:10.1109/tsmc.1979.4310076](https://doi.org/10.1109/tsmc.1979.4310076)
23"#;
24}
25
26#[doc = DOC ! ()]
27#[derive(Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
28pub struct OtsuSplit {}
29
30lazy_info!(
31 OTSU_SPLIT_INFO,
32 OtsuSplit,
33 size: 4,
34 min_ts_length: 2,
35 t_required: false,
36 m_required: true,
37 w_required: false,
38 sorting_required: false,
39);
40
41impl OtsuSplit {
42 pub fn new() -> Self {
43 Self {}
44 }
45
46 pub const fn doc() -> &'static str {
47 DOC
48 }
49
50 pub fn threshold<'a, 'b, T>(
51 ds: &'b mut DataSample<'a, T>,
52 ) -> Result<(T, ArrayView1<'b, T>, ArrayView1<'b, T>), EvaluatorError>
53 where
54 'a: 'b,
55 T: Float,
56 {
57 if ds.sample.len() < 2 {
58 return Err(EvaluatorError::ShortTimeSeries {
59 actual: ds.sample.len(),
60 minimum: 2,
61 });
62 }
63
64 let count = ds.sample.len();
65 let countf = count.approx().unwrap();
66 let sorted = ds.get_sorted();
67
68 if sorted.minimum() == sorted.maximum() {
69 return Err(EvaluatorError::FlatTimeSeries);
70 }
71
72 let cumsum1: Array1<_> = sorted
74 .iter()
75 .take(count - 1)
76 .scan(T::zero(), |state, &m| {
77 *state += m;
78 Some(*state)
79 })
80 .collect();
81
82 let cumsum2: Array1<_> = sorted
83 .iter()
84 .rev()
85 .scan(T::zero(), |state, &m| {
86 *state += m;
87 Some(*state)
88 })
89 .collect();
90 let cumsum2 = cumsum2.slice(s![0..count - 1; -1]);
91
92 let amounts = Array1::linspace(T::one(), (count - 1).approx().unwrap(), count - 1);
93 let mean1 = Zip::from(&cumsum1)
94 .and(&amounts)
95 .map_collect(|&c, &a| c / a);
96 let mean2 = Zip::from(&cumsum2)
97 .and(amounts.slice(s![..;-1]))
98 .map_collect(|&c, &a| c / a);
99
100 let inter_class_variance =
101 Zip::from(&amounts)
102 .and(&mean1)
103 .and(&mean2)
104 .map_collect(|&a, &m1, &m2| {
105 let w1 = a / countf;
106 let w2 = T::one() - w1;
107 w1 * w2 * (m1 - m2).powi(2)
108 });
109
110 let index = inter_class_variance.argmax().unwrap();
111
112 let (lower, upper) = sorted.0.view().split_at(Axis(0), index + 1);
113 Ok((sorted.0[index + 1], lower, upper))
114 }
115}
116
117impl FeatureNamesDescriptionsTrait for OtsuSplit {
118 fn get_names(&self) -> Vec<&str> {
119 vec![
120 "otsu_mean_diff",
121 "otsu_std_lower",
122 "otsu_std_upper",
123 "otsu_lower_to_all_ratio",
124 ]
125 }
126
127 fn get_descriptions(&self) -> Vec<&str> {
128 vec![
129 "difference between mean values of Otsu split subsets",
130 "standard deviation for observations below the threshold given by Otsu method",
131 "standard deviation for observations above the threshold given by Otsu method",
132 "ratio of quantity of observations bellow the threshold given by Otsu method to quantity of all observations",
133 ]
134 }
135}
136
137impl<T> FeatureEvaluator<T> for OtsuSplit
138where
139 T: Float,
140{
141 fn eval(&self, ts: &mut TimeSeries<T>) -> Result<Vec<T>, EvaluatorError> {
142 self.check_ts_length(ts)?;
143
144 let (_, lower, upper) = Self::threshold(&mut ts.m)?;
145 let mut lower: DataSample<_> = lower.into();
146 let mut upper: DataSample<_> = upper.into();
147
148 let std_lower = if lower.sample.len() == 1 {
149 T::zero()
150 } else {
151 lower.get_std()
152 };
153 let mean_lower = lower.get_mean();
154
155 let std_upper = if upper.sample.len() == 1 {
156 T::zero()
157 } else {
158 upper.get_std()
159 };
160 let mean_upper = upper.get_mean();
161
162 let mean_diff = mean_upper - mean_lower;
163 let lower_to_all = lower.sample.len().approx_as::<T>().unwrap() / ts.lenf();
164
165 Ok(vec![mean_diff, std_lower, std_upper, lower_to_all])
166 }
167}
168
169#[cfg(test)]
170#[allow(clippy::unreadable_literal)]
171#[allow(clippy::excessive_precision)]
172mod tests {
173 use super::*;
174
175 use crate::tests::*;
176
177 use approx::assert_relative_eq;
178 use ndarray::array;
179
180 check_feature!(OtsuSplit);
181
182 feature_test!(
183 otsu_split,
184 [OtsuSplit::new()],
185 [
186 0.725,
187 0.012909944487358068,
188 0.07071067811865482,
189 0.6666666666666666
190 ],
191 [0.51, 0.52, 0.53, 0.54, 1.2, 1.3],
192 );
193
194 feature_test!(
195 otsu_split_min_observations,
196 [OtsuSplit::new()],
197 [0.01, 0.0, 0.0, 0.5],
198 [0.51, 0.52],
199 );
200
201 feature_test!(
202 otsu_split_lower,
203 [OtsuSplit::new()],
204 [1.0, 0.0, 0.0, 0.25],
205 [0.5, 1.5, 1.5, 1.5],
206 );
207
208 feature_test!(
209 otsu_split_upper,
210 [OtsuSplit::new()],
211 [1.0, 0.0, 0.0, 0.75],
212 [0.5, 0.5, 0.5, 1.5],
213 );
214
215 #[test]
216 fn otsu_threshold() {
217 let mut ds = vec![0.5, 0.5, 0.5, 1.5].into();
218 let (expected_threshold, expected_lower, expected_upper) =
219 (1.5, array![0.5, 0.5, 0.5], array![1.5]);
220 let (actual_threshold, actual_lower, actual_upper) =
221 OtsuSplit::threshold(&mut ds).expect("input is not flat");
222 assert_eq!(expected_threshold, actual_threshold);
223 assert_eq!(expected_lower, actual_lower);
224 assert_eq!(expected_upper, actual_upper);
225 }
226
227 #[test]
228 fn otsu_two_max() {
229 let mut ds = vec![-1.5, 0.5, 0.5, 1.5].into();
230 let (expected_threshold, expected_lower, expected_upper) =
231 (0.5, array![-1.5], array![0.5, 0.5, 1.5]);
232 let (actual_threshold, actual_lower, actual_upper) =
233 OtsuSplit::threshold(&mut ds).expect("input is not flat");
234 assert_eq!(expected_threshold, actual_threshold);
235 assert_eq!(expected_lower, actual_lower);
236 assert_eq!(expected_upper, actual_upper);
237 }
238
239 #[test]
240 fn otsu_split_plateau() {
241 let eval = OtsuSplit::new();
242 let x = [1.5, 1.5, 1.5, 1.5];
243 let mut ts = TimeSeries::new_without_weight(&x, &x);
244 assert_eq!(eval.eval(&mut ts), Err(EvaluatorError::FlatTimeSeries));
245 }
246
247 #[test]
248 fn otsu_split_small() {
249 let eval = OtsuSplit::new();
250 let mut ts = light_curve_feature_test_util::issue_light_curve_mag::<f32, _>(
251 "light-curve-feature-72/1.csv",
252 )
253 .into_triple(None)
254 .into();
255 let desired = [
256 3.0221021243981205,
257 0.8847146372743603,
258 0.8826366394647659,
259 0.507,
260 ];
261 let actual = eval.eval(&mut ts).unwrap();
262 assert_relative_eq!(&desired[..], &actual[..], epsilon = 1e-6);
263 }
264
265 #[test]
266 #[ignore] fn no_overflow() {
268 const N: usize = (1 << 25) + 57;
270 let feature = OtsuSplit::new();
271 let t = Array1::linspace(0.0_f32, 1.0, N);
272 let mut ts = TimeSeries::new_without_weight(t.view(), t.view());
273 let [mean_diff, _std_lower, _std_upper, lower_to_all]: [f32; 4] =
275 feature.eval(&mut ts).unwrap().try_into().unwrap();
276 assert_relative_eq!(mean_diff, 0.5, epsilon = 1e-3);
277 assert_relative_eq!(lower_to_all, 0.5, epsilon = 1e-6);
278 }
279}