diskann_benchmark_core/search/graph/
strategy.rs1use std::{fmt::Debug, sync::Arc};
7
8use diskann::ANNError;
9use thiserror::Error;
10
11#[derive(Debug)]
17pub enum Strategy<S> {
18 Broadcast(S),
20 Collection(Box<[S]>),
22 Indexable(Box<dyn Indexable<S> + Send + Sync>),
24}
25
26impl<S> Strategy<S> {
27 pub fn broadcast(strategy: S) -> Self {
29 Self::Broadcast(strategy)
30 }
31
32 pub fn collection<I>(itr: I) -> Self
34 where
35 I: IntoIterator<Item = S>,
36 {
37 Self::Collection(itr.into_iter().collect())
38 }
39
40 pub fn from_indexable<I>(indexable: I) -> Self
45 where
46 S: std::fmt::Debug,
47 I: Indexable<S> + Send + Sync + 'static,
48 {
49 Self::Indexable(Box::new(indexable))
50 }
51
52 pub fn get(&self, index: usize) -> Result<&S, Error> {
54 match self {
55 Self::Broadcast(s) => Ok(s),
56 Self::Collection(strategies) => get_as_slice(strategies, index),
57 Self::Indexable(indexable) => indexable.get(index),
58 }
59 }
60
61 pub fn len(&self) -> Option<usize> {
80 match self {
81 Self::Broadcast(_) => None,
82 Self::Collection(strategies) => Some(strategies.len()),
83 Self::Indexable(indexable) => Some(indexable.len()),
84 }
85 }
86
87 pub fn is_empty(&self) -> bool {
89 self.len() == Some(0)
90 }
91
92 pub fn length_compatible(&self, expected: usize) -> Result<(), LengthIncompatible> {
97 if let Some(len) = self.len()
98 && len != expected
99 {
100 Err(LengthIncompatible {
101 strategies: len,
102 expected,
103 })
104 } else {
105 Ok(())
106 }
107 }
108}
109
110pub trait Indexable<S>: std::fmt::Debug {
112 fn len(&self) -> usize;
116
117 fn get(&self, index: usize) -> Result<&S, Error>;
119
120 fn is_empty(&self) -> bool {
122 self.len() == 0
123 }
124}
125
126fn get_as_slice<T>(x: &[T], index: usize) -> Result<&T, Error> {
127 x.get(index).ok_or_else(|| Error::new(index, x.len()))
128}
129
130impl<S> Indexable<S> for Arc<[S]>
131where
132 S: std::fmt::Debug,
133{
134 fn len(&self) -> usize {
135 <[S]>::len(self)
136 }
137
138 fn get(&self, index: usize) -> Result<&S, Error> {
139 get_as_slice(self, index)
140 }
141}
142
143impl<S> Indexable<S> for Box<[S]>
144where
145 S: std::fmt::Debug,
146{
147 fn len(&self) -> usize {
148 <[S]>::len(self)
149 }
150
151 fn get(&self, index: usize) -> Result<&S, Error> {
152 get_as_slice(self, index)
153 }
154}
155
156#[derive(Debug, Clone, Copy, Error)]
159#[error("Tried to index a strategy collection of length {} at index {}", self.len, self.index)]
160pub struct Error {
161 index: usize,
162 len: usize,
163}
164
165impl Error {
166 fn new(index: usize, len: usize) -> Self {
167 Self { index, len }
168 }
169}
170
171impl From<Error> for ANNError {
172 #[track_caller]
173 fn from(error: Error) -> ANNError {
174 ANNError::opaque(error)
175 }
176}
177
178#[derive(Debug, Clone)]
182pub struct LengthIncompatible {
183 strategies: usize,
184 expected: usize,
185}
186
187impl std::fmt::Display for LengthIncompatible {
188 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189 struct Plural {
190 value: usize,
191 singular: &'static str,
192 plural: &'static str,
193 }
194
195 impl std::fmt::Display for Plural {
196 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197 if self.value == 1 {
198 write!(f, "{} {}", self.value, self.singular)
199 } else {
200 write!(f, "{} {}", self.value, self.plural)
201 }
202 }
203 }
204
205 let strategies = Plural {
206 value: self.strategies,
207 singular: "strategy was",
208 plural: "strategies were",
209 };
210
211 let expected = Plural {
212 value: self.expected,
213 singular: "was expected",
214 plural: "were expected",
215 };
216
217 write!(f, "{strategies} provided when {expected}")
218 }
219}
220
221impl std::error::Error for LengthIncompatible {}
222
223#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[derive(Debug, Clone, PartialEq, Eq)]
233 struct TestStrategy(u32);
234
235 #[derive(Debug)]
237 struct CustomIndexable {
238 strategies: Vec<TestStrategy>,
239 }
240
241 impl Indexable<TestStrategy> for CustomIndexable {
242 fn len(&self) -> usize {
243 self.strategies.len()
244 }
245
246 fn get(&self, index: usize) -> Result<&TestStrategy, Error> {
247 get_as_slice(&self.strategies, index)
248 }
249 }
250
251 #[test]
252 fn test_strategy_broadcast() {
253 let strategy = TestStrategy(42);
254 let broadcast = Strategy::broadcast(strategy.clone());
255
256 match &broadcast {
257 Strategy::Broadcast(s) => assert_eq!(*s, strategy),
258 _ => panic!("Expected Broadcast variant"),
259 }
260
261 for i in 0..10 {
262 assert_eq!(broadcast.get(i).unwrap(), &strategy);
263 }
264 }
265
266 #[test]
267 fn test_strategy_collection() {
268 let strategies = [TestStrategy(1), TestStrategy(2), TestStrategy(3)];
269 let collection = Strategy::collection(strategies.clone());
270
271 match &collection {
272 Strategy::Collection(s) => {
273 assert_eq!(s.len(), 3);
274 assert_eq!(s[0], strategies[0]);
275 assert_eq!(s[1], strategies[1]);
276 assert_eq!(s[2], strategies[2]);
277 }
278 _ => panic!("Expected Collection variant"),
279 }
280
281 assert_eq!(collection.get(0).unwrap(), &TestStrategy(1));
282 assert_eq!(collection.get(1).unwrap(), &TestStrategy(2));
283 assert_eq!(collection.get(2).unwrap(), &TestStrategy(3));
284
285 let err = collection.get(3).unwrap_err();
286 assert_eq!(err.index, 3);
287 assert_eq!(err.len, 3);
288 }
289
290 #[test]
291 fn test_strategy_collection_empty() {
292 let collection = Strategy::<TestStrategy>::collection(vec![]);
293
294 let result = collection.get(0);
295 assert!(result.is_err());
296 }
297
298 #[test]
299 fn test_strategy_indexable() {
300 let custom = CustomIndexable {
301 strategies: vec![TestStrategy(100), TestStrategy(200)],
302 };
303
304 let strategy = Strategy::from_indexable(custom);
305
306 match strategy {
307 Strategy::Indexable(_) => {
308 assert_eq!(strategy.get(0).unwrap(), &TestStrategy(100));
309 assert_eq!(strategy.get(1).unwrap(), &TestStrategy(200));
310 }
311 _ => panic!("Expected Indexable variant"),
312 }
313
314 assert_eq!(strategy.get(0).unwrap(), &TestStrategy(100));
315 assert_eq!(strategy.get(1).unwrap(), &TestStrategy(200));
316 let err = strategy.get(5).unwrap_err();
317 assert_eq!(err.index, 5);
318 assert_eq!(err.len, 2);
319 }
320
321 #[test]
322 fn test_indexable_arc_slice() {
323 let strategies: Arc<[TestStrategy]> =
324 Arc::from(vec![TestStrategy(1), TestStrategy(2), TestStrategy(3)]);
325
326 assert_eq!(strategies.len(), 3);
327 assert!(!strategies.is_empty());
328
329 assert_eq!(strategies.get(0).unwrap(), &TestStrategy(1));
330 assert_eq!(strategies.get(1).unwrap(), &TestStrategy(2));
331 assert_eq!(strategies.get(2).unwrap(), &TestStrategy(3));
332
333 assert!(strategies.get(10).is_err());
334 }
335
336 #[test]
337 fn test_indexable_box_slice() {
338 let strategies: Box<[TestStrategy]> =
339 vec![TestStrategy(5), TestStrategy(10)].into_boxed_slice();
340
341 assert_eq!(strategies.len(), 2);
342 assert!(!strategies.is_empty());
343
344 assert_eq!(strategies.get(0).unwrap(), &TestStrategy(5));
345 assert_eq!(strategies.get(1).unwrap(), &TestStrategy(10));
346
347 assert!(strategies.get(5).is_err());
348 }
349
350 #[test]
351 fn test_indexable_is_empty() {
352 let empty: Box<[TestStrategy]> = vec![].into_boxed_slice();
353 assert!(empty.is_empty());
354 assert_eq!(empty.len(), 0);
355
356 let non_empty: Box<[TestStrategy]> = vec![TestStrategy(1)].into_boxed_slice();
357 assert!(!non_empty.is_empty());
358 assert_eq!(non_empty.len(), 1);
359 }
360
361 #[test]
362 fn test_error_to_ann_error() {
363 let error = Error::new(3, 2);
364 let ann_error: ANNError = error.into();
365
366 let message = format!("{:?}", ann_error);
368 assert!(!message.is_empty());
369 }
370
371 #[test]
372 fn test_strategy_len() {
373 let broadcast = Strategy::broadcast(TestStrategy(1));
375 assert_eq!(broadcast.len(), None);
376 assert!(!broadcast.is_empty());
377
378 let collection =
380 Strategy::collection(vec![TestStrategy(1), TestStrategy(2), TestStrategy(3)]);
381 assert_eq!(collection.len(), Some(3));
382 assert!(!collection.is_empty());
383
384 let empty_collection = Strategy::<TestStrategy>::collection(vec![]);
386 assert_eq!(empty_collection.len(), Some(0));
387 assert!(empty_collection.is_empty());
388
389 let custom = CustomIndexable {
391 strategies: vec![TestStrategy(1), TestStrategy(2)],
392 };
393 let indexable = Strategy::from_indexable(custom);
394 assert_eq!(indexable.len(), Some(2));
395 assert!(!indexable.is_empty());
396
397 let empty_custom = CustomIndexable { strategies: vec![] };
399 let empty_indexable = Strategy::from_indexable(empty_custom);
400 assert_eq!(empty_indexable.len(), Some(0));
401 assert!(empty_indexable.is_empty());
402 }
403
404 #[test]
405 fn test_length_compatible_broadcast() {
406 let broadcast = Strategy::broadcast(1usize);
408 assert!(broadcast.length_compatible(0).is_ok());
409 assert!(broadcast.length_compatible(1).is_ok());
410 assert!(broadcast.length_compatible(100).is_ok());
411 assert!(broadcast.length_compatible(usize::MAX).is_ok());
412 }
413
414 #[test]
415 fn test_length_compatible_collection() {
416 let collection = Strategy::collection([1usize, 2, 3]);
417 assert!(collection.length_compatible(3).is_ok());
418
419 let err = collection.length_compatible(2).unwrap_err();
421 assert_eq!(
422 err.to_string(),
423 "3 strategies were provided when 2 were expected"
424 );
425
426 let err = collection.length_compatible(5).unwrap_err();
427 assert_eq!(
428 err.to_string(),
429 "3 strategies were provided when 5 were expected"
430 );
431
432 let single = Strategy::collection([1usize]);
434 assert!(single.length_compatible(1).is_ok());
435
436 let err = single.length_compatible(0).unwrap_err();
437 assert_eq!(
438 err.to_string(),
439 "1 strategy was provided when 0 were expected"
440 );
441
442 let empty = Strategy::<usize>::collection([]);
444 assert!(empty.length_compatible(0).is_ok());
445
446 let err = empty.length_compatible(1).unwrap_err();
447 assert_eq!(
448 err.to_string(),
449 "0 strategies were provided when 1 was expected"
450 );
451 }
452
453 #[test]
454 fn test_length_compatible_indexable() {
455 let custom = CustomIndexable {
456 strategies: vec![TestStrategy(1), TestStrategy(2)],
457 };
458 let indexable = Strategy::from_indexable(custom);
459 assert!(indexable.length_compatible(2).is_ok());
460
461 let err = indexable.length_compatible(1).unwrap_err();
463 assert_eq!(
464 err.to_string(),
465 "2 strategies were provided when 1 was expected"
466 );
467
468 let err = indexable.length_compatible(10).unwrap_err();
469 assert_eq!(
470 err.to_string(),
471 "2 strategies were provided when 10 were expected"
472 );
473
474 let empty_custom = CustomIndexable { strategies: vec![] };
476 let empty_indexable = Strategy::from_indexable(empty_custom);
477 assert!(empty_indexable.length_compatible(0).is_ok());
478
479 let err = empty_indexable.length_compatible(5).unwrap_err();
480 assert_eq!(
481 err.to_string(),
482 "0 strategies were provided when 5 were expected"
483 );
484 }
485}