use std::{fmt::Debug, sync::Arc};
use diskann::ANNError;
use thiserror::Error;
#[derive(Debug)]
pub enum Strategy<S> {
Broadcast(S),
Collection(Box<[S]>),
Indexable(Box<dyn Indexable<S> + Send + Sync>),
}
impl<S> Strategy<S> {
pub fn broadcast(strategy: S) -> Self {
Self::Broadcast(strategy)
}
pub fn collection<I>(itr: I) -> Self
where
I: IntoIterator<Item = S>,
{
Self::Collection(itr.into_iter().collect())
}
pub fn from_indexable<I>(indexable: I) -> Self
where
S: std::fmt::Debug,
I: Indexable<S> + Send + Sync + 'static,
{
Self::Indexable(Box::new(indexable))
}
pub fn get(&self, index: usize) -> Result<&S, Error> {
match self {
Self::Broadcast(s) => Ok(s),
Self::Collection(strategies) => get_as_slice(strategies, index),
Self::Indexable(indexable) => indexable.get(index),
}
}
pub fn len(&self) -> Option<usize> {
match self {
Self::Broadcast(_) => None,
Self::Collection(strategies) => Some(strategies.len()),
Self::Indexable(indexable) => Some(indexable.len()),
}
}
pub fn is_empty(&self) -> bool {
self.len() == Some(0)
}
pub fn length_compatible(&self, expected: usize) -> Result<(), LengthIncompatible> {
if let Some(len) = self.len()
&& len != expected
{
Err(LengthIncompatible {
strategies: len,
expected,
})
} else {
Ok(())
}
}
}
pub trait Indexable<S>: std::fmt::Debug {
fn len(&self) -> usize;
fn get(&self, index: usize) -> Result<&S, Error>;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
fn get_as_slice<T>(x: &[T], index: usize) -> Result<&T, Error> {
x.get(index).ok_or_else(|| Error::new(index, x.len()))
}
impl<S> Indexable<S> for Arc<[S]>
where
S: std::fmt::Debug,
{
fn len(&self) -> usize {
<[S]>::len(self)
}
fn get(&self, index: usize) -> Result<&S, Error> {
get_as_slice(self, index)
}
}
impl<S> Indexable<S> for Box<[S]>
where
S: std::fmt::Debug,
{
fn len(&self) -> usize {
<[S]>::len(self)
}
fn get(&self, index: usize) -> Result<&S, Error> {
get_as_slice(self, index)
}
}
#[derive(Debug, Clone, Copy, Error)]
#[error("Tried to index a strategy collection of length {} at index {}", self.len, self.index)]
pub struct Error {
index: usize,
len: usize,
}
impl Error {
fn new(index: usize, len: usize) -> Self {
Self { index, len }
}
}
impl From<Error> for ANNError {
#[track_caller]
fn from(error: Error) -> ANNError {
ANNError::opaque(error)
}
}
#[derive(Debug, Clone)]
pub struct LengthIncompatible {
strategies: usize,
expected: usize,
}
impl std::fmt::Display for LengthIncompatible {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
struct Plural {
value: usize,
singular: &'static str,
plural: &'static str,
}
impl std::fmt::Display for Plural {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.value == 1 {
write!(f, "{} {}", self.value, self.singular)
} else {
write!(f, "{} {}", self.value, self.plural)
}
}
}
let strategies = Plural {
value: self.strategies,
singular: "strategy was",
plural: "strategies were",
};
let expected = Plural {
value: self.expected,
singular: "was expected",
plural: "were expected",
};
write!(f, "{strategies} provided when {expected}")
}
}
impl std::error::Error for LengthIncompatible {}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone, PartialEq, Eq)]
struct TestStrategy(u32);
#[derive(Debug)]
struct CustomIndexable {
strategies: Vec<TestStrategy>,
}
impl Indexable<TestStrategy> for CustomIndexable {
fn len(&self) -> usize {
self.strategies.len()
}
fn get(&self, index: usize) -> Result<&TestStrategy, Error> {
get_as_slice(&self.strategies, index)
}
}
#[test]
fn test_strategy_broadcast() {
let strategy = TestStrategy(42);
let broadcast = Strategy::broadcast(strategy.clone());
match &broadcast {
Strategy::Broadcast(s) => assert_eq!(*s, strategy),
_ => panic!("Expected Broadcast variant"),
}
for i in 0..10 {
assert_eq!(broadcast.get(i).unwrap(), &strategy);
}
}
#[test]
fn test_strategy_collection() {
let strategies = [TestStrategy(1), TestStrategy(2), TestStrategy(3)];
let collection = Strategy::collection(strategies.clone());
match &collection {
Strategy::Collection(s) => {
assert_eq!(s.len(), 3);
assert_eq!(s[0], strategies[0]);
assert_eq!(s[1], strategies[1]);
assert_eq!(s[2], strategies[2]);
}
_ => panic!("Expected Collection variant"),
}
assert_eq!(collection.get(0).unwrap(), &TestStrategy(1));
assert_eq!(collection.get(1).unwrap(), &TestStrategy(2));
assert_eq!(collection.get(2).unwrap(), &TestStrategy(3));
let err = collection.get(3).unwrap_err();
assert_eq!(err.index, 3);
assert_eq!(err.len, 3);
}
#[test]
fn test_strategy_collection_empty() {
let collection = Strategy::<TestStrategy>::collection(vec![]);
let result = collection.get(0);
assert!(result.is_err());
}
#[test]
fn test_strategy_indexable() {
let custom = CustomIndexable {
strategies: vec![TestStrategy(100), TestStrategy(200)],
};
let strategy = Strategy::from_indexable(custom);
match strategy {
Strategy::Indexable(_) => {
assert_eq!(strategy.get(0).unwrap(), &TestStrategy(100));
assert_eq!(strategy.get(1).unwrap(), &TestStrategy(200));
}
_ => panic!("Expected Indexable variant"),
}
assert_eq!(strategy.get(0).unwrap(), &TestStrategy(100));
assert_eq!(strategy.get(1).unwrap(), &TestStrategy(200));
let err = strategy.get(5).unwrap_err();
assert_eq!(err.index, 5);
assert_eq!(err.len, 2);
}
#[test]
fn test_indexable_arc_slice() {
let strategies: Arc<[TestStrategy]> =
Arc::from(vec![TestStrategy(1), TestStrategy(2), TestStrategy(3)]);
assert_eq!(strategies.len(), 3);
assert!(!strategies.is_empty());
assert_eq!(strategies.get(0).unwrap(), &TestStrategy(1));
assert_eq!(strategies.get(1).unwrap(), &TestStrategy(2));
assert_eq!(strategies.get(2).unwrap(), &TestStrategy(3));
assert!(strategies.get(10).is_err());
}
#[test]
fn test_indexable_box_slice() {
let strategies: Box<[TestStrategy]> =
vec![TestStrategy(5), TestStrategy(10)].into_boxed_slice();
assert_eq!(strategies.len(), 2);
assert!(!strategies.is_empty());
assert_eq!(strategies.get(0).unwrap(), &TestStrategy(5));
assert_eq!(strategies.get(1).unwrap(), &TestStrategy(10));
assert!(strategies.get(5).is_err());
}
#[test]
fn test_indexable_is_empty() {
let empty: Box<[TestStrategy]> = vec![].into_boxed_slice();
assert!(empty.is_empty());
assert_eq!(empty.len(), 0);
let non_empty: Box<[TestStrategy]> = vec![TestStrategy(1)].into_boxed_slice();
assert!(!non_empty.is_empty());
assert_eq!(non_empty.len(), 1);
}
#[test]
fn test_error_to_ann_error() {
let error = Error::new(3, 2);
let ann_error: ANNError = error.into();
let message = format!("{:?}", ann_error);
assert!(!message.is_empty());
}
#[test]
fn test_strategy_len() {
let broadcast = Strategy::broadcast(TestStrategy(1));
assert_eq!(broadcast.len(), None);
assert!(!broadcast.is_empty());
let collection =
Strategy::collection(vec![TestStrategy(1), TestStrategy(2), TestStrategy(3)]);
assert_eq!(collection.len(), Some(3));
assert!(!collection.is_empty());
let empty_collection = Strategy::<TestStrategy>::collection(vec![]);
assert_eq!(empty_collection.len(), Some(0));
assert!(empty_collection.is_empty());
let custom = CustomIndexable {
strategies: vec![TestStrategy(1), TestStrategy(2)],
};
let indexable = Strategy::from_indexable(custom);
assert_eq!(indexable.len(), Some(2));
assert!(!indexable.is_empty());
let empty_custom = CustomIndexable { strategies: vec![] };
let empty_indexable = Strategy::from_indexable(empty_custom);
assert_eq!(empty_indexable.len(), Some(0));
assert!(empty_indexable.is_empty());
}
#[test]
fn test_length_compatible_broadcast() {
let broadcast = Strategy::broadcast(1usize);
assert!(broadcast.length_compatible(0).is_ok());
assert!(broadcast.length_compatible(1).is_ok());
assert!(broadcast.length_compatible(100).is_ok());
assert!(broadcast.length_compatible(usize::MAX).is_ok());
}
#[test]
fn test_length_compatible_collection() {
let collection = Strategy::collection([1usize, 2, 3]);
assert!(collection.length_compatible(3).is_ok());
let err = collection.length_compatible(2).unwrap_err();
assert_eq!(
err.to_string(),
"3 strategies were provided when 2 were expected"
);
let err = collection.length_compatible(5).unwrap_err();
assert_eq!(
err.to_string(),
"3 strategies were provided when 5 were expected"
);
let single = Strategy::collection([1usize]);
assert!(single.length_compatible(1).is_ok());
let err = single.length_compatible(0).unwrap_err();
assert_eq!(
err.to_string(),
"1 strategy was provided when 0 were expected"
);
let empty = Strategy::<usize>::collection([]);
assert!(empty.length_compatible(0).is_ok());
let err = empty.length_compatible(1).unwrap_err();
assert_eq!(
err.to_string(),
"0 strategies were provided when 1 was expected"
);
}
#[test]
fn test_length_compatible_indexable() {
let custom = CustomIndexable {
strategies: vec![TestStrategy(1), TestStrategy(2)],
};
let indexable = Strategy::from_indexable(custom);
assert!(indexable.length_compatible(2).is_ok());
let err = indexable.length_compatible(1).unwrap_err();
assert_eq!(
err.to_string(),
"2 strategies were provided when 1 was expected"
);
let err = indexable.length_compatible(10).unwrap_err();
assert_eq!(
err.to_string(),
"2 strategies were provided when 10 were expected"
);
let empty_custom = CustomIndexable { strategies: vec![] };
let empty_indexable = Strategy::from_indexable(empty_custom);
assert!(empty_indexable.length_compatible(0).is_ok());
let err = empty_indexable.length_compatible(5).unwrap_err();
assert_eq!(
err.to_string(),
"0 strategies were provided when 5 were expected"
);
}
}