fetish_lib/
sampled_embedding_space.rs1use ndarray::*;
2use std::collections::HashMap;
3use crate::sampled_model_embedding::*;
4use crate::space_info::*;
5use crate::type_id::*;
6use crate::context::*;
7use crate::term_index::*;
8
9type ModelKey = TermIndex;
10
11pub struct SampledEmbeddingSpace<'a> {
14 pub type_id : TypeId,
15 pub elaborator : Array2<f32>,
19 pub models : HashMap<ModelKey, SampledModelEmbedding>,
20 pub ctxt : &'a Context
21}
22
23impl<'a> SampledEmbeddingSpace<'a> {
24 pub fn has_embedding(&self, model_key : ModelKey) -> bool {
26 self.models.contains_key(&model_key)
27 }
28 pub fn get_embedding(&self, model_key : ModelKey) -> &SampledModelEmbedding {
30 self.models.get(&model_key).unwrap()
31 }
32
33 pub fn expand_compressed_vector(&self, compressed_vec : ArrayView1<f32>) -> Array1<f32> {
36 let elaborated_vec = self.elaborator.dot(&compressed_vec);
37 elaborated_vec
38 }
39
40 pub fn expand_compressed_function(&self, compressed_vec : ArrayView1<f32>) -> Array2<f32> {
45 let func_space_info = self.ctxt.get_function_space_info(self.type_id);
46 let feat_dims = func_space_info.get_feature_dimensions();
47 let out_dims = func_space_info.get_output_dimensions();
48
49 let elaborated_vec = self.expand_compressed_vector(compressed_vec);
50 let result = elaborated_vec.into_shape((out_dims, feat_dims)).unwrap();
51 result
52 }
53
54 pub fn new(type_id : TypeId, elaborator : Array2<f32>, ctxt : &'a Context) -> SampledEmbeddingSpace<'a> {
56 SampledEmbeddingSpace {
57 type_id,
58 elaborator,
59 models : HashMap::new(),
60 ctxt
61 }
62 }
63}