fetish_lib/
embedder_state.rs

1extern crate ndarray;
2extern crate ndarray_linalg;
3
4use crate::multiple::*;
5use rand::prelude::*;
6use crate::prior_specification::*;
7use crate::space_info::*;
8use crate::newly_evaluated_terms::*;
9use ndarray::*;
10use std::collections::HashSet;
11use std::collections::HashMap;
12use crate::input_to_schmeared_output::*;
13use crate::sampled_embedder_state::*;
14use crate::term_index::*;
15use crate::interpreter_state::*;
16use crate::type_id::*;
17use crate::term_pointer::*;
18use crate::term_reference::*;
19use crate::term_application_result::*;
20use crate::term_model::*;
21use crate::embedding_space::*;
22use crate::schmear::*;
23use crate::func_schmear::*;
24use crate::func_inverse_schmear::*;
25use crate::normal_inverse_wishart::*;
26use crate::elaborator::*;
27use topological_sort::TopologicalSort;
28use crate::context::*;
29
30///An [`EmbedderState`] keeps track of the embeddings of function terms ([`TermModel`]s)
31///which come from some [`InterpreterState`], and also the learned [`Elaborator`]s for
32///every function type.
33pub struct EmbedderState<'a> {
34    pub model_spaces : HashMap::<TypeId, EmbeddingSpace<'a>>,
35    pub ctxt : &'a Context
36}
37
38impl<'a> EmbedderState<'a> {
39
40    ///Draws a sample from the distribution over [`TermModel`]s represented in this
41    ///[`EmbedderState`], yielding a [`SampledEmbedderState`].
42    pub fn sample(&self, rng : &mut ThreadRng) -> SampledEmbedderState<'a> {
43        let mut embedding_spaces = HashMap::new();
44        for (type_id, model_space) in self.model_spaces.iter() {
45            let sampled_embedding_space = model_space.sample(rng); 
46            embedding_spaces.insert(*type_id, sampled_embedding_space);
47        }
48        SampledEmbedderState {
49            embedding_spaces,
50            ctxt : self.ctxt
51        }
52    }
53
54    ///Creates a new [`EmbedderState`], initially populated with default embeddings
55    ///for primitive terms in the passed [`Context`].
56    pub fn new(model_prior_specification : &'a dyn PriorSpecification,
57               elaborator_prior_specification : &'a dyn PriorSpecification, 
58               ctxt : &'a Context) -> EmbedderState<'a> {
59        info!("Readying embedder state");
60
61        let mut model_spaces = HashMap::new();
62        for func_type_id in 0..ctxt.get_total_num_types() {
63            if (!ctxt.is_vector_type(func_type_id)) {
64                let mut model_space = EmbeddingSpace::new(func_type_id, model_prior_specification, 
65                                                      elaborator_prior_specification, ctxt);
66
67                //Initialize embeddings for primitive terms
68                let primitive_type_space = ctxt.primitive_directory
69                                               .primitive_type_spaces.get(&func_type_id).unwrap();
70                for term_index in 0..primitive_type_space.terms.len() {
71                    model_space.add_model(TermIndex::Primitive(term_index));
72                }
73
74                model_spaces.insert(func_type_id, model_space);
75            }
76        }
77
78        EmbedderState {
79            model_spaces,
80            ctxt
81        }
82    }
83
84    ///Initializes default embeddings for the passed collection of terms in a
85    ///[`NewlyEvaluatedTerms`].
86    pub fn init_embeddings_for_new_terms(&mut self, newly_evaluated_terms : &NewlyEvaluatedTerms) {
87        trace!("Initializing embeddings for {} new terms", newly_evaluated_terms.terms.len());
88        for nonprimitive_term_ptr in newly_evaluated_terms.terms.iter() {
89            let term_ptr = TermPointer::from(nonprimitive_term_ptr.clone());
90            if (!self.has_embedding(term_ptr)) {
91                self.init_embedding(term_ptr);
92            }
93        }
94    }
95
96    ///Given an [`InterpreterState`] and a collection of [`NewlyEvaluatedTerms`], performs
97    ///a bottom-up (data) update followed by a top-down (prior) update recursively
98    ///on all modified terms. This method may be used to keep the [`TermModel`]s in this
99    ///[`EmbedderState`] up-to-date with new information.
100    pub fn bayesian_update_step(&mut self, interpreter_state : &InterpreterState,
101                                           newly_evaluated_terms : &NewlyEvaluatedTerms) {
102        self.init_embeddings_for_new_terms(newly_evaluated_terms);
103
104        let mut data_updated_terms : HashSet<TermPointer> = HashSet::new();
105        let mut prior_updated_terms : HashSet<TermPointer> = HashSet::new();
106
107        let mut updated_apps : HashSet::<TermApplicationResult> = HashSet::new();
108        for term_app_result in newly_evaluated_terms.term_app_results.iter() {
109            updated_apps.insert(term_app_result.clone()); 
110        }
111
112        trace!("Propagating data updates for {} applications", updated_apps.len());
113        self.propagate_data_recursive(interpreter_state, &updated_apps, &mut data_updated_terms,
114                                      newly_evaluated_terms);
115        trace!("Propagating prior updates for {} applications", data_updated_terms.len());
116        self.propagate_prior_recursive(interpreter_state, &data_updated_terms, &mut prior_updated_terms,
117                                       newly_evaluated_terms);
118
119        let mut all_updated_terms = HashSet::new();
120        for data_updated_term in data_updated_terms.drain() {
121            all_updated_terms.insert(data_updated_term);
122        }
123        for prior_updated_term in prior_updated_terms.drain() {
124            all_updated_terms.insert(prior_updated_term);
125        }
126        self.update_elaborators(all_updated_terms);
127    }
128
129    ///Determines whether/not there is a stored [`TermModel`] for the given
130    ///[`TermPointer`].
131    pub fn has_embedding(&self, term_ptr : TermPointer) -> bool {
132        let space : &EmbeddingSpace = self.model_spaces.get(&term_ptr.type_id).unwrap();
133        space.has_model(term_ptr.index)
134    }
135
136    ///Given a [`TermPointer`] pointing to a [`TermModel`] tracked by this
137    ///[`EmbedderState`], yields a reference to the [`TermModel`]. Panics if there is
138    ///no such entry stored.
139    pub fn get_embedding(&self, term_ptr : TermPointer) -> &TermModel {
140        let space = self.get_model_space(term_ptr.type_id);
141        space.get_model(term_ptr.index)
142    }
143
144    fn get_model_space(&self, type_id : TypeId) -> &EmbeddingSpace {
145        self.model_spaces.get(&type_id).unwrap()
146    }
147
148    ///Like [`EmbedderState#get_embedding`], but yields a mutable reference to the
149    ///[`TermModel`] given a [`TermPointer`] pointing to it. 
150    pub fn get_mut_embedding(&mut self, term_ptr : TermPointer) -> &mut TermModel<'a> {
151        let space : &mut EmbeddingSpace = self.model_spaces.get_mut(&term_ptr.type_id).unwrap();
152        space.get_model_mut(term_ptr.index)
153    }
154
155    fn init_embedding(&mut self, term_ptr : TermPointer) {
156        let space : &mut EmbeddingSpace = self.model_spaces.get_mut(&term_ptr.type_id).unwrap();
157        space.add_model(term_ptr.index)
158    }
159
160    fn get_schmear_from_ptr(&self, term_ptr : TermPointer) -> FuncSchmear {
161        let embedding : &TermModel = self.get_embedding(term_ptr);
162        embedding.get_schmear()
163    }
164
165    fn get_inverse_schmear_from_ptr(&self, term_ptr : TermPointer) -> FuncInverseSchmear {
166        let embedding : &TermModel = self.get_embedding(term_ptr);
167        embedding.get_inverse_schmear()
168    }
169
170    fn get_compressed_schmear_from_ptr(&self, term_ptr : TermPointer) -> Schmear {
171        let type_id = term_ptr.type_id;
172        let func_schmear = self.get_schmear_from_ptr(term_ptr);
173        let func_feat_info = self.ctxt.get_feature_space_info(type_id);
174        let projection_mat = func_feat_info.get_projection_matrix();
175        let result = func_schmear.compress(projection_mat.view());
176        result
177    }
178
179    fn get_compressed_schmear_from_ref(&self, term_ref : &TermReference) -> Schmear {
180        match term_ref {
181            TermReference::FuncRef(func_ptr) => self.get_compressed_schmear_from_ptr(*func_ptr),
182            TermReference::VecRef(_, vec) => Schmear::from_vector(vec.view())
183        }
184    }
185
186    fn update_elaborators(&mut self, mut updated_terms : HashSet::<TermPointer>) {
187        for term_ptr in updated_terms.drain() {
188            let model_space = self.model_spaces.get_mut(&term_ptr.type_id).unwrap();
189            let elaborator = &mut model_space.elaborator;
190
191            //Remove existing data for the term
192            if (elaborator.has_data(&term_ptr.index)) {
193                elaborator.downdate_data(&term_ptr.index);
194            }
195
196            let term_model = model_space.models.get(&term_ptr.index).unwrap();
197            elaborator.update_data(term_ptr.index, &term_model.model);
198        }
199    }
200
201    //Propagates prior updates downwards
202    fn propagate_prior_recursive(&mut self, interpreter_state : &InterpreterState,
203                                     to_propagate : &HashSet::<TermPointer>,
204                                     all_modified : &mut HashSet::<TermPointer>,
205                                     newly_evaluated : &NewlyEvaluatedTerms) {
206        let new_count_map = newly_evaluated.get_count_map();
207
208        let mut topo_sort = TopologicalSort::<TermApplicationResult>::new();
209        let mut stack = Vec::<TermApplicationResult>::new();
210
211        for func_ptr in to_propagate {
212            let applications = interpreter_state.get_app_results_with_func(*func_ptr);
213            for application in applications {
214                if let TermReference::FuncRef(_) = application.get_ret_ref() {
215                    if (self.has_nontrivial_prior_update(&application)) {
216                        topo_sort.insert(application.clone());
217                        stack.push(application.clone());
218                    }
219                }
220            }
221        }
222
223        let mut ret_type_set = HashSet::new();
224        while (stack.len() > 0) {
225            let elem = stack.pop().unwrap();
226            let ret_ref = elem.get_ret_ref();
227
228            ret_type_set.insert(elem.get_ret_type(self.ctxt));
229
230            if let TermReference::FuncRef(ret_func_ptr) = ret_ref {
231                let applications = interpreter_state.get_app_results_with_func(ret_func_ptr); 
232                for application in applications {
233                    if let TermReference::FuncRef(_) = application.get_ret_ref() {
234                        if (self.has_nontrivial_prior_update(&application)) {
235                            topo_sort.add_dependency(elem.clone(), application.clone());
236                            stack.push(application);
237                        }
238                    }
239                }
240
241                all_modified.insert(ret_func_ptr);
242            }
243        }
244
245        info!("Obtaining elaborator func schmears");
246        let mut elaborator_func_schmears = HashMap::new();
247        for type_id in ret_type_set.drain() {
248            if (!self.ctxt.is_vector_type(type_id)) {
249                let model_space = self.model_spaces.get(&type_id).unwrap();
250                let elaborator = &model_space.elaborator;
251                let elaborator_func_schmear = elaborator.get_expansion_func_schmear();
252                elaborator_func_schmears.insert(type_id, elaborator_func_schmear);
253            }
254        }
255        info!("Propagating priors");
256
257        while (!topo_sort.is_empty()) {
258            let mut to_process = topo_sort.pop_all();
259            for elem in to_process.drain(..) {
260                let out_type = elem.get_ret_type(self.ctxt);
261                let elaborator_func_schmear = elaborator_func_schmears.get(&out_type).unwrap();
262
263                let new_count = match (new_count_map.get(&elem)) {
264                    Option::None => 0,
265                    Option::Some(count) => *count
266                };
267
268                self.propagate_prior(elem, elaborator_func_schmear, new_count);
269            }
270        }
271    }
272
273    //Propagates data updates upwards
274    fn propagate_data_recursive(&mut self, interpreter_state : &InterpreterState,
275                                    to_propagate : &HashSet::<TermApplicationResult>,
276                                    all_modified : &mut HashSet::<TermPointer>,
277                                    newly_evaluated : &NewlyEvaluatedTerms) {
278        let new_count_map = newly_evaluated.get_count_map();
279
280        let mut topo_sort = TopologicalSort::<TermApplicationResult>::new();
281        let mut stack = Vec::<TermApplicationResult>::new();
282
283        for elem in to_propagate {
284            stack.push(elem.clone());
285        }
286
287        while (stack.len() > 0) {
288            let elem = stack.pop().unwrap();
289            let func_ptr = elem.get_func_ptr();
290            let func_ref = TermReference::FuncRef(func_ptr.clone());
291
292            all_modified.insert(func_ptr);
293
294            let args = interpreter_state.get_app_results_with_arg(&func_ref);
295            for arg in args {
296                stack.push(arg.clone());
297                topo_sort.add_dependency(elem.clone(), arg.clone());
298            }
299
300            let rets = interpreter_state.get_app_results_with_result(&func_ref);
301            for ret in rets {
302                stack.push(ret.clone());
303                topo_sort.add_dependency(elem.clone(), ret.clone());
304            }
305
306            topo_sort.insert(elem);
307        }
308
309        while (!topo_sort.is_empty()) {
310            let to_process = topo_sort.pop_all();
311            for elem in to_process {
312                let new_count = match (new_count_map.get(&elem)) {
313                    Option::None => 0,
314                    Option::Some(count) => *count
315                };
316                self.propagate_data(elem, new_count);
317            }
318        }
319        
320    }
321
322    fn get_prior_propagation_func_schmear(&self, term_app_res : &TermApplicationResult) -> FuncSchmear {
323        let func_model = self.get_embedding(term_app_res.get_func_ptr());
324        //If the model for the function has data updates involving this
325        //exact same [`TermApplicationResult`] (which it should), we need to remove all data which
326        //was added to the model for it, or we risk re-inforcing redundant information.
327        let term_input_output = term_app_res.get_term_input_output();
328
329        let mut model_clone = func_model.clone();
330        model_clone.downdate_data(&term_input_output);
331        model_clone.get_schmear()
332    }
333
334    fn has_nontrivial_prior_update(&self, term_app_res : &TermApplicationResult) -> bool {
335        let term_input_output = term_app_res.get_term_input_output();
336        let func_model = self.get_embedding(term_app_res.get_func_ptr());
337        func_model.has_some_data_other_than(&term_input_output)
338    }
339
340    //Given a TermApplicationResult, compute the estimated output from the application
341    //and use it to update the model for the result. If an existing update
342    //exists for the given application of terms, this will first remove that update
343    fn propagate_prior(&mut self, term_app_res : TermApplicationResult,
344                       elaborator_func_schmear : &FuncSchmear, count_increment : usize) {
345        let func_schmear = self.get_prior_propagation_func_schmear(&term_app_res);
346      
347        //Get the model space for the func type
348        let ret_space : &EmbeddingSpace = self.model_spaces.get(&term_app_res.get_ret_type(self.ctxt)).unwrap();
349
350        let func_space_info = self.ctxt.get_function_space_info(term_app_res.get_func_type());
351
352        trace!("Propagating prior for space of size {}->{}", func_space_info.get_feature_dimensions(), 
353                                                             func_space_info.get_output_dimensions());
354
355        let arg_schmear = self.get_compressed_schmear_from_ref(&term_app_res.get_arg_ref());
356
357        let out_schmear : Schmear = func_space_info.apply_schmears(&func_schmear, &arg_schmear);
358
359        if let TermReference::FuncRef(ret_ptr) = term_app_res.get_ret_ref() {
360            let out_prior : NormalInverseWishart = ret_space.schmear_to_prior(&self, elaborator_func_schmear,
361                                                                              ret_ptr, &out_schmear);
362            //Actually perform the update
363            let ret_embedding : &mut TermModel = self.get_mut_embedding(ret_ptr);
364            let prev_count = ret_embedding.downdate_prior(&term_app_res.term_app);
365            let new_count = prev_count + count_increment;
366
367            let out_update = Multiple {
368                elem : out_prior,
369                count : new_count
370            };
371            ret_embedding.update_prior(term_app_res.term_app, out_update);
372        } else {
373            panic!();
374        }
375    }
376
377    //Given a TermApplicationResult, update the model for the function based on the
378    //implicitly-defined data-point for the result
379    fn propagate_data(&mut self, term_app_res : TermApplicationResult, count_increment : usize) {
380        let term_input_output = term_app_res.get_term_input_output();
381        let arg_ref = term_app_res.get_arg_ref();
382        let ret_ref = term_app_res.get_ret_ref();
383
384        let arg_schmear = self.get_compressed_schmear_from_ref(&arg_ref);
385        let ret_schmear = self.get_compressed_schmear_from_ref(&ret_ref);
386
387        let arg_mean : Array1::<f32> = arg_schmear.mean;
388
389        trace!("Propagating data for space of size {}->{}", arg_mean.shape()[0],
390                                                            ret_schmear.mean.shape()[0]);
391
392        let data_point = InputToSchmearedOutput {
393            in_vec : arg_mean,
394            out_schmear : ret_schmear 
395        };
396
397        let func_embedding : &mut TermModel = self.get_mut_embedding(term_app_res.get_func_ptr());
398        let prev_count = func_embedding.downdate_data(&term_input_output);
399        let new_count = prev_count + count_increment;
400
401        let data_update = Multiple {
402            elem : data_point,
403            count : new_count
404        };
405        func_embedding.update_data(term_input_output, data_update);
406    }
407}
408