1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
use {Approximator, EvaluationResult, Projection, Projector, UpdateResult}; use geometry::{Matrix, Vector}; use std::marker::PhantomData; #[derive(Clone, Serialize, Deserialize)] pub struct MultiLinear<I: ?Sized, P: Projector<I>> { pub projector: P, pub weights: Matrix<f64>, phantom: PhantomData<I>, } impl<I: ?Sized, P: Projector<I>> MultiLinear<I, P> { pub fn new(projector: P, n_outputs: usize) -> Self { let n_features = projector.span().into(); Self { projector: projector, weights: Matrix::zeros((n_features, n_outputs)), phantom: PhantomData, } } pub fn assign(&mut self, values: &Matrix<f64>) { self.weights.assign(values); } pub fn assign_cols(&mut self, values: &Vector<f64>) { let view = values.broadcast(self.weights.dim()).unwrap(); self.weights.assign(&view); } pub fn evaluate_projection(&self, p: &Projection) -> Vector<f64> { match p { &Projection::Dense(ref dense) => self.weights.t().dot(&(dense / p.z())), &Projection::Sparse(ref sparse) => (0..self.weights.cols()) .map(|c| { sparse .iter() .fold(0.0, |acc, idx| acc + self.weights[(*idx, c)]) }) .collect(), } } pub fn update_projection(&mut self, p: &Projection, errors: Vector<f64>) { let z = p.z(); match p { &Projection::Dense(ref dense) => { let view = dense.view().into_shape((self.weights.rows(), 1)).unwrap(); let error_matrix = errors.view().into_shape((1, self.weights.cols())).unwrap(); self.weights.scaled_add(1.0 / z, &view.dot(&error_matrix)) }, &Projection::Sparse(ref sparse) => for c in 0..self.weights.cols() { let mut col = self.weights.column_mut(c); let error = errors[c]; let scaled_error = error / z; for idx in sparse { col[*idx] += scaled_error } }, } } } impl<I: ?Sized, P: Projector<I>> Approximator<I> for MultiLinear<I, P> { type Value = Vector<f64>; fn evaluate(&self, input: &I) -> EvaluationResult<Vector<f64>> { let p = self.projector.project(input); Ok(self.evaluate_projection(&p)) } fn update(&mut self, input: &I, errors: Vector<f64>) -> UpdateResult<()> { let p = self.projector.project(input); Ok(self.update_projection(&p, errors)) } } #[cfg(test)] mod tests { extern crate seahash; use super::*; use projection::{Fourier, TileCoding}; use std::hash::BuildHasherDefault; type SHBuilder = BuildHasherDefault<seahash::SeaHasher>; #[test] fn test_sparse_update_eval() { let p = TileCoding::new(SHBuilder::default(), 4, 100); let mut f = MultiLinear::new(p.clone(), 2); let input = vec![5.0]; let _ = f.update(input.as_slice(), Vector::from_vec(vec![20.0, 50.0])); let out = f.evaluate(input.as_slice()).unwrap(); println!("{:?}", out); assert!((out[0] - 20.0).abs() < 1e-6); assert!((out[1] - 50.0).abs() < 1e-6); } #[test] fn test_dense_update_eval() { let p = Fourier::new(3, vec![(0.0, 10.0)]); let mut f = MultiLinear::new(p.clone(), 2); let input = vec![5.0]; let _ = f.update(input.as_slice(), Vector::from_vec(vec![20.0, 50.0])); let out = f.evaluate(input.as_slice()).unwrap(); assert!((out[0] - 20.0).abs() < 1e-6); assert!((out[1] - 50.0).abs() < 1e-6); } }