1extern crate ndarray;
2extern crate ndarray_linalg;
3
4use ndarray::*;
5
6use std::ops;
7
8use crate::multiple::*;
9use crate::context::*;
10use crate::prior_specification::*;
11use crate::type_id::*;
12use crate::data_points::*;
13use crate::data_point::*;
14use crate::feature_collection::*;
15use crate::normal_inverse_wishart::*;
16use crate::func_schmear::*;
17use crate::func_inverse_schmear::*;
18use crate::space_info::*;
19
20use rand::prelude::*;
21
22#[derive(Clone)]
29pub struct Model<'a> {
30 pub arg_type_id : TypeId,
31 pub ret_type_id : TypeId,
32 pub data : NormalInverseWishart,
33 pub ctxt : &'a Context
34}
35
36pub fn to_features(feature_collections : &Vec<Box<dyn FeatureCollection>>, in_vec : ArrayView1<f32>) -> Array1<f32> {
39 let comps = feature_collections.iter()
40 .map(|coll| coll.get_features(in_vec))
41 .collect::<Vec<_>>();
42 let comp_views = comps.iter()
43 .map(|comp| ArrayView::from(comp))
44 .collect::<Vec<_>>();
45
46 stack(Axis(0), &comp_views).unwrap()
47}
48
49pub fn to_features_mat(feature_collections : &Vec<Box<dyn FeatureCollection>>, in_mat : ArrayView2<f32>)
52 -> Array2<f32> {
53 let comps = feature_collections.iter()
54 .map(|coll| coll.get_features_mat(in_mat))
55 .collect::<Vec<_>>();
56 let comp_views = comps.iter()
57 .map(|comp| ArrayView::from(comp))
58 .collect::<Vec<_>>();
59 stack(Axis(1), &comp_views).unwrap()
60}
61
62pub fn to_jacobian(feature_collections : &Vec<Box<dyn FeatureCollection>>, in_vec : ArrayView1<f32>) -> Array2<f32> {
65 let comps = feature_collections.iter()
66 .map(|coll| coll.get_jacobian(in_vec))
67 .collect::<Vec<_>>();
68
69 let comp_views = comps.iter()
70 .map(|comp| ArrayView::from(comp))
71 .collect::<Vec<_>>();
72
73 stack(Axis(0), &comp_views).unwrap()
74}
75
76impl <'a> Model<'a> {
77 pub fn get_total_dims(&self) -> usize {
79 self.data.get_total_dims()
80 }
81}
82
83
84impl <'a> Model<'a> {
85 pub fn get_context(&self) -> &'a Context {
87 self.ctxt
88 }
89 pub fn sample(&self, rng : &mut ThreadRng) -> Array2<f32> {
92 self.data.sample(rng)
93 }
94 pub fn sample_as_vec(&self, rng : &mut ThreadRng) -> Array1::<f32> {
96 self.data.sample_as_vec(rng)
97 }
98 pub fn get_mean_as_vec(&self) -> ArrayView1::<f32> {
101 self.data.get_mean_as_vec()
102 }
103 pub fn get_inverse_schmear(&self) -> FuncInverseSchmear {
106 self.data.get_inverse_schmear()
107 }
108
109 pub fn get_schmear(&self) -> FuncSchmear {
112 self.data.get_schmear()
113 }
114}
115
116impl <'a> ops::AddAssign<DataPoint> for Model<'a> {
117 fn add_assign(&mut self, other: DataPoint) {
121 let func_space_info = self.ctxt.build_function_space_info(self.arg_type_id, self.ret_type_id);
122 self.data += &func_space_info.get_data(other);
123 }
124}
125
126impl <'a> ops::AddAssign<DataPoints> for Model<'a> {
127 fn add_assign(&mut self, other : DataPoints) {
131 let func_space_info = self.ctxt.build_function_space_info(self.arg_type_id, self.ret_type_id);
132 self.data.update_datapoints(&func_space_info.get_data_points(other));
133 }
134}
135
136impl <'a> ops::SubAssign<DataPoint> for Model<'a> {
137 fn sub_assign(&mut self, other: DataPoint) {
139 let func_space_info = self.ctxt.build_function_space_info(self.arg_type_id, self.ret_type_id);
140 self.data -= &func_space_info.get_data(other);
141 }
142}
143
144impl <'a> ops::AddAssign<&NormalInverseWishart> for Model<'a> {
145 fn add_assign(&mut self, other : &NormalInverseWishart) {
149 self.data += other;
150 }
151}
152
153impl <'a> ops::SubAssign<&NormalInverseWishart> for Model<'a> {
154 fn sub_assign(&mut self, other : &NormalInverseWishart) {
156 self.data -= other;
157 }
158}
159
160impl <'a> Model<'a> {
161 pub fn new(prior_spec : &dyn PriorSpecification,
165 arg_type_id : TypeId, ret_type_id : TypeId,
166 ctxt : &'a Context) -> Model<'a> {
167
168 let func_space_info = ctxt.build_function_space_info(arg_type_id, ret_type_id);
169 let data = NormalInverseWishart::from_space_info(prior_spec, &func_space_info);
170
171 Model {
172 arg_type_id,
173 ret_type_id,
174 data,
175 ctxt
176 }
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use crate::params::*;
184 use crate::linalg_utils::*;
185 use crate::test_utils::*;
186
187 fn clone_model<'a>(model : &Model<'a>) -> Model<'a> {
188 let prior_spec = TestPriorSpecification {};
189 let mut result = Model::new(&prior_spec, model.arg_type_id, model.ret_type_id, &model.ctxt);
190 result.data = model.data.clone();
191 result
192 }
193
194 fn clone_and_perturb_model<'a>(model : &Model<'a>, epsilon : f32) -> Model<'a> {
195 let prior_spec = TestPriorSpecification {};
196 let mut result = Model::new(&prior_spec, model.arg_type_id, model.ret_type_id, &model.ctxt);
197 result.data = model.data.clone();
198
199 let mean = &model.data.mean;
200 let t = mean.shape()[0];
201 let s = mean.shape()[1];
202
203 let perturbation = epsilon * random_matrix(t, s);
204
205 result.data.mean += &perturbation;
206
207 result.data.recompute_derived();
208
209 result
210 }
211
212 #[test]
213 fn data_updates_bulk_matches_incremental() {
214 let ctxt = get_test_vector_only_context();
215 let mut bulk_updated = random_model(&ctxt, TEST_VECTOR_T, TEST_VECTOR_T);
216 let mut incremental_updated = bulk_updated.clone();
217
218 let mut data_point = random_data_point(TEST_VECTOR_SIZE, TEST_VECTOR_SIZE);
219 data_point.weight = 1.0f32;
220
221 let mut in_vecs = Array::zeros((1, TEST_VECTOR_SIZE));
222 in_vecs.row_mut(0).assign(&data_point.in_vec);
223 let mut out_vecs = Array::zeros((1, TEST_VECTOR_SIZE));
224 out_vecs.row_mut(0).assign(&data_point.out_vec);
225 let data_points = DataPoints {
226 in_vecs,
227 out_vecs
228 };
229
230 incremental_updated += data_point;
231
232 bulk_updated += data_points;
233
234 assert_equal_distributions_to_within(&incremental_updated.data, &bulk_updated.data, 1.0f32);
235 }
236
237 #[test]
238 fn data_updates_undo_cleanly() {
239 let ctxt = get_test_vector_only_context();
240 let expected = random_model(&ctxt, TEST_VECTOR_T, TEST_VECTOR_T);
241
242 let mut model = expected.clone();
243 let data_point = random_data_point(TEST_VECTOR_SIZE, TEST_VECTOR_SIZE);
244
245 model += data_point.clone();
246 model -= data_point.clone();
247
248 assert_equal_distributions_to_within(&model.data, &expected.data, 1.0f32);
249 }
250
251 #[test]
252 fn sampling_accurate() {
253 let epsilon = 10.0f32;
254 let num_samps = 1000;
255
256 let ctxt = get_test_vector_only_context();
257
258 let model = random_model(&ctxt, TEST_VECTOR_T, TEST_VECTOR_T);
259
260 let model_schmear = model.get_schmear().flatten();
261
262 let model_dims = model_schmear.mean.shape()[0];
263
264 let mut mean = Array::zeros((model_dims,));
265 let mut rng = rand::thread_rng();
266
267 let scale_fac = 1.0f32 / (num_samps as f32);
268
269 for _ in 0..num_samps {
270 let sample = model.sample_as_vec(&mut rng);
271
272 mean += &sample;
273 }
274
275 mean *= scale_fac;
276
277 assert_equal_vectors_to_within(mean.view(), model_schmear.mean.view(), epsilon);
278
279
280 let mut covariance = Array::zeros((model_dims, model_dims));
281 for _ in 0..num_samps {
282 let sample = model.sample_as_vec(&mut rng);
283
284 let diff = &sample - &model_schmear.mean;
285 covariance += &(scale_fac * &outer(diff.view(), diff.view()));
286 }
287
288 assert_equal_matrices_to_within(covariance.view(), model_schmear.covariance.view(), epsilon * (model_dims as f32));
289 }
290
291}