Skip to main content

diskann_benchmark_core/search/graph/
strategy.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{fmt::Debug, sync::Arc};
7
8use diskann::ANNError;
9use thiserror::Error;
10
11/// A dynamic strategy (e.g. `diskann::graph::glue::SearchStrategy`) manager for built-in
12/// searcher such as [`super::KNN`], [`super::Range`], and [`super::MultiHop`].
13///
14/// This provides an efficient means for either broadcasting a single strategy to all
15/// search queries, or maintaining a collection of strategies, one for each query.
16#[derive(Debug)]
17pub enum Strategy<S> {
18    /// Use the same strategy for all queries.
19    Broadcast(S),
20    /// Use a custom strategy for each query.
21    Collection(Box<[S]>),
22    /// Use a custom strategy for each query via an [`Indexable`] trait object.
23    Indexable(Box<dyn Indexable<S> + Send + Sync>),
24}
25
26impl<S> Strategy<S> {
27    /// Create a strategy that broadcasts `strategy` to all queries.
28    pub fn broadcast(strategy: S) -> Self {
29        Self::Broadcast(strategy)
30    }
31
32    /// Create a strategy that uses the strategies in `itr` for each query.
33    pub fn collection<I>(itr: I) -> Self
34    where
35        I: IntoIterator<Item = S>,
36    {
37        Self::Collection(itr.into_iter().collect())
38    }
39
40    /// Create a strategy that uses `indexable` for each query's strategy.
41    ///
42    /// This method is most useful when the strategies are stored in a custom data
43    /// structure and can avoid the cost of rematerializing a collection.
44    pub fn from_indexable<I>(indexable: I) -> Self
45    where
46        S: std::fmt::Debug,
47        I: Indexable<S> + Send + Sync + 'static,
48    {
49        Self::Indexable(Box::new(indexable))
50    }
51
52    /// Get the strategy for the query at `index`.
53    pub fn get(&self, index: usize) -> Result<&S, Error> {
54        match self {
55            Self::Broadcast(s) => Ok(s),
56            Self::Collection(strategies) => get_as_slice(strategies, index),
57            Self::Indexable(indexable) => indexable.get(index),
58        }
59    }
60
61    /// Return the number of strategies contained in `self`, or `None` if there are
62    /// an unbounded number of strategies.
63    ///
64    /// ```rust
65    /// use diskann_benchmark_core::search::graph::Strategy;
66    ///
67    /// let strategy = Strategy::broadcast(42usize);
68    /// assert_eq!(*strategy.get(0).unwrap(), 42);
69    /// assert!(
70    ///     strategy.len().is_none(),
71    ///     "broadcasted strategies can be retrieved from any index",
72    /// );
73    ///
74    /// let strategy = Strategy::collection([42usize, 128usize]);
75    /// assert_eq!(*strategy.get(0).unwrap(), 42);
76    /// assert_eq!(*strategy.get(1).unwrap(), 128);
77    /// assert_eq!(strategy.len(), Some(2));
78    /// ```
79    pub fn len(&self) -> Option<usize> {
80        match self {
81            Self::Broadcast(_) => None,
82            Self::Collection(strategies) => Some(strategies.len()),
83            Self::Indexable(indexable) => Some(indexable.len()),
84        }
85    }
86
87    /// Return `true` only if the number of strategies is bounded and equal to zero.
88    pub fn is_empty(&self) -> bool {
89        self.len() == Some(0)
90    }
91
92    /// Check if the number of strategies in `self` is compatible with `expected`.
93    ///
94    /// [`Self::Broadcast`] is always compatible. Otherwise, the number of strategies must
95    /// exactly match `expected`.
96    pub fn length_compatible(&self, expected: usize) -> Result<(), LengthIncompatible> {
97        if let Some(len) = self.len()
98            && len != expected
99        {
100            Err(LengthIncompatible {
101                strategies: len,
102                expected,
103            })
104        } else {
105            Ok(())
106        }
107    }
108}
109
110/// A helper trait for [`Strategy`] that allows custom collections of strategies.
111pub trait Indexable<S>: std::fmt::Debug {
112    /// Return the number of strategies in the collection.
113    ///
114    /// Implementations should ensure that `get(i)` returns `Ok(s)` for all `i < Self::len()`.
115    fn len(&self) -> usize;
116
117    /// Return the strategy at `index`.
118    fn get(&self, index: usize) -> Result<&S, Error>;
119
120    /// Return `true` if the collection is empty. Otherwise, return `false`.
121    fn is_empty(&self) -> bool {
122        self.len() == 0
123    }
124}
125
126fn get_as_slice<T>(x: &[T], index: usize) -> Result<&T, Error> {
127    x.get(index).ok_or_else(|| Error::new(index, x.len()))
128}
129
130impl<S> Indexable<S> for Arc<[S]>
131where
132    S: std::fmt::Debug,
133{
134    fn len(&self) -> usize {
135        <[S]>::len(self)
136    }
137
138    fn get(&self, index: usize) -> Result<&S, Error> {
139        get_as_slice(self, index)
140    }
141}
142
143impl<S> Indexable<S> for Box<[S]>
144where
145    S: std::fmt::Debug,
146{
147    fn len(&self) -> usize {
148        <[S]>::len(self)
149    }
150
151    fn get(&self, index: usize) -> Result<&S, Error> {
152        get_as_slice(self, index)
153    }
154}
155
156/// An error indicating that an attempt was made to index a strategy collection
157/// at an out-of-bounds index.
158#[derive(Debug, Clone, Copy, Error)]
159#[error("Tried to index a strategy collection of length {} at index {}", self.len, self.index)]
160pub struct Error {
161    index: usize,
162    len: usize,
163}
164
165impl Error {
166    fn new(index: usize, len: usize) -> Self {
167        Self { index, len }
168    }
169}
170
171impl From<Error> for ANNError {
172    #[track_caller]
173    fn from(error: Error) -> ANNError {
174        ANNError::opaque(error)
175    }
176}
177
178/// Error for an incorrect number of strategies.
179///
180/// See: [`Strategy::length_compatible`].
181#[derive(Debug, Clone)]
182pub struct LengthIncompatible {
183    strategies: usize,
184    expected: usize,
185}
186
187impl std::fmt::Display for LengthIncompatible {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        struct Plural {
190            value: usize,
191            singular: &'static str,
192            plural: &'static str,
193        }
194
195        impl std::fmt::Display for Plural {
196            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197                if self.value == 1 {
198                    write!(f, "{} {}", self.value, self.singular)
199                } else {
200                    write!(f, "{} {}", self.value, self.plural)
201                }
202            }
203        }
204
205        let strategies = Plural {
206            value: self.strategies,
207            singular: "strategy was",
208            plural: "strategies were",
209        };
210
211        let expected = Plural {
212            value: self.expected,
213            singular: "was expected",
214            plural: "were expected",
215        };
216
217        write!(f, "{strategies} provided when {expected}")
218    }
219}
220
221impl std::error::Error for LengthIncompatible {}
222
223///////////
224// Tests //
225///////////
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    // Simple test strategy type
232    #[derive(Debug, Clone, PartialEq, Eq)]
233    struct TestStrategy(u32);
234
235    // Custom indexable implementation for testing
236    #[derive(Debug)]
237    struct CustomIndexable {
238        strategies: Vec<TestStrategy>,
239    }
240
241    impl Indexable<TestStrategy> for CustomIndexable {
242        fn len(&self) -> usize {
243            self.strategies.len()
244        }
245
246        fn get(&self, index: usize) -> Result<&TestStrategy, Error> {
247            get_as_slice(&self.strategies, index)
248        }
249    }
250
251    #[test]
252    fn test_strategy_broadcast() {
253        let strategy = TestStrategy(42);
254        let broadcast = Strategy::broadcast(strategy.clone());
255
256        match &broadcast {
257            Strategy::Broadcast(s) => assert_eq!(*s, strategy),
258            _ => panic!("Expected Broadcast variant"),
259        }
260
261        for i in 0..10 {
262            assert_eq!(broadcast.get(i).unwrap(), &strategy);
263        }
264    }
265
266    #[test]
267    fn test_strategy_collection() {
268        let strategies = [TestStrategy(1), TestStrategy(2), TestStrategy(3)];
269        let collection = Strategy::collection(strategies.clone());
270
271        match &collection {
272            Strategy::Collection(s) => {
273                assert_eq!(s.len(), 3);
274                assert_eq!(s[0], strategies[0]);
275                assert_eq!(s[1], strategies[1]);
276                assert_eq!(s[2], strategies[2]);
277            }
278            _ => panic!("Expected Collection variant"),
279        }
280
281        assert_eq!(collection.get(0).unwrap(), &TestStrategy(1));
282        assert_eq!(collection.get(1).unwrap(), &TestStrategy(2));
283        assert_eq!(collection.get(2).unwrap(), &TestStrategy(3));
284
285        let err = collection.get(3).unwrap_err();
286        assert_eq!(err.index, 3);
287        assert_eq!(err.len, 3);
288    }
289
290    #[test]
291    fn test_strategy_collection_empty() {
292        let collection = Strategy::<TestStrategy>::collection(vec![]);
293
294        let result = collection.get(0);
295        assert!(result.is_err());
296    }
297
298    #[test]
299    fn test_strategy_indexable() {
300        let custom = CustomIndexable {
301            strategies: vec![TestStrategy(100), TestStrategy(200)],
302        };
303
304        let strategy = Strategy::from_indexable(custom);
305
306        match strategy {
307            Strategy::Indexable(_) => {
308                assert_eq!(strategy.get(0).unwrap(), &TestStrategy(100));
309                assert_eq!(strategy.get(1).unwrap(), &TestStrategy(200));
310            }
311            _ => panic!("Expected Indexable variant"),
312        }
313
314        assert_eq!(strategy.get(0).unwrap(), &TestStrategy(100));
315        assert_eq!(strategy.get(1).unwrap(), &TestStrategy(200));
316        let err = strategy.get(5).unwrap_err();
317        assert_eq!(err.index, 5);
318        assert_eq!(err.len, 2);
319    }
320
321    #[test]
322    fn test_indexable_arc_slice() {
323        let strategies: Arc<[TestStrategy]> =
324            Arc::from(vec![TestStrategy(1), TestStrategy(2), TestStrategy(3)]);
325
326        assert_eq!(strategies.len(), 3);
327        assert!(!strategies.is_empty());
328
329        assert_eq!(strategies.get(0).unwrap(), &TestStrategy(1));
330        assert_eq!(strategies.get(1).unwrap(), &TestStrategy(2));
331        assert_eq!(strategies.get(2).unwrap(), &TestStrategy(3));
332
333        assert!(strategies.get(10).is_err());
334    }
335
336    #[test]
337    fn test_indexable_box_slice() {
338        let strategies: Box<[TestStrategy]> =
339            vec![TestStrategy(5), TestStrategy(10)].into_boxed_slice();
340
341        assert_eq!(strategies.len(), 2);
342        assert!(!strategies.is_empty());
343
344        assert_eq!(strategies.get(0).unwrap(), &TestStrategy(5));
345        assert_eq!(strategies.get(1).unwrap(), &TestStrategy(10));
346
347        assert!(strategies.get(5).is_err());
348    }
349
350    #[test]
351    fn test_indexable_is_empty() {
352        let empty: Box<[TestStrategy]> = vec![].into_boxed_slice();
353        assert!(empty.is_empty());
354        assert_eq!(empty.len(), 0);
355
356        let non_empty: Box<[TestStrategy]> = vec![TestStrategy(1)].into_boxed_slice();
357        assert!(!non_empty.is_empty());
358        assert_eq!(non_empty.len(), 1);
359    }
360
361    #[test]
362    fn test_error_to_ann_error() {
363        let error = Error::new(3, 2);
364        let ann_error: ANNError = error.into();
365
366        // Verify it converts without panicking
367        let message = format!("{:?}", ann_error);
368        assert!(!message.is_empty());
369    }
370
371    #[test]
372    fn test_strategy_len() {
373        // Broadcast returns None (unbounded)
374        let broadcast = Strategy::broadcast(TestStrategy(1));
375        assert_eq!(broadcast.len(), None);
376        assert!(!broadcast.is_empty());
377
378        // Collection returns Some(len)
379        let collection =
380            Strategy::collection(vec![TestStrategy(1), TestStrategy(2), TestStrategy(3)]);
381        assert_eq!(collection.len(), Some(3));
382        assert!(!collection.is_empty());
383
384        // Empty collection returns Some(0)
385        let empty_collection = Strategy::<TestStrategy>::collection(vec![]);
386        assert_eq!(empty_collection.len(), Some(0));
387        assert!(empty_collection.is_empty());
388
389        // Indexable returns Some(len)
390        let custom = CustomIndexable {
391            strategies: vec![TestStrategy(1), TestStrategy(2)],
392        };
393        let indexable = Strategy::from_indexable(custom);
394        assert_eq!(indexable.len(), Some(2));
395        assert!(!indexable.is_empty());
396
397        // Empty indexable returns Some(0)
398        let empty_custom = CustomIndexable { strategies: vec![] };
399        let empty_indexable = Strategy::from_indexable(empty_custom);
400        assert_eq!(empty_indexable.len(), Some(0));
401        assert!(empty_indexable.is_empty());
402    }
403
404    #[test]
405    fn test_length_compatible_broadcast() {
406        // Broadcast is always compatible with any expected length
407        let broadcast = Strategy::broadcast(1usize);
408        assert!(broadcast.length_compatible(0).is_ok());
409        assert!(broadcast.length_compatible(1).is_ok());
410        assert!(broadcast.length_compatible(100).is_ok());
411        assert!(broadcast.length_compatible(usize::MAX).is_ok());
412    }
413
414    #[test]
415    fn test_length_compatible_collection() {
416        let collection = Strategy::collection([1usize, 2, 3]);
417        assert!(collection.length_compatible(3).is_ok());
418
419        // Incompatible when expected doesn't match
420        let err = collection.length_compatible(2).unwrap_err();
421        assert_eq!(
422            err.to_string(),
423            "3 strategies were provided when 2 were expected"
424        );
425
426        let err = collection.length_compatible(5).unwrap_err();
427        assert_eq!(
428            err.to_string(),
429            "3 strategies were provided when 5 were expected"
430        );
431
432        // One Strategy
433        let single = Strategy::collection([1usize]);
434        assert!(single.length_compatible(1).is_ok());
435
436        let err = single.length_compatible(0).unwrap_err();
437        assert_eq!(
438            err.to_string(),
439            "1 strategy was provided when 0 were expected"
440        );
441
442        // Empty collection
443        let empty = Strategy::<usize>::collection([]);
444        assert!(empty.length_compatible(0).is_ok());
445
446        let err = empty.length_compatible(1).unwrap_err();
447        assert_eq!(
448            err.to_string(),
449            "0 strategies were provided when 1 was expected"
450        );
451    }
452
453    #[test]
454    fn test_length_compatible_indexable() {
455        let custom = CustomIndexable {
456            strategies: vec![TestStrategy(1), TestStrategy(2)],
457        };
458        let indexable = Strategy::from_indexable(custom);
459        assert!(indexable.length_compatible(2).is_ok());
460
461        // Incompatible when expected doesn't match
462        let err = indexable.length_compatible(1).unwrap_err();
463        assert_eq!(
464            err.to_string(),
465            "2 strategies were provided when 1 was expected"
466        );
467
468        let err = indexable.length_compatible(10).unwrap_err();
469        assert_eq!(
470            err.to_string(),
471            "2 strategies were provided when 10 were expected"
472        );
473
474        // Empty indexable
475        let empty_custom = CustomIndexable { strategies: vec![] };
476        let empty_indexable = Strategy::from_indexable(empty_custom);
477        assert!(empty_indexable.length_compatible(0).is_ok());
478
479        let err = empty_indexable.length_compatible(5).unwrap_err();
480        assert_eq!(
481            err.to_string(),
482            "0 strategies were provided when 5 were expected"
483        );
484    }
485}