fetish_lib/
term_model.rs

1extern crate ndarray;
2extern crate ndarray_linalg;
3
4use ndarray::*;
5
6use crate::multiple::*;
7use crate::term_input_output::*;
8use crate::params::*;
9use crate::space_info::*;
10use crate::type_id::*;
11use crate::schmeared_hole::*;
12use crate::normal_inverse_wishart::*;
13use crate::term_application::*;
14use crate::term_reference::*;
15use crate::func_schmear::*;
16use crate::prior_specification::*;
17use crate::func_inverse_schmear::*;
18use crate::input_to_schmeared_output::*;
19use crate::context::*;
20
21use rand::prelude::*;
22
23use std::collections::HashMap;
24use crate::model::*;
25
26///A [`Model`] for a term, with information about what
27///prior updates and data updates have been applied as part of the operation
28///of the Bayesian embedding process in an [`crate::embedder_state::EmbedderState`].
29#[derive(Clone)]
30pub struct TermModel<'a> {
31    pub type_id : TypeId,
32    pub model : Model<'a>,
33    prior_updates : HashMap::<TermApplication, Multiple<NormalInverseWishart>>,
34    data_updates : HashMap::<TermInputOutput, Multiple<InputToSchmearedOutput>>
35}
36
37impl <'a> TermModel<'a> {
38    ///Gets the [`TypeId`] of the term that this [`TermModel`] is responsible
39    ///for learning regression information about.
40    pub fn get_type_id(&self) -> TypeId {
41        self.type_id
42    }
43    ///See [`Model::get_total_dims`].
44    pub fn get_total_dims(&self) -> usize {
45        self.model.get_total_dims()
46    }
47    ///See [`Model::sample`].
48    pub fn sample(&self, rng : &mut ThreadRng) -> Array2<f32> {
49        self.model.sample(rng)
50    }
51    ///See [`Model::sample_as_vec`].
52    pub fn sample_as_vec(&self, rng : &mut ThreadRng) -> Array1::<f32> {
53        self.model.sample_as_vec(rng)
54    }
55    ///See [`Model::get_mean_as_vec`].
56    pub fn get_mean_as_vec(&self) -> ArrayView1::<f32> {
57        self.model.get_mean_as_vec()
58    }
59
60    ///See [`Model::get_inverse_schmear`].
61    pub fn get_inverse_schmear(&self) -> FuncInverseSchmear {
62        self.model.get_inverse_schmear()
63    }
64
65    ///See [`Model::get_schmear`].
66    pub fn get_schmear(&self) -> FuncSchmear {
67        self.model.get_schmear()
68    }
69
70    ///Gets the [`SchmearedHole`] in the base space of the type for this [`TermModel`].
71    pub fn get_schmeared_hole(&self) -> SchmearedHole {
72        let func_type_id = self.get_type_id();
73        let inv_schmear = self.get_inverse_schmear().flatten();
74
75        let result = SchmearedHole {
76            type_id : func_type_id,
77            inv_schmear
78        };
79        result
80    }
81
82    ///Returns true iff this [`TermModel`] has had at least one [`TermInputOutput`]
83    ///applied which is not the given one.
84    pub fn has_some_data_other_than(&self, term_input_output : &TermInputOutput) -> bool {
85        let mut num_data_updates = self.data_updates.len();
86        if (self.data_updates.contains_key(term_input_output)) {
87            num_data_updates -= 1;
88        }
89        num_data_updates > 0
90    }
91
92    ///Updates this [`TermModel`] with a data update stemming from the given [`TermInputOutput`]
93    ///with data given by possibly multiple copies of the same [`InputToSchmearedOutput`].
94    pub fn update_data(&mut self, update_key : TermInputOutput, data_update : Multiple<InputToSchmearedOutput>) {
95        let func_space_info = self.model.ctxt.get_function_space_info(self.get_type_id());
96        let feat_update_elem = data_update.elem.featurize(&func_space_info);
97        let feat_update = Multiple {
98            elem : feat_update_elem,
99            count : data_update.count
100        };
101        self.model.data += &feat_update;
102        self.data_updates.insert(update_key, feat_update);
103    }
104
105    ///Downdates this [`TermModel`] for data updates with the given [`TermInputOutput`] key.
106    ///Yields the number of data-points which were removed as a consequence of this operation.
107    pub fn downdate_data(&mut self, update_key : &TermInputOutput) -> usize {
108        match (self.data_updates.remove(update_key)) {
109            Option::None => 0,
110            Option::Some(multiple) => {
111                self.model.data -= &multiple;
112                multiple.count
113            }
114        }
115    }
116
117    ///Updates this [`TermModel`] with a prior update stemming from the given [`TermApplication`]
118    ///with data given by possibly multiple copies of the same [`NormalInverseWishart`]
119    ///distribution.
120    pub fn update_prior(&mut self, update_key : TermApplication, distr : Multiple<NormalInverseWishart>) {
121        self.model.data += &distr;
122        self.prior_updates.insert(update_key, distr);
123    }
124
125    ///Downdates this [`TermModel`] for prior updates with the given [`TermApplication`] key.
126    ///Yields the number of prior applications which were removed as a consequence of this
127    ///operation.
128    pub fn downdate_prior(&mut self, key : &TermApplication) -> usize {
129        match (self.prior_updates.remove(key)) {
130            Option::None => 0,
131            Option::Some(multiple) => {
132                self.model.data -= &multiple;
133                multiple.count
134            }
135        }
136    }
137
138    ///Constructs a new [`TermModel`] for the given type with the given [`PriorSpecification`]
139    ///within the given [`Context`].
140    pub fn new(type_id : TypeId, prior_specification : &dyn PriorSpecification,
141                                 ctxt : &'a Context) -> TermModel<'a> {
142        let prior_updates = HashMap::new();
143        let data_updates = HashMap::new();
144        let arg_type_id = ctxt.get_arg_type_id(type_id);
145        let ret_type_id = ctxt.get_ret_type_id(type_id);
146
147        let model = Model::new(prior_specification, arg_type_id, ret_type_id, ctxt);
148        TermModel {
149            type_id,
150            model : model,
151            prior_updates : prior_updates,
152            data_updates : data_updates
153        }
154    }
155}