light_curve_feature/features/
otsu_split.rs

1use crate::evaluator::*;
2use crate::time_series::DataSample;
3use conv::prelude::*;
4use ndarray::{Array1, ArrayView1, Axis, Zip, s};
5use ndarray_stats::QuantileExt;
6
7macro_const! {
8    const DOC: &'static str = r#"
9Otsu threshholding algorithm
10
11Difference of subset means, standard deviation of the lower subset, standard deviation of the upper
12subset and lower-to-all observation count ratio for two subsets of magnitudes obtained by Otsu's
13method split. Otsu's method is used to perform automatic thresholding. The algorithm returns a
14single threshold that separate values into two classes. This threshold is determined by minimizing
15intra-class intensity variance, or equivalently, by maximizing inter-class variance.
16The algorithm returns the minimum threshold which corresponds to the absolute maximum of the inter-class variance.
17
18- Depends on: **magnitude**
19- Minimum number of observations: **2**
20- Number of features: **4**
21
22Otsu, Nobuyuki 1979. [DOI:10.1109/tsmc.1979.4310076](https://doi.org/10.1109/tsmc.1979.4310076)
23"#;
24}
25
26#[doc = DOC ! ()]
27#[derive(Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
28pub struct OtsuSplit {}
29
30lazy_info!(
31    OTSU_SPLIT_INFO,
32    OtsuSplit,
33    size: 4,
34    min_ts_length: 2,
35    t_required: false,
36    m_required: true,
37    w_required: false,
38    sorting_required: false,
39);
40
41impl OtsuSplit {
42    pub fn new() -> Self {
43        Self {}
44    }
45
46    pub const fn doc() -> &'static str {
47        DOC
48    }
49
50    pub fn threshold<'a, 'b, T>(
51        ds: &'b mut DataSample<'a, T>,
52    ) -> Result<(T, ArrayView1<'b, T>, ArrayView1<'b, T>), EvaluatorError>
53    where
54        'a: 'b,
55        T: Float,
56    {
57        if ds.sample.len() < 2 {
58            return Err(EvaluatorError::ShortTimeSeries {
59                actual: ds.sample.len(),
60                minimum: 2,
61            });
62        }
63
64        let count = ds.sample.len();
65        let countf = count.approx().unwrap();
66        let sorted = ds.get_sorted();
67
68        if sorted.minimum() == sorted.maximum() {
69            return Err(EvaluatorError::FlatTimeSeries);
70        }
71
72        // size is (count - 1)
73        let cumsum1: Array1<_> = sorted
74            .iter()
75            .take(count - 1)
76            .scan(T::zero(), |state, &m| {
77                *state += m;
78                Some(*state)
79            })
80            .collect();
81
82        let cumsum2: Array1<_> = sorted
83            .iter()
84            .rev()
85            .scan(T::zero(), |state, &m| {
86                *state += m;
87                Some(*state)
88            })
89            .collect();
90        let cumsum2 = cumsum2.slice(s![0..count - 1; -1]);
91
92        let amounts = Array1::linspace(T::one(), (count - 1).approx().unwrap(), count - 1);
93        let mean1 = Zip::from(&cumsum1)
94            .and(&amounts)
95            .map_collect(|&c, &a| c / a);
96        let mean2 = Zip::from(&cumsum2)
97            .and(amounts.slice(s![..;-1]))
98            .map_collect(|&c, &a| c / a);
99
100        let inter_class_variance =
101            Zip::from(&amounts)
102                .and(&mean1)
103                .and(&mean2)
104                .map_collect(|&a, &m1, &m2| {
105                    let w1 = a / countf;
106                    let w2 = T::one() - w1;
107                    w1 * w2 * (m1 - m2).powi(2)
108                });
109
110        let index = inter_class_variance.argmax().unwrap();
111
112        let (lower, upper) = sorted.0.view().split_at(Axis(0), index + 1);
113        Ok((sorted.0[index + 1], lower, upper))
114    }
115}
116
117impl FeatureNamesDescriptionsTrait for OtsuSplit {
118    fn get_names(&self) -> Vec<&str> {
119        vec![
120            "otsu_mean_diff",
121            "otsu_std_lower",
122            "otsu_std_upper",
123            "otsu_lower_to_all_ratio",
124        ]
125    }
126
127    fn get_descriptions(&self) -> Vec<&str> {
128        vec![
129            "difference between mean values of Otsu split subsets",
130            "standard deviation for observations below the threshold given by Otsu method",
131            "standard deviation for observations above the threshold given by Otsu method",
132            "ratio of quantity of observations bellow the threshold given by Otsu method to quantity of all observations",
133        ]
134    }
135}
136
137impl<T> FeatureEvaluator<T> for OtsuSplit
138where
139    T: Float,
140{
141    fn eval(&self, ts: &mut TimeSeries<T>) -> Result<Vec<T>, EvaluatorError> {
142        self.check_ts_length(ts)?;
143
144        let (_, lower, upper) = Self::threshold(&mut ts.m)?;
145        let mut lower: DataSample<_> = lower.into();
146        let mut upper: DataSample<_> = upper.into();
147
148        let std_lower = if lower.sample.len() == 1 {
149            T::zero()
150        } else {
151            lower.get_std()
152        };
153        let mean_lower = lower.get_mean();
154
155        let std_upper = if upper.sample.len() == 1 {
156            T::zero()
157        } else {
158            upper.get_std()
159        };
160        let mean_upper = upper.get_mean();
161
162        let mean_diff = mean_upper - mean_lower;
163        let lower_to_all = lower.sample.len().approx_as::<T>().unwrap() / ts.lenf();
164
165        Ok(vec![mean_diff, std_lower, std_upper, lower_to_all])
166    }
167}
168
169#[cfg(test)]
170#[allow(clippy::unreadable_literal)]
171#[allow(clippy::excessive_precision)]
172mod tests {
173    use super::*;
174
175    use crate::tests::*;
176
177    use approx::assert_relative_eq;
178    use ndarray::array;
179
180    check_feature!(OtsuSplit);
181
182    feature_test!(
183        otsu_split,
184        [OtsuSplit::new()],
185        [
186            0.725,
187            0.012909944487358068,
188            0.07071067811865482,
189            0.6666666666666666
190        ],
191        [0.51, 0.52, 0.53, 0.54, 1.2, 1.3],
192    );
193
194    feature_test!(
195        otsu_split_min_observations,
196        [OtsuSplit::new()],
197        [0.01, 0.0, 0.0, 0.5],
198        [0.51, 0.52],
199    );
200
201    feature_test!(
202        otsu_split_lower,
203        [OtsuSplit::new()],
204        [1.0, 0.0, 0.0, 0.25],
205        [0.5, 1.5, 1.5, 1.5],
206    );
207
208    feature_test!(
209        otsu_split_upper,
210        [OtsuSplit::new()],
211        [1.0, 0.0, 0.0, 0.75],
212        [0.5, 0.5, 0.5, 1.5],
213    );
214
215    #[test]
216    fn otsu_threshold() {
217        let mut ds = vec![0.5, 0.5, 0.5, 1.5].into();
218        let (expected_threshold, expected_lower, expected_upper) =
219            (1.5, array![0.5, 0.5, 0.5], array![1.5]);
220        let (actual_threshold, actual_lower, actual_upper) =
221            OtsuSplit::threshold(&mut ds).expect("input is not flat");
222        assert_eq!(expected_threshold, actual_threshold);
223        assert_eq!(expected_lower, actual_lower);
224        assert_eq!(expected_upper, actual_upper);
225    }
226
227    #[test]
228    fn otsu_two_max() {
229        let mut ds = vec![-1.5, 0.5, 0.5, 1.5].into();
230        let (expected_threshold, expected_lower, expected_upper) =
231            (0.5, array![-1.5], array![0.5, 0.5, 1.5]);
232        let (actual_threshold, actual_lower, actual_upper) =
233            OtsuSplit::threshold(&mut ds).expect("input is not flat");
234        assert_eq!(expected_threshold, actual_threshold);
235        assert_eq!(expected_lower, actual_lower);
236        assert_eq!(expected_upper, actual_upper);
237    }
238
239    #[test]
240    fn otsu_split_plateau() {
241        let eval = OtsuSplit::new();
242        let x = [1.5, 1.5, 1.5, 1.5];
243        let mut ts = TimeSeries::new_without_weight(&x, &x);
244        assert_eq!(eval.eval(&mut ts), Err(EvaluatorError::FlatTimeSeries));
245    }
246
247    #[test]
248    fn otsu_split_small() {
249        let eval = OtsuSplit::new();
250        let mut ts = light_curve_feature_test_util::issue_light_curve_mag::<f32, _>(
251            "light-curve-feature-72/1.csv",
252        )
253        .into_triple(None)
254        .into();
255        let desired = [
256            3.0221021243981205,
257            0.8847146372743603,
258            0.8826366394647659,
259            0.507,
260        ];
261        let actual = eval.eval(&mut ts).unwrap();
262        assert_relative_eq!(&desired[..], &actual[..], epsilon = 1e-6);
263    }
264
265    #[test]
266    #[ignore] // This test takes a long time and requires lots of memory
267    fn no_overflow() {
268        // It should be large enough to trigger the overflow
269        const N: usize = (1 << 25) + 57;
270        let feature = OtsuSplit::new();
271        let t = Array1::linspace(0.0_f32, 1.0, N);
272        let mut ts = TimeSeries::new_without_weight(t.view(), t.view());
273        // This should not panic
274        let [mean_diff, _std_lower, _std_upper, lower_to_all]: [f32; 4] =
275            feature.eval(&mut ts).unwrap().try_into().unwrap();
276        assert_relative_eq!(mean_diff, 0.5, epsilon = 1e-3);
277        assert_relative_eq!(lower_to_all, 0.5, epsilon = 1e-6);
278    }
279}