xgboost_rs/parameters/
linear.rs

1//! `BoosterParameters` for configuring linear boosters.
2
3use std::default::Default;
4
5use derive_builder::Builder;
6
7#[cfg(feature = "use_serde")]
8use serde::{Deserialize, Serialize};
9
10/// Linear model algorithm.
11#[derive(Clone, Default)]
12#[cfg_attr(feature = "use_serde", derive(Deserialize, Serialize))]
13pub enum LinearUpdate {
14    /// Parallel coordinate descent algorithm based on shotgun algorithm. Uses ‘hogwild’ parallelism and
15    /// therefore produces a nondeterministic solution on each run.
16    #[default]
17    Shotgun,
18
19    /// Ordinary coordinate descent algorithm. Also multithreaded but still produces a deterministic solution.
20    CoordDescent,
21}
22
23impl ToString for LinearUpdate {
24    fn to_string(&self) -> String {
25        match *self {
26            LinearUpdate::Shotgun => "shotgun".to_owned(),
27            LinearUpdate::CoordDescent => "coord_descent".to_owned(),
28        }
29    }
30}
31
32/// `BoosterParameters` for Linear Booster.
33#[derive(Builder, Clone)]
34#[cfg_attr(feature = "use_serde", derive(Deserialize, Serialize))]
35#[builder(default)]
36pub struct LinearBoosterParameters {
37    /// L2 regularization term on weights, increase this value will make model more conservative.
38    /// Normalised to number of training examples.
39    ///
40    /// * default: 0.0
41    lambda: f32,
42
43    /// L1 egularization term on weights, increase this value will make model more conservative.
44    /// Normalised to number of training examples.
45    ///
46    /// * default: 0.0
47    alpha: f32,
48
49    /// Linear model algorithm.
50    ///
51    /// * default: `LinearUpdate::Shotgun`
52    updater: LinearUpdate,
53}
54
55impl LinearBoosterParameters {
56    pub(crate) fn as_string_pairs(&self) -> Vec<(String, String)> {
57        let v = vec![
58            ("booster".to_owned(), "gblinear".to_owned()),
59            ("lambda".to_owned(), self.lambda.to_string()),
60            ("alpha".to_owned(), self.alpha.to_string()),
61            ("updater".to_owned(), self.updater.to_string()),
62        ];
63
64        v
65    }
66}
67
68impl Default for LinearBoosterParameters {
69    fn default() -> Self {
70        LinearBoosterParameters {
71            lambda: 0.0,
72            alpha: 0.0,
73            updater: LinearUpdate::default(),
74        }
75    }
76}