1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use serde::{Deserialize, Serialize};
use crate::error::{Error, Result};
/// Hyperparameters for GBDT training. Names match LightGBM conventions.
///
/// Construct with `Config { num_iterations: 100, ..Config::default() }`.
/// Call [`Config::validate`] (the trainer does so automatically at the top of
/// `fit`) to surface out-of-range values as [`crate::Error::Config`] instead of
/// silently misbehaving.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
/// Maximum number of boosting rounds (trees) to fit. With
/// `early_stopping_round > 0` and a validation set, training may stop
/// sooner and the model is truncated to the best iteration.
pub num_iterations: usize,
/// Shrinkage applied to each tree's leaf values when added to the running
/// score: `score += learning_rate * leaf_value`. Smaller values need more
/// iterations but generalize better. Must be in `(0, 1]`.
pub learning_rate: f64,
/// Maximum number of leaves per tree. Together with `max_depth`, controls
/// tree complexity. Must be `>= 2`.
pub num_leaves: usize,
/// Maximum tree depth. `-1` means unlimited (let `num_leaves` cap things).
/// Use a positive cap to discourage very deep, narrow trees on noisy data.
pub max_depth: i32,
/// Minimum number of samples in a leaf. Splits that would produce a leaf
/// with fewer rows are rejected. Higher = stronger regularization.
pub min_data_in_leaf: usize,
/// Minimum sum of hessians in a leaf. Cheaper proxy for "this leaf has
/// enough signal to be worth splitting"; complements `min_data_in_leaf`.
pub min_sum_hessian_in_leaf: f64,
/// L1 regularization on leaf values (soft-threshold in the gain formula).
pub lambda_l1: f64,
/// L2 regularization on leaf values (denominator term in the gain formula).
pub lambda_l2: f64,
/// Minimum loss reduction required to accept a split. Defaults to `0.0`;
/// raise it to prune trivially-improving splits on noisy datasets.
pub min_gain_to_split: f64,
/// Maximum bins per numerical feature (includes one slot for the MISSING
/// bin, so at most `max_bin - 1` real bins). Must be in `[2, 65535]`.
pub max_bin: usize,
/// Minimum samples per histogram bin when fitting numerical bin mappers.
/// Bins with fewer samples get merged into neighbors.
pub min_data_in_bin: usize,
/// Per-tree feature subsample fraction in `(0, 1]`. `1.0` disables.
pub feature_fraction: f64,
/// Row subsample fraction in `(0, 1]`. `1.0` disables. Only takes effect
/// when `bagging_freq > 0`.
pub bagging_fraction: f64,
/// Re-sample bagged rows every `K` iterations. `0` disables bagging.
pub bagging_freq: usize,
/// Stop after this many iterations without validation-metric improvement.
/// Only active when a validation set is passed to `fit`. `0` disables.
/// On stop, the model is truncated to `best_iter + 1` trees.
pub early_stopping_round: usize,
/// Seed for the `ChaCha8Rng` used for bagging and feature subsampling.
/// Same config + same data => byte-identical model.
pub seed: u64,
/// If `true`, log per-iteration validation/training metric and an
/// end-of-fit timing breakdown to stderr.
pub verbose: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
num_iterations: 100,
learning_rate: 0.1,
num_leaves: 31,
max_depth: -1,
min_data_in_leaf: 20,
min_sum_hessian_in_leaf: 1e-3,
lambda_l1: 0.0,
lambda_l2: 0.0,
min_gain_to_split: 0.0,
max_bin: 255,
min_data_in_bin: 3,
feature_fraction: 1.0,
bagging_fraction: 1.0,
bagging_freq: 0,
early_stopping_round: 0,
seed: 0,
verbose: false,
}
}
}
impl Config {
pub fn validate(&self) -> Result<()> {
if self.num_leaves < 2 {
return Err(Error::Config("num_leaves must be >= 2".into()));
}
if self.max_bin < 2 || self.max_bin > 65535 {
return Err(Error::Config("max_bin must be in [2, 65535]".into()));
}
if !(0.0 < self.learning_rate && self.learning_rate <= 1.0) {
return Err(Error::Config("learning_rate must be in (0, 1]".into()));
}
if !(0.0 < self.feature_fraction && self.feature_fraction <= 1.0) {
return Err(Error::Config("feature_fraction must be in (0, 1]".into()));
}
if !(0.0 < self.bagging_fraction && self.bagging_fraction <= 1.0) {
return Err(Error::Config("bagging_fraction must be in (0, 1]".into()));
}
Ok(())
}
}