next_plaid/
strided_tensor.rs

1//! StridedTensor for efficient batch lookup of variable-length sequences.
2//!
3//! This module provides a data structure optimized for storing and retrieving
4//! variable-length sequences (like document token embeddings) efficiently.
5
6use ndarray::{s, Array1, Array2};
7
8use crate::utils::quantile;
9
10/// A data structure for efficient batch lookups on tensors of varying lengths.
11///
12/// `StridedTensor` stores a collection of variable-length sequences as a single,
13/// contiguous array with padding. It precomputes several views with different
14/// strides to optimize retrieval of sequences with common lengths.
15#[derive(Clone)]
16pub struct StridedTensor<T: Clone + Default + Copy + 'static> {
17    /// The flattened, contiguous data containing all sequences with padding
18    pub underlying_data: Array2<T>,
19    /// The shape of each individual element within the data (e.g., embedding dim)
20    pub inner_dim: usize,
21    /// Length of each element sequence
22    pub element_lengths: Array1<i64>,
23    /// Maximum length found among all sequences
24    pub max_element_len: usize,
25    /// Sorted vector of strides for precomputed views
26    pub precomputed_strides: Vec<usize>,
27    /// Cumulative sum of element_lengths for offset calculation
28    pub cumulative_lengths: Array1<i64>,
29}
30
31impl<T: Clone + Default + Copy + 'static> StridedTensor<T> {
32    /// Compute optimal strides based on the distribution of element lengths.
33    ///
34    /// Strides are determined by sampling quantiles, ensuring that common sequence
35    /// lengths are well-represented. The maximum element length is always included.
36    fn compute_strides(lengths: &Array1<i64>, max_len: usize) -> Vec<usize> {
37        if lengths.is_empty() {
38            return if max_len > 0 {
39                vec![max_len]
40            } else {
41                Vec::new()
42            };
43        }
44
45        // Sample lengths for quantile computation
46        let lengths_f32: Array1<f32> = lengths.mapv(|x| x as f32);
47
48        let target_quantiles = [0.5, 0.75, 0.9, 0.95];
49
50        let mut strides: Vec<usize> = target_quantiles
51            .iter()
52            .map(|&q| quantile(&lengths_f32, q) as usize)
53            .filter(|&s| s > 0)
54            .collect();
55
56        strides.push(max_len);
57        strides.sort_unstable();
58        strides.dedup();
59
60        if strides.len() == 1 && strides[0] == 0 {
61            return Vec::new();
62        }
63
64        strides
65    }
66
67    /// Creates a new `StridedTensor` from concatenated data and lengths.
68    ///
69    /// # Arguments
70    ///
71    /// * `data` - Concatenated data of all elements, shape `[total_tokens, inner_dim]`
72    /// * `lengths` - Length of each element sequence
73    ///
74    /// # Returns
75    ///
76    /// A new `StridedTensor` with precomputed views for efficient lookup
77    pub fn new(data: Array2<T>, lengths: Array1<i64>) -> Self {
78        let inner_dim = if data.ncols() > 0 { data.ncols() } else { 0 };
79
80        let max_element_len = if !lengths.is_empty() {
81            lengths.iter().copied().max().unwrap_or(0) as usize
82        } else {
83            0
84        };
85
86        let precomputed_strides = Self::compute_strides(&lengths, max_element_len);
87
88        // Compute cumulative lengths (with leading zero for offset calculation)
89        let mut cumulative = Array1::<i64>::zeros(lengths.len() + 1);
90        for (i, &len) in lengths.iter().enumerate() {
91            cumulative[i + 1] = cumulative[i] + len;
92        }
93
94        // Pad data if necessary
95        let total_needed = if !lengths.is_empty() {
96            cumulative[lengths.len() - 1] as usize + max_element_len
97        } else {
98            0
99        };
100
101        let underlying_data = if total_needed > data.nrows() && inner_dim > 0 {
102            let mut padded = Array2::<T>::default((total_needed, inner_dim));
103            padded.slice_mut(s![..data.nrows(), ..]).assign(&data);
104            padded
105        } else {
106            data
107        };
108
109        Self {
110            underlying_data,
111            inner_dim,
112            element_lengths: lengths,
113            max_element_len,
114            precomputed_strides,
115            cumulative_lengths: cumulative,
116        }
117    }
118
119    /// Returns the number of elements (sequences) stored
120    pub fn len(&self) -> usize {
121        self.element_lengths.len()
122    }
123
124    /// Returns true if there are no elements
125    pub fn is_empty(&self) -> bool {
126        self.element_lengths.is_empty()
127    }
128
129    /// Get the total number of tokens across all sequences
130    pub fn total_tokens(&self) -> usize {
131        self.element_lengths.iter().sum::<i64>() as usize
132    }
133}
134
135impl StridedTensor<i64> {
136    /// Retrieve elements by their indices (for codes - 1D per element).
137    ///
138    /// # Arguments
139    ///
140    /// * `indices` - Indices of elements to retrieve
141    ///
142    /// # Returns
143    ///
144    /// Tuple of (flattened data, lengths for each element)
145    pub fn lookup_1d(&self, indices: &[usize]) -> (Array1<i64>, Array1<i64>) {
146        if indices.is_empty() {
147            return (Array1::zeros(0), Array1::zeros(0));
148        }
149
150        // Gather lengths and calculate total size
151        let mut selected_lengths = Array1::<i64>::zeros(indices.len());
152        let mut total_len = 0usize;
153
154        for (i, &idx) in indices.iter().enumerate() {
155            let len = self.element_lengths[idx];
156            selected_lengths[i] = len;
157            total_len += len as usize;
158        }
159
160        // Gather data
161        let mut result = Array1::<i64>::zeros(total_len);
162        let mut offset = 0usize;
163
164        for &idx in indices {
165            let start = self.cumulative_lengths[idx] as usize;
166            let len = self.element_lengths[idx] as usize;
167
168            for j in 0..len {
169                result[offset + j] = self.underlying_data[[start + j, 0]];
170            }
171            offset += len;
172        }
173
174        (result, selected_lengths)
175    }
176}
177
178impl StridedTensor<u8> {
179    /// Retrieve elements by their indices (for residuals - 2D per element).
180    ///
181    /// # Arguments
182    ///
183    /// * `indices` - Indices of elements to retrieve
184    ///
185    /// # Returns
186    ///
187    /// Tuple of (concatenated data, lengths for each element)
188    pub fn lookup_2d(&self, indices: &[usize]) -> (Array2<u8>, Array1<i64>) {
189        if indices.is_empty() {
190            return (Array2::zeros((0, self.inner_dim)), Array1::zeros(0));
191        }
192
193        // Gather lengths and calculate total size
194        let mut selected_lengths = Array1::<i64>::zeros(indices.len());
195        let mut total_len = 0usize;
196
197        for (i, &idx) in indices.iter().enumerate() {
198            let len = self.element_lengths[idx];
199            selected_lengths[i] = len;
200            total_len += len as usize;
201        }
202
203        // Gather data
204        let mut result = Array2::<u8>::zeros((total_len, self.inner_dim));
205        let mut offset = 0usize;
206
207        for &idx in indices {
208            let start = self.cumulative_lengths[idx] as usize;
209            let len = self.element_lengths[idx] as usize;
210
211            result
212                .slice_mut(s![offset..offset + len, ..])
213                .assign(&self.underlying_data.slice(s![start..start + len, ..]));
214            offset += len;
215        }
216
217        (result, selected_lengths)
218    }
219}
220
221impl StridedTensor<usize> {
222    /// Retrieve elements by their indices (for codes stored as usize).
223    ///
224    /// # Arguments
225    ///
226    /// * `indices` - Indices of elements to retrieve
227    ///
228    /// # Returns
229    ///
230    /// Tuple of (flattened codes, lengths for each element)
231    pub fn lookup_codes(&self, indices: &[usize]) -> (Array1<usize>, Array1<i64>) {
232        if indices.is_empty() {
233            return (Array1::zeros(0), Array1::zeros(0));
234        }
235
236        // Gather lengths and calculate total size
237        let mut selected_lengths = Array1::<i64>::zeros(indices.len());
238        let mut total_len = 0usize;
239
240        for (i, &idx) in indices.iter().enumerate() {
241            let len = self.element_lengths[idx];
242            selected_lengths[i] = len;
243            total_len += len as usize;
244        }
245
246        // Gather data
247        let mut result = Array1::<usize>::zeros(total_len);
248        let mut offset = 0usize;
249
250        for &idx in indices {
251            let start = self.cumulative_lengths[idx] as usize;
252            let len = self.element_lengths[idx] as usize;
253
254            for j in 0..len {
255                result[offset + j] = self.underlying_data[[start + j, 0]];
256            }
257            offset += len;
258        }
259
260        (result, selected_lengths)
261    }
262}
263
264/// StridedTensor for IVF (inverted file) - maps centroid ID to passage IDs
265pub struct IvfStridedTensor {
266    /// Concatenated passage IDs for all centroids
267    pub passage_ids: Array1<i64>,
268    /// Length of each centroid's passage list
269    pub lengths: Array1<i32>,
270    /// Cumulative offsets into passage_ids
271    pub offsets: Array1<i64>,
272}
273
274impl IvfStridedTensor {
275    /// Create a new IVF strided tensor
276    pub fn new(passage_ids: Array1<i64>, lengths: Array1<i32>) -> Self {
277        let num_centroids = lengths.len();
278        let mut offsets = Array1::<i64>::zeros(num_centroids + 1);
279
280        for i in 0..num_centroids {
281            offsets[i + 1] = offsets[i] + lengths[i] as i64;
282        }
283
284        Self {
285            passage_ids,
286            lengths,
287            offsets,
288        }
289    }
290
291    /// Lookup passage IDs for given centroid indices
292    pub fn lookup(&self, centroid_indices: &[usize]) -> Vec<i64> {
293        let mut result = Vec::new();
294
295        for &idx in centroid_indices {
296            if idx < self.lengths.len() {
297                let start = self.offsets[idx] as usize;
298                let len = self.lengths[idx] as usize;
299
300                for i in 0..len {
301                    result.push(self.passage_ids[start + i]);
302                }
303            }
304        }
305
306        // Deduplicate
307        result.sort_unstable();
308        result.dedup();
309        result
310    }
311
312    /// Get number of centroids
313    pub fn num_centroids(&self) -> usize {
314        self.lengths.len()
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    #[test]
323    fn test_strided_tensor_creation() {
324        // Create test data: 3 sequences of lengths [2, 3, 1]
325        let data = Array2::from_shape_vec(
326            (6, 4),
327            vec![
328                1, 2, 3, 4, // seq 0, token 0
329                5, 6, 7, 8, // seq 0, token 1
330                9, 10, 11, 12, // seq 1, token 0
331                13, 14, 15, 16, // seq 1, token 1
332                17, 18, 19, 20, // seq 1, token 2
333                21, 22, 23, 24u8, // seq 2, token 0
334            ],
335        )
336        .unwrap();
337
338        let lengths = Array1::from_vec(vec![2i64, 3, 1]);
339
340        let st = StridedTensor::new(data, lengths);
341
342        assert_eq!(st.len(), 3);
343        assert_eq!(st.max_element_len, 3);
344        assert_eq!(st.total_tokens(), 6);
345    }
346
347    #[test]
348    fn test_strided_tensor_lookup() {
349        let data = Array2::from_shape_vec(
350            (6, 2),
351            vec![
352                1, 2, // seq 0, token 0
353                3, 4, // seq 0, token 1
354                5, 6, // seq 1, token 0
355                7, 8, // seq 1, token 1
356                9, 10, // seq 1, token 2
357                11, 12u8, // seq 2, token 0
358            ],
359        )
360        .unwrap();
361
362        let lengths = Array1::from_vec(vec![2i64, 3, 1]);
363        let st = StridedTensor::new(data, lengths);
364
365        // Lookup sequences 0 and 2
366        let (result, lens) = st.lookup_2d(&[0, 2]);
367
368        assert_eq!(lens.len(), 2);
369        assert_eq!(lens[0], 2);
370        assert_eq!(lens[1], 1);
371
372        assert_eq!(result.nrows(), 3); // 2 + 1 tokens
373        assert_eq!(result[[0, 0]], 1);
374        assert_eq!(result[[1, 0]], 3);
375        assert_eq!(result[[2, 0]], 11);
376    }
377
378    #[test]
379    fn test_ivf_strided_tensor() {
380        // 3 centroids with passage lists
381        let passage_ids = Array1::from_vec(vec![0i64, 1, 2, 3, 4, 5, 6]);
382        let lengths = Array1::from_vec(vec![2i32, 3, 2]); // centroid 0: [0,1], centroid 1: [2,3,4], centroid 2: [5,6]
383
384        let ivf = IvfStridedTensor::new(passage_ids, lengths);
385
386        assert_eq!(ivf.num_centroids(), 3);
387
388        // Lookup centroids 0 and 2
389        let pids = ivf.lookup(&[0, 2]);
390        assert_eq!(pids, vec![0, 1, 5, 6]);
391
392        // Lookup centroid 1
393        let pids = ivf.lookup(&[1]);
394        assert_eq!(pids, vec![2, 3, 4]);
395    }
396}