fetish_lib/
quadratic_feature_collection.rs

1extern crate ndarray;
2extern crate ndarray_linalg;
3
4use ndarray::*;
5
6use crate::feature_collection::*;
7use crate::count_sketch::*;
8
9use std::sync::Arc;
10use rustfft::FFTplanner;
11use rustfft::FFT;
12use rustfft::num_complex::Complex;
13use rustfft::num_traits::Zero;
14use crate::params::*;
15
16///A feature collection consisting of sketched quadratic features
17///utilizing the [`CountSketch`]-and-[`FFT`] technique
18///described in (TODO: cite reference from paper here)
19#[derive(Clone)]
20pub struct QuadraticFeatureCollection {
21    in_dimensions : usize,
22    alpha : f32,
23    sketch_one : CountSketch,
24    sketch_two : CountSketch,
25    fft : Arc<dyn FFT<f32>>,
26    ifft : Arc<dyn FFT<f32>>
27}
28
29impl QuadraticFeatureCollection {
30    ///Constructs a new [`QuadraticFeatureCollection`] with the given number of input
31    ///dimensions, the given scaling factor `alpha`], and the given number of quadratic
32    ///features `out_dimensions`.
33    pub fn new(in_dimensions : usize, out_dimensions : usize, alpha : f32) -> QuadraticFeatureCollection {
34
35        let sketch_one = CountSketch::new(in_dimensions, out_dimensions);
36        let sketch_two = CountSketch::new(in_dimensions, out_dimensions);
37        
38        let mut fftplanner = FFTplanner::<f32>::new(false);
39        let mut ifftplanner = FFTplanner::<f32>::new(true);
40
41        let fft = fftplanner.plan_fft(out_dimensions);
42        let ifft = ifftplanner.plan_fft(out_dimensions);
43
44        QuadraticFeatureCollection {
45            in_dimensions,
46            alpha,
47            sketch_one,
48            sketch_two,
49            fft,
50            ifft
51        }
52    }
53}
54
55fn to_complex(real : f32) -> Complex<f32> {
56    Complex::<f32>::new(real, 0.0)
57}
58
59fn from_complex(complex : Complex<f32>) -> f32 {
60    complex.re
61}
62
63impl QuadraticFeatureCollection {
64    ///Unoptimized implementation of "get_features", for testing purposes
65    fn unoptimized_get_features(&self, in_vec : ArrayView1<f32>) -> Array1<f32> {
66        let s = self.in_dimensions;
67        let t = self.get_dimension();
68        
69        let mut result : Array1<f32> = Array::zeros((t,));
70        for i in 0..s {
71            for j in 0..s {
72                let x = in_vec[[i,]];
73                let y = in_vec[[j,]];
74                let sign = self.sketch_one.signs[i] * self.sketch_two.signs[j];
75                let index = (self.sketch_one.indices[i] + self.sketch_two.indices[j]) % t;
76                result[[index,]] += sign * x * y;
77            }
78        }
79        self.alpha * result
80    }
81}
82
83impl FeatureCollection for QuadraticFeatureCollection {
84
85    fn get_jacobian(&self, in_vec: ArrayView1<f32>) -> Array2<f32> {
86        //Yield the t x s jacobian of the feature mapping
87        //since the feature mapping here is a circular convolution
88        //of sketched versions of the input features,
89        //we will actually wind up computing our output manually
90        let s = self.in_dimensions;
91        let t = self.get_dimension();
92
93        let mut result : Array2<f32> = Array::zeros((t, s));
94        for i in 0..s {
95            for j in 0..s {
96                let x = in_vec[[i,]];
97                let y = in_vec[[j,]];
98                let sign = self.sketch_one.signs[i] * self.sketch_two.signs[j];
99                let index = (self.sketch_one.indices[i] + self.sketch_two.indices[j]) % t;
100                
101                result[[index, i]] += sign * y;
102                result[[index, j]] += sign * x;
103            }
104        }
105        self.alpha * result
106    }
107
108    fn get_features(&self, in_vec: ArrayView1<f32>) -> Array1<f32> {
109        let first_sketch = self.sketch_one.sketch(in_vec);
110        let second_sketch = self.sketch_two.sketch(in_vec);
111
112        //FFT polynomial multiplication
113        let mut complex_first_sketch = first_sketch.mapv(to_complex).to_vec();
114        let mut complex_second_sketch = second_sketch.mapv(to_complex).to_vec();
115
116        let out_dim = self.get_dimension();
117        
118        let mut first_fft = vec![Complex::zero(); out_dim];
119        let mut second_fft = vec![Complex::zero(); out_dim];
120
121        self.fft.process(&mut complex_first_sketch, &mut first_fft);
122        self.fft.process(&mut complex_second_sketch, &mut second_fft);
123
124        //Turn second_fft into the multiplied fft in-place
125        for i in 0..out_dim {
126            second_fft[i] *= first_fft[i];
127        }
128
129        //Turn first_fft into the result inverse-fft
130        self.ifft.process(&mut second_fft, &mut first_fft);
131
132        //Normalize [since the fft library does unnormalized ffts]
133        let scale_fac : f32 = 1.0 / (out_dim as f32);
134        for i in 0..out_dim {
135            first_fft[i] *= scale_fac;
136        }
137        
138        let result = Array::from(first_fft).mapv(from_complex);
139        self.alpha * result
140    }
141
142    fn get_in_dimensions(&self) -> usize {
143        self.in_dimensions
144    }
145
146    fn get_dimension(&self) -> usize {
147        self.sketch_one.get_out_dimensions()
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use crate::test_utils::*;
155
156    #[test]
157    fn empirical_jacobian_is_jacobian() {
158        let quadratic_feature_collection = QuadraticFeatureCollection::new(10, 15, 1.0f32);
159        let in_vec = random_vector(10);
160        let jacobian = quadratic_feature_collection.get_jacobian(in_vec.view());
161        let empirical_jacobian = empirical_jacobian(|x| quadratic_feature_collection.get_features(x),
162                                                        in_vec.view());
163        assert_equal_matrices_to_within(jacobian.view(), empirical_jacobian.view(), 0.1f32);
164    }
165
166    #[test]
167    fn unoptimized_get_features_is_get_features() {
168        let quadratic_feature_collection = QuadraticFeatureCollection::new(10, 15, 1.0f32);
169        let in_vec = random_vector(10);
170        let unoptimized = quadratic_feature_collection.unoptimized_get_features(in_vec.view());
171        let optimized = quadratic_feature_collection.get_features(in_vec.view());
172        assert_equal_vectors(optimized.view(), unoptimized.view());
173    }
174}