Skip to main content

encoderfile/transforms/engine/
sentence_embedding.rs

1use crate::{common::model_type, error::ApiError};
2
3use super::{super::tensor::Tensor, Postprocessor, Transform};
4use ndarray::{Array2, Array3, Ix2};
5
6impl Postprocessor for Transform<model_type::SentenceEmbedding> {
7    type Input = (Array3<f32>, Array2<f32>);
8    type Output = Array2<f32>;
9
10    fn postprocess(&self, (data, mask): Self::Input) -> Result<Self::Output, ApiError> {
11        let func = match &self.postprocessor {
12            Some(p) => p,
13            None => {
14                let Tensor(mean_pooled) = Tensor(data.into_dyn())
15                    .mean_pool(Tensor(mask.into_dyn()))
16                    .map_err(|e| {
17                        tracing::error!(
18                            "Failed to mean pool. This should not happen. More details: {:?}",
19                            e
20                        );
21                        ApiError::InternalError("Failed to postprocess embeddings")
22                    })?;
23
24                return mean_pooled.into_dimensionality::<Ix2>()
25                    .map_err(|e| {
26                        tracing::error!("Failed to cast mean pool results into Ix2. This should not happen. More details: {:?}", e);
27                        ApiError::InternalError("Failed to postprocess embeddings")
28                    });
29            }
30        };
31
32        let batch_size = data.shape()[0];
33
34        let tensor = Tensor(data.into_dyn());
35
36        let result = func
37            .call::<Tensor>((tensor, Tensor(mask.into_dyn())))
38            .map_err(|e| ApiError::LuaError(e.to_string()))?
39            .into_inner()
40            .into_dimensionality::<Ix2>().map_err(|e| {
41                tracing::error!("Failed to cast array into Ix2: {e}. Check your lua transform to make sure it returns a tensor of shape [batch_size, *]");
42                ApiError::LuaError("Error postprocessing embeddings".to_string())
43            })?;
44
45        let result_shape = result.shape();
46
47        if batch_size != result_shape[0] {
48            tracing::error!(
49                "Transform error: expected tensor of shape [{}, *], got tensor of shape {:?}",
50                batch_size,
51                result_shape
52            );
53
54            return Err(ApiError::LuaError(
55                "Error postprocessing embeddings".to_string(),
56            ));
57        }
58
59        Ok(result)
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66    use crate::transforms::DEFAULT_LIBS;
67    use ndarray::Axis;
68
69    #[test]
70    fn test_no_pooling() {
71        let engine = Transform::<model_type::SentenceEmbedding>::new(
72            DEFAULT_LIBS.to_vec(),
73            Some("".to_string()),
74        )
75        .expect("Failed to create engine");
76
77        let arr = ndarray::Array3::<f32>::from_elem((16, 32, 128), 2.0);
78        let mask = ndarray::Array2::<f32>::from_elem((16, 32), 1.0);
79
80        let result = engine
81            .postprocess((arr.clone(), mask))
82            .expect("Failed to compute pool");
83
84        assert_eq!(result.shape(), [16, 128]);
85
86        // if all elements are the same and all mask = 1, should return mean axis array
87        assert_eq!(arr.mean_axis(Axis(1)), Some(result));
88    }
89
90    #[test]
91    fn test_successful_pool() {
92        let engine = Transform::<model_type::SentenceEmbedding>::new(
93            DEFAULT_LIBS.to_vec(),
94            Some(
95                r##"
96        function Postprocess(arr, mask)
97            -- sum along second axis (lol)
98            return arr:sum_axis(2)
99        end
100        "##
101                .to_string(),
102            ),
103        )
104        .expect("Failed to create engine");
105
106        let arr = ndarray::Array3::<f32>::from_elem((16, 32, 128), 2.0);
107        let mask = ndarray::Array2::<f32>::from_elem((16, 32), 1.0);
108
109        let result = engine
110            .postprocess((arr, mask))
111            .expect("Failed to compute pool");
112
113        assert_eq!(result.shape(), [16, 128])
114    }
115
116    #[test]
117    fn test_bad_dim_pool() {
118        let engine = Transform::<model_type::SentenceEmbedding>::new(
119            DEFAULT_LIBS.to_vec(),
120            Some(
121                r##"
122        function Postprocess(arr, mask)
123            return arr
124        end
125        "##
126                .to_string(),
127            ),
128        )
129        .expect("Failed to create engine");
130
131        let arr = ndarray::Array3::<f32>::from_elem((16, 32, 128), 2.0);
132        let mask = ndarray::Array2::<f32>::from_elem((16, 32), 1.0);
133
134        let result = engine.postprocess((arr, mask));
135
136        assert!(result.is_err());
137    }
138
139    #[test]
140    fn test_sentence_embedding_transform_bad_fn() {
141        let engine = Transform::<model_type::SentenceEmbedding>::new(
142            DEFAULT_LIBS.to_vec(),
143            Some(
144                r##"
145        function Postprocess(arr, mask)
146            return 1
147        end
148        "##
149                .to_string(),
150            ),
151        )
152        .expect("Failed to create engine");
153
154        let arr = ndarray::Array3::<f32>::from_elem((16, 32, 128), 2.0);
155        let mask = ndarray::Array2::<f32>::from_elem((16, 32), 1.0);
156
157        let result = engine.postprocess((arr.clone(), mask));
158
159        assert!(result.is_err())
160    }
161
162    #[test]
163    fn test_bad_dimensionality_transform_postprocessing() {
164        let engine = Transform::<model_type::SentenceEmbedding>::new(
165            DEFAULT_LIBS.to_vec(),
166            Some(
167                r##"
168        function Postprocess(arr, mask)
169            return arr
170        end
171        "##
172                .to_string(),
173            ),
174        )
175        .unwrap();
176
177        let arr = ndarray::Array3::<f32>::from_elem((3, 3, 3), 2.0);
178        let mask = ndarray::Array2::<f32>::from_elem((3, 3), 1.0);
179        let result = engine.postprocess((arr.clone(), mask));
180
181        assert!(result.is_err());
182
183        if let Err(e) = result {
184            match e {
185                ApiError::LuaError(s) => {
186                    assert!(s.contains("Error postprocessing embeddings"))
187                }
188                _ => panic!("Didn't return lua error"),
189            }
190        }
191    }
192}