use crate::core::snapshot::reader::SnapshotData;
use crate::core::wal::record::{
build_del_node_vector_payload, build_set_node_vector_payload, WalRecord,
};
use crate::error::{KiteError, Result};
use crate::types::*;
use crate::vector::store::{
create_vector_store, vector_store_delete, vector_store_get, vector_store_has, vector_store_insert,
};
use crate::vector::types::{VectorManifest, VectorStoreConfig};
use std::collections::HashMap;
use std::sync::Arc;
use super::SingleFileDB;
impl SingleFileDB {
pub fn set_node_vector(
&self,
node_id: NodeId,
prop_key_id: PropKeyId,
vector: &[f32],
) -> Result<()> {
let (txid, tx_handle) = self.require_write_tx_handle()?;
{
let stores = self.vector_stores.read();
if let Some(store) = stores.get(&prop_key_id) {
if store.config.dimensions != vector.len() {
return Err(KiteError::VectorDimensionMismatch {
expected: store.config.dimensions,
got: vector.len(),
});
}
}
}
let record = WalRecord::new(
WalRecordType::SetNodeVector,
txid,
build_set_node_vector_payload(node_id, prop_key_id, vector),
);
self.write_wal_tx(&tx_handle, record)?;
{
let mut tx = tx_handle.lock();
tx.pending.pending_vectors.insert(
(node_id, prop_key_id),
Some(VectorRef::from(vector.to_vec())),
);
}
Ok(())
}
pub fn delete_node_vector(&self, node_id: NodeId, prop_key_id: PropKeyId) -> Result<()> {
let (txid, tx_handle) = self.require_write_tx_handle()?;
let record = WalRecord::new(
WalRecordType::DelNodeVector,
txid,
build_del_node_vector_payload(node_id, prop_key_id),
);
self.write_wal_tx(&tx_handle, record)?;
{
let mut tx = tx_handle.lock();
tx.pending
.pending_vectors
.insert((node_id, prop_key_id), None); }
Ok(())
}
pub fn node_vector(&self, node_id: NodeId, prop_key_id: PropKeyId) -> Option<VectorRef> {
let tx_handle = self.current_tx_handle();
if let Some(handle) = tx_handle.as_ref() {
let tx = handle.lock();
if tx.pending.is_node_deleted(node_id) {
return None;
}
if let Some(pending) = tx.pending.pending_vectors.get(&(node_id, prop_key_id)) {
return pending.as_ref().map(Arc::clone);
}
}
let delta = self.delta.read();
if delta.is_node_deleted(node_id) {
return None;
}
if let Some(pending) = delta.pending_vectors.get(&(node_id, prop_key_id)) {
return pending.as_ref().map(Arc::clone);
}
let stores = self.vector_stores.read();
let store = stores.get(&prop_key_id)?;
vector_store_get(store, node_id).map(Arc::from)
}
pub fn has_node_vector(&self, node_id: NodeId, prop_key_id: PropKeyId) -> bool {
let tx_handle = self.current_tx_handle();
if let Some(handle) = tx_handle.as_ref() {
let tx = handle.lock();
if tx.pending.is_node_deleted(node_id) {
return false;
}
if let Some(pending) = tx.pending.pending_vectors.get(&(node_id, prop_key_id)) {
return pending.is_some();
}
}
let delta = self.delta.read();
if delta.is_node_deleted(node_id) {
return false;
}
if let Some(pending) = delta.pending_vectors.get(&(node_id, prop_key_id)) {
return pending.is_some();
}
let stores = self.vector_stores.read();
if let Some(store) = stores.get(&prop_key_id) {
return vector_store_has(store, node_id);
}
false
}
pub fn vector_store_or_create(&self, prop_key_id: PropKeyId, dimensions: usize) -> Result<()> {
let mut stores = self.vector_stores.write();
if stores.contains_key(&prop_key_id) {
let store = stores.get(&prop_key_id).ok_or_else(|| {
KiteError::Internal("vector store missing after contains_key".to_string())
})?;
if store.config.dimensions != dimensions {
return Err(KiteError::VectorDimensionMismatch {
expected: store.config.dimensions,
got: dimensions,
});
}
return Ok(());
}
let config = VectorStoreConfig::new(dimensions);
let manifest = create_vector_store(config);
stores.insert(prop_key_id, manifest);
Ok(())
}
pub(crate) fn apply_pending_vectors(
&self,
pending_vectors: &HashMap<(NodeId, PropKeyId), Option<VectorRef>>,
) {
let mut stores = self.vector_stores.write();
for (&(node_id, prop_key_id), operation) in pending_vectors {
match operation {
Some(vector) => {
let store = stores.entry(prop_key_id).or_insert_with(|| {
let config = VectorStoreConfig::new(vector.len());
create_vector_store(config)
});
let _ = vector_store_insert(store, node_id, vector.as_ref());
}
None => {
if let Some(store) = stores.get_mut(&prop_key_id) {
vector_store_delete(store, node_id);
}
}
}
}
}
}
pub(crate) fn vector_stores_from_snapshot(
snapshot: &SnapshotData,
) -> Result<HashMap<PropKeyId, VectorManifest>> {
let mut stores: HashMap<PropKeyId, VectorManifest> = HashMap::new();
if !snapshot.header.flags.contains(SnapshotFlags::HAS_VECTORS) {
return Ok(stores);
}
let num_nodes = snapshot.header.num_nodes as usize;
for phys in 0..num_nodes {
let node_id = match snapshot.node_id(phys as u32) {
Some(id) => id,
None => continue,
};
let Some(props) = snapshot.node_props(phys as u32) else {
continue;
};
for (key_id, value) in props {
if let PropValue::VectorF32(vec) = value {
let store = stores.entry(key_id).or_insert_with(|| {
let config = VectorStoreConfig::new(vec.len());
create_vector_store(config)
});
if store.config.dimensions != vec.len() {
return Err(KiteError::InvalidSnapshot(format!(
"Vector dimension mismatch for prop key {key_id}: expected {}, got {}",
store.config.dimensions,
vec.len()
)));
}
vector_store_insert(store, node_id, &vec).map_err(|e| {
KiteError::InvalidSnapshot(format!(
"Failed to insert vector for node {node_id} (prop {key_id}): {e}"
))
})?;
}
}
}
Ok(stores)
}
#[cfg(test)]
mod tests {
use crate::core::single_file::{close_single_file, open_single_file, SingleFileOpenOptions};
use crate::vector::distance::normalize;
use tempfile::tempdir;
#[test]
fn test_vector_persistence_across_checkpoint() {
let temp_dir = tempdir().unwrap();
let db_path = temp_dir.path().join("vectors.kitedb");
let db = open_single_file(&db_path, SingleFileOpenOptions::new()).unwrap();
db.begin(false).unwrap();
let node_id = db.create_node(None).unwrap();
let prop_key_id = db.define_propkey("embedding").unwrap();
db.set_node_vector(node_id, prop_key_id, &[0.1, 0.2, 0.3])
.unwrap();
db.commit().unwrap();
db.checkpoint().unwrap();
close_single_file(db).unwrap();
let db = open_single_file(&db_path, SingleFileOpenOptions::new()).unwrap();
let vec = db.node_vector(node_id, prop_key_id).unwrap();
let expected = normalize(&[0.1, 0.2, 0.3]);
assert_eq!(vec.len(), expected.len());
for (got, exp) in vec.iter().zip(expected.iter()) {
assert!((got - exp).abs() < 1e-6);
}
close_single_file(db).unwrap();
}
}