fetish_lib/
sampled_embedding_space.rs

1use ndarray::*;
2use std::collections::HashMap;
3use crate::sampled_model_embedding::*;
4use crate::space_info::*;
5use crate::type_id::*;
6use crate::context::*;
7use crate::term_index::*;
8
9type ModelKey = TermIndex;
10
11///All information in a [`crate::sampled_embedder_state::SampledEmbedderState`] pertaining to a particular
12///function [`TypeId`]
13pub struct SampledEmbeddingSpace<'a> {
14    pub type_id : TypeId,
15    ///A sample drawn from the [`crate::elaborator::Elaborator`] for this type in the original 
16    ///[`crate::embedder_state::EmbedderState`].
17    ///Maps from the compressed space for the type to the base space for the type.
18    pub elaborator : Array2<f32>,
19    pub models : HashMap<ModelKey, SampledModelEmbedding>,
20    pub ctxt : &'a Context
21}
22
23impl<'a> SampledEmbeddingSpace<'a> {
24    ///Determines whether an embedding exists for the given [`TermIndex`].
25    pub fn has_embedding(&self, model_key : ModelKey) -> bool {
26        self.models.contains_key(&model_key)
27    }
28    ///Gets the [`SampledModelEmbedding`] corresponding to the given  [`TermIndex`].
29    pub fn get_embedding(&self, model_key : ModelKey) -> &SampledModelEmbedding {
30        self.models.get(&model_key).unwrap()
31    }
32
33    ///Given a compressed vector, uses `self.elaborator` to expand it to a vector in the
34    ///base space of `self.type_id`.
35    pub fn expand_compressed_vector(&self, compressed_vec : ArrayView1<f32>) -> Array1<f32> {
36        let elaborated_vec = self.elaborator.dot(&compressed_vec);
37        elaborated_vec
38    }
39
40    ///Given a compressed vector for a function of `self.type_id`, 
41    ///first performs [`Self::expand_compressed_vector`] and
42    ///then inflates the result to yield a linear transformation
43    ///from the feature space of the input type to the compressed space of the output type.
44    pub fn expand_compressed_function(&self, compressed_vec : ArrayView1<f32>) -> Array2<f32> {
45        let func_space_info = self.ctxt.get_function_space_info(self.type_id);
46        let feat_dims = func_space_info.get_feature_dimensions();
47        let out_dims = func_space_info.get_output_dimensions();
48
49        let elaborated_vec = self.expand_compressed_vector(compressed_vec);
50        let result = elaborated_vec.into_shape((out_dims, feat_dims)).unwrap(); 
51        result
52    }
53
54    ///Creates a new, initially-empty [`SampledEmbeddingSpace`].
55    pub fn new(type_id : TypeId, elaborator : Array2<f32>, ctxt : &'a Context) -> SampledEmbeddingSpace<'a> {
56        SampledEmbeddingSpace {
57            type_id,
58            elaborator,
59            models : HashMap::new(),
60            ctxt
61        }
62    }
63}