1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use {Approximator, EvaluationResult, Projection, Projector, UpdateResult};
use geometry::{Matrix, Vector};
use std::marker::PhantomData;

#[derive(Clone, Serialize, Deserialize)]
pub struct MultiLinear<I: ?Sized, P: Projector<I>> {
    pub projector: P,
    pub weights: Matrix<f64>,

    phantom: PhantomData<I>,
}

impl<I: ?Sized, P: Projector<I>> MultiLinear<I, P> {
    pub fn new(projector: P, n_outputs: usize) -> Self {
        let n_features = projector.span().into();

        Self {
            projector: projector,
            weights: Matrix::zeros((n_features, n_outputs)),

            phantom: PhantomData,
        }
    }

    pub fn assign(&mut self, values: &Matrix<f64>) { self.weights.assign(values); }

    pub fn assign_cols(&mut self, values: &Vector<f64>) {
        let view = values.broadcast(self.weights.dim()).unwrap();

        self.weights.assign(&view);
    }

    pub fn evaluate_projection(&self, p: &Projection) -> Vector<f64> {
        match p {
            &Projection::Dense(ref dense) => self.weights.t().dot(&(dense / p.z())),
            &Projection::Sparse(ref sparse) => (0..self.weights.cols())
                .map(|c| {
                    sparse
                        .iter()
                        .fold(0.0, |acc, idx| acc + self.weights[(*idx, c)])
                })
                .collect(),
        }
    }

    pub fn update_projection(&mut self, p: &Projection, errors: Vector<f64>) {
        let z = p.z();

        match p {
            &Projection::Dense(ref dense) => {
                let view = dense.view().into_shape((self.weights.rows(), 1)).unwrap();
                let error_matrix = errors.view().into_shape((1, self.weights.cols())).unwrap();

                self.weights.scaled_add(1.0 / z, &view.dot(&error_matrix))
            },
            &Projection::Sparse(ref sparse) => for c in 0..self.weights.cols() {
                let mut col = self.weights.column_mut(c);
                let error = errors[c];
                let scaled_error = error / z;

                for idx in sparse {
                    col[*idx] += scaled_error
                }
            },
        }
    }
}

impl<I: ?Sized, P: Projector<I>> Approximator<I> for MultiLinear<I, P> {
    type Value = Vector<f64>;

    fn evaluate(&self, input: &I) -> EvaluationResult<Vector<f64>> {
        let p = self.projector.project(input);

        Ok(self.evaluate_projection(&p))
    }

    fn update(&mut self, input: &I, errors: Vector<f64>) -> UpdateResult<()> {
        let p = self.projector.project(input);

        Ok(self.update_projection(&p, errors))
    }
}

#[cfg(test)]
mod tests {
    extern crate seahash;

    use super::*;
    use projection::{Fourier, TileCoding};
    use std::hash::BuildHasherDefault;

    type SHBuilder = BuildHasherDefault<seahash::SeaHasher>;

    #[test]
    fn test_sparse_update_eval() {
        let p = TileCoding::new(SHBuilder::default(), 4, 100);
        let mut f = MultiLinear::new(p.clone(), 2);

        let input = vec![5.0];

        let _ = f.update(input.as_slice(), Vector::from_vec(vec![20.0, 50.0]));
        let out = f.evaluate(input.as_slice()).unwrap();

        println!("{:?}", out);

        assert!((out[0] - 20.0).abs() < 1e-6);
        assert!((out[1] - 50.0).abs() < 1e-6);
    }

    #[test]
    fn test_dense_update_eval() {
        let p = Fourier::new(3, vec![(0.0, 10.0)]);
        let mut f = MultiLinear::new(p.clone(), 2);

        let input = vec![5.0];

        let _ = f.update(input.as_slice(), Vector::from_vec(vec![20.0, 50.0]));
        let out = f.evaluate(input.as_slice()).unwrap();

        assert!((out[0] - 20.0).abs() < 1e-6);
        assert!((out[1] - 50.0).abs() < 1e-6);
    }
}