1use std::collections::{HashMap, HashSet};
5
6use linfa::Dataset;
7use linfa::prelude::*;
8use ndarray::{Array1, Array2, Axis};
9
10use crate::model::{LinfaKind, ModelMeta, TlModel};
11use crate::tensor::TlTensor;
12
13pub struct TrainConfig {
15 pub features: TlTensor,
17 pub target: TlTensor,
19 pub feature_names: Vec<String>,
21 pub target_name: String,
23 pub model_name: String,
25 pub split_ratio: f64,
27 pub hyperparams: HashMap<String, f64>,
29}
30
31pub fn train(algorithm: &str, config: &TrainConfig) -> Result<TlModel, String> {
33 match algorithm {
34 "linear" => train_linear(config),
35 "logistic" => train_logistic(config),
36 "tree" | "decision_tree" => train_decision_tree(config),
37 "random_forest" | "forest" => train_random_forest(config),
38 "kmeans" | "k_means" => train_kmeans(config),
39 "knn" | "k_nearest_neighbors" => train_knn(config),
40 "naive_bayes" | "gaussian_nb" | "nb" => train_naive_bayes(config),
41 "dbscan" => train_dbscan(config),
42 "ridge" => train_ridge(config),
43 "gradient_boosting" | "gbt" | "gbm" | "xgboost" => train_gradient_boosting(config),
44 _ => Err(format!(
45 "Unknown training algorithm: '{algorithm}'. Supported: linear, ridge, logistic, \
46 tree, random_forest, gradient_boosting, knn, naive_bayes, kmeans, dbscan"
47 )),
48 }
49}
50
51fn apply_rowwise<P: Fn(&[f64]) -> f64>(
56 input: &TlTensor,
57 predict_row: P,
58) -> Result<TlTensor, String> {
59 let shape = input.shape();
60 let flat = input.to_vec();
61 if shape.len() == 1 {
62 Ok(TlTensor::from_list(vec![predict_row(&flat)]))
63 } else if shape.len() == 2 {
64 let (rows, cols) = (shape[0], shape[1]);
65 let mut preds = Vec::with_capacity(rows);
66 for i in 0..rows {
67 preds.push(predict_row(&flat[i * cols..(i + 1) * cols]));
68 }
69 Ok(TlTensor::from_list(preds))
70 } else {
71 Err(format!("Input must be 1D or 2D, got {}D", shape.len()))
72 }
73}
74
75fn tree_node_to_json(node: &linfa_trees::TreeNode<f64, usize>) -> serde_json::Value {
78 if node.is_leaf() {
79 serde_json::json!({ "leaf": true, "value": node.prediction().unwrap_or(0) })
80 } else {
81 let (feature, threshold, _) = node.split();
82 let children = node.children(); let left = children[0]
84 .as_ref()
85 .map(|c| tree_node_to_json(c))
86 .unwrap_or(serde_json::Value::Null);
87 let right = children[1]
88 .as_ref()
89 .map(|c| tree_node_to_json(c))
90 .unwrap_or(serde_json::Value::Null);
91 serde_json::json!({ "leaf": false, "feature": feature, "threshold": threshold, "left": left, "right": right })
92 }
93}
94
95fn predict_tree_json(node: &serde_json::Value, row: &[f64]) -> f64 {
98 if node["leaf"].as_bool().unwrap_or(true) {
99 return node["value"].as_f64().unwrap_or(0.0);
100 }
101 let f = node["feature"].as_u64().unwrap_or(0) as usize;
102 let thr = node["threshold"].as_f64().unwrap_or(0.0);
103 let xv = row.get(f).copied().unwrap_or(0.0);
104 if xv < thr {
105 predict_tree_json(&node["left"], row)
106 } else {
107 predict_tree_json(&node["right"], row)
108 }
109}
110
111fn vote_trees(trees: &[serde_json::Value], row: &[f64]) -> f64 {
114 let mut counts: HashMap<i64, usize> = HashMap::new();
115 for t in trees {
116 *counts.entry(predict_tree_json(t, row) as i64).or_insert(0) += 1;
117 }
118 counts
119 .into_iter()
120 .max_by_key(|(_, c)| *c)
121 .map(|(v, _)| v as f64)
122 .unwrap_or(0.0)
123}
124
125fn features_to_array2(features: &TlTensor) -> Result<Array2<f64>, String> {
126 let shape = features.shape();
127 if shape.len() != 2 {
128 return Err(format!("Features must be 2D, got {}D", shape.len()));
129 }
130 let rows = shape[0];
131 let cols = shape[1];
132 let flat = features.to_vec();
133 Array2::from_shape_vec((rows, cols), flat).map_err(|e| format!("Shape error: {e}"))
134}
135
136fn target_to_array1(target: &TlTensor) -> Result<Array1<f64>, String> {
137 let shape = target.shape();
138 if shape.len() != 1 {
139 return Err(format!("Target must be 1D, got {}D", shape.len()));
140 }
141 Ok(Array1::from_vec(target.to_vec()))
142}
143
144fn train_linear(config: &TrainConfig) -> Result<TlModel, String> {
145 let x = features_to_array2(&config.features)?;
146 let y = target_to_array1(&config.target)?;
147 let dataset = Dataset::new(x, y);
148
149 let model = linfa_linear::LinearRegression::default()
150 .fit(&dataset)
151 .map_err(|e| format!("Linear regression training failed: {e}"))?;
152
153 let pred = model.predict(&dataset);
155 let r2 = pred
156 .r2(&dataset)
157 .map_err(|e| format!("R² computation failed: {e}"))?;
158
159 let params = model.params();
161 let intercept = model.intercept();
162 let model_data = serde_json::json!({
163 "params": params.as_slice().unwrap_or(&[]),
164 "intercept": intercept,
165 });
166 let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
167
168 let mut metrics = HashMap::new();
169 metrics.insert("r2".to_string(), r2);
170
171 Ok(TlModel::Linfa {
172 kind: LinfaKind::LinearRegression,
173 data,
174 metadata: ModelMeta {
175 name: config.model_name.clone(),
176 version: "0.1.0".to_string(),
177 created_at: String::new(),
178 features: config.feature_names.clone(),
179 target: config.target_name.clone(),
180 metrics,
181 },
182 })
183}
184
185fn train_logistic(config: &TrainConfig) -> Result<TlModel, String> {
186 let x = features_to_array2(&config.features)?;
187 let y_float = target_to_array1(&config.target)?;
188
189 let y_bool: Array1<bool> = y_float.mapv(|v| v > 0.5);
191
192 let dataset = Dataset::new(x, y_bool);
193
194 let model = linfa_logistic::LogisticRegression::default()
195 .max_iterations(100)
196 .fit(&dataset)
197 .map_err(|e| format!("Logistic regression training failed: {e}"))?;
198
199 let pred = model.predict(&dataset);
201 let correct = pred
202 .iter()
203 .zip(dataset.targets().iter())
204 .filter(|(p, t)| p == t)
205 .count();
206 let accuracy = correct as f64 / dataset.targets().len() as f64;
207
208 let params = model.params();
210 let intercept = model.intercept();
211 let params_slice = params.as_slice().unwrap_or(&[]);
212
213 let (mut pos_label, mut neg_label) = (1.0_f64, 0.0_f64);
218 {
219 let records = dataset.records();
220 for (i, p) in pred.iter().enumerate() {
221 let row = records.row(i);
222 let logit: f64 = row
223 .iter()
224 .zip(params_slice.iter())
225 .map(|(a, b)| a * b)
226 .sum::<f64>()
227 + intercept;
228 let label = if *p { 1.0 } else { 0.0 };
229 if logit > 0.0 {
230 pos_label = label;
231 } else {
232 neg_label = label;
233 }
234 }
235 }
236
237 let model_data = serde_json::json!({
238 "params": params_slice,
239 "intercept": intercept,
240 "pos_label": pos_label,
241 "neg_label": neg_label,
242 });
243 let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
244
245 let mut metrics = HashMap::new();
246 metrics.insert("accuracy".to_string(), accuracy);
247
248 Ok(TlModel::Linfa {
249 kind: LinfaKind::LogisticRegression,
250 data,
251 metadata: ModelMeta {
252 name: config.model_name.clone(),
253 version: "0.1.0".to_string(),
254 created_at: String::new(),
255 features: config.feature_names.clone(),
256 target: config.target_name.clone(),
257 metrics,
258 },
259 })
260}
261
262fn train_decision_tree(config: &TrainConfig) -> Result<TlModel, String> {
263 let x = features_to_array2(&config.features)?;
264 let y_float = target_to_array1(&config.target)?;
265
266 let y_usize: Array1<usize> = y_float.mapv(|v| v as usize);
268
269 let max_depth = config
270 .hyperparams
271 .get("max_depth")
272 .copied()
273 .map(|d| d as usize);
274
275 let dataset = Dataset::new(x, y_usize);
276
277 let mut builder = linfa_trees::DecisionTree::params();
278 if let Some(depth) = max_depth {
279 builder = builder.max_depth(Some(depth));
280 }
281 let model = builder
282 .fit(&dataset)
283 .map_err(|e| format!("Decision tree training failed: {e}"))?;
284
285 let pred = model.predict(&dataset);
287 let correct = pred
288 .iter()
289 .zip(dataset.targets().iter())
290 .filter(|(p, t)| p == t)
291 .count();
292 let accuracy = correct as f64 / dataset.targets().len() as f64;
293
294 let model_data = serde_json::json!({
296 "type": "decision_tree",
297 "accuracy": accuracy,
298 "tree": tree_node_to_json(model.root_node()),
299 });
300 let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
301
302 let mut metrics = HashMap::new();
303 metrics.insert("accuracy".to_string(), accuracy);
304
305 Ok(TlModel::Linfa {
306 kind: LinfaKind::DecisionTree,
307 data,
308 metadata: ModelMeta {
309 name: config.model_name.clone(),
310 version: "0.1.0".to_string(),
311 created_at: String::new(),
312 features: config.feature_names.clone(),
313 target: config.target_name.clone(),
314 metrics,
315 },
316 })
317}
318
319fn train_random_forest(config: &TrainConfig) -> Result<TlModel, String> {
323 let x = features_to_array2(&config.features)?;
324 let y_float = target_to_array1(&config.target)?;
325 let y_usize: Array1<usize> = y_float.mapv(|v| v as usize);
326
327 let n = x.nrows();
328 if n == 0 {
329 return Err("Random forest: no training samples".to_string());
330 }
331 let n_trees = config
332 .hyperparams
333 .get("n_trees")
334 .or_else(|| config.hyperparams.get("trees"))
335 .copied()
336 .map(|v| (v as usize).max(1))
337 .unwrap_or(10);
338 let max_depth = config
339 .hyperparams
340 .get("max_depth")
341 .copied()
342 .map(|d| d as usize);
343
344 let mut seed: u64 = 0x2545F4914F6CDD1D;
346 let mut next = || {
347 seed ^= seed << 13;
348 seed ^= seed >> 7;
349 seed ^= seed << 17;
350 seed
351 };
352
353 let mut trees: Vec<serde_json::Value> = Vec::with_capacity(n_trees);
354 for _ in 0..n_trees {
355 let rows: Vec<usize> = (0..n).map(|_| (next() as usize) % n).collect();
356 let xb = x.select(Axis(0), &rows);
357 let yb = y_usize.select(Axis(0), &rows);
358 let ds = Dataset::new(xb, yb);
359 let mut builder = linfa_trees::DecisionTree::params();
360 if let Some(d) = max_depth {
361 builder = builder.max_depth(Some(d));
362 }
363 let tree = builder
364 .fit(&ds)
365 .map_err(|e| format!("Random forest tree training failed: {e}"))?;
366 trees.push(tree_node_to_json(tree.root_node()));
367 }
368
369 let flat = x.iter().copied().collect::<Vec<f64>>();
371 let cols = x.ncols();
372 let mut correct = 0usize;
373 for i in 0..n {
374 let row = &flat[i * cols..(i + 1) * cols];
375 if vote_trees(&trees, row) as usize == y_usize[i] {
376 correct += 1;
377 }
378 }
379 let accuracy = correct as f64 / n as f64;
380
381 let model_data = serde_json::json!({ "type": "random_forest", "trees": trees });
382 let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
383
384 let mut metrics = HashMap::new();
385 metrics.insert("accuracy".to_string(), accuracy);
386 metrics.insert("n_trees".to_string(), n_trees as f64);
387
388 Ok(TlModel::Linfa {
389 kind: LinfaKind::RandomForest,
390 data,
391 metadata: ModelMeta {
392 name: config.model_name.clone(),
393 version: "0.1.0".to_string(),
394 created_at: String::new(),
395 features: config.feature_names.clone(),
396 target: config.target_name.clone(),
397 metrics,
398 },
399 })
400}
401
402fn train_kmeans(config: &TrainConfig) -> Result<TlModel, String> {
406 let x = features_to_array2(&config.features)?;
407 let n = x.nrows();
408 let d = x.ncols();
409 if n == 0 {
410 return Err("K-means: no training samples".to_string());
411 }
412 let k = config
413 .hyperparams
414 .get("k")
415 .or_else(|| config.hyperparams.get("clusters"))
416 .copied()
417 .map(|v| (v as usize).max(1))
418 .unwrap_or(3)
419 .min(n);
420 let max_iter = config
421 .hyperparams
422 .get("max_iter")
423 .copied()
424 .map(|v| (v as usize).max(1))
425 .unwrap_or(100);
426
427 let mut centroids: Vec<Vec<f64>> = (0..k).map(|i| x.row((i * n) / k).to_vec()).collect();
429 let mut assign = vec![0usize; n];
430
431 for _ in 0..max_iter {
432 let mut changed = false;
433 for (i, slot) in assign.iter_mut().enumerate() {
434 let row = x.row(i);
435 let mut best = 0usize;
436 let mut best_d = f64::INFINITY;
437 for (c, cen) in centroids.iter().enumerate() {
438 let dist: f64 = row.iter().zip(cen).map(|(a, b)| (a - b) * (a - b)).sum();
439 if dist < best_d {
440 best_d = dist;
441 best = c;
442 }
443 }
444 if *slot != best {
445 *slot = best;
446 changed = true;
447 }
448 }
449 let mut sums = vec![vec![0.0f64; d]; k];
450 let mut counts = vec![0usize; k];
451 for i in 0..n {
452 let row = x.row(i);
453 counts[assign[i]] += 1;
454 for j in 0..d {
455 sums[assign[i]][j] += row[j];
456 }
457 }
458 for c in 0..k {
459 if counts[c] > 0 {
460 for j in 0..d {
461 centroids[c][j] = sums[c][j] / counts[c] as f64;
462 }
463 }
464 }
465 if !changed {
466 break;
467 }
468 }
469
470 let mut inertia = 0.0f64;
472 for i in 0..n {
473 let row = x.row(i);
474 let cen = ¢roids[assign[i]];
475 inertia += row
476 .iter()
477 .zip(cen)
478 .map(|(a, b)| (a - b) * (a - b))
479 .sum::<f64>();
480 }
481
482 let model_data = serde_json::json!({ "type": "kmeans", "centroids": centroids });
483 let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
484
485 let mut metrics = HashMap::new();
486 metrics.insert("k".to_string(), k as f64);
487 metrics.insert("inertia".to_string(), inertia);
488
489 Ok(TlModel::Linfa {
490 kind: LinfaKind::KMeans,
491 data,
492 metadata: ModelMeta {
493 name: config.model_name.clone(),
494 version: "0.1.0".to_string(),
495 created_at: String::new(),
496 features: config.feature_names.clone(),
497 target: config.target_name.clone(),
498 metrics,
499 },
500 })
501}
502
503fn sq_dist(a: &[f64], b: &[f64]) -> f64 {
505 a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum()
506}
507
508fn knn_vote(xtrain: &[Vec<f64>], ytrain: &[f64], k: usize, row: &[f64]) -> f64 {
510 let mut dists: Vec<(f64, f64)> = xtrain
511 .iter()
512 .zip(ytrain)
513 .map(|(p, &l)| (sq_dist(p, row), l))
514 .collect();
515 dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
516 let mut counts: HashMap<i64, usize> = HashMap::new();
517 for (_, l) in dists.iter().take(k.min(dists.len())) {
518 *counts.entry(*l as i64).or_insert(0) += 1;
519 }
520 counts
521 .into_iter()
522 .max_by_key(|(_, c)| *c)
523 .map(|(v, _)| v as f64)
524 .unwrap_or(0.0)
525}
526
527fn solve_linear_system(mut a: Vec<Vec<f64>>, mut b: Vec<f64>) -> Option<Vec<f64>> {
530 let n = b.len();
531 for col in 0..n {
532 let mut piv = col;
533 for r in (col + 1)..n {
534 if a[r][col].abs() > a[piv][col].abs() {
535 piv = r;
536 }
537 }
538 if a[piv][col].abs() < 1e-12 {
539 return None;
540 }
541 a.swap(col, piv);
542 b.swap(col, piv);
543 let d = a[col][col];
544 for v in a[col].iter_mut() {
545 *v /= d;
546 }
547 b[col] /= d;
548 let pivot_row = a[col].clone();
549 let pivot_b = b[col];
550 for r in 0..n {
551 if r != col {
552 let f = a[r][col];
553 if f != 0.0 {
554 for (v, p) in a[r].iter_mut().zip(&pivot_row) {
555 *v -= f * p;
556 }
557 b[r] -= f * pivot_b;
558 }
559 }
560 }
561 }
562 Some(b)
563}
564
565fn train_knn(config: &TrainConfig) -> Result<TlModel, String> {
568 let x = features_to_array2(&config.features)?;
569 let y = target_to_array1(&config.target)?;
570 let k = config
571 .hyperparams
572 .get("k")
573 .or_else(|| config.hyperparams.get("neighbors"))
574 .copied()
575 .map(|v| (v as usize).max(1))
576 .unwrap_or(5);
577 let rows: Vec<Vec<f64>> = (0..x.nrows()).map(|i| x.row(i).to_vec()).collect();
578 let labels: Vec<f64> = y.to_vec();
579
580 let mut correct = 0usize;
581 for i in 0..rows.len() {
582 if knn_vote(&rows, &labels, k, &rows[i]) == labels[i] {
583 correct += 1;
584 }
585 }
586 let accuracy = if rows.is_empty() {
587 0.0
588 } else {
589 correct as f64 / rows.len() as f64
590 };
591
592 let model_data = serde_json::json!({ "type": "knn", "k": k, "x": rows, "y": labels });
593 let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
594 let mut metrics = HashMap::new();
595 metrics.insert("accuracy".to_string(), accuracy);
596 metrics.insert("k".to_string(), k as f64);
597 Ok(linfa_model(LinfaKind::Knn, data, config, metrics))
598}
599
600fn train_naive_bayes(config: &TrainConfig) -> Result<TlModel, String> {
603 let x = features_to_array2(&config.features)?;
604 let y = target_to_array1(&config.target)?;
605 let n = x.nrows();
606 let d = x.ncols();
607 if n == 0 {
608 return Err("Naive Bayes: no training samples".to_string());
609 }
610 let mut by_class: HashMap<i64, Vec<usize>> = HashMap::new();
612 for i in 0..n {
613 by_class.entry(y[i] as i64).or_default().push(i);
614 }
615 let mut classes: Vec<serde_json::Value> = Vec::new();
616 for (label, idxs) in &by_class {
617 let cnt = idxs.len();
618 let mut means = vec![0.0f64; d];
619 for &i in idxs {
620 let row = x.row(i);
621 for j in 0..d {
622 means[j] += row[j];
623 }
624 }
625 for m in &mut means {
626 *m /= cnt as f64;
627 }
628 let mut vars = vec![0.0f64; d];
629 for &i in idxs {
630 let row = x.row(i);
631 for j in 0..d {
632 vars[j] += (row[j] - means[j]).powi(2);
633 }
634 }
635 for v in &mut vars {
636 *v = (*v / cnt as f64).max(1e-9); }
638 classes.push(serde_json::json!({
639 "label": *label as f64,
640 "prior": (cnt as f64 / n as f64).ln(),
641 "means": means,
642 "vars": vars,
643 }));
644 }
645
646 let nb = NaiveBayesModel::from_json(&classes);
648 let correct = (0..n)
649 .filter(|&i| nb.predict(&x.row(i).to_vec()) == y[i].round())
650 .count();
651 let accuracy = correct as f64 / n as f64;
652
653 let model_data = serde_json::json!({ "type": "naive_bayes", "classes": classes });
654 let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
655 let mut metrics = HashMap::new();
656 metrics.insert("accuracy".to_string(), accuracy);
657 metrics.insert("classes".to_string(), by_class.len() as f64);
658 Ok(linfa_model(LinfaKind::NaiveBayes, data, config, metrics))
659}
660
661struct NaiveBayesModel {
663 classes: Vec<(f64, f64, Vec<f64>, Vec<f64>)>, }
665impl NaiveBayesModel {
666 fn from_json(classes: &[serde_json::Value]) -> Self {
667 let classes = classes
668 .iter()
669 .map(|c| {
670 let label = c["label"].as_f64().unwrap_or(0.0);
671 let prior = c["prior"].as_f64().unwrap_or(0.0);
672 let means: Vec<f64> =
673 serde_json::from_value(c["means"].clone()).unwrap_or_default();
674 let vars: Vec<f64> = serde_json::from_value(c["vars"].clone()).unwrap_or_default();
675 (label, prior, means, vars)
676 })
677 .collect();
678 Self { classes }
679 }
680 fn predict(&self, row: &[f64]) -> f64 {
681 let mut best_label = 0.0;
682 let mut best_score = f64::NEG_INFINITY;
683 for (label, log_prior, means, vars) in &self.classes {
684 let mut score = *log_prior;
685 for j in 0..row.len().min(means.len()) {
686 let v = vars[j].max(1e-9);
687 score += -0.5
688 * ((row[j] - means[j]).powi(2) / v + (2.0 * std::f64::consts::PI * v).ln());
689 }
690 if score > best_score {
691 best_score = score;
692 best_label = *label;
693 }
694 }
695 best_label
696 }
697}
698
699fn train_dbscan(config: &TrainConfig) -> Result<TlModel, String> {
704 let x = features_to_array2(&config.features)?;
705 let n = x.nrows();
706 if n == 0 {
707 return Err("DBSCAN: no training samples".to_string());
708 }
709 let pts: Vec<Vec<f64>> = (0..n).map(|i| x.row(i).to_vec()).collect();
710 let eps = config.hyperparams.get("eps").copied().unwrap_or(0.5);
711 let min_samples = config
712 .hyperparams
713 .get("min_samples")
714 .or_else(|| config.hyperparams.get("min_points"))
715 .copied()
716 .map(|v| (v as usize).max(1))
717 .unwrap_or(3);
718 let eps2 = eps * eps;
719 let neighbors = |i: usize| -> Vec<usize> {
720 (0..n)
721 .filter(|&j| sq_dist(&pts[i], &pts[j]) <= eps2)
722 .collect()
723 };
724
725 let mut labels = vec![-1i64; n];
726 let mut visited = vec![false; n];
727 let mut cid = 0i64;
728 for i in 0..n {
729 if visited[i] {
730 continue;
731 }
732 visited[i] = true;
733 let nb = neighbors(i);
734 if nb.len() < min_samples {
735 continue; }
737 labels[i] = cid;
738 let mut queue = nb;
739 let mut qi = 0;
740 while qi < queue.len() {
741 let q = queue[qi];
742 qi += 1;
743 if labels[q] < 0 {
744 labels[q] = cid;
745 }
746 if !visited[q] {
747 visited[q] = true;
748 let qnb = neighbors(q);
749 if qnb.len() >= min_samples {
750 for m in qnb {
751 if !queue.contains(&m) {
752 queue.push(m);
753 }
754 }
755 }
756 }
757 }
758 cid += 1;
759 }
760
761 let mut cores: Vec<serde_json::Value> = Vec::new();
762 let mut n_noise = 0usize;
763 for i in 0..n {
764 if labels[i] < 0 {
765 n_noise += 1;
766 } else if neighbors(i).len() >= min_samples {
767 cores.push(serde_json::json!({ "p": pts[i], "c": labels[i] as f64 }));
768 }
769 }
770
771 let model_data = serde_json::json!({ "type": "dbscan", "eps": eps, "cores": cores });
772 let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
773 let mut metrics = HashMap::new();
774 metrics.insert("clusters".to_string(), cid as f64);
775 metrics.insert("noise".to_string(), n_noise as f64);
776 Ok(linfa_model(LinfaKind::Dbscan, data, config, metrics))
777}
778
779fn train_ridge(config: &TrainConfig) -> Result<TlModel, String> {
783 let x = features_to_array2(&config.features)?;
784 let y = target_to_array1(&config.target)?;
785 let n = x.nrows();
786 let d = x.ncols();
787 if n == 0 {
788 return Err("Ridge: no training samples".to_string());
789 }
790 let lambda = config
791 .hyperparams
792 .get("alpha")
793 .or_else(|| config.hyperparams.get("lambda"))
794 .copied()
795 .unwrap_or(1.0);
796
797 let p = d + 1; let row_aug = |i: usize| -> Vec<f64> {
799 let mut r = x.row(i).to_vec();
800 r.push(1.0);
801 r
802 };
803 let mut a = vec![vec![0.0f64; p]; p];
804 let mut bvec = vec![0.0f64; p];
805 for i in 0..n {
806 let r = row_aug(i);
807 let yi = y[i];
808 for j in 0..p {
809 for k2 in 0..p {
810 a[j][k2] += r[j] * r[k2];
811 }
812 bvec[j] += r[j] * yi;
813 }
814 }
815 for (j, row) in a.iter_mut().enumerate().take(d) {
817 row[j] += lambda;
818 }
819 let w = solve_linear_system(a, bvec)
820 .ok_or("Ridge: singular system — try a larger alpha or fewer collinear features")?;
821 let coef: Vec<f64> = w[0..d].to_vec();
822 let intercept = w[d];
823
824 let mean_y = y.iter().sum::<f64>() / n as f64;
826 let (mut ss_res, mut ss_tot) = (0.0, 0.0);
827 for i in 0..n {
828 let row = x.row(i);
829 let pred: f64 = row.iter().zip(&coef).map(|(a, b)| a * b).sum::<f64>() + intercept;
830 ss_res += (y[i] - pred).powi(2);
831 ss_tot += (y[i] - mean_y).powi(2);
832 }
833 let r2 = if ss_tot > 0.0 {
834 1.0 - ss_res / ss_tot
835 } else {
836 0.0
837 };
838
839 let model_data = serde_json::json!({ "type": "ridge", "params": coef, "intercept": intercept });
840 let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
841 let mut metrics = HashMap::new();
842 metrics.insert("r2".to_string(), r2);
843 Ok(linfa_model(LinfaKind::Ridge, data, config, metrics))
844}
845
846fn build_reg_tree(
851 idx: &[usize],
852 x: &Array2<f64>,
853 r: &[f64],
854 w: &[f64],
855 depth: usize,
856 max_depth: usize,
857 min_leaf: usize,
858) -> serde_json::Value {
859 let (mut sw, mut swr, mut swr2) = (0.0f64, 0.0f64, 0.0f64);
860 for &i in idx {
861 sw += w[i];
862 swr += w[i] * r[i];
863 swr2 += w[i] * r[i] * r[i];
864 }
865 let leaf_val = if sw > 0.0 { swr / sw } else { 0.0 };
866 let leaf = serde_json::json!({ "leaf": true, "value": leaf_val });
867 if depth >= max_depth || idx.len() <= min_leaf.max(1) || sw <= 0.0 {
868 return leaf;
869 }
870 let parent_sse = swr2 - swr * swr / sw;
871
872 let d = x.ncols();
873 let mut best: Option<(usize, f64, f64)> = None; for f in 0..d {
875 let mut order: Vec<usize> = idx.to_vec();
876 order.sort_by(|&a, &b| {
877 x[[a, f]]
878 .partial_cmp(&x[[b, f]])
879 .unwrap_or(std::cmp::Ordering::Equal)
880 });
881 let (mut lw, mut lwr, mut lwr2) = (0.0f64, 0.0f64, 0.0f64);
882 for k in 0..order.len() - 1 {
883 let i = order[k];
884 lw += w[i];
885 lwr += w[i] * r[i];
886 lwr2 += w[i] * r[i] * r[i];
887 let (xi, xnext) = (x[[order[k], f]], x[[order[k + 1], f]]);
888 if xi == xnext {
889 continue;
890 }
891 let rw = sw - lw;
892 if lw <= 0.0 || rw <= 0.0 {
893 continue;
894 }
895 let sse_l = lwr2 - lwr * lwr / lw;
896 let sse_r = (swr2 - lwr2) - (swr - lwr) * (swr - lwr) / rw;
897 let sse = sse_l + sse_r;
898 if best.is_none_or(|(_, _, bs)| sse < bs) {
899 best = Some((f, (xi + xnext) / 2.0, sse));
900 }
901 }
902 }
903
904 match best {
905 Some((f, thr, sse)) if sse < parent_sse - 1e-12 => {
906 let left: Vec<usize> = idx.iter().copied().filter(|&i| x[[i, f]] < thr).collect();
907 let right: Vec<usize> = idx.iter().copied().filter(|&i| x[[i, f]] >= thr).collect();
908 if left.is_empty() || right.is_empty() {
909 return leaf;
910 }
911 serde_json::json!({
912 "leaf": false, "feature": f, "threshold": thr,
913 "left": build_reg_tree(&left, x, r, w, depth + 1, max_depth, min_leaf),
914 "right": build_reg_tree(&right, x, r, w, depth + 1, max_depth, min_leaf),
915 })
916 }
917 _ => leaf,
918 }
919}
920
921fn train_gradient_boosting(config: &TrainConfig) -> Result<TlModel, String> {
928 let x = features_to_array2(&config.features)?;
929 let y = target_to_array1(&config.target)?;
930 let n = x.nrows();
931 if n == 0 {
932 return Err("Gradient boosting: no training samples".to_string());
933 }
934 let hp_usize = |a: &str, b: &str, def: usize| -> usize {
935 config
936 .hyperparams
937 .get(a)
938 .or_else(|| config.hyperparams.get(b))
939 .copied()
940 .map(|v| (v as usize).max(1))
941 .unwrap_or(def)
942 };
943 let n_est = hp_usize("n_estimators", "trees", 100);
944 let max_depth = hp_usize("max_depth", "depth", 3);
945 let min_leaf = hp_usize("min_leaf", "min_samples_leaf", 1);
946 let lr = config
947 .hyperparams
948 .get("learning_rate")
949 .or_else(|| config.hyperparams.get("eta"))
950 .copied()
951 .unwrap_or(0.1);
952
953 let all01 = y.iter().all(|v| *v == 0.0 || *v == 1.0);
954 let distinct: HashSet<i64> = y.iter().map(|v| *v as i64).collect();
955 let binary = match config.hyperparams.get("objective") {
956 Some(o) => *o > 0.5,
957 None => all01 && distinct.len() <= 2,
958 };
959 if binary && !all01 {
960 return Err("Gradient boosting (binary objective) requires 0/1 targets".to_string());
961 }
962
963 let init = if binary {
965 let pos = y.iter().filter(|&&v| v == 1.0).count() as f64;
966 let p = (pos / n as f64).clamp(1e-6, 1.0 - 1e-6);
967 (p / (1.0 - p)).ln()
968 } else {
969 y.iter().sum::<f64>() / n as f64
970 };
971
972 let mut f_scores = vec![init; n];
973 let all_idx: Vec<usize> = (0..n).collect();
974 let mut trees: Vec<serde_json::Value> = Vec::with_capacity(n_est);
975
976 for _ in 0..n_est {
977 let mut r = vec![0.0f64; n];
980 let mut w = vec![0.0f64; n];
981 for i in 0..n {
982 let (g, h) = if binary {
983 let p = 1.0 / (1.0 + (-f_scores[i]).exp());
984 (p - y[i], (p * (1.0 - p)).max(1e-6))
985 } else {
986 (f_scores[i] - y[i], 1.0)
987 };
988 r[i] = -g / h;
989 w[i] = h;
990 }
991 let tree = build_reg_tree(&all_idx, &x, &r, &w, 0, max_depth, min_leaf);
992 for (i, fs) in f_scores.iter_mut().enumerate() {
993 *fs += lr * predict_tree_json(&tree, &x.row(i).to_vec());
994 }
995 trees.push(tree);
996 }
997
998 let mut metrics = HashMap::new();
1000 if binary {
1001 let correct = (0..n)
1002 .filter(|&i| ((1.0 / (1.0 + (-f_scores[i]).exp()) > 0.5) as i32 as f64) == y[i])
1003 .count();
1004 metrics.insert("accuracy".to_string(), correct as f64 / n as f64);
1005 } else {
1006 let mean_y = y.iter().sum::<f64>() / n as f64;
1007 let (mut ss_res, mut ss_tot) = (0.0, 0.0);
1008 for i in 0..n {
1009 ss_res += (y[i] - f_scores[i]).powi(2);
1010 ss_tot += (y[i] - mean_y).powi(2);
1011 }
1012 metrics.insert(
1013 "r2".to_string(),
1014 if ss_tot > 0.0 {
1015 1.0 - ss_res / ss_tot
1016 } else {
1017 0.0
1018 },
1019 );
1020 }
1021 metrics.insert("n_estimators".to_string(), n_est as f64);
1022
1023 let model_data = serde_json::json!({
1024 "type": "gradient_boosting", "binary": binary, "init": init, "lr": lr, "trees": trees,
1025 });
1026 let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
1027 Ok(linfa_model(
1028 LinfaKind::GradientBoosting,
1029 data,
1030 config,
1031 metrics,
1032 ))
1033}
1034
1035fn linfa_model(
1037 kind: LinfaKind,
1038 data: Vec<u8>,
1039 config: &TrainConfig,
1040 metrics: HashMap<String, f64>,
1041) -> TlModel {
1042 TlModel::Linfa {
1043 kind,
1044 data,
1045 metadata: ModelMeta {
1046 name: config.model_name.clone(),
1047 version: "0.1.0".to_string(),
1048 created_at: String::new(),
1049 features: config.feature_names.clone(),
1050 target: config.target_name.clone(),
1051 metrics,
1052 },
1053 }
1054}
1055
1056pub fn predict_linfa(model: &TlModel, input: &TlTensor) -> Result<TlTensor, String> {
1058 match model {
1059 TlModel::Linfa { kind, data, .. } => match kind {
1060 LinfaKind::LinearRegression | LinfaKind::Ridge => {
1061 let model_data: serde_json::Value = serde_json::from_slice(data)
1062 .map_err(|e| format!("Deserialization failed: {e}"))?;
1063 let params: Vec<f64> = model_data["params"]
1064 .as_array()
1065 .ok_or("Missing params")?
1066 .iter()
1067 .map(|v| v.as_f64().unwrap_or(0.0))
1068 .collect();
1069 let intercept: f64 = model_data["intercept"].as_f64().unwrap_or(0.0);
1070
1071 let shape = input.shape();
1072 if shape.len() == 1 {
1073 let x = input.to_vec();
1074 let pred: f64 =
1075 x.iter().zip(params.iter()).map(|(a, b)| a * b).sum::<f64>() + intercept;
1076 Ok(TlTensor::from_list(vec![pred]))
1077 } else if shape.len() == 2 {
1078 let rows = shape[0];
1079 let cols = shape[1];
1080 let flat = input.to_vec();
1081 let mut preds = Vec::with_capacity(rows);
1082 for i in 0..rows {
1083 let row = &flat[i * cols..(i + 1) * cols];
1084 let pred: f64 = row
1085 .iter()
1086 .zip(params.iter())
1087 .map(|(a, b)| a * b)
1088 .sum::<f64>()
1089 + intercept;
1090 preds.push(pred);
1091 }
1092 Ok(TlTensor::from_list(preds))
1093 } else {
1094 Err(format!("Input must be 1D or 2D, got {}D", shape.len()))
1095 }
1096 }
1097 LinfaKind::LogisticRegression => {
1098 let model_data: serde_json::Value = serde_json::from_slice(data)
1099 .map_err(|e| format!("Deserialization failed: {e}"))?;
1100 let params: Vec<f64> = model_data["params"]
1101 .as_array()
1102 .ok_or("Missing params")?
1103 .iter()
1104 .map(|v| v.as_f64().unwrap_or(0.0))
1105 .collect();
1106 let intercept: f64 = model_data["intercept"].as_f64().unwrap_or(0.0);
1107 let pos_label = model_data["pos_label"].as_f64().unwrap_or(1.0);
1110 let neg_label = model_data["neg_label"].as_f64().unwrap_or(0.0);
1111
1112 apply_rowwise(input, |row| {
1113 let logit: f64 = row
1114 .iter()
1115 .zip(params.iter())
1116 .map(|(a, b)| a * b)
1117 .sum::<f64>()
1118 + intercept;
1119 let prob = 1.0 / (1.0 + (-logit).exp());
1120 if prob > 0.5 { pos_label } else { neg_label }
1121 })
1122 }
1123 LinfaKind::DecisionTree => {
1124 let model_data: serde_json::Value = serde_json::from_slice(data)
1125 .map_err(|e| format!("Deserialization failed: {e}"))?;
1126 let tree = model_data["tree"].clone();
1127 if tree.is_null() {
1128 return Err(
1129 "This decision-tree model was saved without its tree structure; retrain it."
1130 .to_string(),
1131 );
1132 }
1133 apply_rowwise(input, |row| predict_tree_json(&tree, row))
1134 }
1135 LinfaKind::RandomForest => {
1136 let model_data: serde_json::Value = serde_json::from_slice(data)
1137 .map_err(|e| format!("Deserialization failed: {e}"))?;
1138 let trees: Vec<serde_json::Value> = model_data["trees"]
1139 .as_array()
1140 .ok_or("Missing trees")?
1141 .clone();
1142 apply_rowwise(input, |row| vote_trees(&trees, row))
1143 }
1144 LinfaKind::KMeans => {
1145 let model_data: serde_json::Value = serde_json::from_slice(data)
1146 .map_err(|e| format!("Deserialization failed: {e}"))?;
1147 let centroids: Vec<Vec<f64>> =
1148 serde_json::from_value(model_data["centroids"].clone())
1149 .map_err(|e| format!("Missing centroids: {e}"))?;
1150 apply_rowwise(input, |row| {
1151 let mut best = 0usize;
1152 let mut best_d = f64::INFINITY;
1153 for (c, cen) in centroids.iter().enumerate() {
1154 let dist: f64 = row.iter().zip(cen).map(|(a, b)| (a - b) * (a - b)).sum();
1155 if dist < best_d {
1156 best_d = dist;
1157 best = c;
1158 }
1159 }
1160 best as f64
1161 })
1162 }
1163 LinfaKind::Knn => {
1164 let model_data: serde_json::Value = serde_json::from_slice(data)
1165 .map_err(|e| format!("Deserialization failed: {e}"))?;
1166 let k = model_data["k"].as_u64().unwrap_or(5) as usize;
1167 let xtrain: Vec<Vec<f64>> = serde_json::from_value(model_data["x"].clone())
1168 .map_err(|e| format!("Missing training data: {e}"))?;
1169 let ytrain: Vec<f64> = serde_json::from_value(model_data["y"].clone())
1170 .map_err(|e| format!("Missing labels: {e}"))?;
1171 apply_rowwise(input, |row| knn_vote(&xtrain, &ytrain, k, row))
1172 }
1173 LinfaKind::NaiveBayes => {
1174 let model_data: serde_json::Value = serde_json::from_slice(data)
1175 .map_err(|e| format!("Deserialization failed: {e}"))?;
1176 let classes = model_data["classes"]
1177 .as_array()
1178 .ok_or("Missing classes")?
1179 .clone();
1180 let nb = NaiveBayesModel::from_json(&classes);
1181 apply_rowwise(input, |row| nb.predict(row))
1182 }
1183 LinfaKind::Dbscan => {
1184 let model_data: serde_json::Value = serde_json::from_slice(data)
1185 .map_err(|e| format!("Deserialization failed: {e}"))?;
1186 let eps = model_data["eps"].as_f64().unwrap_or(0.5);
1187 let eps2 = eps * eps;
1188 let cores: Vec<(Vec<f64>, f64)> = model_data["cores"]
1189 .as_array()
1190 .ok_or("Missing cores")?
1191 .iter()
1192 .map(|c| {
1193 let p: Vec<f64> =
1194 serde_json::from_value(c["p"].clone()).unwrap_or_default();
1195 (p, c["c"].as_f64().unwrap_or(-1.0))
1196 })
1197 .collect();
1198 apply_rowwise(input, |row| {
1199 let mut best = -1.0;
1200 let mut best_d = f64::INFINITY;
1201 for (p, c) in &cores {
1202 let dist = sq_dist(p, row);
1203 if dist <= eps2 && dist < best_d {
1204 best_d = dist;
1205 best = *c;
1206 }
1207 }
1208 best
1209 })
1210 }
1211 LinfaKind::GradientBoosting => {
1212 let model_data: serde_json::Value = serde_json::from_slice(data)
1213 .map_err(|e| format!("Deserialization failed: {e}"))?;
1214 let binary = model_data["binary"].as_bool().unwrap_or(false);
1215 let init = model_data["init"].as_f64().unwrap_or(0.0);
1216 let lr = model_data["lr"].as_f64().unwrap_or(0.1);
1217 let trees: Vec<serde_json::Value> = model_data["trees"]
1218 .as_array()
1219 .ok_or("Missing trees")?
1220 .clone();
1221 apply_rowwise(input, |row| {
1222 let mut score = init;
1223 for t in &trees {
1224 score += lr * predict_tree_json(t, row);
1225 }
1226 if binary {
1227 if 1.0 / (1.0 + (-score).exp()) > 0.5 {
1228 1.0
1229 } else {
1230 0.0
1231 }
1232 } else {
1233 score
1234 }
1235 })
1236 }
1237 },
1238 _ => Err("predict_linfa called on non-Linfa model".to_string()),
1239 }
1240}
1241
1242#[cfg(test)]
1243mod tests {
1244 use super::*;
1245
1246 #[test]
1247 fn test_train_linear_regression() {
1248 let features = TlTensor::from_vec(
1250 vec![
1251 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 2.0, 1.0, 3.0, 2.0, 3.0,
1252 3.0, 3.0, 4.0, 4.0,
1253 ],
1254 &[10, 2],
1255 )
1256 .unwrap();
1257
1258 let target = TlTensor::from_list(vec![
1259 6.0, 8.0, 10.0, 9.0, 11.0, 13.0, 12.0, 14.0, 16.0, 21.0,
1260 ]);
1261
1262 let config = TrainConfig {
1263 features,
1264 target,
1265 feature_names: vec!["x1".to_string(), "x2".to_string()],
1266 target_name: "y".to_string(),
1267 model_name: "test_linear".to_string(),
1268 split_ratio: 1.0,
1269 hyperparams: HashMap::new(),
1270 };
1271
1272 let model = train("linear", &config).unwrap();
1273 if let TlModel::Linfa { metadata, .. } = &model {
1274 assert!(metadata.metrics["r2"] > 0.9, "R² should be > 0.9");
1275 } else {
1276 panic!("Expected Linfa model");
1277 }
1278 }
1279
1280 #[test]
1281 fn test_predict_linear() {
1282 let features =
1283 TlTensor::from_vec(vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 0.0], &[4, 2]).unwrap();
1284 let target = TlTensor::from_list(vec![2.0, 3.0, 5.0, 4.0]);
1285
1286 let config = TrainConfig {
1287 features,
1288 target,
1289 feature_names: vec!["x1".to_string(), "x2".to_string()],
1290 target_name: "y".to_string(),
1291 model_name: "test".to_string(),
1292 split_ratio: 1.0,
1293 hyperparams: HashMap::new(),
1294 };
1295
1296 let model = train("linear", &config).unwrap();
1297 let input = TlTensor::from_vec(vec![1.0, 0.0], &[1, 2]).unwrap();
1298 let pred = predict_linfa(&model, &input).unwrap();
1299 assert!((pred.to_vec()[0] - 2.0).abs() < 1.0);
1301 }
1302}