use super::config::{MAX_DELTA_CHAIN_DEPTH, MIN_CAPACITY_ESTIMATE};
use crate::core::error::{Result, VectorError};
use crate::core::hasher::IdentityHasher;
use crate::core::id::NodeId;
use crate::core::temporal::Timestamp;
use crate::index::vector::VectorIndex;
use crate::index::vector::hnsw::HnswIndex;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::hash::BuildHasherDefault;
use std::sync::Arc;
type IdentityBuildHasher = BuildHasherDefault<IdentityHasher>;
#[derive(Clone)]
pub(crate) enum VectorSnapshot {
Full(Arc<HashMap<NodeId, Arc<[f32]>, IdentityBuildHasher>>),
Delta {
base_time: Timestamp,
added: Arc<HashMap<NodeId, Arc<[f32]>, IdentityBuildHasher>>,
removed: Arc<HashSet<NodeId, IdentityBuildHasher>>,
},
}
impl VectorSnapshot {
pub(crate) fn get_vector(
&self,
node_id: &NodeId,
all_snapshots: &BTreeMap<Timestamp, VectorSnapshot>,
) -> Result<Option<Arc<[f32]>>> {
let mut current = self;
let mut depth = 0;
loop {
if depth >= MAX_DELTA_CHAIN_DEPTH {
return Err(VectorError::IndexError(format!(
"Delta chain depth exceeded {} for node {:?}. \
This indicates corrupted snapshot state or misconfiguration. \
Reduce full_snapshot_interval or check snapshot integrity.",
MAX_DELTA_CHAIN_DEPTH, node_id
))
.into());
}
match current {
VectorSnapshot::Full(vectors) => {
return Ok(vectors.get(node_id).cloned());
}
VectorSnapshot::Delta {
base_time,
added,
removed,
} => {
if removed.contains(node_id) {
return Ok(None);
}
if let Some(vec) = added.get(node_id) {
return Ok(Some(Arc::clone(vec)));
}
if let Some(base) = all_snapshots.get(base_time) {
current = base;
depth += 1;
} else {
return Err(VectorError::IndexError(format!(
"Base snapshot at {} missing for delta snapshot",
base_time
))
.into());
}
}
}
}
}
pub(crate) fn to_hashmap(
&self,
all_snapshots: &BTreeMap<Timestamp, VectorSnapshot>,
) -> Result<HashMap<NodeId, Arc<[f32]>, IdentityBuildHasher>> {
let mut current = self;
let mut delta_layers = Vec::new();
let mut depth = 0;
let base_vectors: HashMap<NodeId, Arc<[f32]>, IdentityBuildHasher> = loop {
if depth >= MAX_DELTA_CHAIN_DEPTH {
return Err(VectorError::IndexError(format!(
"Delta chain depth exceeded {} in to_hashmap(). \
This indicates corrupted snapshot state or misconfiguration. \
Reduce full_snapshot_interval or check snapshot integrity.",
MAX_DELTA_CHAIN_DEPTH
))
.into());
}
match current {
VectorSnapshot::Full(vectors) => {
break vectors.as_ref().clone();
}
VectorSnapshot::Delta {
base_time,
added,
removed,
} => {
struct DeltaLayer<'a> {
added: &'a HashMap<NodeId, Arc<[f32]>, IdentityBuildHasher>,
removed: &'a HashSet<NodeId, IdentityBuildHasher>,
}
delta_layers.push(DeltaLayer {
added: added.as_ref(),
removed: removed.as_ref(),
});
if let Some(base) = all_snapshots.get(base_time) {
current = base;
depth += 1;
} else {
return Err(VectorError::IndexError(format!(
"Base snapshot at {} missing during reconstruction",
base_time
))
.into());
}
}
}
};
let mut result = base_vectors;
for layer in delta_layers.iter().rev() {
for node_id in layer.removed.iter() {
result.remove(node_id);
}
for (node_id, vector) in layer.added.iter() {
result.insert(*node_id, Arc::clone(vector));
}
}
Ok(result)
}
pub(crate) fn collect_all(
&self,
all_snapshots: &BTreeMap<Timestamp, VectorSnapshot>,
) -> Result<Vec<(NodeId, Arc<[f32]>)>> {
let map = self.to_hashmap(all_snapshots)?;
Ok(map.into_iter().collect())
}
pub(crate) fn len(&self) -> usize {
match self {
VectorSnapshot::Full(vectors) => vectors.len(),
VectorSnapshot::Delta { added, .. } => {
added.len().max(MIN_CAPACITY_ESTIMATE)
}
}
}
#[allow(dead_code)]
pub(crate) fn to_vec(
&self,
all_snapshots: &BTreeMap<Timestamp, VectorSnapshot>,
) -> Result<Vec<(NodeId, Arc<[f32]>)>> {
let map = self.to_hashmap(all_snapshots)?;
Ok(map.into_iter().collect())
}
}
#[derive(Clone)]
pub(crate) enum SnapshotIndex {
Full(Arc<HnswIndex>),
Delta(Arc<DeltaIndex>),
}
impl std::fmt::Debug for SnapshotIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SnapshotIndex::Full(index) => f
.debug_struct("SnapshotIndex::Full")
.field("len", &index.len())
.field("dimensions", &index.dimensions())
.finish(),
SnapshotIndex::Delta(delta) => f
.debug_struct("SnapshotIndex::Delta")
.field("base_len", &delta.base.len())
.field("added_len", &delta.added.len())
.field("removed_len", &delta.removed.len())
.field(
"total_len",
&(delta.base.len() + delta.added.len() - delta.removed.len()),
)
.finish(),
}
}
}
impl SnapshotIndex {
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(NodeId, f32)>> {
match self {
SnapshotIndex::Full(index) => index.search(query, k),
SnapshotIndex::Delta(delta) => delta.search(query, k),
}
}
pub fn search_with_filter(
&self,
query: &[f32],
k: usize,
predicate: &(dyn Fn(&NodeId) -> bool + Send + Sync),
) -> Result<Vec<(NodeId, f32)>> {
match self {
SnapshotIndex::Full(index) => index.search_with_filter(query, k, predicate),
SnapshotIndex::Delta(delta) => delta.search_with_filter(query, k, predicate),
}
}
pub(crate) fn len(&self) -> usize {
match self {
SnapshotIndex::Full(index) => index.len(),
SnapshotIndex::Delta(delta) => {
delta.base.len() + delta.added.len() - delta.removed.len()
}
}
}
pub(crate) fn dimensions(&self) -> usize {
match self {
SnapshotIndex::Full(index) => index.dimensions(),
SnapshotIndex::Delta(delta) => delta.added.dimensions(),
}
}
}
pub(crate) struct DeltaIndex {
pub(crate) base: Arc<SnapshotIndex>,
pub(crate) added: Arc<HnswIndex>,
pub(crate) removed: Arc<HashSet<NodeId, IdentityBuildHasher>>,
}
impl std::fmt::Debug for DeltaIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DeltaIndex")
.field("base", &self.base)
.field("added_len", &self.added.len())
.field("added_dimensions", &self.added.dimensions())
.field("removed_count", &self.removed.len())
.finish()
}
}
impl DeltaIndex {
fn search(&self, query: &[f32], k: usize) -> Result<Vec<(NodeId, f32)>> {
let search_k = k.saturating_mul(2).max(k + 10);
let mut results = self.added.search(query, search_k)?;
let removed = &self.removed;
let added_ids: HashSet<NodeId> = results.iter().map(|(id, _)| *id).collect();
let predicate = |id: &NodeId| !removed.contains(id) && !added_ids.contains(id);
let base_results = self.base.search_with_filter(query, search_k, &predicate)?;
results.extend(base_results);
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(k);
Ok(results)
}
fn search_with_filter(
&self,
query: &[f32],
k: usize,
predicate: &(dyn Fn(&NodeId) -> bool + Send + Sync),
) -> Result<Vec<(NodeId, f32)>> {
let search_k = k.saturating_mul(2).max(k + 10);
let removed = &self.removed;
let combined_predicate = |id: &NodeId| predicate(id) && !removed.contains(id);
let mut results = self.added.search_with_filter(query, search_k, predicate)?;
let base_results = self
.base
.search_with_filter(query, search_k, &combined_predicate)?;
results.extend(base_results);
use std::collections::HashSet;
let mut seen = HashSet::new();
results.retain(|(id, _)| seen.insert(*id));
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(k);
Ok(results)
}
}
pub(crate) struct SnapshotData {
pub(crate) snapshots: BTreeMap<Timestamp, (usize, SnapshotIndex)>,
pub(crate) vector_history: BTreeMap<Timestamp, VectorSnapshot>,
}
impl SnapshotData {
pub(crate) fn new() -> Self {
Self {
snapshots: BTreeMap::new(),
vector_history: BTreeMap::new(),
}
}
pub(crate) fn insert(
&mut self,
timestamp: Timestamp,
id: usize,
index: SnapshotIndex,
vectors: VectorSnapshot,
) {
self.snapshots.insert(timestamp, (id, index));
self.vector_history.insert(timestamp, vectors);
}
pub(crate) fn len(&self) -> usize {
self.snapshots.len()
}
pub(crate) fn remove_oldest(&mut self) {
if let Some(key) = self.snapshots.keys().next().copied() {
self.snapshots.remove(&key);
self.vector_history.remove(&key);
}
}
}
pub(crate) struct SnapshotMetadata {
pub(crate) total_snapshots: usize,
pub(crate) transactions_since_snapshot: usize,
pub(crate) last_snapshot_time: Timestamp,
pub(crate) vectors_changed_since_snapshot: HashSet<NodeId, IdentityBuildHasher>,
pub(crate) last_full_snapshot_time: Timestamp,
pub(crate) snapshots_since_full: usize,
pub(crate) changes_accumulated: HashSet<NodeId, IdentityBuildHasher>,
}
impl SnapshotMetadata {
pub(crate) fn new(initial_time: Timestamp) -> Self {
Self {
total_snapshots: 0,
transactions_since_snapshot: 0,
last_snapshot_time: initial_time,
vectors_changed_since_snapshot: HashSet::with_hasher(BuildHasherDefault::default()),
last_full_snapshot_time: initial_time,
snapshots_since_full: 0,
changes_accumulated: HashSet::with_hasher(BuildHasherDefault::default()),
}
}
pub(crate) fn record_change(&mut self, id: NodeId) {
self.vectors_changed_since_snapshot.insert(id);
self.changes_accumulated.insert(id);
}
pub(crate) fn record_transaction(&mut self) {
self.transactions_since_snapshot += 1;
}
pub(crate) fn reset(&mut self, current_time: Timestamp, is_full: bool) {
self.transactions_since_snapshot = 0;
self.last_snapshot_time = current_time;
self.vectors_changed_since_snapshot.clear();
self.total_snapshots += 1;
if is_full {
self.last_full_snapshot_time = current_time;
self.snapshots_since_full = 0;
self.changes_accumulated.clear();
} else {
self.snapshots_since_full += 1;
}
}
}
pub(crate) struct VectorState {
pub(crate) vectors: HashMap<NodeId, Arc<[f32]>, IdentityBuildHasher>,
pub(crate) metadata: SnapshotMetadata,
}
impl VectorState {
pub(crate) fn new(initial_time: Timestamp) -> Self {
Self {
vectors: HashMap::with_hasher(BuildHasherDefault::default()),
metadata: SnapshotMetadata::new(initial_time),
}
}
}