1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
use crate::errors::ForustError;
use crate::utils::items_to_strings;
use rand::rngs::StdRng;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::str::FromStr;

#[derive(Serialize, Deserialize)]
pub enum SampleMethod {
    None,
    Random,
    Goss,
}

impl FromStr for SampleMethod {
    type Err = ForustError;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        match s {
            "random" => Ok(SampleMethod::Random),
            "goss" => Ok(SampleMethod::Goss),
            _ => Err(ForustError::ParseString(
                s.to_string(),
                "SampleMethod".to_string(),
                items_to_strings(vec!["random", "goss"]),
            )),
        }
    }
}

// A sampler can be used to subset the data prior to fitting a new tree.
pub trait Sampler {
    /// Sample the data, returning a tuple, where the first item is the samples
    /// chosen for training, and the second are the samples excluded.
    fn sample(&mut self, rng: &mut StdRng, index: &[usize]) -> (Vec<usize>, Vec<usize>);
}

pub struct RandomSampler {
    subsample: f32,
}

impl RandomSampler {
    #[allow(dead_code)]
    pub fn new(subsample: f32) -> Self {
        RandomSampler { subsample }
    }
}

impl Sampler for RandomSampler {
    fn sample(&mut self, rng: &mut StdRng, index: &[usize]) -> (Vec<usize>, Vec<usize>) {
        let subsample = self.subsample;
        let mut chosen = Vec::new();
        let mut excluded = Vec::new();
        for i in index {
            if rng.gen_range(0.0..1.0) < subsample {
                chosen.push(*i);
            } else {
                excluded.push(*i)
            }
        }
        (chosen, excluded)
    }
}

#[allow(dead_code)]
pub struct GossSampler<'a> {
    gradient: Option<&'a [f64]>,
}

impl<'a> Default for GossSampler<'a> {
    fn default() -> Self {
        Self::new()
    }
}

#[allow(dead_code)]
impl<'a> GossSampler<'a> {
    pub fn new() -> Self {
        GossSampler { gradient: None }
    }
    pub fn add_gradient(&mut self, gradient: &'a [f64]) {
        self.gradient = Some(gradient);
    }
}

impl<'a> Sampler for GossSampler<'a> {
    #[allow(unused_variables)]
    fn sample(&mut self, rng: &mut StdRng, index: &[usize]) -> (Vec<usize>, Vec<usize>) {
        todo!()
    }
}