use std::collections::HashMap;
use linfa::Dataset;
use linfa::prelude::*;
use ndarray::{Array1, Array2};
use crate::model::{LinfaKind, ModelMeta, TlModel};
use crate::tensor::TlTensor;
pub struct TrainConfig {
pub features: TlTensor,
pub target: TlTensor,
pub feature_names: Vec<String>,
pub target_name: String,
pub model_name: String,
pub split_ratio: f64,
pub hyperparams: HashMap<String, f64>,
}
pub fn train(algorithm: &str, config: &TrainConfig) -> Result<TlModel, String> {
match algorithm {
"linear" => train_linear(config),
"logistic" => train_logistic(config),
"tree" | "decision_tree" => train_decision_tree(config),
_ => Err(format!(
"Unknown training algorithm: '{algorithm}'. Supported: linear, logistic, tree"
)),
}
}
fn features_to_array2(features: &TlTensor) -> Result<Array2<f64>, String> {
let shape = features.shape();
if shape.len() != 2 {
return Err(format!("Features must be 2D, got {}D", shape.len()));
}
let rows = shape[0];
let cols = shape[1];
let flat = features.to_vec();
Array2::from_shape_vec((rows, cols), flat).map_err(|e| format!("Shape error: {e}"))
}
fn target_to_array1(target: &TlTensor) -> Result<Array1<f64>, String> {
let shape = target.shape();
if shape.len() != 1 {
return Err(format!("Target must be 1D, got {}D", shape.len()));
}
Ok(Array1::from_vec(target.to_vec()))
}
fn train_linear(config: &TrainConfig) -> Result<TlModel, String> {
let x = features_to_array2(&config.features)?;
let y = target_to_array1(&config.target)?;
let dataset = Dataset::new(x, y);
let model = linfa_linear::LinearRegression::default()
.fit(&dataset)
.map_err(|e| format!("Linear regression training failed: {e}"))?;
let pred = model.predict(&dataset);
let r2 = pred
.r2(&dataset)
.map_err(|e| format!("R² computation failed: {e}"))?;
let params = model.params();
let intercept = model.intercept();
let model_data = serde_json::json!({
"params": params.as_slice().unwrap_or(&[]),
"intercept": intercept,
});
let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
let mut metrics = HashMap::new();
metrics.insert("r2".to_string(), r2);
Ok(TlModel::Linfa {
kind: LinfaKind::LinearRegression,
data,
metadata: ModelMeta {
name: config.model_name.clone(),
version: "0.1.0".to_string(),
created_at: String::new(),
features: config.feature_names.clone(),
target: config.target_name.clone(),
metrics,
},
})
}
fn train_logistic(config: &TrainConfig) -> Result<TlModel, String> {
let x = features_to_array2(&config.features)?;
let y_float = target_to_array1(&config.target)?;
let y_bool: Array1<bool> = y_float.mapv(|v| v > 0.5);
let dataset = Dataset::new(x, y_bool);
let model = linfa_logistic::LogisticRegression::default()
.max_iterations(100)
.fit(&dataset)
.map_err(|e| format!("Logistic regression training failed: {e}"))?;
let pred = model.predict(&dataset);
let correct = pred
.iter()
.zip(dataset.targets().iter())
.filter(|(p, t)| p == t)
.count();
let accuracy = correct as f64 / dataset.targets().len() as f64;
let params = model.params();
let intercept = model.intercept();
let model_data = serde_json::json!({
"params": params.as_slice().unwrap_or(&[]),
"intercept": intercept,
});
let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
let mut metrics = HashMap::new();
metrics.insert("accuracy".to_string(), accuracy);
Ok(TlModel::Linfa {
kind: LinfaKind::LogisticRegression,
data,
metadata: ModelMeta {
name: config.model_name.clone(),
version: "0.1.0".to_string(),
created_at: String::new(),
features: config.feature_names.clone(),
target: config.target_name.clone(),
metrics,
},
})
}
fn train_decision_tree(config: &TrainConfig) -> Result<TlModel, String> {
let x = features_to_array2(&config.features)?;
let y_float = target_to_array1(&config.target)?;
let y_usize: Array1<usize> = y_float.mapv(|v| v as usize);
let max_depth = config
.hyperparams
.get("max_depth")
.copied()
.map(|d| d as usize);
let dataset = Dataset::new(x, y_usize);
let mut builder = linfa_trees::DecisionTree::params();
if let Some(depth) = max_depth {
builder = builder.max_depth(Some(depth));
}
let model = builder
.fit(&dataset)
.map_err(|e| format!("Decision tree training failed: {e}"))?;
let pred = model.predict(&dataset);
let correct = pred
.iter()
.zip(dataset.targets().iter())
.filter(|(p, t)| p == t)
.count();
let accuracy = correct as f64 / dataset.targets().len() as f64;
let model_data = serde_json::json!({
"type": "decision_tree",
"accuracy": accuracy,
});
let data = serde_json::to_vec(&model_data).map_err(|e| format!("Serialization failed: {e}"))?;
let mut metrics = HashMap::new();
metrics.insert("accuracy".to_string(), accuracy);
Ok(TlModel::Linfa {
kind: LinfaKind::DecisionTree,
data,
metadata: ModelMeta {
name: config.model_name.clone(),
version: "0.1.0".to_string(),
created_at: String::new(),
features: config.feature_names.clone(),
target: config.target_name.clone(),
metrics,
},
})
}
pub fn predict_linfa(model: &TlModel, input: &TlTensor) -> Result<TlTensor, String> {
match model {
TlModel::Linfa { kind, data, .. } => match kind {
LinfaKind::LinearRegression => {
let model_data: serde_json::Value = serde_json::from_slice(data)
.map_err(|e| format!("Deserialization failed: {e}"))?;
let params: Vec<f64> = model_data["params"]
.as_array()
.ok_or("Missing params")?
.iter()
.map(|v| v.as_f64().unwrap_or(0.0))
.collect();
let intercept: f64 = model_data["intercept"].as_f64().unwrap_or(0.0);
let shape = input.shape();
if shape.len() == 1 {
let x = input.to_vec();
let pred: f64 =
x.iter().zip(params.iter()).map(|(a, b)| a * b).sum::<f64>() + intercept;
Ok(TlTensor::from_list(vec![pred]))
} else if shape.len() == 2 {
let rows = shape[0];
let cols = shape[1];
let flat = input.to_vec();
let mut preds = Vec::with_capacity(rows);
for i in 0..rows {
let row = &flat[i * cols..(i + 1) * cols];
let pred: f64 = row
.iter()
.zip(params.iter())
.map(|(a, b)| a * b)
.sum::<f64>()
+ intercept;
preds.push(pred);
}
Ok(TlTensor::from_list(preds))
} else {
Err(format!("Input must be 1D or 2D, got {}D", shape.len()))
}
}
LinfaKind::LogisticRegression => {
let model_data: serde_json::Value = serde_json::from_slice(data)
.map_err(|e| format!("Deserialization failed: {e}"))?;
let params: Vec<f64> = model_data["params"]
.as_array()
.ok_or("Missing params")?
.iter()
.map(|v| v.as_f64().unwrap_or(0.0))
.collect();
let intercept: f64 = model_data["intercept"].as_f64().unwrap_or(0.0);
let shape = input.shape();
if shape.len() == 1 {
let x = input.to_vec();
let logit: f64 =
x.iter().zip(params.iter()).map(|(a, b)| a * b).sum::<f64>() + intercept;
let prob = 1.0 / (1.0 + (-logit).exp());
Ok(TlTensor::from_list(vec![if prob > 0.5 {
1.0
} else {
0.0
}]))
} else if shape.len() == 2 {
let rows = shape[0];
let cols = shape[1];
let flat = input.to_vec();
let mut preds = Vec::with_capacity(rows);
for i in 0..rows {
let row = &flat[i * cols..(i + 1) * cols];
let logit: f64 = row
.iter()
.zip(params.iter())
.map(|(a, b)| a * b)
.sum::<f64>()
+ intercept;
let prob = 1.0 / (1.0 + (-logit).exp());
preds.push(if prob > 0.5 { 1.0 } else { 0.0 });
}
Ok(TlTensor::from_list(preds))
} else {
Err(format!("Input must be 1D or 2D, got {}D", shape.len()))
}
}
LinfaKind::DecisionTree => Err(
"Decision tree prediction from serialized model not yet supported. \
Use predict immediately after training."
.to_string(),
),
},
_ => Err("predict_linfa called on non-Linfa model".to_string()),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_train_linear_regression() {
let features = TlTensor::from_vec(
vec![
1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 2.0, 1.0, 3.0, 2.0, 3.0,
3.0, 3.0, 4.0, 4.0,
],
&[10, 2],
)
.unwrap();
let target = TlTensor::from_list(vec![
6.0, 8.0, 10.0, 9.0, 11.0, 13.0, 12.0, 14.0, 16.0, 21.0,
]);
let config = TrainConfig {
features,
target,
feature_names: vec!["x1".to_string(), "x2".to_string()],
target_name: "y".to_string(),
model_name: "test_linear".to_string(),
split_ratio: 1.0,
hyperparams: HashMap::new(),
};
let model = train("linear", &config).unwrap();
if let TlModel::Linfa { metadata, .. } = &model {
assert!(metadata.metrics["r2"] > 0.9, "R² should be > 0.9");
} else {
panic!("Expected Linfa model");
}
}
#[test]
fn test_predict_linear() {
let features =
TlTensor::from_vec(vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 0.0], &[4, 2]).unwrap();
let target = TlTensor::from_list(vec![2.0, 3.0, 5.0, 4.0]);
let config = TrainConfig {
features,
target,
feature_names: vec!["x1".to_string(), "x2".to_string()],
target_name: "y".to_string(),
model_name: "test".to_string(),
split_ratio: 1.0,
hyperparams: HashMap::new(),
};
let model = train("linear", &config).unwrap();
let input = TlTensor::from_vec(vec![1.0, 0.0], &[1, 2]).unwrap();
let pred = predict_linfa(&model, &input).unwrap();
assert!((pred.to_vec()[0] - 2.0).abs() < 1.0);
}
}