use std::{
fmt::Debug,
sync::atomic::{AtomicUsize, Ordering},
};
use diskann::graph::AdjacencyList;
use diskann_utils::future::AsyncFriendly;
use super::{
bf_cache::{AdjacencyListCacher, Cache, CacheableId},
error::CacheAccessError,
provider,
};
#[derive(Debug, Default)]
pub struct HitStats {
hits: AtomicUsize,
misses: AtomicUsize,
}
impl HitStats {
pub fn new() -> Self {
Self::default()
}
pub fn hit(&self, count: usize) {
self.hits.fetch_add(count, Ordering::Relaxed);
}
pub fn miss(&self, count: usize) {
self.misses.fetch_add(count, Ordering::Relaxed);
}
pub fn get_hits(&self) -> usize {
self.hits.load(Ordering::Relaxed)
}
pub fn get_misses(&self) -> usize {
self.misses.load(Ordering::Relaxed)
}
}
#[derive(Debug)]
pub struct LocalStats<'a> {
parent: &'a HitStats,
hits: usize,
misses: usize,
}
impl<'a> LocalStats<'a> {
pub fn new(parent: &'a HitStats) -> Self {
Self {
parent,
hits: 0,
misses: 0,
}
}
pub fn hit(&mut self) {
self.hits += 1;
}
pub fn miss(&mut self) {
self.misses += 1;
}
pub fn get_local_hits(&self) -> usize {
self.hits
}
pub fn get_local_misses(&self) -> usize {
self.misses
}
}
impl Drop for LocalStats<'_> {
fn drop(&mut self) {
self.parent.hit(self.hits);
self.parent.miss(self.misses);
}
}
pub trait KeyGen<K> {
type Key: bytemuck::Pod;
fn generate(&self, key: K) -> Self::Key;
}
#[derive(Debug, Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C, packed)]
pub struct CacheKey<I> {
id: I,
tag: I,
}
impl<I> CacheKey<I>
where
I: bytemuck::Pod,
{
pub fn new(id: I, tag: I) -> Self {
Self { id, tag }
}
pub fn id(&self) -> I {
self.id
}
pub fn tag(&self) -> I {
self.tag
}
}
#[derive(Debug, Clone, Copy)]
pub struct Tag<I>(I);
impl<I> Tag<I> {
pub fn new(tag: I) -> Self {
Self(tag)
}
pub fn tag(self) -> I {
self.0
}
}
impl<I> KeyGen<I> for Tag<I>
where
I: bytemuck::Pod,
{
type Key = CacheKey<I>;
fn generate(&self, key: I) -> CacheKey<I> {
CacheKey {
id: key,
tag: self.0,
}
}
}
#[derive(Debug)]
pub struct Graph<'a, I, T> {
cache: &'a Cache,
stats: LocalStats<'a>,
accessor: AdjacencyListCacher<I>,
keygen: T,
}
impl<'a, I, T> Graph<'a, I, T>
where
I: Default + Clone,
{
pub fn new(cache: &'a Cache, max_degree: usize, keygen: T, stats: &'a HitStats) -> Self {
Self {
cache,
stats: LocalStats::new(stats),
accessor: AdjacencyListCacher::new(max_degree),
keygen,
}
}
}
impl<'a, I, T> Graph<'a, I, T> {
pub fn cache(&self) -> &Cache {
self.cache
}
pub fn stats(&self) -> &LocalStats<'_> {
&self.stats
}
pub fn stats_mut(&mut self) -> &mut LocalStats<'a> {
&mut self.stats
}
}
impl<I, T> provider::NeighborCache<I> for Graph<'_, I, T>
where
I: CacheableId,
T: KeyGen<I> + AsyncFriendly,
{
type Error = CacheAccessError;
fn try_get_neighbors(
&mut self,
id: I,
neighbors: &mut AdjacencyList<I>,
) -> Result<provider::NeighborStatus, CacheAccessError> {
let hit = self
.cache
.get_into(self.keygen.generate(id), &mut self.accessor, neighbors)
.map_err(|err| CacheAccessError::read(id, err))?;
if hit.into_inner() {
self.stats.hit();
Ok(provider::NeighborStatus::Hit)
} else {
self.stats.miss();
Ok(provider::NeighborStatus::Miss)
}
}
fn set_neighbors(&mut self, id: I, neighbors: &[I]) -> Result<(), CacheAccessError> {
self.cache
.set(self.keygen.generate(id), &mut self.accessor, neighbors)
.map_err(|err| CacheAccessError::write(id, err))
}
fn invalidate_neighbors(&mut self, id: I) {
self.cache.delete(self.keygen.generate(id))
}
}
#[cfg(test)]
mod tests {
use super::*;
use diskann_quantization::num::PowerOfTwo;
use crate::model::graph::provider::async_::caching::provider::NeighborCache;
#[test]
fn test_hit_stats() {
let stats = HitStats::new();
assert_eq!(stats.get_hits(), 0);
assert_eq!(stats.get_misses(), 0);
stats.hit(5);
stats.miss(10);
assert_eq!(stats.get_hits(), 5);
assert_eq!(stats.get_misses(), 10);
stats.hit(1);
stats.miss(2);
assert_eq!(stats.get_hits(), 6);
assert_eq!(stats.get_misses(), 12);
let hits = stats.get_hits();
let misses = stats.get_misses();
{
let mut local = LocalStats::new(&stats);
assert_eq!(local.get_local_hits(), 0);
assert_eq!(local.get_local_misses(), 0);
for _ in 0..5 {
local.hit();
}
for _ in 0..10 {
local.miss();
}
assert_eq!(local.get_local_hits(), 5);
assert_eq!(local.get_local_misses(), 10);
assert_eq!(local.parent.get_hits(), hits);
assert_eq!(local.parent.get_misses(), misses);
}
assert_eq!(stats.get_hits(), hits + 5);
assert_eq!(stats.get_misses(), misses + 10);
}
#[test]
fn test_tag() {
let tag0 = Tag::<usize>::new(0);
let tag1 = Tag::<usize>::new(1);
assert_eq!(tag0.tag(), 0);
assert_eq!(tag1.tag(), 1);
for k in 0..10 {
let key = tag0.generate(k);
assert_eq!(key.id(), k);
assert_eq!(key.tag(), 0);
let key = tag1.generate(k);
assert_eq!(key.id(), k);
assert_eq!(key.tag(), 1);
}
}
#[test]
fn test_graph() {
let tag = 42;
let cache = Cache::new(PowerOfTwo::new(128 * 1024).unwrap()).unwrap();
let max_degree = 4;
let keygen = Tag::<u32>::new(tag);
let stats = HitStats::new();
let mut graph = Graph::new(&cache, max_degree, keygen, &stats);
assert_eq!(graph.stats().get_local_hits(), 0);
assert_eq!(graph.stats().get_local_misses(), 0);
let mut a = AdjacencyList::new();
let id = 90u32;
assert_eq!(
graph.try_get_neighbors(id, &mut a).unwrap(),
provider::NeighborStatus::Miss,
"`try_get_neighbors` should return `Miss` when the term does not exist",
);
graph.set_neighbors(id, &[1, 2, 3]).unwrap();
assert_eq!(
graph.try_get_neighbors(id, &mut a).unwrap(),
provider::NeighborStatus::Hit,
"`try_get_neighbors` should succeed when neighbors are present",
);
assert_eq!(&*a, &[1, 2, 3]);
{
let mut cacher = AdjacencyListCacher::<u32>::new(max_degree);
let mut a = AdjacencyList::<u32>::new();
assert!(
!cache
.get_into(id, &mut cacher, &mut a)
.unwrap()
.into_inner(),
"attempt to access via raw `id` should fail because keys are tagged"
);
assert!(
cache
.get_into(CacheKey { id, tag }, &mut cacher, &mut a)
.unwrap()
.into_inner()
);
assert_eq!(&*a, &[1, 2, 3]);
}
graph.set_neighbors(id, &[]).unwrap();
assert_eq!(
graph.try_get_neighbors(id, &mut a).unwrap(),
provider::NeighborStatus::Hit,
"`try_get_neighbors` should succeed when neighbors are present",
);
assert!(a.is_empty());
graph.invalidate_neighbors(id);
assert_eq!(
graph.try_get_neighbors(id, &mut a).unwrap(),
provider::NeighborStatus::Miss,
"attempted mutation invalidates the graph"
);
}
}