fetish_lib/
embedder_state.rs1extern crate ndarray;
2extern crate ndarray_linalg;
3
4use crate::multiple::*;
5use rand::prelude::*;
6use crate::prior_specification::*;
7use crate::space_info::*;
8use crate::newly_evaluated_terms::*;
9use ndarray::*;
10use std::collections::HashSet;
11use std::collections::HashMap;
12use crate::input_to_schmeared_output::*;
13use crate::sampled_embedder_state::*;
14use crate::term_index::*;
15use crate::interpreter_state::*;
16use crate::type_id::*;
17use crate::term_pointer::*;
18use crate::term_reference::*;
19use crate::term_application_result::*;
20use crate::term_model::*;
21use crate::embedding_space::*;
22use crate::schmear::*;
23use crate::func_schmear::*;
24use crate::func_inverse_schmear::*;
25use crate::normal_inverse_wishart::*;
26use crate::elaborator::*;
27use topological_sort::TopologicalSort;
28use crate::context::*;
29
30pub struct EmbedderState<'a> {
34 pub model_spaces : HashMap::<TypeId, EmbeddingSpace<'a>>,
35 pub ctxt : &'a Context
36}
37
38impl<'a> EmbedderState<'a> {
39
40 pub fn sample(&self, rng : &mut ThreadRng) -> SampledEmbedderState<'a> {
43 let mut embedding_spaces = HashMap::new();
44 for (type_id, model_space) in self.model_spaces.iter() {
45 let sampled_embedding_space = model_space.sample(rng);
46 embedding_spaces.insert(*type_id, sampled_embedding_space);
47 }
48 SampledEmbedderState {
49 embedding_spaces,
50 ctxt : self.ctxt
51 }
52 }
53
54 pub fn new(model_prior_specification : &'a dyn PriorSpecification,
57 elaborator_prior_specification : &'a dyn PriorSpecification,
58 ctxt : &'a Context) -> EmbedderState<'a> {
59 info!("Readying embedder state");
60
61 let mut model_spaces = HashMap::new();
62 for func_type_id in 0..ctxt.get_total_num_types() {
63 if (!ctxt.is_vector_type(func_type_id)) {
64 let mut model_space = EmbeddingSpace::new(func_type_id, model_prior_specification,
65 elaborator_prior_specification, ctxt);
66
67 let primitive_type_space = ctxt.primitive_directory
69 .primitive_type_spaces.get(&func_type_id).unwrap();
70 for term_index in 0..primitive_type_space.terms.len() {
71 model_space.add_model(TermIndex::Primitive(term_index));
72 }
73
74 model_spaces.insert(func_type_id, model_space);
75 }
76 }
77
78 EmbedderState {
79 model_spaces,
80 ctxt
81 }
82 }
83
84 pub fn init_embeddings_for_new_terms(&mut self, newly_evaluated_terms : &NewlyEvaluatedTerms) {
87 trace!("Initializing embeddings for {} new terms", newly_evaluated_terms.terms.len());
88 for nonprimitive_term_ptr in newly_evaluated_terms.terms.iter() {
89 let term_ptr = TermPointer::from(nonprimitive_term_ptr.clone());
90 if (!self.has_embedding(term_ptr)) {
91 self.init_embedding(term_ptr);
92 }
93 }
94 }
95
96 pub fn bayesian_update_step(&mut self, interpreter_state : &InterpreterState,
101 newly_evaluated_terms : &NewlyEvaluatedTerms) {
102 self.init_embeddings_for_new_terms(newly_evaluated_terms);
103
104 let mut data_updated_terms : HashSet<TermPointer> = HashSet::new();
105 let mut prior_updated_terms : HashSet<TermPointer> = HashSet::new();
106
107 let mut updated_apps : HashSet::<TermApplicationResult> = HashSet::new();
108 for term_app_result in newly_evaluated_terms.term_app_results.iter() {
109 updated_apps.insert(term_app_result.clone());
110 }
111
112 trace!("Propagating data updates for {} applications", updated_apps.len());
113 self.propagate_data_recursive(interpreter_state, &updated_apps, &mut data_updated_terms,
114 newly_evaluated_terms);
115 trace!("Propagating prior updates for {} applications", data_updated_terms.len());
116 self.propagate_prior_recursive(interpreter_state, &data_updated_terms, &mut prior_updated_terms,
117 newly_evaluated_terms);
118
119 let mut all_updated_terms = HashSet::new();
120 for data_updated_term in data_updated_terms.drain() {
121 all_updated_terms.insert(data_updated_term);
122 }
123 for prior_updated_term in prior_updated_terms.drain() {
124 all_updated_terms.insert(prior_updated_term);
125 }
126 self.update_elaborators(all_updated_terms);
127 }
128
129 pub fn has_embedding(&self, term_ptr : TermPointer) -> bool {
132 let space : &EmbeddingSpace = self.model_spaces.get(&term_ptr.type_id).unwrap();
133 space.has_model(term_ptr.index)
134 }
135
136 pub fn get_embedding(&self, term_ptr : TermPointer) -> &TermModel {
140 let space = self.get_model_space(term_ptr.type_id);
141 space.get_model(term_ptr.index)
142 }
143
144 fn get_model_space(&self, type_id : TypeId) -> &EmbeddingSpace {
145 self.model_spaces.get(&type_id).unwrap()
146 }
147
148 pub fn get_mut_embedding(&mut self, term_ptr : TermPointer) -> &mut TermModel<'a> {
151 let space : &mut EmbeddingSpace = self.model_spaces.get_mut(&term_ptr.type_id).unwrap();
152 space.get_model_mut(term_ptr.index)
153 }
154
155 fn init_embedding(&mut self, term_ptr : TermPointer) {
156 let space : &mut EmbeddingSpace = self.model_spaces.get_mut(&term_ptr.type_id).unwrap();
157 space.add_model(term_ptr.index)
158 }
159
160 fn get_schmear_from_ptr(&self, term_ptr : TermPointer) -> FuncSchmear {
161 let embedding : &TermModel = self.get_embedding(term_ptr);
162 embedding.get_schmear()
163 }
164
165 fn get_inverse_schmear_from_ptr(&self, term_ptr : TermPointer) -> FuncInverseSchmear {
166 let embedding : &TermModel = self.get_embedding(term_ptr);
167 embedding.get_inverse_schmear()
168 }
169
170 fn get_compressed_schmear_from_ptr(&self, term_ptr : TermPointer) -> Schmear {
171 let type_id = term_ptr.type_id;
172 let func_schmear = self.get_schmear_from_ptr(term_ptr);
173 let func_feat_info = self.ctxt.get_feature_space_info(type_id);
174 let projection_mat = func_feat_info.get_projection_matrix();
175 let result = func_schmear.compress(projection_mat.view());
176 result
177 }
178
179 fn get_compressed_schmear_from_ref(&self, term_ref : &TermReference) -> Schmear {
180 match term_ref {
181 TermReference::FuncRef(func_ptr) => self.get_compressed_schmear_from_ptr(*func_ptr),
182 TermReference::VecRef(_, vec) => Schmear::from_vector(vec.view())
183 }
184 }
185
186 fn update_elaborators(&mut self, mut updated_terms : HashSet::<TermPointer>) {
187 for term_ptr in updated_terms.drain() {
188 let model_space = self.model_spaces.get_mut(&term_ptr.type_id).unwrap();
189 let elaborator = &mut model_space.elaborator;
190
191 if (elaborator.has_data(&term_ptr.index)) {
193 elaborator.downdate_data(&term_ptr.index);
194 }
195
196 let term_model = model_space.models.get(&term_ptr.index).unwrap();
197 elaborator.update_data(term_ptr.index, &term_model.model);
198 }
199 }
200
201 fn propagate_prior_recursive(&mut self, interpreter_state : &InterpreterState,
203 to_propagate : &HashSet::<TermPointer>,
204 all_modified : &mut HashSet::<TermPointer>,
205 newly_evaluated : &NewlyEvaluatedTerms) {
206 let new_count_map = newly_evaluated.get_count_map();
207
208 let mut topo_sort = TopologicalSort::<TermApplicationResult>::new();
209 let mut stack = Vec::<TermApplicationResult>::new();
210
211 for func_ptr in to_propagate {
212 let applications = interpreter_state.get_app_results_with_func(*func_ptr);
213 for application in applications {
214 if let TermReference::FuncRef(_) = application.get_ret_ref() {
215 if (self.has_nontrivial_prior_update(&application)) {
216 topo_sort.insert(application.clone());
217 stack.push(application.clone());
218 }
219 }
220 }
221 }
222
223 let mut ret_type_set = HashSet::new();
224 while (stack.len() > 0) {
225 let elem = stack.pop().unwrap();
226 let ret_ref = elem.get_ret_ref();
227
228 ret_type_set.insert(elem.get_ret_type(self.ctxt));
229
230 if let TermReference::FuncRef(ret_func_ptr) = ret_ref {
231 let applications = interpreter_state.get_app_results_with_func(ret_func_ptr);
232 for application in applications {
233 if let TermReference::FuncRef(_) = application.get_ret_ref() {
234 if (self.has_nontrivial_prior_update(&application)) {
235 topo_sort.add_dependency(elem.clone(), application.clone());
236 stack.push(application);
237 }
238 }
239 }
240
241 all_modified.insert(ret_func_ptr);
242 }
243 }
244
245 info!("Obtaining elaborator func schmears");
246 let mut elaborator_func_schmears = HashMap::new();
247 for type_id in ret_type_set.drain() {
248 if (!self.ctxt.is_vector_type(type_id)) {
249 let model_space = self.model_spaces.get(&type_id).unwrap();
250 let elaborator = &model_space.elaborator;
251 let elaborator_func_schmear = elaborator.get_expansion_func_schmear();
252 elaborator_func_schmears.insert(type_id, elaborator_func_schmear);
253 }
254 }
255 info!("Propagating priors");
256
257 while (!topo_sort.is_empty()) {
258 let mut to_process = topo_sort.pop_all();
259 for elem in to_process.drain(..) {
260 let out_type = elem.get_ret_type(self.ctxt);
261 let elaborator_func_schmear = elaborator_func_schmears.get(&out_type).unwrap();
262
263 let new_count = match (new_count_map.get(&elem)) {
264 Option::None => 0,
265 Option::Some(count) => *count
266 };
267
268 self.propagate_prior(elem, elaborator_func_schmear, new_count);
269 }
270 }
271 }
272
273 fn propagate_data_recursive(&mut self, interpreter_state : &InterpreterState,
275 to_propagate : &HashSet::<TermApplicationResult>,
276 all_modified : &mut HashSet::<TermPointer>,
277 newly_evaluated : &NewlyEvaluatedTerms) {
278 let new_count_map = newly_evaluated.get_count_map();
279
280 let mut topo_sort = TopologicalSort::<TermApplicationResult>::new();
281 let mut stack = Vec::<TermApplicationResult>::new();
282
283 for elem in to_propagate {
284 stack.push(elem.clone());
285 }
286
287 while (stack.len() > 0) {
288 let elem = stack.pop().unwrap();
289 let func_ptr = elem.get_func_ptr();
290 let func_ref = TermReference::FuncRef(func_ptr.clone());
291
292 all_modified.insert(func_ptr);
293
294 let args = interpreter_state.get_app_results_with_arg(&func_ref);
295 for arg in args {
296 stack.push(arg.clone());
297 topo_sort.add_dependency(elem.clone(), arg.clone());
298 }
299
300 let rets = interpreter_state.get_app_results_with_result(&func_ref);
301 for ret in rets {
302 stack.push(ret.clone());
303 topo_sort.add_dependency(elem.clone(), ret.clone());
304 }
305
306 topo_sort.insert(elem);
307 }
308
309 while (!topo_sort.is_empty()) {
310 let to_process = topo_sort.pop_all();
311 for elem in to_process {
312 let new_count = match (new_count_map.get(&elem)) {
313 Option::None => 0,
314 Option::Some(count) => *count
315 };
316 self.propagate_data(elem, new_count);
317 }
318 }
319
320 }
321
322 fn get_prior_propagation_func_schmear(&self, term_app_res : &TermApplicationResult) -> FuncSchmear {
323 let func_model = self.get_embedding(term_app_res.get_func_ptr());
324 let term_input_output = term_app_res.get_term_input_output();
328
329 let mut model_clone = func_model.clone();
330 model_clone.downdate_data(&term_input_output);
331 model_clone.get_schmear()
332 }
333
334 fn has_nontrivial_prior_update(&self, term_app_res : &TermApplicationResult) -> bool {
335 let term_input_output = term_app_res.get_term_input_output();
336 let func_model = self.get_embedding(term_app_res.get_func_ptr());
337 func_model.has_some_data_other_than(&term_input_output)
338 }
339
340 fn propagate_prior(&mut self, term_app_res : TermApplicationResult,
344 elaborator_func_schmear : &FuncSchmear, count_increment : usize) {
345 let func_schmear = self.get_prior_propagation_func_schmear(&term_app_res);
346
347 let ret_space : &EmbeddingSpace = self.model_spaces.get(&term_app_res.get_ret_type(self.ctxt)).unwrap();
349
350 let func_space_info = self.ctxt.get_function_space_info(term_app_res.get_func_type());
351
352 trace!("Propagating prior for space of size {}->{}", func_space_info.get_feature_dimensions(),
353 func_space_info.get_output_dimensions());
354
355 let arg_schmear = self.get_compressed_schmear_from_ref(&term_app_res.get_arg_ref());
356
357 let out_schmear : Schmear = func_space_info.apply_schmears(&func_schmear, &arg_schmear);
358
359 if let TermReference::FuncRef(ret_ptr) = term_app_res.get_ret_ref() {
360 let out_prior : NormalInverseWishart = ret_space.schmear_to_prior(&self, elaborator_func_schmear,
361 ret_ptr, &out_schmear);
362 let ret_embedding : &mut TermModel = self.get_mut_embedding(ret_ptr);
364 let prev_count = ret_embedding.downdate_prior(&term_app_res.term_app);
365 let new_count = prev_count + count_increment;
366
367 let out_update = Multiple {
368 elem : out_prior,
369 count : new_count
370 };
371 ret_embedding.update_prior(term_app_res.term_app, out_update);
372 } else {
373 panic!();
374 }
375 }
376
377 fn propagate_data(&mut self, term_app_res : TermApplicationResult, count_increment : usize) {
380 let term_input_output = term_app_res.get_term_input_output();
381 let arg_ref = term_app_res.get_arg_ref();
382 let ret_ref = term_app_res.get_ret_ref();
383
384 let arg_schmear = self.get_compressed_schmear_from_ref(&arg_ref);
385 let ret_schmear = self.get_compressed_schmear_from_ref(&ret_ref);
386
387 let arg_mean : Array1::<f32> = arg_schmear.mean;
388
389 trace!("Propagating data for space of size {}->{}", arg_mean.shape()[0],
390 ret_schmear.mean.shape()[0]);
391
392 let data_point = InputToSchmearedOutput {
393 in_vec : arg_mean,
394 out_schmear : ret_schmear
395 };
396
397 let func_embedding : &mut TermModel = self.get_mut_embedding(term_app_res.get_func_ptr());
398 let prev_count = func_embedding.downdate_data(&term_input_output);
399 let new_count = prev_count + count_increment;
400
401 let data_update = Multiple {
402 elem : data_point,
403 count : new_count
404 };
405 func_embedding.update_data(term_input_output, data_update);
406 }
407}
408