use hashbrown::HashMap;
use nalgebra::{DMatrix, DVector, OMatrix};
use rand::{Rng, thread_rng};
use rayon::prelude::*;
use crate::errors::ALSError;
const DEFAULT_ITERATIONS : usize = 10;
const DEFAULT_EPS : f64 = 1.0e-9;
const DEFAULT_REG : f64 = 1.0;
type T = f64;
pub type RTriplet<T> = (usize, usize, T);
pub struct ALS<T> {
n : usize,
m : usize,
k : usize,
r_row_first: HashMap<usize, HashMap<usize, T>>,
r_col_first : HashMap<usize, HashMap<usize, T>>,
x_mat : Vec<DVector<T>>,
y_mat : Vec<DVector<T>>,
default_iters : usize,
default_regularization: T,
}
impl ALS<T> {
pub fn new(n : usize, m : usize, k : usize) -> Self {
let mut als =
ALS {
n,
m,
k,
r_row_first : HashMap::new(),
r_col_first : HashMap::new(),
x_mat : vec![],
y_mat : vec![],
default_iters : DEFAULT_ITERATIONS,
default_regularization: DEFAULT_REG,
};
als.init_y();
als.init_x();
als
}
pub fn add(&mut self, e : RTriplet<T>) -> Result<Option<T>, ALSError<T>> {
if e.0 >= self.n {
return Err(ALSError::InvalidTripletError(e, format!("{} exceeds row index range for R = {}x{}", e.0, self.n, self.m)))
}
if e.1 >= self.m {
return Err(ALSError::InvalidTripletError(e, format!("{} exceeds column index range of R = {}x{}", e.1, self.n, self.m)))
}
let mut previous_entry_val = None;
self.r_row_first.entry(e.0)
.and_modify(|col| {
previous_entry_val = col.insert(e.1, e.2);
})
.or_insert({
let mut col = HashMap::new();
previous_entry_val = col.insert(e.1, e.2);
col
});
self.r_col_first.entry(e.1)
.and_modify(|row| {
row.insert(e.0, e.2);
})
.or_insert({
let mut row = HashMap::new();
row.insert(e.0, e.2);
row
});
Ok(previous_entry_val)
}
pub fn reset_x(&mut self) {
let upper_init_bound : T = 1.0 / (self.k as T).sqrt();
self.x_mat.par_iter_mut().for_each(|x_col| {
x_col.fill_with(|| thread_rng().gen_range(0.0..upper_init_bound))
});
}
pub fn reset_y(&mut self) {
let upper_init_bound : T = 1.0 / (self.k as T).sqrt();
self.y_mat.par_iter_mut().for_each(|y_col| {
y_col.fill_with(|| thread_rng().gen_range(0.0..upper_init_bound))
});
}
fn init_x(&mut self) {
self.x_mat = Vec::with_capacity(self.n);
let upper_init_bound : T = 1.0 / (self.k as T).sqrt();
self.x_mat.par_extend((0..self.n).into_par_iter()
.map(|_| DVector::<T>::from_fn(
self.k,
|_, _| thread_rng().gen_range(0.0..upper_init_bound))));
}
fn init_y(&mut self) {
self.y_mat = Vec::with_capacity(self.m);
let upper_init_bound : T = 1.0 / (self.k as T).sqrt();
self.y_mat.par_extend((0..self.m).into_par_iter()
.map(|_| DVector::<T>::from_fn(
self.k,
|_, _| thread_rng().gen_range(0.0..upper_init_bound))));
}
pub fn reset_r(&mut self) {
self.r_row_first = HashMap::new();
self.r_col_first = HashMap::new();
}
pub fn set_regularization(&mut self, lambda : T) {
self.default_regularization = lambda;
}
pub fn set_default_iters(&mut self, iters : usize) {
self.default_iters = iters;
}
pub fn train_for(&mut self, iters: usize) {
self.ensure_x_y_existence();
let mut precomp_yyt: HashMap<usize, OMatrix<T, _, _>> = HashMap::with_capacity(self.m);
let mut precomp_xxt: HashMap<usize, OMatrix<T, _, _>> = HashMap::with_capacity(self.n);
let reg_diag = DMatrix::<T>::from_diagonal_element(self.k, self.k, self.default_regularization);
precomp_yyt.par_extend(
self.r_col_first.par_keys()
.map(|i_m| {
(*i_m, DMatrix::<T>::zeros(self.k, self.k))
})
);
precomp_xxt.par_extend(
self.r_row_first.par_keys()
.map(|i_n| {
(*i_n, DMatrix::<T>::zeros(self.k, self.k))
})
);
for _ in 0..iters {
precomp_yyt.par_iter_mut().for_each(|(i_m, kk_term)| {
let y_i = &self.y_mat[*i_m];
y_i.mul_to(&y_i.transpose(), kk_term);
});
self.x_mat.par_iter_mut().enumerate().for_each(|(i_n, x_row)| {
if let Some(r_row) = self.r_row_first.get(&i_n) {
let mut first_sum = reg_diag.clone();
let mut second_sum: DVector<T> = DVector::zeros(self.k);
r_row.iter().for_each(|(i_m, r_nm)|{
first_sum += precomp_yyt.get(i_m).unwrap();
second_sum += &(&self.y_mat[*i_m] * *r_nm);
});
if !first_sum.try_inverse_mut() {
first_sum = first_sum.pseudo_inverse(DEFAULT_EPS).unwrap();
}
first_sum.mul_to(&second_sum, x_row);
}
});
precomp_xxt.par_iter_mut().for_each(|(i_n, kk_term)| {
let x_i = &self.x_mat[*i_n];
x_i.mul_to(&x_i.transpose(), kk_term);
});
self.y_mat.par_iter_mut().enumerate().for_each(|(i_m, y_row)| {
if let Some(r_col) = self.r_col_first.get(&i_m) {
let mut first_sum = reg_diag.clone();
let mut second_sum: DVector<T> = DVector::zeros(self.k);
r_col.iter().for_each(|(i_n, r_nm)|{
first_sum += precomp_xxt.get(i_n).unwrap();
second_sum += &(&self.x_mat[*i_n] * *r_nm);
});
if !first_sum.try_inverse_mut() {
first_sum = first_sum.pseudo_inverse(DEFAULT_EPS).unwrap();
}
first_sum.mul_to(&second_sum, y_row);
}
});
}
}
fn ensure_x_y_existence(&mut self) {
if self.x_mat.len() != self.n {
self.init_x();
}
if self.y_mat.len() != self.m {
self.init_y();
}
}
pub fn train(&mut self) {
self.train_for(self.default_iters);
}
pub fn get_row_factors(&self, row : usize) -> Option<&DVector<T>> {
self.x_mat.get(row)
}
pub fn get_col_factors(&self, col : usize) -> Option<&DVector<T>> {
self.y_mat.get(col)
}
pub fn get_x(&self) -> &Vec<DVector<T>> {
&self.x_mat
}
pub fn get_y(&self) -> &Vec<DVector<T>> {
&self.y_mat
}
pub fn cost(&mut self) -> T {
self.ensure_x_y_existence();
let r_term : T = self.r_row_first.par_iter().map(|(i_n, col)| {
col
.par_iter()
.map(|(i_m, val)|
(*val - (self.x_mat[*i_n].transpose() * &self.y_mat[*i_m])[(0, 0)])
.powi(2)
)
.sum::<T>()
}).sum::<T>();
let x_term : T = self.x_mat
.par_iter()
.map(|x_in| (x_in.transpose() * x_in)[(0, 0)])
.sum::<T>();
let y_term : T = self.y_mat
.par_iter()
.map(|y_in| (y_in.transpose() * y_in)[(0, 0)])
.sum::<T>();
r_term + self.default_regularization * (x_term + y_term)
}
pub fn predict_r_val(&self, n :usize, m : usize) -> T {
(self.x_mat[n].transpose() * &self.y_mat[m])[(0, 0)]
}
}