encoderfile/transforms/engine/
sentence_embedding.rs1use crate::{common::model_type, error::ApiError};
2
3use super::{super::tensor::Tensor, Postprocessor, Transform};
4use ndarray::{Array2, Array3, Ix2};
5
6impl Postprocessor for Transform<model_type::SentenceEmbedding> {
7 type Input = (Array3<f32>, Array2<f32>);
8 type Output = Array2<f32>;
9
10 fn postprocess(&self, (data, mask): Self::Input) -> Result<Self::Output, ApiError> {
11 let func = match &self.postprocessor {
12 Some(p) => p,
13 None => {
14 let Tensor(mean_pooled) = Tensor(data.into_dyn())
15 .mean_pool(Tensor(mask.into_dyn()))
16 .map_err(|e| {
17 tracing::error!(
18 "Failed to mean pool. This should not happen. More details: {:?}",
19 e
20 );
21 ApiError::InternalError("Failed to postprocess embeddings")
22 })?;
23
24 return mean_pooled.into_dimensionality::<Ix2>()
25 .map_err(|e| {
26 tracing::error!("Failed to cast mean pool results into Ix2. This should not happen. More details: {:?}", e);
27 ApiError::InternalError("Failed to postprocess embeddings")
28 });
29 }
30 };
31
32 let batch_size = data.shape()[0];
33
34 let tensor = Tensor(data.into_dyn());
35
36 let result = func
37 .call::<Tensor>((tensor, Tensor(mask.into_dyn())))
38 .map_err(|e| ApiError::LuaError(e.to_string()))?
39 .into_inner()
40 .into_dimensionality::<Ix2>().map_err(|e| {
41 tracing::error!("Failed to cast array into Ix2: {e}. Check your lua transform to make sure it returns a tensor of shape [batch_size, *]");
42 ApiError::LuaError("Error postprocessing embeddings".to_string())
43 })?;
44
45 let result_shape = result.shape();
46
47 if batch_size != result_shape[0] {
48 tracing::error!(
49 "Transform error: expected tensor of shape [{}, *], got tensor of shape {:?}",
50 batch_size,
51 result_shape
52 );
53
54 return Err(ApiError::LuaError(
55 "Error postprocessing embeddings".to_string(),
56 ));
57 }
58
59 Ok(result)
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66 use crate::transforms::DEFAULT_LIBS;
67 use ndarray::Axis;
68
69 #[test]
70 fn test_no_pooling() {
71 let engine = Transform::<model_type::SentenceEmbedding>::new(
72 DEFAULT_LIBS.to_vec(),
73 Some("".to_string()),
74 )
75 .expect("Failed to create engine");
76
77 let arr = ndarray::Array3::<f32>::from_elem((16, 32, 128), 2.0);
78 let mask = ndarray::Array2::<f32>::from_elem((16, 32), 1.0);
79
80 let result = engine
81 .postprocess((arr.clone(), mask))
82 .expect("Failed to compute pool");
83
84 assert_eq!(result.shape(), [16, 128]);
85
86 assert_eq!(arr.mean_axis(Axis(1)), Some(result));
88 }
89
90 #[test]
91 fn test_successful_pool() {
92 let engine = Transform::<model_type::SentenceEmbedding>::new(
93 DEFAULT_LIBS.to_vec(),
94 Some(
95 r##"
96 function Postprocess(arr, mask)
97 -- sum along second axis (lol)
98 return arr:sum_axis(2)
99 end
100 "##
101 .to_string(),
102 ),
103 )
104 .expect("Failed to create engine");
105
106 let arr = ndarray::Array3::<f32>::from_elem((16, 32, 128), 2.0);
107 let mask = ndarray::Array2::<f32>::from_elem((16, 32), 1.0);
108
109 let result = engine
110 .postprocess((arr, mask))
111 .expect("Failed to compute pool");
112
113 assert_eq!(result.shape(), [16, 128])
114 }
115
116 #[test]
117 fn test_bad_dim_pool() {
118 let engine = Transform::<model_type::SentenceEmbedding>::new(
119 DEFAULT_LIBS.to_vec(),
120 Some(
121 r##"
122 function Postprocess(arr, mask)
123 return arr
124 end
125 "##
126 .to_string(),
127 ),
128 )
129 .expect("Failed to create engine");
130
131 let arr = ndarray::Array3::<f32>::from_elem((16, 32, 128), 2.0);
132 let mask = ndarray::Array2::<f32>::from_elem((16, 32), 1.0);
133
134 let result = engine.postprocess((arr, mask));
135
136 assert!(result.is_err());
137 }
138
139 #[test]
140 fn test_sentence_embedding_transform_bad_fn() {
141 let engine = Transform::<model_type::SentenceEmbedding>::new(
142 DEFAULT_LIBS.to_vec(),
143 Some(
144 r##"
145 function Postprocess(arr, mask)
146 return 1
147 end
148 "##
149 .to_string(),
150 ),
151 )
152 .expect("Failed to create engine");
153
154 let arr = ndarray::Array3::<f32>::from_elem((16, 32, 128), 2.0);
155 let mask = ndarray::Array2::<f32>::from_elem((16, 32), 1.0);
156
157 let result = engine.postprocess((arr.clone(), mask));
158
159 assert!(result.is_err())
160 }
161
162 #[test]
163 fn test_bad_dimensionality_transform_postprocessing() {
164 let engine = Transform::<model_type::SentenceEmbedding>::new(
165 DEFAULT_LIBS.to_vec(),
166 Some(
167 r##"
168 function Postprocess(arr, mask)
169 return arr
170 end
171 "##
172 .to_string(),
173 ),
174 )
175 .unwrap();
176
177 let arr = ndarray::Array3::<f32>::from_elem((3, 3, 3), 2.0);
178 let mask = ndarray::Array2::<f32>::from_elem((3, 3), 1.0);
179 let result = engine.postprocess((arr.clone(), mask));
180
181 assert!(result.is_err());
182
183 if let Err(e) = result {
184 match e {
185 ApiError::LuaError(s) => {
186 assert!(s.contains("Error postprocessing embeddings"))
187 }
188 _ => panic!("Didn't return lua error"),
189 }
190 }
191 }
192}