use crate::error::{RecsysError, RecsysResult};
#[derive(Debug, Clone)]
pub struct EaseConfig {
pub n_users: usize,
pub n_items: usize,
pub lambda: f32,
pub lambda_l1: f32,
pub l1_iter: usize,
pub l1_tol: f32,
}
impl Default for EaseConfig {
fn default() -> Self {
Self {
n_users: 0,
n_items: 0,
lambda: 500.0,
lambda_l1: 0.0,
l1_iter: 50,
l1_tol: 1e-4,
}
}
}
#[derive(Debug, Clone)]
pub struct Ease {
pub cfg: EaseConfig,
pub weights: Vec<f32>,
}
impl Ease {
pub fn fit(interactions: &[f32], cfg: EaseConfig) -> RecsysResult<Self> {
if cfg.n_users == 0 {
return Err(RecsysError::InvalidNumUsers { n: cfg.n_users });
}
if cfg.n_items == 0 {
return Err(RecsysError::InvalidNumItems { n: cfg.n_items });
}
if cfg.lambda <= 0.0 {
return Err(RecsysError::InvalidLambda { val: cfg.lambda });
}
let expected_len = cfg.n_users * cfg.n_items;
if interactions.len() != expected_len {
return Err(RecsysError::DimensionMismatch {
expected: expected_len,
got: interactions.len(),
});
}
let gram = Self::compute_gram(interactions, cfg.n_users, cfg.n_items, cfg.lambda);
let g_inv = Self::invert_via_cholesky(gram, cfg.n_items)?;
let n = cfg.n_items;
let mut weights = vec![0.0_f32; n * n];
for j in 0..n {
let g_inv_jj = g_inv[j * n + j];
for i in 0..n {
if i == j {
weights[i * n + j] = 0.0;
} else {
weights[i * n + j] = -g_inv[i * n + j] / g_inv_jj;
}
}
}
if cfg.lambda_l1 > 0.0 {
Self::easer_coordinate_descent(&mut weights, &g_inv, n, &cfg);
}
Ok(Self { cfg, weights })
}
pub fn predict(&self, user_row: &[f32]) -> RecsysResult<Vec<f32>> {
let n = self.cfg.n_items;
if user_row.len() != n {
return Err(RecsysError::DimensionMismatch {
expected: n,
got: user_row.len(),
});
}
let mut scores = vec![0.0_f32; n];
for (i, &x_i) in user_row.iter().enumerate() {
if x_i == 0.0 {
continue;
}
let row_offset = i * n;
for (j, s) in scores.iter_mut().enumerate() {
*s += x_i * self.weights[row_offset + j];
}
}
Ok(scores)
}
pub fn predict_batch(&self, user_matrix: &[f32], n_users: usize) -> RecsysResult<Vec<f32>> {
let n = self.cfg.n_items;
let expected = n_users * n;
if user_matrix.len() != expected {
return Err(RecsysError::DimensionMismatch {
expected,
got: user_matrix.len(),
});
}
let mut scores = vec![0.0_f32; n_users * n];
for u in 0..n_users {
let row = &user_matrix[u * n..(u + 1) * n];
let score_row = self.predict(row)?;
scores[u * n..(u + 1) * n].copy_from_slice(&score_row);
}
Ok(scores)
}
pub fn recommend_top_k(
&self,
user_row: &[f32],
k: usize,
exclude_interacted: bool,
) -> RecsysResult<Vec<usize>> {
let scores = self.predict(user_row)?;
let mut candidates: Vec<(f32, usize)> = scores
.iter()
.enumerate()
.filter(|&(i, _)| !exclude_interacted || user_row[i] == 0.0)
.map(|(i, &s)| (s, i))
.collect();
let take = k.min(candidates.len());
if take > 0 && take < candidates.len() {
candidates.select_nth_unstable_by(take - 1, |a, b| {
b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
});
candidates.truncate(take);
}
candidates
.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
Ok(candidates.into_iter().map(|(_, idx)| idx).collect())
}
pub fn compute_gram(
interactions: &[f32],
n_users: usize,
n_items: usize,
lambda: f32,
) -> Vec<f32> {
let mut g = vec![0.0_f32; n_items * n_items];
for u in 0..n_users {
let row = &interactions[u * n_items..(u + 1) * n_items];
for i in 0..n_items {
let xi = row[i];
if xi == 0.0 {
continue;
}
for j in i..n_items {
let val = xi * row[j];
g[i * n_items + j] += val;
if j != i {
g[j * n_items + i] += val;
}
}
}
}
for k in 0..n_items {
g[k * n_items + k] += lambda;
}
g
}
pub fn cholesky(a: &mut [f32], n: usize) -> RecsysResult<()> {
for i in 0..n {
for j in 0..i {
let mut sum: f32 = a[i * n + j];
for k in 0..j {
sum -= a[i * n + k] * a[j * n + k];
}
a[i * n + j] = sum / a[j * n + j];
}
let mut diag_sum: f32 = a[i * n + i];
for k in 0..i {
diag_sum -= a[i * n + k] * a[i * n + k];
}
if diag_sum <= 0.0 {
return Err(RecsysError::NotPositiveDefinite);
}
a[i * n + i] = diag_sum.sqrt();
for j in (i + 1)..n {
a[i * n + j] = 0.0;
}
}
Ok(())
}
pub fn forward_substitution(l: &[f32], b: &mut [f32], n: usize) {
for i in 0..n {
let mut sum = b[i];
for k in 0..i {
sum -= l[i * n + k] * b[k];
}
b[i] = sum / l[i * n + i];
}
}
pub fn backward_substitution(l: &[f32], b: &mut [f32], n: usize) {
let mut i = n;
loop {
if i == 0 {
break;
}
i -= 1;
let mut sum = b[i];
for k in (i + 1)..n {
sum -= l[k * n + i] * b[k];
}
b[i] = sum / l[i * n + i];
}
}
pub fn invert_via_cholesky(mut gram: Vec<f32>, n: usize) -> RecsysResult<Vec<f32>> {
Self::cholesky(&mut gram, n)?;
let mut g_inv = vec![0.0_f32; n * n];
for j in 0..n {
let mut col = vec![0.0_f32; n];
col[j] = 1.0;
Self::forward_substitution(&gram, &mut col, n);
Self::backward_substitution(&gram, &mut col, n);
for i in 0..n {
g_inv[i * n + j] = col[i];
}
}
Ok(g_inv)
}
fn easer_coordinate_descent(weights: &mut [f32], g_inv: &[f32], n: usize, cfg: &EaseConfig) {
let lambda_l1 = cfg.lambda_l1;
for _iter in 0..cfg.l1_iter {
let mut max_delta: f32 = 0.0;
for j in 0..n {
let g_inv_jj = g_inv[j * n + j];
let threshold = lambda_l1 / (2.0 * g_inv_jj);
for i in 0..n {
if i == j {
continue;
}
let r = -(g_inv[i * n + j] / g_inv_jj);
let new_w = soft_threshold(r, threshold);
let old_w = weights[i * n + j];
let delta = (new_w - old_w).abs();
if delta > max_delta {
max_delta = delta;
}
weights[i * n + j] = new_w;
}
}
if max_delta < cfg.l1_tol {
break;
}
}
}
}
#[inline]
fn soft_threshold(x: f32, threshold: f32) -> f32 {
if x > threshold {
x - threshold
} else if x < -threshold {
x + threshold
} else {
0.0
}
}
#[cfg(test)]
mod tests {
use super::*;
fn small_interactions() -> Vec<f32> {
vec![
1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, ]
}
fn default_cfg(n_users: usize, n_items: usize) -> EaseConfig {
EaseConfig {
n_users,
n_items,
lambda: 500.0,
..EaseConfig::default()
}
}
#[test]
fn ease_fit_small() {
let x = small_interactions();
let cfg = default_cfg(3, 3);
let model = Ease::fit(&x, cfg).expect("fit should succeed");
assert_eq!(model.weights.len(), 9);
}
#[test]
fn ease_weights_diagonal_zero() {
let x = small_interactions();
let cfg = default_cfg(3, 3);
let model = Ease::fit(&x, cfg).expect("fit should succeed");
let n = 3;
for i in 0..n {
assert_eq!(model.weights[i * n + i], 0.0, "W[{i},{i}] should be 0");
}
}
#[test]
fn ease_weights_shape_correct() {
let n_items = 5;
let n_users = 4;
let x: Vec<f32> = vec![1.0; n_users * n_items];
let cfg = EaseConfig {
n_users,
n_items,
lambda: 10.0,
..EaseConfig::default()
};
let model = Ease::fit(&x, cfg).expect("fit should succeed");
assert_eq!(model.weights.len(), n_items * n_items);
}
#[test]
fn ease_predict_shape() {
let x = small_interactions();
let cfg = default_cfg(3, 3);
let model = Ease::fit(&x, cfg).expect("fit should succeed");
let user_row = vec![1.0, 0.0, 0.0];
let scores = model.predict(&user_row).expect("predict should succeed");
assert_eq!(scores.len(), 3);
}
#[test]
fn ease_predict_batch_shape() {
let x = small_interactions();
let cfg = default_cfg(3, 3);
let model = Ease::fit(&x, cfg).expect("fit should succeed");
let scores = model
.predict_batch(&x, 3)
.expect("predict_batch should succeed");
assert_eq!(scores.len(), 9);
}
#[test]
fn ease_recommend_top_k_length() {
let x = small_interactions();
let cfg = default_cfg(3, 3);
let model = Ease::fit(&x, cfg).expect("fit should succeed");
let user_row = vec![1.0, 0.0, 0.0];
let recs = model
.recommend_top_k(&user_row, 2, false)
.expect("recommend_top_k should succeed");
assert_eq!(recs.len(), 2);
}
#[test]
fn ease_recommend_excludes_interacted() {
let x = small_interactions();
let cfg = default_cfg(3, 3);
let model = Ease::fit(&x, cfg).expect("fit should succeed");
let user_row = vec![1.0, 1.0, 0.0];
let recs = model
.recommend_top_k(&user_row, 1, true)
.expect("recommend_top_k should succeed");
for &idx in &recs {
assert!(
user_row[idx] == 0.0,
"Recommendation {idx} should not be an already-interacted item"
);
}
}
#[test]
fn ease_recommend_not_excluding() {
let x = small_interactions();
let cfg = default_cfg(3, 3);
let model = Ease::fit(&x, cfg).expect("fit should succeed");
let user_row = vec![1.0, 1.0, 0.0];
let recs = model
.recommend_top_k(&user_row, 3, false)
.expect("recommend_top_k should succeed");
assert_eq!(recs.len(), 3);
}
#[test]
fn ease_high_lambda_forces_small_weights() {
let x = small_interactions();
let cfg = EaseConfig {
n_users: 3,
n_items: 3,
lambda: 1_000_000.0,
..EaseConfig::default()
};
let model = Ease::fit(&x, cfg).expect("fit should succeed");
for (i, &w) in model.weights.iter().enumerate() {
if i % 4 == 0 {
continue;
}
assert!(
w.abs() < 0.01,
"Weight {w} too large at index {i} with high lambda"
);
}
}
#[test]
fn ease_low_lambda_allows_larger_weights() {
let x = small_interactions();
let cfg = EaseConfig {
n_users: 3,
n_items: 3,
lambda: 1.0,
..EaseConfig::default()
};
let model = Ease::fit(&x, cfg).expect("fit should succeed");
let max_off_diag = model
.weights
.iter()
.enumerate()
.filter(|(i, _)| i % 4 != 0)
.map(|(_, &w)| w.abs())
.fold(0.0_f32, f32::max);
assert!(
max_off_diag > 0.01,
"Expected larger weights with low lambda, got max={max_off_diag}"
);
}
#[test]
fn gram_matrix_symmetric() {
let x = small_interactions();
let g = Ease::compute_gram(&x, 3, 3, 1.0);
for i in 0..3 {
for j in 0..3 {
assert!(
(g[i * 3 + j] - g[j * 3 + i]).abs() < 1e-6,
"G[{i},{j}]={} ≠ G[{j},{i}]={}",
g[i * 3 + j],
g[j * 3 + i]
);
}
}
}
#[test]
fn gram_matrix_diagonal_includes_lambda() {
let x = small_interactions();
let lambda = 7.0;
let g = Ease::compute_gram(&x, 3, 3, lambda);
for i in 0..3 {
assert!(
g[i * 3 + i] >= lambda,
"G[{i},{i}]={} should be >= lambda={lambda}",
g[i * 3 + i]
);
}
}
#[test]
fn cholesky_identity_gives_identity() {
let lambda = 4.0;
let n = 3;
let scale = 1.0 + lambda;
let mut a: Vec<f32> = (0..n * n)
.map(|k| if k % (n + 1) == 0 { scale } else { 0.0 })
.collect();
Ease::cholesky(&mut a, n).expect("cholesky should succeed");
for i in 0..n {
let diag = a[i * n + i];
assert!(
(diag - scale.sqrt()).abs() < 1e-5,
"L[{i},{i}] = {diag}, expected {}",
scale.sqrt()
);
}
}
#[test]
fn cholesky_fails_non_positive_definite() {
let n = 2;
let mut a = vec![1.0_f32, 2.0, 2.0, 1.0];
let result = Ease::cholesky(&mut a, n);
assert!(
result.is_err(),
"Cholesky should fail on non-positive-definite matrix"
);
}
#[test]
fn invert_cholesky_product_is_identity() {
let x = small_interactions();
let lambda = 500.0;
let gram = Ease::compute_gram(&x, 3, 3, lambda);
let gram_orig = gram.clone();
let g_inv = Ease::invert_via_cholesky(gram, 3).expect("invert_via_cholesky should succeed");
let n = 3;
for i in 0..n {
for j in 0..n {
let dot: f32 = (0..n)
.map(|k| gram_orig[i * n + k] * g_inv[k * n + j])
.sum();
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(dot - expected).abs() < 1e-3,
"G*G_inv[{i},{j}] = {dot}, expected {expected}"
);
}
}
}
#[test]
fn easer_lambda_l1_zero_matches_ease() {
let x = small_interactions();
let cfg_ease = default_cfg(3, 3);
let cfg_easer = EaseConfig {
n_users: 3,
n_items: 3,
lambda: 500.0,
lambda_l1: 0.0,
..EaseConfig::default()
};
let ease = Ease::fit(&x, cfg_ease).expect("fit should succeed");
let easer = Ease::fit(&x, cfg_easer).expect("fit should succeed");
for (i, (&w_e, &w_r)) in ease.weights.iter().zip(easer.weights.iter()).enumerate() {
assert!(
(w_e - w_r).abs() < 1e-5,
"Weight mismatch at {i}: EASE={w_e} EASER={w_r}"
);
}
}
#[test]
fn easer_l1_reduces_weights() {
let x = small_interactions();
let cfg_ease = default_cfg(3, 3);
let cfg_easer = EaseConfig {
n_users: 3,
n_items: 3,
lambda: 500.0,
lambda_l1: 1_000.0, l1_iter: 100,
l1_tol: 1e-6,
};
let ease = Ease::fit(&x, cfg_ease).expect("fit should succeed");
let easer = Ease::fit(&x, cfg_easer).expect("fit should succeed");
let ease_sum: f32 = ease.weights.iter().map(|w| w.abs()).sum();
let easer_sum: f32 = easer.weights.iter().map(|w| w.abs()).sum();
assert!(
easer_sum <= ease_sum,
"EASER L1 sum {easer_sum} should be ≤ EASE sum {ease_sum}"
);
}
#[test]
fn err_n_items_zero() {
let cfg = EaseConfig {
n_users: 3,
n_items: 0,
lambda: 1.0,
..EaseConfig::default()
};
let result = Ease::fit(&[], cfg);
assert!(
matches!(result, Err(RecsysError::InvalidNumItems { .. })),
"Expected InvalidNumItems error"
);
}
#[test]
fn err_lambda_negative() {
let cfg = EaseConfig {
n_users: 3,
n_items: 3,
lambda: -1.0,
..EaseConfig::default()
};
let x = small_interactions();
let result = Ease::fit(&x, cfg);
assert!(
matches!(result, Err(RecsysError::InvalidLambda { .. })),
"Expected InvalidLambda error"
);
}
#[test]
fn err_interaction_length_mismatch() {
let cfg = EaseConfig {
n_users: 3,
n_items: 3,
lambda: 1.0,
..EaseConfig::default()
};
let x = vec![1.0_f32; 5];
let result = Ease::fit(&x, cfg);
assert!(
matches!(result, Err(RecsysError::DimensionMismatch { .. })),
"Expected DimensionMismatch error"
);
}
}