1use hashbrown::HashMap;
2use nalgebra::{DMatrix, DVector, OMatrix};
3use rand::{Rng, thread_rng};
4use rayon::prelude::*;
5use crate::errors::ALSError;
6
7const DEFAULT_ITERATIONS : usize = 10;
8const DEFAULT_EPS : f64 = 1.0e-9;
9const DEFAULT_REG : f64 = 1.0;
10type T = f64;
11pub type RTriplet<T> = (usize, usize, T);
12
13pub struct ALS<T> {
14 n : usize,
15 m : usize,
16 k : usize,
17 r_row_first: HashMap<usize, HashMap<usize, T>>,
18 r_col_first : HashMap<usize, HashMap<usize, T>>,
19 x_mat : Vec<DVector<T>>,
20 y_mat : Vec<DVector<T>>,
21 default_iters : usize,
22 default_regularization: T,
23}
24
25impl ALS<T> {
26
27 pub fn new(n : usize, m : usize, k : usize) -> Self {
30 let mut als =
31 ALS {
32 n,
33 m,
34 k,
35 r_row_first : HashMap::new(),
36 r_col_first : HashMap::new(),
37 x_mat : vec![],
38 y_mat : vec![],
39 default_iters : DEFAULT_ITERATIONS,
40 default_regularization: DEFAULT_REG,
41 };
42 als.init_y();
43 als.init_x();
44 als
45 }
46
47 pub fn add(&mut self, e : RTriplet<T>) -> Result<Option<T>, ALSError<T>> {
49 if e.0 >= self.n {
50 return Err(ALSError::InvalidTripletError(e, format!("{} exceeds row index range for R = {}x{}", e.0, self.n, self.m)))
51 }
52 if e.1 >= self.m {
53 return Err(ALSError::InvalidTripletError(e, format!("{} exceeds column index range of R = {}x{}", e.1, self.n, self.m)))
54 }
55
56 let mut previous_entry_val = None;
57 self.r_row_first.entry(e.0)
58 .and_modify(|col| {
59 previous_entry_val = col.insert(e.1, e.2);
60 })
61 .or_insert({
62 let mut col = HashMap::new();
63 previous_entry_val = col.insert(e.1, e.2);
64 col
65 });
66
67 self.r_col_first.entry(e.1)
68 .and_modify(|row| {
69 row.insert(e.0, e.2);
70 })
71 .or_insert({
72 let mut row = HashMap::new();
73 row.insert(e.0, e.2);
74 row
75 });
76
77 Ok(previous_entry_val)
78 }
79
80 pub fn reset_x(&mut self) {
82 let upper_init_bound : T = 1.0 / (self.k as T).sqrt();
83 self.x_mat.par_iter_mut().for_each(|x_col| {
84 x_col.fill_with(|| thread_rng().gen_range(0.0..upper_init_bound))
85 });
86 }
87
88 pub fn reset_y(&mut self) {
90 let upper_init_bound : T = 1.0 / (self.k as T).sqrt();
91 self.y_mat.par_iter_mut().for_each(|y_col| {
92 y_col.fill_with(|| thread_rng().gen_range(0.0..upper_init_bound))
93 });
94 }
95
96 fn init_x(&mut self) {
97 self.x_mat = Vec::with_capacity(self.n);
98 let upper_init_bound : T = 1.0 / (self.k as T).sqrt();
99 self.x_mat.par_extend((0..self.n).into_par_iter()
100 .map(|_| DVector::<T>::from_fn(
101 self.k,
102 |_, _| thread_rng().gen_range(0.0..upper_init_bound))));
103 }
104
105 fn init_y(&mut self) {
106 self.y_mat = Vec::with_capacity(self.m);
107 let upper_init_bound : T = 1.0 / (self.k as T).sqrt();
108 self.y_mat.par_extend((0..self.m).into_par_iter()
109 .map(|_| DVector::<T>::from_fn(
110 self.k,
111 |_, _| thread_rng().gen_range(0.0..upper_init_bound))));
112 }
113
114 pub fn reset_r(&mut self) {
116 self.r_row_first = HashMap::new();
117 self.r_col_first = HashMap::new();
118 }
119
120 pub fn set_regularization(&mut self, lambda : T) {
122 self.default_regularization = lambda;
123 }
124
125 pub fn set_default_iters(&mut self, iters : usize) {
126 self.default_iters = iters;
127 }
128
129 pub fn train_for(&mut self, iters: usize) {
131 self.ensure_x_y_existence();
132 let mut precomp_yyt: HashMap<usize, OMatrix<T, _, _>> = HashMap::with_capacity(self.m);
133 let mut precomp_xxt: HashMap<usize, OMatrix<T, _, _>> = HashMap::with_capacity(self.n);
134 let reg_diag = DMatrix::<T>::from_diagonal_element(self.k, self.k, self.default_regularization);
135 precomp_yyt.par_extend(
136 self.r_col_first.par_keys()
137 .map(|i_m| {
138 (*i_m, DMatrix::<T>::zeros(self.k, self.k))
139 })
140 );
141 precomp_xxt.par_extend(
142 self.r_row_first.par_keys()
143 .map(|i_n| {
144 (*i_n, DMatrix::<T>::zeros(self.k, self.k))
145 })
146 );
147 for _ in 0..iters {
148 precomp_yyt.par_iter_mut().for_each(|(i_m, kk_term)| {
149 let y_i = &self.y_mat[*i_m];
150 y_i.mul_to(&y_i.transpose(), kk_term);
151 });
152
153 self.x_mat.par_iter_mut().enumerate().for_each(|(i_n, x_row)| {
154 if let Some(r_row) = self.r_row_first.get(&i_n) {
155 let mut first_sum = reg_diag.clone();
156 let mut second_sum: DVector<T> = DVector::zeros(self.k);
157 r_row.iter().for_each(|(i_m, r_nm)|{
158 first_sum += precomp_yyt.get(i_m).unwrap();
159 second_sum += &(&self.y_mat[*i_m] * *r_nm);
160 });
161 if !first_sum.try_inverse_mut() {
162 first_sum = first_sum.pseudo_inverse(DEFAULT_EPS).unwrap();
163 }
164 first_sum.mul_to(&second_sum, x_row);
165 }
166 });
167
168 precomp_xxt.par_iter_mut().for_each(|(i_n, kk_term)| {
169 let x_i = &self.x_mat[*i_n];
170 x_i.mul_to(&x_i.transpose(), kk_term);
171 });
172
173 self.y_mat.par_iter_mut().enumerate().for_each(|(i_m, y_row)| {
174 if let Some(r_col) = self.r_col_first.get(&i_m) {
175 let mut first_sum = reg_diag.clone();
176 let mut second_sum: DVector<T> = DVector::zeros(self.k);
177 r_col.iter().for_each(|(i_n, r_nm)|{
178 first_sum += precomp_xxt.get(i_n).unwrap();
179 second_sum += &(&self.x_mat[*i_n] * *r_nm);
180 });
181 if !first_sum.try_inverse_mut() {
182 first_sum = first_sum.pseudo_inverse(DEFAULT_EPS).unwrap();
183 }
184 first_sum.mul_to(&second_sum, y_row);
185 }
186
187 });
188 }
189 }
190
191 fn ensure_x_y_existence(&mut self) {
192 if self.x_mat.len() != self.n {
193 self.init_x();
194 }
195
196 if self.y_mat.len() != self.m {
197 self.init_y();
198 }
199 }
200
201 pub fn train(&mut self) {
203 self.train_for(self.default_iters);
204 }
205
206 pub fn get_row_factors(&self, row : usize) -> Option<&DVector<T>> {
208 self.x_mat.get(row)
209 }
210 pub fn get_col_factors(&self, col : usize) -> Option<&DVector<T>> {
211 self.y_mat.get(col)
212 }
213
214 pub fn get_x(&self) -> &Vec<DVector<T>> {
215 &self.x_mat
216 }
217
218 pub fn get_y(&self) -> &Vec<DVector<T>> {
219 &self.y_mat
220 }
221
222
223 pub fn cost(&mut self) -> T {
225 self.ensure_x_y_existence();
226 let r_term : T = self.r_row_first.par_iter().map(|(i_n, col)| {
227 col
228 .par_iter()
229 .map(|(i_m, val)|
230 (*val - (self.x_mat[*i_n].transpose() * &self.y_mat[*i_m])[(0, 0)])
231 .powi(2)
232 )
233 .sum::<T>()
234 }).sum::<T>();
235
236 let x_term : T = self.x_mat
237 .par_iter()
238 .map(|x_in| (x_in.transpose() * x_in)[(0, 0)])
239 .sum::<T>();
240
241 let y_term : T = self.y_mat
242 .par_iter()
243 .map(|y_in| (y_in.transpose() * y_in)[(0, 0)])
244 .sum::<T>();
245
246 r_term + self.default_regularization * (x_term + y_term)
247 }
248
249 pub fn predict_r_val(&self, n :usize, m : usize) -> T {
251 (self.x_mat[n].transpose() * &self.y_mat[m])[(0, 0)]
252 }
253}
254