fetish_lib/
sampled_embedder_state.rs1use ndarray::*;
2use crate::sampled_embedding_space::*;
3use std::collections::HashMap;
4use crate::space_info::*;
5use crate::term_pointer::*;
6use crate::type_id::*;
7use crate::term_reference::*;
8use crate::array_utils::*;
9use crate::interpreter_state::*;
10use crate::sampled_model_embedding::*;
11use crate::term_application::*;
12use crate::displayable_with_state::*;
13use crate::typed_vector::*;
14use crate::context::*;
15
16pub struct SampledEmbedderState<'a> {
21 pub embedding_spaces : HashMap::<TypeId, SampledEmbeddingSpace<'a>>,
22 pub ctxt : &'a Context
23}
24
25impl<'a> SampledEmbedderState<'a> {
26 pub fn has_embedding(&self, term_ptr : TermPointer) -> bool {
29 let space = self.embedding_spaces.get(&term_ptr.type_id).unwrap();
30 space.has_embedding(term_ptr.index)
31 }
32 pub fn get_model_embedding(&self, term_ptr : TermPointer) -> &SampledModelEmbedding {
34 let space = self.embedding_spaces.get(&term_ptr.type_id).unwrap();
35 space.get_embedding(term_ptr.index)
36 }
37
38 pub fn expand_compressed_function(&self, compressed_vec : &TypedVector) -> Array2<f32> {
43 let space = self.embedding_spaces.get(&compressed_vec.type_id).unwrap();
44 let result = space.expand_compressed_function(compressed_vec.vec.view());
45 result
46 }
47
48 pub fn evaluate_term_application(&self, term_application : &TermApplication) -> TypedVector {
52 let func_type_id = term_application.func_ptr.type_id;
53 let ret_type_id = self.ctxt.get_ret_type_id(func_type_id);
54
55 let func_space_info = self.ctxt.get_function_space_info(func_type_id);
56 let func_embedding_space = self.embedding_spaces.get(&func_type_id).unwrap();
57 let func_mat = &func_embedding_space.get_embedding(term_application.func_ptr.index).sampled_mat;
58
59 let arg_vec = match (&term_application.arg_ref) {
60 TermReference::VecRef(_, vec) => from_noisy(vec.view()),
61 TermReference::FuncRef(arg_ptr) => {
62 let arg_embedding_space = self.embedding_spaces.get(&arg_ptr.type_id).unwrap();
63 arg_embedding_space.get_embedding(arg_ptr.index).sampled_compressed_vec.clone()
64 }
65 };
66
67 let ret_vec = func_space_info.apply(func_mat.view(), arg_vec.view());
68 TypedVector {
69 vec : ret_vec,
70 type_id : ret_type_id
71 }
72 }
73}