fetish_lib/
sampled_embedder_state.rs

1use ndarray::*;
2use crate::sampled_embedding_space::*;
3use std::collections::HashMap;
4use crate::space_info::*;
5use crate::term_pointer::*;
6use crate::type_id::*;
7use crate::term_reference::*;
8use crate::array_utils::*;
9use crate::interpreter_state::*;
10use crate::sampled_model_embedding::*;
11use crate::term_application::*;
12use crate::displayable_with_state::*;
13use crate::typed_vector::*;
14use crate::context::*;
15
16///A sampled possible state for embeddings drawn from an [`crate::embedder_state::EmbedderState`].
17///This [`SampledEmbedderState`] has one [`SampledEmbeddingSpace`] for every
18///function [`TypeId`] in the given [`Context`], which in turn contains
19///a sampled [`crate::elaborator::Elaborator`] and [`SampledModelEmbedding`]s
20pub struct SampledEmbedderState<'a> {
21    pub embedding_spaces : HashMap::<TypeId, SampledEmbeddingSpace<'a>>,
22    pub ctxt : &'a Context
23}
24
25impl<'a> SampledEmbedderState<'a> {
26    ///Determines whether this [`SampledEmbedderState`] has an embedding for the
27    ///given [`TermPointer`].
28    pub fn has_embedding(&self, term_ptr : TermPointer) -> bool {
29        let space = self.embedding_spaces.get(&term_ptr.type_id).unwrap();
30        space.has_embedding(term_ptr.index)
31    }
32    ///Gets the [`SampledModelEmbedding`] for the given [`TermPointer`].
33    pub fn get_model_embedding(&self, term_ptr : TermPointer) -> &SampledModelEmbedding {
34        let space = self.embedding_spaces.get(&term_ptr.type_id).unwrap();
35        space.get_embedding(term_ptr.index)
36    }
37
38    ///Given a compressed [`TypedVector`] for a function type, expands the
39    ///compressed vector using this [`SampledEmbedderState`]'s corresponding
40    ///elaborator sample and inflates it to yield a linear transformation
41    ///from the feature space of the input space to the compressed space of the output.
42    pub fn expand_compressed_function(&self, compressed_vec : &TypedVector) -> Array2<f32> {
43        let space = self.embedding_spaces.get(&compressed_vec.type_id).unwrap();
44        let result = space.expand_compressed_function(compressed_vec.vec.view());
45        result
46    }
47
48    ///Given a [`TermApplication`], evaluates the result that the [`SampledModelEmbedding`]s
49    ///in this [`SampledEmbedderState`] would yield for the expression. The result will be
50    ///in the compressed space of the output type.
51    pub fn evaluate_term_application(&self, term_application : &TermApplication) -> TypedVector {
52        let func_type_id = term_application.func_ptr.type_id;
53        let ret_type_id = self.ctxt.get_ret_type_id(func_type_id);
54        
55        let func_space_info = self.ctxt.get_function_space_info(func_type_id);
56        let func_embedding_space = self.embedding_spaces.get(&func_type_id).unwrap();
57        let func_mat = &func_embedding_space.get_embedding(term_application.func_ptr.index).sampled_mat;
58
59        let arg_vec = match (&term_application.arg_ref) {
60            TermReference::VecRef(_, vec) => from_noisy(vec.view()),
61            TermReference::FuncRef(arg_ptr) => {
62                let arg_embedding_space = self.embedding_spaces.get(&arg_ptr.type_id).unwrap();
63                arg_embedding_space.get_embedding(arg_ptr.index).sampled_compressed_vec.clone()
64            }
65        };
66
67        let ret_vec = func_space_info.apply(func_mat.view(), arg_vec.view());
68        TypedVector {
69            vec : ret_vec,
70            type_id : ret_type_id
71        }
72    }
73}