1use rand::rngs::StdRng;
2use rand::Rng;
3use serde::{Deserialize, Serialize};
4
5#[derive(Serialize, Deserialize)]
6pub enum SampleMethod {
7 None,
8 Random,
9 Goss,
10}
11
12pub trait Sampler {
14 fn sample(
17 &mut self,
18 rng: &mut StdRng,
19 index: &[usize],
20 grad: &mut [f32],
21 hess: &mut [f32],
22 ) -> (Vec<usize>, Vec<usize>);
23}
24
25pub struct RandomSampler {
26 subsample: f32,
27}
28
29impl RandomSampler {
30 #[allow(dead_code)]
31 pub fn new(subsample: f32) -> Self {
32 RandomSampler { subsample }
33 }
34}
35
36impl Sampler for RandomSampler {
37 fn sample(
38 &mut self,
39 rng: &mut StdRng,
40 index: &[usize],
41 _grad: &mut [f32],
42 _hess: &mut [f32],
43 ) -> (Vec<usize>, Vec<usize>) {
44 let subsample = self.subsample;
45 let mut chosen = Vec::new();
46 let mut excluded = Vec::new();
47 for i in index {
48 if rng.gen_range(0.0..1.0) < subsample {
49 chosen.push(*i);
50 } else {
51 excluded.push(*i)
52 }
53 }
54 (chosen, excluded)
55 }
56}
57
58#[allow(dead_code)]
59pub struct GossSampler {
60 a: f64, b: f64, }
63
64impl Default for GossSampler {
65 fn default() -> Self {
66 GossSampler { a: 0.2, b: 0.1 }
67 }
68}
69
70#[allow(dead_code)]
71impl GossSampler {
72 pub fn new(a: f64, b: f64) -> Self {
73 GossSampler { a, b }
74 }
75}
76
77impl Sampler for GossSampler {
78 fn sample(
79 &mut self,
80 rng: &mut StdRng,
81 index: &[usize],
82 grad: &mut [f32],
83 hess: &mut [f32],
84 ) -> (Vec<usize>, Vec<usize>) {
85 let fact = ((1. - self.a) / self.b) as f32;
86 let top_n = (self.a * index.len() as f64) as usize;
87 let rand_n = (self.b * index.len() as f64) as usize;
88
89 let mut sorted = (0..index.len()).collect::<Vec<_>>();
91 sorted.sort_unstable_by(|&a, &b| grad[b].abs().total_cmp(&grad[a].abs()));
92
93 let mut used_set = sorted[0..top_n].to_vec();
95
96 let subsample = rand_n as f64 / (index.len() as f64 - top_n as f64);
98
99 for i in &sorted[top_n..sorted.len()] {
101 if rng.gen_range(0.0..1.0) < subsample {
102 grad[*i] *= fact;
103 hess[*i] *= fact;
104 used_set.push(*i);
105 }
106 }
107
108 (used_set, Vec::new())
109 }
110}