1use ndarray::{s, Array1, Array2};
7
8use crate::utils::quantile;
9
10#[derive(Clone)]
16pub struct StridedTensor<T: Clone + Default + Copy + 'static> {
17 pub underlying_data: Array2<T>,
19 pub inner_dim: usize,
21 pub element_lengths: Array1<i64>,
23 pub max_element_len: usize,
25 pub precomputed_strides: Vec<usize>,
27 pub cumulative_lengths: Array1<i64>,
29}
30
31impl<T: Clone + Default + Copy + 'static> StridedTensor<T> {
32 fn compute_strides(lengths: &Array1<i64>, max_len: usize) -> Vec<usize> {
37 if lengths.is_empty() {
38 return if max_len > 0 {
39 vec![max_len]
40 } else {
41 Vec::new()
42 };
43 }
44
45 let lengths_f32: Array1<f32> = lengths.mapv(|x| x as f32);
47
48 let target_quantiles = [0.5, 0.75, 0.9, 0.95];
49
50 let mut strides: Vec<usize> = target_quantiles
51 .iter()
52 .map(|&q| quantile(&lengths_f32, q) as usize)
53 .filter(|&s| s > 0)
54 .collect();
55
56 strides.push(max_len);
57 strides.sort_unstable();
58 strides.dedup();
59
60 if strides.len() == 1 && strides[0] == 0 {
61 return Vec::new();
62 }
63
64 strides
65 }
66
67 pub fn new(data: Array2<T>, lengths: Array1<i64>) -> Self {
78 let inner_dim = if data.ncols() > 0 { data.ncols() } else { 0 };
79
80 let max_element_len = if !lengths.is_empty() {
81 lengths.iter().copied().max().unwrap_or(0) as usize
82 } else {
83 0
84 };
85
86 let precomputed_strides = Self::compute_strides(&lengths, max_element_len);
87
88 let mut cumulative = Array1::<i64>::zeros(lengths.len() + 1);
90 for (i, &len) in lengths.iter().enumerate() {
91 cumulative[i + 1] = cumulative[i] + len;
92 }
93
94 let total_needed = if !lengths.is_empty() {
96 cumulative[lengths.len() - 1] as usize + max_element_len
97 } else {
98 0
99 };
100
101 let underlying_data = if total_needed > data.nrows() && inner_dim > 0 {
102 let mut padded = Array2::<T>::default((total_needed, inner_dim));
103 padded.slice_mut(s![..data.nrows(), ..]).assign(&data);
104 padded
105 } else {
106 data
107 };
108
109 Self {
110 underlying_data,
111 inner_dim,
112 element_lengths: lengths,
113 max_element_len,
114 precomputed_strides,
115 cumulative_lengths: cumulative,
116 }
117 }
118
119 pub fn len(&self) -> usize {
121 self.element_lengths.len()
122 }
123
124 pub fn is_empty(&self) -> bool {
126 self.element_lengths.is_empty()
127 }
128
129 pub fn total_tokens(&self) -> usize {
131 self.element_lengths.iter().sum::<i64>() as usize
132 }
133}
134
135impl StridedTensor<i64> {
136 pub fn lookup_1d(&self, indices: &[usize]) -> (Array1<i64>, Array1<i64>) {
146 if indices.is_empty() {
147 return (Array1::zeros(0), Array1::zeros(0));
148 }
149
150 let mut selected_lengths = Array1::<i64>::zeros(indices.len());
152 let mut total_len = 0usize;
153
154 for (i, &idx) in indices.iter().enumerate() {
155 let len = self.element_lengths[idx];
156 selected_lengths[i] = len;
157 total_len += len as usize;
158 }
159
160 let mut result = Array1::<i64>::zeros(total_len);
162 let mut offset = 0usize;
163
164 for &idx in indices {
165 let start = self.cumulative_lengths[idx] as usize;
166 let len = self.element_lengths[idx] as usize;
167
168 for j in 0..len {
169 result[offset + j] = self.underlying_data[[start + j, 0]];
170 }
171 offset += len;
172 }
173
174 (result, selected_lengths)
175 }
176}
177
178impl StridedTensor<u8> {
179 pub fn lookup_2d(&self, indices: &[usize]) -> (Array2<u8>, Array1<i64>) {
189 if indices.is_empty() {
190 return (Array2::zeros((0, self.inner_dim)), Array1::zeros(0));
191 }
192
193 let mut selected_lengths = Array1::<i64>::zeros(indices.len());
195 let mut total_len = 0usize;
196
197 for (i, &idx) in indices.iter().enumerate() {
198 let len = self.element_lengths[idx];
199 selected_lengths[i] = len;
200 total_len += len as usize;
201 }
202
203 let mut result = Array2::<u8>::zeros((total_len, self.inner_dim));
205 let mut offset = 0usize;
206
207 for &idx in indices {
208 let start = self.cumulative_lengths[idx] as usize;
209 let len = self.element_lengths[idx] as usize;
210
211 result
212 .slice_mut(s![offset..offset + len, ..])
213 .assign(&self.underlying_data.slice(s![start..start + len, ..]));
214 offset += len;
215 }
216
217 (result, selected_lengths)
218 }
219}
220
221impl StridedTensor<usize> {
222 pub fn lookup_codes(&self, indices: &[usize]) -> (Array1<usize>, Array1<i64>) {
232 if indices.is_empty() {
233 return (Array1::zeros(0), Array1::zeros(0));
234 }
235
236 let mut selected_lengths = Array1::<i64>::zeros(indices.len());
238 let mut total_len = 0usize;
239
240 for (i, &idx) in indices.iter().enumerate() {
241 let len = self.element_lengths[idx];
242 selected_lengths[i] = len;
243 total_len += len as usize;
244 }
245
246 let mut result = Array1::<usize>::zeros(total_len);
248 let mut offset = 0usize;
249
250 for &idx in indices {
251 let start = self.cumulative_lengths[idx] as usize;
252 let len = self.element_lengths[idx] as usize;
253
254 for j in 0..len {
255 result[offset + j] = self.underlying_data[[start + j, 0]];
256 }
257 offset += len;
258 }
259
260 (result, selected_lengths)
261 }
262}
263
264pub struct IvfStridedTensor {
266 pub passage_ids: Array1<i64>,
268 pub lengths: Array1<i32>,
270 pub offsets: Array1<i64>,
272}
273
274impl IvfStridedTensor {
275 pub fn new(passage_ids: Array1<i64>, lengths: Array1<i32>) -> Self {
277 let num_centroids = lengths.len();
278 let mut offsets = Array1::<i64>::zeros(num_centroids + 1);
279
280 for i in 0..num_centroids {
281 offsets[i + 1] = offsets[i] + lengths[i] as i64;
282 }
283
284 Self {
285 passage_ids,
286 lengths,
287 offsets,
288 }
289 }
290
291 pub fn lookup(&self, centroid_indices: &[usize]) -> Vec<i64> {
293 let mut result = Vec::new();
294
295 for &idx in centroid_indices {
296 if idx < self.lengths.len() {
297 let start = self.offsets[idx] as usize;
298 let len = self.lengths[idx] as usize;
299
300 for i in 0..len {
301 result.push(self.passage_ids[start + i]);
302 }
303 }
304 }
305
306 result.sort_unstable();
308 result.dedup();
309 result
310 }
311
312 pub fn num_centroids(&self) -> usize {
314 self.lengths.len()
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn test_strided_tensor_creation() {
324 let data = Array2::from_shape_vec(
326 (6, 4),
327 vec![
328 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24u8, ],
335 )
336 .unwrap();
337
338 let lengths = Array1::from_vec(vec![2i64, 3, 1]);
339
340 let st = StridedTensor::new(data, lengths);
341
342 assert_eq!(st.len(), 3);
343 assert_eq!(st.max_element_len, 3);
344 assert_eq!(st.total_tokens(), 6);
345 }
346
347 #[test]
348 fn test_strided_tensor_lookup() {
349 let data = Array2::from_shape_vec(
350 (6, 2),
351 vec![
352 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12u8, ],
359 )
360 .unwrap();
361
362 let lengths = Array1::from_vec(vec![2i64, 3, 1]);
363 let st = StridedTensor::new(data, lengths);
364
365 let (result, lens) = st.lookup_2d(&[0, 2]);
367
368 assert_eq!(lens.len(), 2);
369 assert_eq!(lens[0], 2);
370 assert_eq!(lens[1], 1);
371
372 assert_eq!(result.nrows(), 3); assert_eq!(result[[0, 0]], 1);
374 assert_eq!(result[[1, 0]], 3);
375 assert_eq!(result[[2, 0]], 11);
376 }
377
378 #[test]
379 fn test_ivf_strided_tensor() {
380 let passage_ids = Array1::from_vec(vec![0i64, 1, 2, 3, 4, 5, 6]);
382 let lengths = Array1::from_vec(vec![2i32, 3, 2]); let ivf = IvfStridedTensor::new(passage_ids, lengths);
385
386 assert_eq!(ivf.num_centroids(), 3);
387
388 let pids = ivf.lookup(&[0, 2]);
390 assert_eq!(pids, vec![0, 1, 5, 6]);
391
392 let pids = ivf.lookup(&[1]);
394 assert_eq!(pids, vec![2, 3, 4]);
395 }
396}