use super::{Key, Value, ValueData, LSMConfig};
use crate::index::diskann::fresh_graph::{FreshVamanaGraph, FreshGraphConfig, VectorNode};
use crate::distance::DistanceKind;
use crate::{Result, StorageError};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use parking_lot::RwLock;
use std::collections::BTreeMap;
#[derive(Clone, Debug)]
pub struct UnifiedEntry {
pub data: ValueData,
pub vector: Option<Vec<f32>>,
pub timestamp: u64,
pub deleted: bool,
}
impl UnifiedEntry {
pub fn new(data: ValueData, vector: Option<Vec<f32>>, timestamp: u64) -> Self {
Self {
data,
vector,
timestamp,
deleted: false,
}
}
pub fn tombstone(timestamp: u64) -> Self {
Self {
data: ValueData::Inline(Vec::new()),
vector: None,
timestamp,
deleted: true,
}
}
pub fn memory_size(&self) -> usize {
let data_size = match &self.data {
ValueData::Inline(data) => data.len(),
ValueData::Blob(_) => 16, };
let vector_size = self.vector.as_ref().map(|v| v.len() * 4).unwrap_or(0);
data_size + vector_size + 16 }
}
pub struct UnifiedMemTable {
entries: Arc<RwLock<BTreeMap<Key, UnifiedEntry>>>,
vector_graph: Option<Arc<FreshVamanaGraph>>,
vector_dimension: Option<usize>,
size: AtomicUsize,
max_size: usize,
next_seq: AtomicUsize,
}
impl UnifiedMemTable {
pub fn new(config: &LSMConfig) -> Self {
Self {
entries: Arc::new(RwLock::new(BTreeMap::new())),
vector_graph: None,
vector_dimension: None,
size: AtomicUsize::new(0),
max_size: config.memtable_size,
next_seq: AtomicUsize::new(0),
}
}
pub fn new_with_vector_support(config: &LSMConfig, dimension: usize) -> Self {
let fresh_config = FreshGraphConfig {
max_nodes: 5000, max_degree: 32,
search_list_size: 64,
alpha: 1.2,
memory_threshold: 20 * 1024 * 1024, };
let metric = DistanceKind::Cosine;
let vector_graph = FreshVamanaGraph::new(fresh_config, metric);
Self {
entries: Arc::new(RwLock::new(BTreeMap::new())),
vector_graph: Some(Arc::new(vector_graph)),
vector_dimension: Some(dimension),
size: AtomicUsize::new(0),
max_size: config.memtable_size,
next_seq: AtomicUsize::new(0),
}
}
pub fn put(&self, key: Key, value: Value) -> Result<()> {
let entry = UnifiedEntry {
data: value.data,
vector: None,
timestamp: value.timestamp,
deleted: value.deleted,
};
self.put_unified(key, entry)
}
pub fn put_with_vector(&self, key: Key, data: ValueData, vector: Vec<f32>, timestamp: u64) -> Result<()> {
if let Some(expected_dim) = self.vector_dimension {
if vector.len() != expected_dim {
return Err(StorageError::InvalidData(
format!("Vector dimension mismatch: expected {}, got {}", expected_dim, vector.len())
));
}
}
let entry = UnifiedEntry::new(data, Some(vector.clone()), timestamp);
self.put_unified(key, entry)?;
if let Some(ref graph) = self.vector_graph {
graph.insert(key, vector)?;
}
Ok(())
}
fn put_unified(&self, key: Key, entry: UnifiedEntry) -> Result<()> {
let entry_size = entry.memory_size();
let mut entries = self.entries.write();
if let Some(old_entry) = entries.get(&key) {
let old_size = old_entry.memory_size();
self.size.fetch_sub(old_size, Ordering::Relaxed);
}
entries.insert(key, entry);
self.size.fetch_add(entry_size, Ordering::Relaxed);
self.next_seq.fetch_add(1, Ordering::Relaxed);
Ok(())
}
pub fn batch_put(&self, kvs: &[(Key, Value)]) -> Result<()> {
if kvs.is_empty() {
return Ok(());
}
let mut entries = self.entries.write();
let mut total_size_change: i64 = 0;
for (key, value) in kvs {
let entry = UnifiedEntry {
data: value.data.clone(),
vector: None,
timestamp: value.timestamp,
deleted: value.deleted,
};
let entry_size = entry.memory_size();
if let Some(old_entry) = entries.get(key) {
let old_size = old_entry.memory_size();
total_size_change -= old_size as i64;
}
entries.insert(*key, entry);
total_size_change += entry_size as i64;
}
if total_size_change > 0 {
self.size.fetch_add(total_size_change as usize, Ordering::Relaxed);
} else if total_size_change < 0 {
self.size.fetch_sub((-total_size_change) as usize, Ordering::Relaxed);
}
self.next_seq.fetch_add(kvs.len(), Ordering::Relaxed);
Ok(())
}
pub fn get(&self, key: Key) -> Result<Option<UnifiedEntry>> {
let entries = self.entries.read();
Ok(entries.get(&key).cloned())
}
pub fn delete(&self, key: Key, timestamp: u64) -> Result<()> {
let entry = UnifiedEntry::tombstone(timestamp);
let mut entries = self.entries.write();
if let Some(old_entry) = entries.get(&key) {
let old_size = old_entry.memory_size();
self.size.fetch_sub(old_size, Ordering::Relaxed);
}
entries.insert(key, entry.clone());
self.size.fetch_add(entry.memory_size(), Ordering::Relaxed);
Ok(())
}
pub fn should_flush(&self) -> bool {
self.size.load(Ordering::Relaxed) >= self.max_size
}
#[inline]
pub fn should_flush_atomic(&self) -> bool {
self.size.load(Ordering::Relaxed) >= self.max_size
}
pub fn size(&self) -> usize {
self.size.load(Ordering::Relaxed)
}
pub fn len(&self) -> usize {
self.entries.read().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn vector_search(&self, query: &[f32], k: usize) -> Result<Vec<(Key, UnifiedEntry, f32)>> {
let graph = self.vector_graph.as_ref()
.ok_or_else(|| StorageError::Index("Vector search not supported".into()))?;
let ef = k * 5; let candidates = graph.search(query, k, ef)?;
let entries = self.entries.read();
let mut results = Vec::with_capacity(candidates.len());
for candidate in candidates {
if let Some(entry) = entries.get(&candidate.id) {
if !entry.deleted {
results.push((candidate.id, entry.clone(), candidate.distance));
}
}
}
Ok(results)
}
pub fn iter(&self) -> UnifiedMemTableIterator {
UnifiedMemTableIterator::new(self.entries.clone())
}
pub fn snapshot(&self) -> Vec<(Key, UnifiedEntry)> {
let entries = self.entries.read();
entries.iter()
.map(|(k, v)| (*k, v.clone()))
.collect()
}
pub fn scan(&self, start: Key, end: Key) -> Result<Vec<(Key, UnifiedEntry)>> {
let entries = self.entries.read();
use std::ops::Bound;
let range = entries.range((
Bound::Included(&start),
Bound::Excluded(&end)
));
let estimated_size = ((end - start) as usize).min(1000);
let mut results = Vec::with_capacity(estimated_size);
for (k, v) in range {
results.push((*k, v.clone()));
}
Ok(results)
}
pub fn scan_all(&self) -> Result<Vec<(Key, UnifiedEntry)>> {
let entries = self.entries.read();
let mut results = Vec::with_capacity(entries.len());
for (k, v) in entries.iter() {
results.push((*k, v.clone()));
}
Ok(results)
}
pub fn export_vector_nodes(&self) -> Result<Vec<(Key, VectorNode)>> {
let graph = self.vector_graph.as_ref()
.ok_or_else(|| StorageError::Index("Vector graph not available".into()))?;
graph.export_nodes()
}
pub fn vector_dimension(&self) -> Option<usize> {
self.vector_dimension
}
}
pub struct UnifiedMemTableIterator {
entries: std::vec::IntoIter<(Key, UnifiedEntry)>,
}
impl UnifiedMemTableIterator {
pub fn new(entries: Arc<RwLock<BTreeMap<Key, UnifiedEntry>>>) -> Self {
let entries_guard = entries.read();
let entries: Vec<(Key, UnifiedEntry)> = entries_guard.iter()
.map(|(k, v)| (*k, v.clone()))
.collect();
Self {
entries: entries.into_iter(),
}
}
}
impl Iterator for UnifiedMemTableIterator {
type Item = (Key, UnifiedEntry);
fn next(&mut self) -> Option<Self::Item> {
self.entries.next()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_memtable() -> UnifiedMemTable {
UnifiedMemTable::new(&LSMConfig::default())
}
fn create_vector_memtable(dimension: usize) -> UnifiedMemTable {
UnifiedMemTable::new_with_vector_support(&LSMConfig::default(), dimension)
}
#[test]
fn test_put_get() {
let memtable = create_memtable();
let key = 12345u64;
let value = Value::new(b"test_value".to_vec(), 1);
memtable.put(key, value.clone()).unwrap();
let retrieved = memtable.get(key).unwrap().unwrap();
assert_eq!(retrieved.timestamp, 1);
assert_eq!(retrieved.deleted, false);
}
#[test]
fn test_put_with_vector() {
let memtable = create_vector_memtable(128);
let key = 1u64;
let data = ValueData::Inline(b"test_data".to_vec());
let vector = vec![1.0f32; 128];
memtable.put_with_vector(key, data, vector, 1).unwrap();
let retrieved = memtable.get(key).unwrap().unwrap();
assert!(retrieved.vector.is_some());
assert_eq!(retrieved.vector.unwrap().len(), 128);
}
#[test]
fn test_vector_search() {
let memtable = create_vector_memtable(3);
for i in 0..10 {
let key = i;
let data = ValueData::Inline(format!("data_{}", i).into_bytes());
let vector = vec![i as f32, (i + 1) as f32, (i + 2) as f32];
memtable.put_with_vector(key, data, vector, i).unwrap();
}
let query = vec![5.0, 6.0, 7.0];
let results = memtable.vector_search(&query, 3).unwrap();
assert!(results.len() > 0);
assert!(results.len() <= 3);
for (row_id, entry, distance) in results {
assert!(!entry.deleted);
assert!(entry.vector.is_some());
println!("Found row_id={}, distance={:.4}", row_id, distance);
}
}
#[test]
fn test_size_tracking() {
let memtable = create_vector_memtable(128);
assert_eq!(memtable.size(), 0);
let key = 1u64;
let data = ValueData::Inline(b"test".to_vec());
let vector = vec![1.0f32; 128];
memtable.put_with_vector(key, data, vector, 1).unwrap();
let size_after = memtable.size();
assert!(size_after > 0);
println!("Memory size: {} bytes", size_after);
assert!(size_after >= 500 && size_after <= 600);
}
#[test]
fn test_should_flush() {
let mut config = LSMConfig::default();
config.memtable_size = 5000;
let memtable = UnifiedMemTable::new_with_vector_support(&config, 128);
assert_eq!(memtable.should_flush(), false);
for i in 0..10 {
let data = ValueData::Inline(vec![0u8; 10]);
let vector = vec![1.0f32; 128];
memtable.put_with_vector(i, data, vector, i).unwrap();
}
assert_eq!(memtable.should_flush(), true);
}
}