1use diskann_utils::views::{self, Matrix};
7
8use crate::recall;
9
10#[derive(Debug)]
17pub struct ResultIds<I> {
18 inner: ResultIdsInner<I>,
19}
20
21impl<I> ResultIds<I> {
22 pub fn as_rows(&self) -> &dyn recall::Rows<I> {
24 self.inner.as_rows()
25 }
26
27 pub(crate) fn new(inner: ResultIdsInner<I>) -> Self {
28 Self { inner }
29 }
30}
31
32#[derive(Debug)]
38pub(crate) struct Bounded<I> {
39 ids: Matrix<I>,
40 lengths: Vec<usize>,
42}
43
44impl<I> Bounded<I> {
45 pub(crate) fn new(ids: Matrix<I>, lengths: Vec<usize>) -> Self {
56 assert_eq!(
57 ids.nrows(),
58 lengths.len(),
59 "an internal invariant was not upheld",
60 );
61 Self { ids, lengths }
62 }
63
64 pub(crate) fn len(&self) -> usize {
66 self.lengths.len()
67 }
68
69 pub(crate) fn iter(&self) -> impl ExactSizeIterator<Item = &[I]> {
74 std::iter::zip(self.ids.row_iter(), self.lengths.iter()).map(|(row, len)| {
75 match row.get(..*len) {
76 Some(v) => v,
77 None => row,
78 }
79 })
80 }
81}
82
83impl<I> recall::Rows<I> for Bounded<I> {
84 fn nrows(&self) -> usize {
85 self.len()
86 }
87 fn row(&self, index: usize) -> &[I] {
88 let length = self.lengths[index];
89 let row = self.ids.row(index);
90 match row.get(..length) {
91 Some(v) => v,
92 None => row,
93 }
94 }
95 fn ncols(&self) -> Option<usize> {
96 None
97 }
98}
99
100#[derive(Debug)]
109pub(crate) enum ResultIdsInner<I> {
110 Fixed(Bounded<I>),
111 Dynamic(Vec<Vec<I>>),
112}
113
114impl<I> ResultIdsInner<I> {
115 pub(crate) fn as_rows(&self) -> &dyn recall::Rows<I> {
116 match self {
117 Self::Fixed(bounded) => bounded,
118 Self::Dynamic(ids) => ids,
119 }
120 }
121}
122
123#[derive(Debug, Default)]
129pub(crate) enum IdAggregator<I> {
130 #[default]
132 Empty,
133 Fixed {
137 matrices: Vec<Bounded<I>>,
138 len: usize,
139 num_ids: usize,
140 },
141 Dynamic(Vec<ResultIdsInner<I>>),
143}
144
145impl<I> IdAggregator<I>
146where
147 I: Clone + Default,
148{
149 pub(crate) fn new() -> Self {
151 Self::Empty
152 }
153
154 pub(crate) fn push(&mut self, ids: ResultIdsInner<I>) {
156 *self = match std::mem::take(self) {
172 Self::Empty => match ids {
173 ResultIdsInner::Fixed(bounded) => {
174 let len = bounded.ids.nrows();
175 let num_ids = bounded.ids.ncols();
176 Self::Fixed {
177 matrices: vec![bounded],
178 len,
179 num_ids,
180 }
181 }
182 ResultIdsInner::Dynamic(ids) => Self::Dynamic(vec![ResultIdsInner::Dynamic(ids)]),
183 },
184 Self::Fixed {
185 mut matrices,
186 len,
187 num_ids,
188 } => match ids {
189 ResultIdsInner::Fixed(bounded) => {
190 if bounded.ids.ncols() == num_ids {
191 let len = len + bounded.len();
192 matrices.push(bounded);
193 Self::Fixed {
194 matrices,
195 len,
196 num_ids,
197 }
198 } else {
199 let mut dynamic: Vec<_> =
200 matrices.into_iter().map(ResultIdsInner::Fixed).collect();
201 dynamic.push(ResultIdsInner::Fixed(bounded));
202 Self::Dynamic(dynamic)
203 }
204 }
205 ResultIdsInner::Dynamic(ids) => {
206 let mut dynamic: Vec<_> =
207 matrices.into_iter().map(ResultIdsInner::Fixed).collect();
208 dynamic.push(ResultIdsInner::Dynamic(ids));
209 Self::Dynamic(dynamic)
210 }
211 },
212 Self::Dynamic(mut dynamic) => {
213 dynamic.push(ids);
214 Self::Dynamic(dynamic)
215 }
216 };
217 }
218
219 pub(crate) fn finish(self) -> ResultIds<I> {
222 match self {
228 Self::Empty => ResultIds::new(ResultIdsInner::Dynamic(Vec::new())),
229 Self::Fixed {
230 matrices,
231 len,
232 num_ids,
233 } => {
234 let mut dst = Matrix::new(views::Init(|| I::default()), len, num_ids);
235 let mut lengths = Vec::with_capacity(len);
236
237 let mut output_row = 0;
238 for bounded in matrices {
239 for row in bounded.ids.row_iter() {
240 dst.row_mut(output_row).clone_from_slice(row);
241 output_row += 1;
242 }
243 lengths.extend_from_slice(&bounded.lengths);
244 }
245
246 ResultIds::new(ResultIdsInner::Fixed(Bounded::new(dst, lengths)))
247 }
248 Self::Dynamic(all) => {
249 let mut dst = Vec::<Vec<I>>::new();
250 for ids in all {
251 match ids {
252 ResultIdsInner::Fixed(bounded) => {
253 bounded.iter().for_each(|row| dst.push(row.into()));
254 }
255 ResultIdsInner::Dynamic(dynamic) => {
256 dynamic.into_iter().for_each(|i| dst.push(i));
257 }
258 }
259 }
260
261 ResultIds::new(ResultIdsInner::Dynamic(dst))
262 }
263 }
264 }
265}
266
267#[cfg(test)]
272mod tests {
273 use super::*;
274
275 use crate::recall::Rows;
276
277 fn make_bounded(data: Vec<Vec<u32>>) -> Bounded<u32> {
279 let nrows = data.len();
280 let ncols = data.iter().map(|v| v.len()).max().unwrap_or(0);
281
282 let mut matrix = Matrix::new(0u32, nrows, ncols);
283 let mut lengths = Vec::with_capacity(nrows);
284
285 for (row, row_data) in std::iter::zip(matrix.row_iter_mut(), data.iter()) {
286 let len = std::iter::zip(row.iter_mut(), row_data.iter())
287 .map(|(dst, src)| {
288 *dst = *src;
289 })
290 .count();
291 lengths.push(len);
292 }
293
294 Bounded::new(matrix, lengths)
295 }
296
297 #[test]
298 fn test_bounded_new_valid() {
299 let matrix = Matrix::new(0u32, 3, 5);
300 let lengths = vec![2, 3, 1];
301 let bounded = Bounded::new(matrix, lengths);
302
303 assert_eq!(bounded.len(), 3);
304 }
305
306 #[test]
307 fn test_bounded_length_clamping() {
308 let matrix = Matrix::new(0u32, 3, 3);
309 let lengths = vec![2, 3, 5]; let bounded = Bounded::new(matrix, lengths);
311
312 assert_eq!(bounded.len(), 3);
313 assert_eq!(bounded.row(0), &[0, 0]);
314 assert_eq!(bounded.row(1), &[0, 0, 0]);
315 assert_eq!(bounded.row(2), &[0, 0, 0]); let rows: Vec<&[u32]> = bounded.iter().collect();
318 assert_eq!(rows[0], &[0, 0]);
319 assert_eq!(rows[1], &[0, 0, 0]);
320 assert_eq!(rows[2], &[0, 0, 0]); }
322
323 #[test]
324 #[should_panic(expected = "an internal invariant was not upheld")]
325 fn test_bounded_new_mismatched_lengths() {
326 let matrix = Matrix::new(0u32, 3, 5);
327 let lengths = vec![2, 3]; Bounded::new(matrix, lengths);
329 }
330
331 #[test]
332 fn test_bounded() {
333 let bounded = make_bounded(vec![vec![1, 2], vec![3, 4, 5], vec![6]]);
334 assert_eq!(bounded.len(), 3);
335
336 assert_eq!(bounded.nrows(), 3);
338 assert_eq!(bounded.row(0), &[1, 2]);
339 assert_eq!(bounded.row(1), &[3, 4, 5]);
340 assert_eq!(bounded.row(2), &[6]);
341 assert_eq!(bounded.ncols(), None);
342
343 let rows: Vec<&[u32]> = bounded.iter().collect();
345 assert_eq!(rows.len(), 3);
346 assert_eq!(rows[0], &[1, 2]);
347 assert_eq!(rows[1], &[3, 4, 5]);
348 assert_eq!(rows[2], &[6]);
349 }
350
351 #[test]
352 fn test_result_ids_inner_fixed() {
353 let bounded = make_bounded(vec![vec![1, 2], vec![3, 4, 5]]);
354 let inner = ResultIdsInner::Fixed(bounded);
355
356 let rows = inner.as_rows();
357 assert_eq!(rows.nrows(), 2);
358 assert_eq!(rows.row(0), &[1, 2]);
359 assert_eq!(rows.row(1), &[3, 4, 5]);
360 }
361
362 #[test]
363 fn test_result_ids_inner_dynamic() {
364 let vecs = vec![vec![1, 2, 3], vec![4], vec![5, 6]];
365 let inner = ResultIdsInner::Dynamic(vecs);
366
367 let rows = inner.as_rows();
368 assert_eq!(rows.nrows(), 3);
369 assert_eq!(rows.row(0), &[1, 2, 3]);
370 assert_eq!(rows.row(1), &[4]);
371 assert_eq!(rows.row(2), &[5, 6]);
372 }
373
374 #[test]
375 fn test_result_ids_wrapper() {
376 let bounded = make_bounded(vec![vec![10], vec![20, 30]]);
377 let result = ResultIds::new(ResultIdsInner::Fixed(bounded));
378
379 let rows = result.as_rows();
380 assert_eq!(rows.nrows(), 2);
381 assert_eq!(rows.row(0), &[10]);
382 assert_eq!(rows.row(1), &[20, 30]);
383 }
384
385 #[test]
388 fn test_aggregator_empty_finish() {
389 let aggregator = IdAggregator::<u32>::new();
390 let result = aggregator.finish();
391
392 let rows = result.as_rows();
393 assert_eq!(rows.nrows(), 0);
394 assert_eq!(rows.ncols(), None);
395 }
396
397 #[test]
398 fn test_aggregator_empty_to_fixed() {
399 let mut aggregator = IdAggregator::new();
400
401 let bounded = make_bounded(vec![vec![1, 2], vec![3], vec![4, 5]]);
402 aggregator.push(ResultIdsInner::Fixed(bounded));
403
404 match aggregator {
406 IdAggregator::Fixed { len, num_ids, .. } => {
407 assert_eq!(len, 3);
408 assert_eq!(num_ids, 2);
409 }
410 _ => panic!("Expected Fixed state"),
411 }
412
413 let finished = aggregator.finish();
414 let rows = finished.as_rows();
415 assert_eq!(rows.nrows(), 3);
416 assert_eq!(rows.row(0), &[1, 2]);
417 assert_eq!(rows.row(1), &[3]);
418 assert_eq!(rows.row(2), &[4, 5]);
419 }
420
421 #[test]
422 fn test_aggregator_empty_to_dynamic() {
423 let mut aggregator = IdAggregator::new();
424
425 let vecs = vec![vec![1, 2, 3], vec![4]];
426 aggregator.push(ResultIdsInner::Dynamic(vecs));
427
428 match aggregator {
430 IdAggregator::Dynamic(ref inner) => {
431 assert_eq!(inner.len(), 1);
432 }
433 _ => panic!("Expected Dynamic state"),
434 }
435
436 let finished = aggregator.finish();
437 let rows = finished.as_rows();
438 assert_eq!(rows.nrows(), 2);
439 assert_eq!(rows.row(0), &[1, 2, 3]);
440 assert_eq!(rows.row(1), &[4]);
441 }
442
443 #[test]
444 fn test_aggregator_fixed_stays_fixed_same_size() {
445 let mut aggregator = IdAggregator::new();
446
447 let bounded1 = make_bounded(vec![vec![1, 2, 3], vec![4, 5]]);
449 aggregator.push(ResultIdsInner::Fixed(bounded1));
450
451 let bounded2 = make_bounded(vec![vec![6, 7, 8]]);
453 aggregator.push(ResultIdsInner::Fixed(bounded2));
454
455 match &aggregator {
457 IdAggregator::Fixed {
458 len,
459 num_ids,
460 matrices,
461 } => {
462 assert_eq!(*len, 3); assert_eq!(*num_ids, 3);
464 assert_eq!(matrices.len(), 2);
465 }
466 _ => panic!("Expected Fixed state"),
467 }
468
469 let finished = aggregator.finish();
470 let rows = finished.as_rows();
471 assert_eq!(rows.nrows(), 3);
472 assert_eq!(rows.row(0), &[1, 2, 3]);
473 assert_eq!(rows.row(1), &[4, 5]);
474 assert_eq!(rows.row(2), &[6, 7, 8]);
475 }
476
477 #[test]
478 fn test_aggregator_fixed_to_dynamic_different_sizes() {
479 let mut aggregator = IdAggregator::new();
480
481 let bounded1 = make_bounded(vec![vec![1, 2], vec![3, 4]]);
483 aggregator.push(ResultIdsInner::Fixed(bounded1));
484
485 let bounded2 = make_bounded(vec![vec![5, 6, 7]]);
487 aggregator.push(ResultIdsInner::Fixed(bounded2));
488
489 match aggregator {
491 IdAggregator::Dynamic(ref inner) => {
492 assert_eq!(inner.len(), 2);
493 }
494 _ => panic!("Expected Dynamic state after size mismatch"),
495 }
496
497 let finished = aggregator.finish();
498 let rows = finished.as_rows();
499 assert_eq!(rows.nrows(), 3);
500 assert_eq!(rows.row(0), &[1, 2]);
501 assert_eq!(rows.row(1), &[3, 4]);
502 assert_eq!(rows.row(2), &[5, 6, 7]);
503 }
504
505 #[test]
506 fn test_aggregator_fixed_to_dynamic_incoming_dynamic() {
507 let mut aggregator = IdAggregator::new();
508
509 let bounded = make_bounded(vec![vec![1, 2], vec![3, 4]]);
511 aggregator.push(ResultIdsInner::Fixed(bounded));
512
513 let vecs = vec![vec![5, 6, 7]];
515 aggregator.push(ResultIdsInner::Dynamic(vecs));
516
517 match aggregator {
519 IdAggregator::Dynamic(ref inner) => {
520 assert_eq!(inner.len(), 2);
521 }
522 _ => panic!("Expected Dynamic state"),
523 }
524
525 let finished = aggregator.finish();
526 let rows = finished.as_rows();
527 assert_eq!(rows.nrows(), 3);
528 assert_eq!(rows.row(0), &[1, 2]);
529 assert_eq!(rows.row(1), &[3, 4]);
530 assert_eq!(rows.row(2), &[5, 6, 7]);
531 }
532
533 #[test]
534 fn test_aggregator_dynamic_stays_dynamic() {
535 let mut aggregator = IdAggregator::new();
536
537 let vecs1 = vec![vec![1, 2]];
539 aggregator.push(ResultIdsInner::Dynamic(vecs1));
540
541 let vecs2 = vec![vec![3, 4, 5]];
543 aggregator.push(ResultIdsInner::Dynamic(vecs2));
544
545 let bounded = make_bounded(vec![vec![6, 7]]);
547 aggregator.push(ResultIdsInner::Fixed(bounded));
548
549 match aggregator {
551 IdAggregator::Dynamic(ref inner) => {
552 assert_eq!(inner.len(), 3);
553 }
554 _ => panic!("Expected Dynamic state"),
555 }
556
557 let finished = aggregator.finish();
558 let rows = finished.as_rows();
559 assert_eq!(rows.nrows(), 3);
560 assert_eq!(rows.row(0), &[1, 2]);
561 assert_eq!(rows.row(1), &[3, 4, 5]);
562 assert_eq!(rows.row(2), &[6, 7]);
563 }
564
565 }