use std::fmt::Debug;
use std::hash::Hash;
use serde::{Serialize, de::DeserializeOwned};
pub trait DataId:
Eq + Hash + Clone + Ord + Send + Sync + Serialize + DeserializeOwned + Debug + 'static
{
}
impl<T> DataId for T where
T: Eq + Hash + Clone + Ord + Send + Sync + Serialize + DeserializeOwned + Debug + 'static
{
}
pub trait DataLoader<Id: DataId, Item>: Send + Sync {
fn all_ids(&self) -> Vec<Id>;
fn fetch(&self, ids: &[Id]) -> crate::error::Result<Vec<Item>>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub trait MutableDataLoader<Id: DataId, Item>: DataLoader<Id, Item> {
fn add_items(&mut self, items: Vec<Item>);
}
#[derive(Debug, Clone)]
pub struct VecLoader<T> {
items: Vec<T>,
}
impl<T> VecLoader<T> {
pub fn new(items: Vec<T>) -> Self {
Self { items }
}
pub fn as_slice(&self) -> &[T] {
&self.items
}
}
impl<T> DataLoader<usize, T> for VecLoader<T>
where
T: Clone + Send + Sync,
{
fn all_ids(&self) -> Vec<usize> {
(0..self.items.len()).collect()
}
fn fetch(&self, ids: &[usize]) -> crate::error::Result<Vec<T>> {
ids.iter()
.map(|&idx| {
self.items.get(idx).cloned().ok_or_else(|| {
crate::error::GEPAError::Config(format!(
"VecLoader: index {idx} out of range (len={})",
self.items.len()
))
})
})
.collect()
}
fn len(&self) -> usize {
self.items.len()
}
}
impl<T> MutableDataLoader<usize, T> for VecLoader<T>
where
T: Clone + Send + Sync,
{
fn add_items(&mut self, items: Vec<T>) {
self.items.extend(items);
}
}
pub fn ensure_loader<T: Clone + Send + Sync>(items: Vec<T>) -> VecLoader<T> {
VecLoader::new(items)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn vec_loader_all_ids_and_fetch() {
let loader = VecLoader::new(vec!["alpha", "beta", "gamma"]);
assert_eq!(loader.all_ids(), vec![0, 1, 2]);
assert_eq!(loader.fetch(&[0, 2]).unwrap(), vec!["alpha", "gamma"]);
assert_eq!(loader.len(), 3);
assert!(!loader.is_empty());
}
#[test]
fn vec_loader_out_of_range_returns_error() {
let loader = VecLoader::new(vec![10_i32, 20, 30]);
let result = loader.fetch(&[5]);
assert!(result.is_err(), "expected an error for index 5");
}
#[test]
fn mutable_loader_add_items() {
let mut loader = VecLoader::new(vec!["a", "b"]);
loader.add_items(vec!["c", "d"]);
assert_eq!(loader.len(), 4);
assert_eq!(loader.all_ids(), vec![0, 1, 2, 3]);
assert_eq!(loader.fetch(&[2, 3]).unwrap(), vec!["c", "d"]);
}
#[test]
fn ensure_loader_wraps_vec() {
let loader = ensure_loader(vec![1_u32, 2, 3]);
assert_eq!(loader.len(), 3);
assert_eq!(loader.fetch(&[1]).unwrap(), vec![2_u32]);
}
#[test]
fn vec_loader_preserves_order() {
let loader = VecLoader::new(vec![10_i32, 20, 30, 40, 50]);
let result = loader.fetch(&[4, 2, 0]).unwrap();
assert_eq!(result, vec![50, 30, 10]);
}
}