1extern crate ndarray;
2extern crate ndarray_linalg;
3
4use ndarray::*;
5
6use crate::multiple::*;
7use crate::term_input_output::*;
8use crate::params::*;
9use crate::space_info::*;
10use crate::type_id::*;
11use crate::schmeared_hole::*;
12use crate::normal_inverse_wishart::*;
13use crate::term_application::*;
14use crate::term_reference::*;
15use crate::func_schmear::*;
16use crate::prior_specification::*;
17use crate::func_inverse_schmear::*;
18use crate::input_to_schmeared_output::*;
19use crate::context::*;
20
21use rand::prelude::*;
22
23use std::collections::HashMap;
24use crate::model::*;
25
26#[derive(Clone)]
30pub struct TermModel<'a> {
31 pub type_id : TypeId,
32 pub model : Model<'a>,
33 prior_updates : HashMap::<TermApplication, Multiple<NormalInverseWishart>>,
34 data_updates : HashMap::<TermInputOutput, Multiple<InputToSchmearedOutput>>
35}
36
37impl <'a> TermModel<'a> {
38 pub fn get_type_id(&self) -> TypeId {
41 self.type_id
42 }
43 pub fn get_total_dims(&self) -> usize {
45 self.model.get_total_dims()
46 }
47 pub fn sample(&self, rng : &mut ThreadRng) -> Array2<f32> {
49 self.model.sample(rng)
50 }
51 pub fn sample_as_vec(&self, rng : &mut ThreadRng) -> Array1::<f32> {
53 self.model.sample_as_vec(rng)
54 }
55 pub fn get_mean_as_vec(&self) -> ArrayView1::<f32> {
57 self.model.get_mean_as_vec()
58 }
59
60 pub fn get_inverse_schmear(&self) -> FuncInverseSchmear {
62 self.model.get_inverse_schmear()
63 }
64
65 pub fn get_schmear(&self) -> FuncSchmear {
67 self.model.get_schmear()
68 }
69
70 pub fn get_schmeared_hole(&self) -> SchmearedHole {
72 let func_type_id = self.get_type_id();
73 let inv_schmear = self.get_inverse_schmear().flatten();
74
75 let result = SchmearedHole {
76 type_id : func_type_id,
77 inv_schmear
78 };
79 result
80 }
81
82 pub fn has_some_data_other_than(&self, term_input_output : &TermInputOutput) -> bool {
85 let mut num_data_updates = self.data_updates.len();
86 if (self.data_updates.contains_key(term_input_output)) {
87 num_data_updates -= 1;
88 }
89 num_data_updates > 0
90 }
91
92 pub fn update_data(&mut self, update_key : TermInputOutput, data_update : Multiple<InputToSchmearedOutput>) {
95 let func_space_info = self.model.ctxt.get_function_space_info(self.get_type_id());
96 let feat_update_elem = data_update.elem.featurize(&func_space_info);
97 let feat_update = Multiple {
98 elem : feat_update_elem,
99 count : data_update.count
100 };
101 self.model.data += &feat_update;
102 self.data_updates.insert(update_key, feat_update);
103 }
104
105 pub fn downdate_data(&mut self, update_key : &TermInputOutput) -> usize {
108 match (self.data_updates.remove(update_key)) {
109 Option::None => 0,
110 Option::Some(multiple) => {
111 self.model.data -= &multiple;
112 multiple.count
113 }
114 }
115 }
116
117 pub fn update_prior(&mut self, update_key : TermApplication, distr : Multiple<NormalInverseWishart>) {
121 self.model.data += &distr;
122 self.prior_updates.insert(update_key, distr);
123 }
124
125 pub fn downdate_prior(&mut self, key : &TermApplication) -> usize {
129 match (self.prior_updates.remove(key)) {
130 Option::None => 0,
131 Option::Some(multiple) => {
132 self.model.data -= &multiple;
133 multiple.count
134 }
135 }
136 }
137
138 pub fn new(type_id : TypeId, prior_specification : &dyn PriorSpecification,
141 ctxt : &'a Context) -> TermModel<'a> {
142 let prior_updates = HashMap::new();
143 let data_updates = HashMap::new();
144 let arg_type_id = ctxt.get_arg_type_id(type_id);
145 let ret_type_id = ctxt.get_ret_type_id(type_id);
146
147 let model = Model::new(prior_specification, arg_type_id, ret_type_id, ctxt);
148 TermModel {
149 type_id,
150 model : model,
151 prior_updates : prior_updates,
152 data_updates : data_updates
153 }
154 }
155}