use crate::error::WasmError;
use serde::{Deserialize, Serialize};
use wasm_bindgen::prelude::*;
fn squared_euclidean(a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let d = x - y;
d * d
})
.sum()
}
fn dot_f64(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[wasm_bindgen]
pub struct WasmLinearModel {
weights: Vec<f64>,
bias: f64,
n_features: usize,
feature_mean: Option<Vec<f64>>,
feature_scale: Option<Vec<f64>>,
}
#[wasm_bindgen]
impl WasmLinearModel {
#[wasm_bindgen(constructor)]
pub fn new(weights: Vec<f64>, bias: f64) -> Result<WasmLinearModel, JsValue> {
if weights.is_empty() {
return Err(
WasmError::InvalidParameter("WasmLinearModel: weights must not be empty".to_string())
.into(),
);
}
let n_features = weights.len();
Ok(WasmLinearModel {
weights,
bias,
n_features,
feature_mean: None,
feature_scale: None,
})
}
pub fn from_json(json: &str) -> Result<WasmLinearModel, JsValue> {
serde_json::from_str(json)
.map_err(|e| WasmError::SerializationError(e.to_string()).into())
}
pub fn to_json(&self) -> Result<String, JsValue> {
serde_json::to_string(self)
.map_err(|e| WasmError::SerializationError(e.to_string()).into())
}
pub fn predict(&self, x: &[f64]) -> Result<f64, JsValue> {
if x.len() != self.n_features {
return Err(WasmError::InvalidParameter(format!(
"WasmLinearModel::predict: expected {} features, got {}",
self.n_features,
x.len()
))
.into());
}
let standardized: Vec<f64> = self.standardize_sample(x);
Ok(dot_f64(&standardized, &self.weights) + self.bias)
}
pub fn predict_batch(&self, x: &[f64], n_features: usize) -> Result<Vec<f64>, JsValue> {
if n_features != self.n_features {
return Err(WasmError::InvalidParameter(format!(
"WasmLinearModel::predict_batch: model expects {} features, got {}",
self.n_features, n_features
))
.into());
}
if x.len() % n_features != 0 {
return Err(WasmError::InvalidParameter(format!(
"WasmLinearModel::predict_batch: x length {} is not a multiple of n_features {}",
x.len(),
n_features
))
.into());
}
let n_samples = x.len() / n_features;
let mut predictions = Vec::with_capacity(n_samples);
for i in 0..n_samples {
let sample = &x[i * n_features..(i + 1) * n_features];
let standardized = self.standardize_sample(sample);
predictions.push(dot_f64(&standardized, &self.weights) + self.bias);
}
Ok(predictions)
}
pub fn n_features(&self) -> usize {
self.n_features
}
pub fn weights(&self) -> Vec<f64> {
self.weights.clone()
}
pub fn bias(&self) -> f64 {
self.bias
}
}
impl WasmLinearModel {
fn standardize_sample(&self, x: &[f64]) -> Vec<f64> {
match (&self.feature_mean, &self.feature_scale) {
(Some(mean), Some(scale)) => x
.iter()
.zip(mean.iter().zip(scale.iter()))
.map(|(&xi, (&m, &s))| if s == 0.0 { 0.0 } else { (xi - m) / s })
.collect(),
_ => x.to_vec(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[wasm_bindgen]
pub struct WasmKMeans {
centroids: Vec<f64>,
k: usize,
n_features: usize,
cluster_labels: Option<Vec<String>>,
}
#[wasm_bindgen]
impl WasmKMeans {
#[wasm_bindgen(constructor)]
pub fn new(centroids: Vec<f64>, k: usize, n_features: usize) -> Result<WasmKMeans, JsValue> {
if k == 0 || n_features == 0 {
return Err(WasmError::InvalidParameter(
"WasmKMeans: k and n_features must be > 0".to_string(),
)
.into());
}
let expected = k.checked_mul(n_features).ok_or_else(|| {
WasmError::InvalidParameter("WasmKMeans: k × n_features overflow".to_string())
})?;
if centroids.len() != expected {
return Err(WasmError::InvalidParameter(format!(
"WasmKMeans: expected {} centroid values ({}×{}), got {}",
expected,
k,
n_features,
centroids.len()
))
.into());
}
Ok(WasmKMeans {
centroids,
k,
n_features,
cluster_labels: None,
})
}
pub fn from_json(json: &str) -> Result<WasmKMeans, JsValue> {
serde_json::from_str(json)
.map_err(|e| WasmError::SerializationError(e.to_string()).into())
}
pub fn to_json(&self) -> Result<String, JsValue> {
serde_json::to_string(self)
.map_err(|e| WasmError::SerializationError(e.to_string()).into())
}
pub fn predict(&self, x: &[f64]) -> Result<u32, JsValue> {
if x.len() != self.n_features {
return Err(WasmError::InvalidParameter(format!(
"WasmKMeans::predict: expected {} features, got {}",
self.n_features,
x.len()
))
.into());
}
let best = self.nearest_centroid(x).ok_or_else(|| {
WasmError::ComputationError("WasmKMeans: no centroids available".to_string())
})?;
Ok(best as u32)
}
pub fn predict_batch(&self, x: &[f64], n_features: usize) -> Result<Vec<u32>, JsValue> {
if n_features != self.n_features {
return Err(WasmError::InvalidParameter(format!(
"WasmKMeans::predict_batch: model has {} features, got {}",
self.n_features, n_features
))
.into());
}
if x.len() % n_features != 0 {
return Err(WasmError::InvalidParameter(
"WasmKMeans::predict_batch: x.len() not divisible by n_features".to_string(),
)
.into());
}
let n_samples = x.len() / n_features;
let mut labels = Vec::with_capacity(n_samples);
for i in 0..n_samples {
let sample = &x[i * n_features..(i + 1) * n_features];
let label = self.nearest_centroid(sample).unwrap_or(0);
labels.push(label as u32);
}
Ok(labels)
}
pub fn get_centroid(&self, k_idx: usize) -> Result<Vec<f64>, JsValue> {
if k_idx >= self.k {
return Err(WasmError::IndexOutOfBounds(format!(
"WasmKMeans::get_centroid: index {} out of range (k={})",
k_idx, self.k
))
.into());
}
let start = k_idx * self.n_features;
Ok(self.centroids[start..start + self.n_features].to_vec())
}
pub fn k(&self) -> usize {
self.k
}
pub fn n_features(&self) -> usize {
self.n_features
}
}
impl WasmKMeans {
fn nearest_centroid(&self, x: &[f64]) -> Option<usize> {
let mut best_idx = 0usize;
let mut best_dist = f64::INFINITY;
for ci in 0..self.k {
let start = ci * self.n_features;
let centroid = &self.centroids[start..start + self.n_features];
let d = squared_euclidean(x, centroid);
if d < best_dist {
best_dist = d;
best_idx = ci;
}
}
if best_dist.is_finite() {
Some(best_idx)
} else {
None
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GaussianClass {
pub label: u32,
pub log_prior: f64,
pub means: Vec<f64>,
pub variances: Vec<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[wasm_bindgen]
pub struct WasmNaiveBayes {
classes: Vec<GaussianClass>,
n_features: usize,
var_smoothing: f64,
}
#[wasm_bindgen]
impl WasmNaiveBayes {
#[wasm_bindgen(constructor)]
pub fn new(classes_json: &str, var_smoothing: f64) -> Result<WasmNaiveBayes, JsValue> {
let classes: Vec<GaussianClass> = serde_json::from_str(classes_json)
.map_err(|e| WasmError::SerializationError(e.to_string()))?;
if classes.is_empty() {
return Err(WasmError::InvalidParameter(
"WasmNaiveBayes: at least one class required".to_string(),
)
.into());
}
let n_features = classes[0].means.len();
for cls in &classes {
if cls.means.len() != n_features || cls.variances.len() != n_features {
return Err(WasmError::InvalidParameter(format!(
"WasmNaiveBayes: class {} has inconsistent feature count",
cls.label
))
.into());
}
}
Ok(WasmNaiveBayes {
classes,
n_features,
var_smoothing: var_smoothing.max(0.0),
})
}
pub fn from_json(json: &str) -> Result<WasmNaiveBayes, JsValue> {
serde_json::from_str(json)
.map_err(|e| WasmError::SerializationError(e.to_string()).into())
}
pub fn to_json(&self) -> Result<String, JsValue> {
serde_json::to_string(self)
.map_err(|e| WasmError::SerializationError(e.to_string()).into())
}
pub fn predict(&self, x: &[f64]) -> Result<u32, JsValue> {
self.predict_internal(x)
}
pub fn predict_batch(&self, x: &[f64], n_features: usize) -> Result<Vec<u32>, JsValue> {
if n_features != self.n_features {
return Err(WasmError::InvalidParameter(format!(
"WasmNaiveBayes::predict_batch: expected {} features, got {}",
self.n_features, n_features
))
.into());
}
if x.len() % n_features != 0 {
return Err(WasmError::InvalidParameter(
"WasmNaiveBayes::predict_batch: x.len() not divisible by n_features".to_string(),
)
.into());
}
let n_samples = x.len() / n_features;
let mut out = Vec::with_capacity(n_samples);
for i in 0..n_samples {
let sample = &x[i * n_features..(i + 1) * n_features];
let label = self.predict_internal(sample)?;
out.push(label);
}
Ok(out)
}
pub fn log_posteriors(&self, x: &[f64]) -> Result<Vec<f64>, JsValue> {
if x.len() != self.n_features {
return Err(WasmError::InvalidParameter(format!(
"WasmNaiveBayes::log_posteriors: expected {} features, got {}",
self.n_features,
x.len()
))
.into());
}
Ok(self.classes.iter().map(|c| self.log_posterior(c, x)).collect())
}
pub fn n_classes(&self) -> usize {
self.classes.len()
}
pub fn n_features(&self) -> usize {
self.n_features
}
}
impl WasmNaiveBayes {
fn log_posterior(&self, cls: &GaussianClass, x: &[f64]) -> f64 {
let mut log_prob = cls.log_prior;
for (i, &xi) in x.iter().enumerate() {
let mean = cls.means[i];
let var = cls.variances[i] + self.var_smoothing;
let diff = xi - mean;
log_prob += -0.5 * ((2.0 * std::f64::consts::PI * var).ln() + diff * diff / var);
}
log_prob
}
fn predict_internal(&self, x: &[f64]) -> Result<u32, JsValue> {
if x.len() != self.n_features {
return Err(WasmError::InvalidParameter(format!(
"WasmNaiveBayes::predict: expected {} features, got {}",
self.n_features,
x.len()
))
.into());
}
let best = self
.classes
.iter()
.max_by(|a, b| {
let la = self.log_posterior(a, x);
let lb = self.log_posterior(b, x);
la.partial_cmp(&lb).unwrap_or(std::cmp::Ordering::Equal)
})
.ok_or_else(|| {
WasmError::ComputationError("WasmNaiveBayes: no classes defined".to_string())
})?;
Ok(best.label)
}
}
#[wasm_bindgen]
pub fn linear_model_predict_batch(
model_json: &str,
x: &[f64],
n_features: usize,
) -> Result<Vec<f64>, JsValue> {
let model = WasmLinearModel::from_json(model_json)?;
model.predict_batch(x, n_features)
}
#[wasm_bindgen]
pub fn kmeans_predict_batch(
model_json: &str,
x: &[f64],
n_features: usize,
) -> Result<Vec<u32>, JsValue> {
let model = WasmKMeans::from_json(model_json)?;
model.predict_batch(x, n_features)
}
#[wasm_bindgen]
pub fn naive_bayes_predict_batch(
model_json: &str,
x: &[f64],
n_features: usize,
) -> Result<Vec<u32>, JsValue> {
let model = WasmNaiveBayes::from_json(model_json)?;
model.predict_batch(x, n_features)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_linear_model_predict() {
let model = WasmLinearModel::new(vec![2.0, 3.0], 1.0).expect("model ok");
let pred = model.predict(&[1.0, 1.0]).expect("predict ok");
assert!((pred - 6.0).abs() < 1e-10, "expected 6, got {pred}");
}
#[test]
fn test_linear_model_batch() {
let model = WasmLinearModel::new(vec![1.0, 1.0], 0.0).expect("model ok");
let x = vec![1.0, 2.0, 3.0, 4.0];
let preds = model.predict_batch(&x, 2).expect("batch ok");
assert_eq!(preds.len(), 2);
assert!((preds[0] - 3.0).abs() < 1e-10);
assert!((preds[1] - 7.0).abs() < 1e-10);
}
#[test]
fn test_linear_model_feature_mismatch() {
let model = WasmLinearModel::new(vec![1.0, 2.0], 0.0).expect("ok");
assert!(model.predict(&[1.0]).is_err());
}
#[test]
fn test_linear_model_json_roundtrip() {
let model = WasmLinearModel::new(vec![1.5, -2.3, 0.7], 0.42).expect("ok");
let json = model.to_json().expect("to_json ok");
let recovered = WasmLinearModel::from_json(&json).expect("from_json ok");
assert!((recovered.bias - model.bias).abs() < 1e-12);
for (a, b) in recovered.weights.iter().zip(model.weights.iter()) {
assert!((a - b).abs() < 1e-12);
}
}
#[test]
fn test_kmeans_predict() {
let centroids = vec![0.0_f64, 0.0, 10.0, 10.0];
let km = WasmKMeans::new(centroids, 2, 2).expect("kmeans ok");
let label_near = km.predict(&[0.5, 0.5]).expect("predict ok");
let label_far = km.predict(&[9.8, 10.2]).expect("predict ok");
assert_eq!(label_near, 0, "should assign to cluster 0");
assert_eq!(label_far, 1, "should assign to cluster 1");
}
#[test]
fn test_kmeans_batch() {
let centroids = vec![0.0_f64, 0.0, 10.0, 10.0];
let km = WasmKMeans::new(centroids, 2, 2).expect("ok");
let x = vec![0.1, 0.1, 9.9, 9.9, 0.2, -0.1];
let labels = km.predict_batch(&x, 2).expect("batch ok");
assert_eq!(labels, vec![0, 1, 0]);
}
#[test]
fn test_kmeans_size_mismatch() {
let result = WasmKMeans::new(vec![0.0; 5], 2, 2);
assert!(result.is_err());
}
#[test]
fn test_kmeans_json_roundtrip() {
let km = WasmKMeans::new(vec![1.0, 2.0, 3.0, 4.0], 2, 2).expect("ok");
let json = km.to_json().expect("to_json ok");
let recovered = WasmKMeans::from_json(&json).expect("from_json ok");
assert_eq!(recovered.k, 2);
assert_eq!(recovered.n_features, 2);
}
fn make_nb_json() -> String {
let log_prior = 0.5_f64.ln();
serde_json::json!([
{
"label": 0,
"log_prior": log_prior,
"means": [1.0, 2.0],
"variances": [0.1, 0.1]
},
{
"label": 1,
"log_prior": log_prior,
"means": [5.0, 6.0],
"variances": [0.1, 0.1]
}
])
.to_string()
}
#[test]
fn test_naive_bayes_predict() {
let nb = WasmNaiveBayes::new(&make_nb_json(), 1e-9).expect("nb ok");
let label = nb.predict(&[1.05, 1.95]).expect("predict ok");
assert_eq!(label, 0);
let label = nb.predict(&[4.9, 6.1]).expect("predict ok");
assert_eq!(label, 1);
}
#[test]
fn test_naive_bayes_batch() {
let nb = WasmNaiveBayes::new(&make_nb_json(), 1e-9).expect("nb ok");
let x = vec![1.0_f64, 2.0, 5.0, 6.0, 1.1, 2.1];
let labels = nb.predict_batch(&x, 2).expect("batch ok");
assert_eq!(labels, vec![0, 1, 0]);
}
#[test]
fn test_naive_bayes_json_roundtrip() {
let nb = WasmNaiveBayes::new(&make_nb_json(), 1e-9).expect("ok");
let json = nb.to_json().expect("to_json ok");
let recovered = WasmNaiveBayes::from_json(&json).expect("from_json ok");
assert_eq!(recovered.n_classes(), 2);
assert_eq!(recovered.n_features(), 2);
}
#[test]
fn test_naive_bayes_wrong_features() {
let nb = WasmNaiveBayes::new(&make_nb_json(), 1e-9).expect("ok");
assert!(nb.predict(&[1.0]).is_err());
}
}