1use ghostflow_core::Tensor;
10use rand::prelude::*;
11use rayon::prelude::*;
12use std::collections::HashMap;
13
14const MAX_BINS: usize = 255;
15
16pub struct LightGBMClassifier {
18 pub n_estimators: usize,
19 pub learning_rate: f32,
20 pub num_leaves: usize,
21 pub max_depth: i32,
22 pub min_data_in_leaf: usize,
23 pub min_sum_hessian_in_leaf: f32,
24 pub feature_fraction: f32,
25 pub bagging_fraction: f32,
26 pub bagging_freq: usize,
27 pub lambda_l1: f32,
28 pub lambda_l2: f32,
29 pub max_bin: usize,
30 trees: Vec<LGBMTree>,
31 feature_bins: Vec<Vec<f32>>,
32 base_score: f32,
33}
34
35#[derive(Clone)]
36struct LGBMTree {
37 nodes: Vec<LGBMNode>,
38}
39
40#[derive(Clone)]
41struct LGBMNode {
42 is_leaf: bool,
43 feature: usize,
44 threshold_bin: usize,
45 left: usize,
46 right: usize,
47 value: f32,
48 split_gain: f32,
49}
50
51impl LightGBMClassifier {
52 pub fn new(n_estimators: usize) -> Self {
53 Self {
54 n_estimators,
55 learning_rate: 0.1,
56 num_leaves: 31,
57 max_depth: -1, min_data_in_leaf: 20,
59 min_sum_hessian_in_leaf: 1e-3,
60 feature_fraction: 1.0,
61 bagging_fraction: 1.0,
62 bagging_freq: 0,
63 lambda_l1: 0.0,
64 lambda_l2: 0.0,
65 max_bin: MAX_BINS,
66 trees: Vec::new(),
67 feature_bins: Vec::new(),
68 base_score: 0.5,
69 }
70 }
71
72 pub fn learning_rate(mut self, lr: f32) -> Self {
73 self.learning_rate = lr;
74 self
75 }
76
77 pub fn num_leaves(mut self, leaves: usize) -> Self {
78 self.num_leaves = leaves;
79 self
80 }
81
82 pub fn max_depth(mut self, depth: i32) -> Self {
83 self.max_depth = depth;
84 self
85 }
86
87 pub fn feature_fraction(mut self, fraction: f32) -> Self {
88 self.feature_fraction = fraction;
89 self
90 }
91
92 fn build_histograms(&mut self, x: &Tensor) {
94 let n_samples = x.dims()[0];
95 let n_features = x.dims()[1];
96 let x_data = x.data_f32();
97
98 self.feature_bins = (0..n_features)
99 .into_par_iter()
100 .map(|feature| {
101 let mut values: Vec<f32> = (0..n_samples)
103 .map(|i| x_data[i * n_features + feature])
104 .collect();
105
106 values.sort_by(|a, b| a.partial_cmp(b).unwrap());
107 values.dedup();
108
109 if values.len() <= self.max_bin {
111 values
112 } else {
113 let step = values.len() / self.max_bin;
115 (0..self.max_bin)
116 .map(|i| values[i * step])
117 .collect()
118 }
119 })
120 .collect();
121 }
122
123 fn value_to_bin(&self, feature: usize, value: f32) -> usize {
125 let bins = &self.feature_bins[feature];
126
127 match bins.binary_search_by(|&bin| bin.partial_cmp(&value).unwrap()) {
129 Ok(idx) => idx,
130 Err(idx) => idx.saturating_sub(1),
131 }
132 }
133
134 pub fn fit(&mut self, x: &Tensor, y: &Tensor) {
135 let n_samples = x.dims()[0];
136 let n_features = x.dims()[1];
137 let x_data = x.data_f32();
138 let y_data = y.data_f32();
139
140 self.build_histograms(x);
142
143 let mut predictions = vec![self.base_score; n_samples];
145
146 for iter in 0..self.n_estimators {
148 let mut gradients = Vec::with_capacity(n_samples);
150 let mut hessians = Vec::with_capacity(n_samples);
151
152 for i in 0..n_samples {
153 let pred = 1.0 / (1.0 + (-predictions[i]).exp());
154 let grad = pred - y_data[i];
155 let hess = pred * (1.0 - pred);
156 gradients.push(grad);
157 hessians.push(hess.max(1e-6));
158 }
159
160 let mut rng = thread_rng();
162 let sample_indices: Vec<usize> = if self.bagging_freq > 0 && iter % self.bagging_freq == 0 {
163 let n_subsample = (n_samples as f32 * self.bagging_fraction) as usize;
164 (0..n_samples).choose_multiple(&mut rng, n_subsample)
165 } else {
166 (0..n_samples).collect()
167 };
168
169 let feature_indices: Vec<usize> = if self.feature_fraction < 1.0 {
171 let n_features_sample = (n_features as f32 * self.feature_fraction) as usize;
172 (0..n_features).choose_multiple(&mut rng, n_features_sample)
173 } else {
174 (0..n_features).collect()
175 };
176
177 let tree = self.build_tree_leafwise(
179 &x_data,
180 &gradients,
181 &hessians,
182 &sample_indices,
183 &feature_indices,
184 n_features,
185 );
186
187 for i in 0..n_samples {
189 let leaf_value = self.predict_tree(&tree, &x_data[i * n_features..(i + 1) * n_features]);
190 predictions[i] += self.learning_rate * leaf_value;
191 }
192
193 self.trees.push(tree);
194 }
195 }
196
197 fn build_tree_leafwise(
199 &self,
200 x_data: &[f32],
201 gradients: &[f32],
202 hessians: &[f32],
203 sample_indices: &[usize],
204 feature_indices: &[usize],
205 n_features: usize,
206 ) -> LGBMTree {
207 let mut nodes = Vec::new();
208 let mut leaf_queue: Vec<(Vec<usize>, usize, usize)> = Vec::new(); let root_idx = nodes.len();
212 let sum_grad: f32 = sample_indices.iter().map(|&i| gradients[i]).sum();
213 let sum_hess: f32 = sample_indices.iter().map(|&i| hessians[i]).sum();
214 let root_value = self.calculate_leaf_value(sum_grad, sum_hess);
215
216 nodes.push(LGBMNode {
217 is_leaf: true,
218 feature: 0,
219 threshold_bin: 0,
220 left: 0,
221 right: 0,
222 value: root_value,
223 split_gain: 0.0,
224 });
225
226 leaf_queue.push((sample_indices.to_vec(), root_idx, 0));
227
228 while nodes.len() < self.num_leaves && !leaf_queue.is_empty() {
230 let (best_idx, best_split) = self.find_best_leaf_to_split(
232 x_data,
233 gradients,
234 hessians,
235 &leaf_queue,
236 feature_indices,
237 n_features,
238 );
239
240 if best_split.is_none() {
241 break;
242 }
243
244 let (_samples, node_idx, depth) = leaf_queue.remove(best_idx);
245 let (feature, threshold_bin, gain, left_samples, right_samples) = best_split.unwrap();
246
247 if self.max_depth > 0 && depth >= self.max_depth as usize {
249 continue;
250 }
251
252 nodes[node_idx].is_leaf = false;
254 nodes[node_idx].feature = feature;
255 nodes[node_idx].threshold_bin = threshold_bin;
256 nodes[node_idx].split_gain = gain;
257
258 let left_idx = nodes.len();
260 let left_sum_grad: f32 = left_samples.iter().map(|&i| gradients[i]).sum();
261 let left_sum_hess: f32 = left_samples.iter().map(|&i| hessians[i]).sum();
262 let left_value = self.calculate_leaf_value(left_sum_grad, left_sum_hess);
263
264 nodes.push(LGBMNode {
265 is_leaf: true,
266 feature: 0,
267 threshold_bin: 0,
268 left: 0,
269 right: 0,
270 value: left_value,
271 split_gain: 0.0,
272 });
273
274 let right_idx = nodes.len();
276 let right_sum_grad: f32 = right_samples.iter().map(|&i| gradients[i]).sum();
277 let right_sum_hess: f32 = right_samples.iter().map(|&i| hessians[i]).sum();
278 let right_value = self.calculate_leaf_value(right_sum_grad, right_sum_hess);
279
280 nodes.push(LGBMNode {
281 is_leaf: true,
282 feature: 0,
283 threshold_bin: 0,
284 left: 0,
285 right: 0,
286 value: right_value,
287 split_gain: 0.0,
288 });
289
290 nodes[node_idx].left = left_idx;
292 nodes[node_idx].right = right_idx;
293
294 if left_samples.len() >= self.min_data_in_leaf {
296 leaf_queue.push((left_samples, left_idx, depth + 1));
297 }
298 if right_samples.len() >= self.min_data_in_leaf {
299 leaf_queue.push((right_samples, right_idx, depth + 1));
300 }
301 }
302
303 LGBMTree { nodes }
304 }
305
306 fn find_best_leaf_to_split(
307 &self,
308 x_data: &[f32],
309 gradients: &[f32],
310 hessians: &[f32],
311 leaf_queue: &[(Vec<usize>, usize, usize)],
312 feature_indices: &[usize],
313 n_features: usize,
314 ) -> (usize, Option<(usize, usize, f32, Vec<usize>, Vec<usize>)>) {
315 let mut best_leaf_idx = 0;
316 let mut best_split: Option<(usize, usize, f32, Vec<usize>, Vec<usize>)> = None;
317 let mut best_gain = 0.0;
318
319 for (idx, (samples, _, _)) in leaf_queue.iter().enumerate() {
320 if samples.len() < self.min_data_in_leaf * 2 {
321 continue;
322 }
323
324 let split = self.find_best_split_histogram(
325 x_data,
326 gradients,
327 hessians,
328 samples,
329 feature_indices,
330 n_features,
331 );
332
333 if let Some((_, _, gain, _, _)) = &split {
334 if *gain > best_gain {
335 best_gain = *gain;
336 best_split = split;
337 best_leaf_idx = idx;
338 }
339 }
340 }
341
342 (best_leaf_idx, best_split)
343 }
344
345 fn find_best_split_histogram(
346 &self,
347 x_data: &[f32],
348 gradients: &[f32],
349 hessians: &[f32],
350 sample_indices: &[usize],
351 feature_indices: &[usize],
352 n_features: usize,
353 ) -> Option<(usize, usize, f32, Vec<usize>, Vec<usize>)> {
354 let sum_grad: f32 = sample_indices.iter().map(|&i| gradients[i]).sum();
355 let sum_hess: f32 = sample_indices.iter().map(|&i| hessians[i]).sum();
356
357 let mut best_gain = 0.0;
358 let mut best_feature = 0;
359 let mut best_bin = 0;
360 let mut best_left = Vec::new();
361 let mut best_right = Vec::new();
362
363 for &feature in feature_indices {
364 let n_bins = self.feature_bins[feature].len();
365
366 let mut hist_grad = vec![0.0; n_bins];
368 let mut hist_hess = vec![0.0; n_bins];
369 let mut bin_samples: Vec<Vec<usize>> = vec![Vec::new(); n_bins];
370
371 for &idx in sample_indices {
372 let value = x_data[idx * n_features + feature];
373 let bin = self.value_to_bin(feature, value);
374 hist_grad[bin] += gradients[idx];
375 hist_hess[bin] += hessians[idx];
376 bin_samples[bin].push(idx);
377 }
378
379 let mut left_grad = 0.0;
381 let mut left_hess = 0.0;
382 let mut left_samples = Vec::new();
383
384 for bin in 0..n_bins - 1 {
385 left_grad += hist_grad[bin];
386 left_hess += hist_hess[bin];
387 left_samples.extend(&bin_samples[bin]);
388
389 let right_grad = sum_grad - left_grad;
390 let right_hess = sum_hess - left_hess;
391
392 if left_hess < self.min_sum_hessian_in_leaf || right_hess < self.min_sum_hessian_in_leaf {
394 continue;
395 }
396
397 let gain = self.calculate_split_gain(left_grad, left_hess, right_grad, right_hess, sum_grad, sum_hess);
398
399 if gain > best_gain {
400 best_gain = gain;
401 best_feature = feature;
402 best_bin = bin;
403 best_left = left_samples.clone();
404 best_right = sample_indices.iter()
405 .filter(|&&idx| !best_left.contains(&idx))
406 .copied()
407 .collect();
408 }
409 }
410 }
411
412 if best_gain > 0.0 {
413 Some((best_feature, best_bin, best_gain, best_left, best_right))
414 } else {
415 None
416 }
417 }
418
419 fn calculate_leaf_value(&self, sum_grad: f32, sum_hess: f32) -> f32 {
420 -sum_grad / (sum_hess + self.lambda_l2)
421 }
422
423 fn calculate_split_gain(
424 &self,
425 left_grad: f32,
426 left_hess: f32,
427 right_grad: f32,
428 right_hess: f32,
429 sum_grad: f32,
430 sum_hess: f32,
431 ) -> f32 {
432 let left_weight = -left_grad / (left_hess + self.lambda_l2);
433 let right_weight = -right_grad / (right_hess + self.lambda_l2);
434 let parent_weight = -sum_grad / (sum_hess + self.lambda_l2);
435
436 let gain = 0.5 * (
437 left_grad * left_weight +
438 right_grad * right_weight -
439 sum_grad * parent_weight
440 );
441
442 let l1_penalty = self.lambda_l1 * (left_weight.abs() + right_weight.abs() - parent_weight.abs());
444
445 gain - l1_penalty
446 }
447
448 fn predict_tree(&self, tree: &LGBMTree, sample: &[f32]) -> f32 {
449 let mut node_idx = 0;
450
451 loop {
452 let node = &tree.nodes[node_idx];
453
454 if node.is_leaf {
455 return node.value;
456 }
457
458 let bin = self.value_to_bin(node.feature, sample[node.feature]);
459 if bin <= node.threshold_bin {
460 node_idx = node.left;
461 } else {
462 node_idx = node.right;
463 }
464 }
465 }
466
467 pub fn predict(&self, x: &Tensor) -> Tensor {
468 let n_samples = x.dims()[0];
469 let n_features = x.dims()[1];
470 let x_data = x.data_f32();
471
472 let predictions: Vec<f32> = (0..n_samples)
473 .into_par_iter()
474 .map(|i| {
475 let sample = &x_data[i * n_features..(i + 1) * n_features];
476 let mut pred = self.base_score;
477
478 for tree in &self.trees {
479 pred += self.learning_rate * self.predict_tree(tree, sample);
480 }
481
482 let prob = 1.0 / (1.0 + (-pred).exp());
483 if prob >= 0.5 { 1.0 } else { 0.0 }
484 })
485 .collect();
486
487 Tensor::from_slice(&predictions, &[n_samples]).unwrap()
488 }
489
490 pub fn predict_proba(&self, x: &Tensor) -> Tensor {
491 let n_samples = x.dims()[0];
492 let n_features = x.dims()[1];
493 let x_data = x.data_f32();
494
495 let probabilities: Vec<f32> = (0..n_samples)
496 .into_par_iter()
497 .map(|i| {
498 let sample = &x_data[i * n_features..(i + 1) * n_features];
499 let mut pred = self.base_score;
500
501 for tree in &self.trees {
502 pred += self.learning_rate * self.predict_tree(tree, sample);
503 }
504
505 1.0 / (1.0 + (-pred).exp())
506 })
507 .collect();
508
509 Tensor::from_slice(&probabilities, &[n_samples]).unwrap()
510 }
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516
517 #[test]
518 fn test_lightgbm_classifier() {
519 let x = Tensor::from_slice(
520 &[0.0f32, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0],
521 &[4, 2],
522 ).unwrap();
523 let y = Tensor::from_slice(&[0.0f32, 1.0, 1.0, 0.0], &[4]).unwrap();
524
525 let mut lgbm = LightGBMClassifier::new(10)
526 .learning_rate(0.1)
527 .num_leaves(7);
528
529 lgbm.fit(&x, &y);
530 let predictions = lgbm.predict(&x);
531
532 assert_eq!(predictions.dims()[0], 4); }
534}
535
536