Skip to main content

gbt_quantile/
tree.rs

1//! Tree model data structures: nodes, leaves, and the full GBT model.
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6/// A single decision tree node.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct TreeNode {
9    /// Index of the feature used for splitting.
10    pub feature_index: usize,
11    /// Threshold value: samples with `feature[feature_index] <= threshold` go left.
12    pub threshold: f64,
13    /// Left child (feature <= threshold).
14    pub left: NodeRef,
15    /// Right child (feature > threshold).
16    pub right: NodeRef,
17}
18
19/// A reference to a child node: either a leaf value or a subtree.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub enum NodeRef {
22    /// Terminal leaf with a prediction value.
23    Leaf(f64),
24    /// Internal node with further splits.
25    Node(Box<TreeNode>),
26}
27
28/// A gradient-boosted tree ensemble model.
29///
30/// Predictions are computed as: `base_score + learning_rate * sum(tree_predictions)`.
31/// Models are JSON-serializable with schema versioning for forward compatibility.
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct GradientBoostedTree {
34    /// Schema version for forward/backward compatibility.
35    #[serde(default = "default_schema_version")]
36    pub schema_version: u32,
37    /// The ensemble of decision trees.
38    pub trees: Vec<TreeNode>,
39    /// Initial prediction (bias / intercept).
40    pub base_score: f64,
41    /// Learning rate (shrinkage factor applied to each tree).
42    pub learning_rate: f64,
43    /// Feature names in order matching feature_index.
44    pub feature_names: Vec<String>,
45    /// Target quantile: `None` for L2/mean, `Some(q)` for quantile regression.
46    pub quantile: Option<f64>,
47    /// Output scale factor applied to the final prediction.
48    #[serde(default = "default_output_scale")]
49    pub output_scale: f64,
50    /// Optional metadata about the model.
51    pub metadata: Option<serde_json::Value>,
52}
53
54fn default_schema_version() -> u32 {
55    1
56}
57fn default_output_scale() -> f64 {
58    1.0
59}
60
61impl GradientBoostedTree {
62    /// Predict a single sample.
63    pub fn predict(&self, features: &[f64]) -> f64 {
64        let raw = self.trees.iter().fold(self.base_score, |acc, tree| {
65            acc + self.learning_rate * traverse_node(tree, features)
66        });
67        raw * self.output_scale
68    }
69
70    /// Predict a batch of samples.
71    pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<f64> {
72        feature_matrix.iter().map(|row| self.predict(row)).collect()
73    }
74
75    /// Number of trees in the ensemble.
76    pub fn n_trees(&self) -> usize {
77        self.trees.len()
78    }
79
80    /// Number of features the model expects.
81    pub fn n_features(&self) -> usize {
82        self.feature_names.len()
83    }
84
85    /// Compute feature importance based on split frequency.
86    /// Returns a map from feature index to normalized importance [0, 1].
87    pub fn feature_importance(&self) -> HashMap<usize, f64> {
88        let n = self.feature_names.len();
89        let mut counts = vec![0.0_f64; n];
90        for tree in &self.trees {
91            count_splits(tree, &mut counts);
92        }
93        let total: f64 = counts.iter().sum();
94        let mut importance = HashMap::new();
95        for (i, &count) in counts.iter().enumerate() {
96            let imp = if total > 0.0 {
97                count / total
98            } else {
99                1.0 / n as f64
100            };
101            importance.insert(i, imp);
102        }
103        importance
104    }
105
106    /// Compute named feature importance.
107    pub fn feature_importance_named(&self) -> HashMap<String, f64> {
108        let indexed = self.feature_importance();
109        indexed
110            .into_iter()
111            .map(|(i, v)| {
112                let name = self
113                    .feature_names
114                    .get(i)
115                    .cloned()
116                    .unwrap_or_else(|| format!("f{i}"));
117                (name, v)
118            })
119            .collect()
120    }
121
122    /// Serialize the model to JSON string.
123    pub fn to_json(&self) -> anyhow::Result<String> {
124        Ok(serde_json::to_string_pretty(self)?)
125    }
126
127    /// Deserialize a model from JSON string.
128    pub fn from_json(json: &str) -> anyhow::Result<Self> {
129        Ok(serde_json::from_str(json)?)
130    }
131
132    /// Serialize the model to bytes (UTF-8 JSON).
133    pub fn to_bytes(&self) -> anyhow::Result<Vec<u8>> {
134        Ok(serde_json::to_vec(self)?)
135    }
136
137    /// Deserialize a model from bytes (UTF-8 JSON).
138    pub fn from_bytes(data: &[u8]) -> anyhow::Result<Self> {
139        Ok(serde_json::from_slice(data)?)
140    }
141}
142
143/// Traverse a tree node to get the leaf prediction for a feature vector.
144pub(crate) fn traverse_node(node: &TreeNode, features: &[f64]) -> f64 {
145    let val = features.get(node.feature_index).copied().unwrap_or(0.0);
146    let child = if val <= node.threshold {
147        &node.left
148    } else {
149        &node.right
150    };
151    match child {
152        NodeRef::Leaf(v) => *v,
153        NodeRef::Node(next) => traverse_node(next, features),
154    }
155}
156
157fn count_splits(node: &TreeNode, counts: &mut [f64]) {
158    if node.feature_index < counts.len() {
159        counts[node.feature_index] += 1.0;
160    }
161    if let NodeRef::Node(ref child) = node.left {
162        count_splits(child, counts);
163    }
164    if let NodeRef::Node(ref child) = node.right {
165        count_splits(child, counts);
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    fn simple_model() -> GradientBoostedTree {
174        let tree = TreeNode {
175            feature_index: 0,
176            threshold: 5.0,
177            left: NodeRef::Leaf(1.0),
178            right: NodeRef::Leaf(10.0),
179        };
180        GradientBoostedTree {
181            schema_version: 1,
182            trees: vec![tree],
183            base_score: 0.0,
184            learning_rate: 1.0,
185            feature_names: vec!["x".to_string()],
186            quantile: None,
187            output_scale: 1.0,
188            metadata: None,
189        }
190    }
191
192    #[test]
193    fn test_predict_left_right() {
194        let model = simple_model();
195        assert_eq!(model.predict(&[3.0]), 1.0); // 3 <= 5 → left
196        assert_eq!(model.predict(&[7.0]), 10.0); // 7 > 5 → right
197        assert_eq!(model.predict(&[5.0]), 1.0); // 5 <= 5 → left
198    }
199
200    #[test]
201    fn test_predict_batch() {
202        let model = simple_model();
203        let preds = model.predict_batch(&[vec![3.0], vec![7.0]]);
204        assert_eq!(preds, vec![1.0, 10.0]);
205    }
206
207    #[test]
208    fn test_json_roundtrip() {
209        let model = simple_model();
210        let json = model.to_json().unwrap();
211        let restored = GradientBoostedTree::from_json(&json).unwrap();
212        assert_eq!(restored.predict(&[3.0]), model.predict(&[3.0]));
213        assert_eq!(restored.n_trees(), 1);
214    }
215
216    #[test]
217    fn test_bytes_roundtrip() {
218        let model = simple_model();
219        let bytes = model.to_bytes().unwrap();
220        let restored = GradientBoostedTree::from_bytes(&bytes).unwrap();
221        assert_eq!(restored.predict(&[7.0]), model.predict(&[7.0]));
222    }
223
224    #[test]
225    fn test_feature_importance() {
226        let model = simple_model();
227        let imp = model.feature_importance();
228        assert_eq!(imp[&0], 1.0); // only feature 0 is used
229    }
230}