1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct TreeNode {
9 pub feature_index: usize,
11 pub threshold: f64,
13 pub left: NodeRef,
15 pub right: NodeRef,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub enum NodeRef {
22 Leaf(f64),
24 Node(Box<TreeNode>),
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct GradientBoostedTree {
34 #[serde(default = "default_schema_version")]
36 pub schema_version: u32,
37 pub trees: Vec<TreeNode>,
39 pub base_score: f64,
41 pub learning_rate: f64,
43 pub feature_names: Vec<String>,
45 pub quantile: Option<f64>,
47 #[serde(default = "default_output_scale")]
49 pub output_scale: f64,
50 pub metadata: Option<serde_json::Value>,
52}
53
54fn default_schema_version() -> u32 {
55 1
56}
57fn default_output_scale() -> f64 {
58 1.0
59}
60
61impl GradientBoostedTree {
62 pub fn predict(&self, features: &[f64]) -> f64 {
64 let raw = self.trees.iter().fold(self.base_score, |acc, tree| {
65 acc + self.learning_rate * traverse_node(tree, features)
66 });
67 raw * self.output_scale
68 }
69
70 pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<f64> {
72 feature_matrix.iter().map(|row| self.predict(row)).collect()
73 }
74
75 pub fn n_trees(&self) -> usize {
77 self.trees.len()
78 }
79
80 pub fn n_features(&self) -> usize {
82 self.feature_names.len()
83 }
84
85 pub fn feature_importance(&self) -> HashMap<usize, f64> {
88 let n = self.feature_names.len();
89 let mut counts = vec![0.0_f64; n];
90 for tree in &self.trees {
91 count_splits(tree, &mut counts);
92 }
93 let total: f64 = counts.iter().sum();
94 let mut importance = HashMap::new();
95 for (i, &count) in counts.iter().enumerate() {
96 let imp = if total > 0.0 {
97 count / total
98 } else {
99 1.0 / n as f64
100 };
101 importance.insert(i, imp);
102 }
103 importance
104 }
105
106 pub fn feature_importance_named(&self) -> HashMap<String, f64> {
108 let indexed = self.feature_importance();
109 indexed
110 .into_iter()
111 .map(|(i, v)| {
112 let name = self
113 .feature_names
114 .get(i)
115 .cloned()
116 .unwrap_or_else(|| format!("f{i}"));
117 (name, v)
118 })
119 .collect()
120 }
121
122 pub fn to_json(&self) -> anyhow::Result<String> {
124 Ok(serde_json::to_string_pretty(self)?)
125 }
126
127 pub fn from_json(json: &str) -> anyhow::Result<Self> {
129 Ok(serde_json::from_str(json)?)
130 }
131
132 pub fn to_bytes(&self) -> anyhow::Result<Vec<u8>> {
134 Ok(serde_json::to_vec(self)?)
135 }
136
137 pub fn from_bytes(data: &[u8]) -> anyhow::Result<Self> {
139 Ok(serde_json::from_slice(data)?)
140 }
141}
142
143pub(crate) fn traverse_node(node: &TreeNode, features: &[f64]) -> f64 {
145 let val = features.get(node.feature_index).copied().unwrap_or(0.0);
146 let child = if val <= node.threshold {
147 &node.left
148 } else {
149 &node.right
150 };
151 match child {
152 NodeRef::Leaf(v) => *v,
153 NodeRef::Node(next) => traverse_node(next, features),
154 }
155}
156
157fn count_splits(node: &TreeNode, counts: &mut [f64]) {
158 if node.feature_index < counts.len() {
159 counts[node.feature_index] += 1.0;
160 }
161 if let NodeRef::Node(ref child) = node.left {
162 count_splits(child, counts);
163 }
164 if let NodeRef::Node(ref child) = node.right {
165 count_splits(child, counts);
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 fn simple_model() -> GradientBoostedTree {
174 let tree = TreeNode {
175 feature_index: 0,
176 threshold: 5.0,
177 left: NodeRef::Leaf(1.0),
178 right: NodeRef::Leaf(10.0),
179 };
180 GradientBoostedTree {
181 schema_version: 1,
182 trees: vec![tree],
183 base_score: 0.0,
184 learning_rate: 1.0,
185 feature_names: vec!["x".to_string()],
186 quantile: None,
187 output_scale: 1.0,
188 metadata: None,
189 }
190 }
191
192 #[test]
193 fn test_predict_left_right() {
194 let model = simple_model();
195 assert_eq!(model.predict(&[3.0]), 1.0); assert_eq!(model.predict(&[7.0]), 10.0); assert_eq!(model.predict(&[5.0]), 1.0); }
199
200 #[test]
201 fn test_predict_batch() {
202 let model = simple_model();
203 let preds = model.predict_batch(&[vec![3.0], vec![7.0]]);
204 assert_eq!(preds, vec![1.0, 10.0]);
205 }
206
207 #[test]
208 fn test_json_roundtrip() {
209 let model = simple_model();
210 let json = model.to_json().unwrap();
211 let restored = GradientBoostedTree::from_json(&json).unwrap();
212 assert_eq!(restored.predict(&[3.0]), model.predict(&[3.0]));
213 assert_eq!(restored.n_trees(), 1);
214 }
215
216 #[test]
217 fn test_bytes_roundtrip() {
218 let model = simple_model();
219 let bytes = model.to_bytes().unwrap();
220 let restored = GradientBoostedTree::from_bytes(&bytes).unwrap();
221 assert_eq!(restored.predict(&[7.0]), model.predict(&[7.0]));
222 }
223
224 #[test]
225 fn test_feature_importance() {
226 let model = simple_model();
227 let imp = model.feature_importance();
228 assert_eq!(imp[&0], 1.0); }
230}