forust_ml/
sampler.rs

1use rand::rngs::StdRng;
2use rand::Rng;
3use serde::{Deserialize, Serialize};
4
5#[derive(Serialize, Deserialize)]
6pub enum SampleMethod {
7    None,
8    Random,
9    Goss,
10}
11
12// A sampler can be used to subset the data prior to fitting a new tree.
13pub trait Sampler {
14    /// Sample the data, returning a tuple, where the first item is the samples
15    /// chosen for training, and the second are the samples excluded.
16    fn sample(
17        &mut self,
18        rng: &mut StdRng,
19        index: &[usize],
20        grad: &mut [f32],
21        hess: &mut [f32],
22    ) -> (Vec<usize>, Vec<usize>);
23}
24
25pub struct RandomSampler {
26    subsample: f32,
27}
28
29impl RandomSampler {
30    #[allow(dead_code)]
31    pub fn new(subsample: f32) -> Self {
32        RandomSampler { subsample }
33    }
34}
35
36impl Sampler for RandomSampler {
37    fn sample(
38        &mut self,
39        rng: &mut StdRng,
40        index: &[usize],
41        _grad: &mut [f32],
42        _hess: &mut [f32],
43    ) -> (Vec<usize>, Vec<usize>) {
44        let subsample = self.subsample;
45        let mut chosen = Vec::new();
46        let mut excluded = Vec::new();
47        for i in index {
48            if rng.gen_range(0.0..1.0) < subsample {
49                chosen.push(*i);
50            } else {
51                excluded.push(*i)
52            }
53        }
54        (chosen, excluded)
55    }
56}
57
58#[allow(dead_code)]
59pub struct GossSampler {
60    a: f64, // https://lightgbm.readthedocs.io/en/latest/Parameters.html#top_rate
61    b: f64, // https://lightgbm.readthedocs.io/en/latest/Parameters.html#other_rate
62}
63
64impl Default for GossSampler {
65    fn default() -> Self {
66        GossSampler { a: 0.2, b: 0.1 }
67    }
68}
69
70#[allow(dead_code)]
71impl GossSampler {
72    pub fn new(a: f64, b: f64) -> Self {
73        GossSampler { a, b }
74    }
75}
76
77impl Sampler for GossSampler {
78    fn sample(
79        &mut self,
80        rng: &mut StdRng,
81        index: &[usize],
82        grad: &mut [f32],
83        hess: &mut [f32],
84    ) -> (Vec<usize>, Vec<usize>) {
85        let fact = ((1. - self.a) / self.b) as f32;
86        let top_n = (self.a * index.len() as f64) as usize;
87        let rand_n = (self.b * index.len() as f64) as usize;
88
89        // sort gradient by absolute value from highest to lowest
90        let mut sorted = (0..index.len()).collect::<Vec<_>>();
91        sorted.sort_unstable_by(|&a, &b| grad[b].abs().total_cmp(&grad[a].abs()));
92
93        // select the topN largest gradients
94        let mut used_set = sorted[0..top_n].to_vec();
95
96        // sample the rest based on randN
97        let subsample = rand_n as f64 / (index.len() as f64 - top_n as f64);
98
99        // weight the sampled "small gradients" by fact and append indices to used_set
100        for i in &sorted[top_n..sorted.len()] {
101            if rng.gen_range(0.0..1.0) < subsample {
102                grad[*i] *= fact;
103                hess[*i] *= fact;
104                used_set.push(*i);
105            }
106        }
107
108        (used_set, Vec::new())
109    }
110}