pub fn multivariate_transform<F>(x: &[Vec<Vec<f64>>], transform_fn: F) -> Vec<Vec<Vec<f64>>>
where
F: Fn(&[Vec<f64>]) -> Vec<Vec<f64>>,
{
assert!(!x.is_empty(), "Input must have at least one sample");
let n_features = x[0].len();
assert!(
x.iter().all(|s| s.len() == n_features),
"All samples must have the same number of features"
);
let n_samples = x.len();
let mut output: Vec<Vec<Vec<f64>>> = (0..n_samples)
.map(|_| Vec::with_capacity(n_features))
.collect();
for feat_idx in 0..n_features {
let single_feat: Vec<Vec<f64>> = x.iter().map(|sample| sample[feat_idx].clone()).collect();
let transformed = transform_fn(&single_feat);
for (i, row) in transformed.into_iter().enumerate() {
output[i].push(row);
}
}
output
}
pub fn multivariate_classify<FitFn, PredFn, M>(
x_train: &[Vec<Vec<f64>>],
y_train: &[String],
x_test: &[Vec<Vec<f64>>],
fit_fn: FitFn,
predict_fn: PredFn,
) -> Vec<String>
where
FitFn: Fn(&[Vec<f64>], &[String]) -> M,
PredFn: Fn(&M, &[Vec<f64>]) -> Vec<String>,
{
assert!(
!x_train.is_empty(),
"Training data must have at least one sample"
);
let n_features = x_train[0].len();
let n_test = x_test.len();
let mut all_predictions: Vec<Vec<String>> = (0..n_test)
.map(|_| Vec::with_capacity(n_features))
.collect();
for feat_idx in 0..n_features {
let train_feat: Vec<Vec<f64>> = x_train.iter().map(|s| s[feat_idx].clone()).collect();
let test_feat: Vec<Vec<f64>> = x_test.iter().map(|s| s[feat_idx].clone()).collect();
let model = fit_fn(&train_feat, y_train);
let preds = predict_fn(&model, &test_feat);
for (i, pred) in preds.into_iter().enumerate() {
all_predictions[i].push(pred);
}
}
all_predictions
.iter()
.map(|preds| {
let mut counts: std::collections::HashMap<&str, usize> =
std::collections::HashMap::new();
for p in preds {
*counts.entry(p.as_str()).or_insert(0) += 1;
}
counts
.into_iter()
.max_by_key(|&(_, count)| count)
.map(|(class, _)| class.to_string())
.unwrap()
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multivariate_transform() {
let x = vec![
vec![vec![1.0, 2.0, 3.0, 4.0], vec![5.0, 6.0, 7.0, 8.0]],
vec![vec![9.0, 10.0, 11.0, 12.0], vec![13.0, 14.0, 15.0, 16.0]],
];
let result = multivariate_transform(&x, |batch| batch.to_vec());
assert_eq!(result.len(), 2);
assert_eq!(result[0].len(), 2);
assert_eq!(result[0][0], vec![1.0, 2.0, 3.0, 4.0]);
assert_eq!(result[1][1], vec![13.0, 14.0, 15.0, 16.0]);
}
#[test]
fn test_multivariate_classify() {
let x_train = vec![
vec![vec![0.0, 1.0], vec![0.0, 1.0]],
vec![vec![10.0, 11.0], vec![10.0, 11.0]],
];
let y_train = vec!["A".to_string(), "B".to_string()];
let x_test = vec![vec![vec![0.5, 1.5], vec![0.5, 1.5]]];
let predictions = multivariate_classify(
&x_train,
&y_train,
&x_test,
|train, labels| (train.to_vec(), labels.to_vec()),
|model, test| {
let (train, labels) = model;
test.iter()
.map(|sample| {
let mut best_dist = f64::INFINITY;
let mut best_label = labels[0].clone();
for (i, t) in train.iter().enumerate() {
let d: f64 = sample
.iter()
.zip(t.iter())
.map(|(a, b)| (a - b).powi(2))
.sum();
if d < best_dist {
best_dist = d;
best_label = labels[i].clone();
}
}
best_label
})
.collect()
},
);
assert_eq!(predictions.len(), 1);
assert_eq!(predictions[0], "A");
}
}