use std::collections::HashMap;
use std::fs::{File, OpenOptions};
use std::io::{BufWriter, Write};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::thread::JoinHandle;
#[derive(Debug, Clone)]
pub struct LsmConfig {
pub dim: usize,
pub mutable_capacity: usize,
pub wal_path: PathBuf,
pub sync_wal: bool,
pub wal_batch_size: usize,
pub build_threads: usize,
pub auto_compact: bool,
pub compact_threshold: usize,
}
impl Default for LsmConfig {
fn default() -> Self {
Self {
dim: 768,
mutable_capacity: 10_000,
wal_path: PathBuf::from("./wal"),
sync_wal: true,
wal_batch_size: 100,
build_threads: 2,
auto_compact: true,
compact_threshold: 4,
}
}
}
pub type VectorKey = u64;
mod wal_record_compat {
use sochdb_core::txn::WalRecordType;
pub(super) fn to_disk_byte(rt: WalRecordType) -> u8 {
match rt {
WalRecordType::Data => 1, WalRecordType::Delete => 2, WalRecordType::Flush => 3, WalRecordType::Compaction => 4, _ => 0xFF, }
}
}
use sochdb_core::txn::WalRecordType;
#[repr(C, packed)]
struct WalHeader {
record_type: u8,
key: u64,
dim: u32,
checksum: u32,
}
pub struct WriteAheadLog {
writer: BufWriter<File>,
position: u64,
pending: usize,
sync_interval: usize,
writes: AtomicU64,
syncs: AtomicU64,
}
impl WriteAheadLog {
pub fn open(path: &Path, sync_interval: usize) -> std::io::Result<Self> {
let file = OpenOptions::new()
.create(true)
.write(true)
.append(true)
.open(path)?;
let position = file.metadata()?.len();
let writer = BufWriter::with_capacity(64 * 1024, file);
Ok(Self {
writer,
position,
pending: 0,
sync_interval,
writes: AtomicU64::new(0),
syncs: AtomicU64::new(0),
})
}
pub fn write_insert(&mut self, key: VectorKey, vector: &[f32]) -> std::io::Result<()> {
let dim = vector.len() as u32;
let checksum = self.compute_checksum(key, vector);
let header = WalHeader {
record_type: wal_record_compat::to_disk_byte(WalRecordType::Data),
key,
dim,
checksum,
};
let header_bytes = unsafe {
std::slice::from_raw_parts(
&header as *const WalHeader as *const u8,
std::mem::size_of::<WalHeader>(),
)
};
self.writer.write_all(header_bytes)?;
let vector_bytes = unsafe {
std::slice::from_raw_parts(
vector.as_ptr() as *const u8,
vector.len() * std::mem::size_of::<f32>(),
)
};
self.writer.write_all(vector_bytes)?;
self.position += header_bytes.len() as u64 + vector_bytes.len() as u64;
self.pending += 1;
self.writes.fetch_add(1, Ordering::Relaxed);
if self.pending >= self.sync_interval {
self.sync()?;
}
Ok(())
}
pub fn write_seal_start(&mut self, segment_id: u64) -> std::io::Result<()> {
let header = WalHeader {
record_type: wal_record_compat::to_disk_byte(WalRecordType::Flush),
key: segment_id,
dim: 0,
checksum: 0,
};
let header_bytes = unsafe {
std::slice::from_raw_parts(
&header as *const WalHeader as *const u8,
std::mem::size_of::<WalHeader>(),
)
};
self.writer.write_all(header_bytes)?;
self.sync()?;
Ok(())
}
pub fn write_seal_complete(&mut self, segment_id: u64) -> std::io::Result<()> {
let header = WalHeader {
record_type: wal_record_compat::to_disk_byte(WalRecordType::Compaction),
key: segment_id,
dim: 0,
checksum: 0,
};
let header_bytes = unsafe {
std::slice::from_raw_parts(
&header as *const WalHeader as *const u8,
std::mem::size_of::<WalHeader>(),
)
};
self.writer.write_all(header_bytes)?;
self.sync()?;
Ok(())
}
pub fn sync(&mut self) -> std::io::Result<()> {
self.writer.flush()?;
self.writer.get_ref().sync_all()?;
self.pending = 0;
self.syncs.fetch_add(1, Ordering::Relaxed);
Ok(())
}
fn compute_checksum(&self, key: VectorKey, vector: &[f32]) -> u32 {
let mut hasher = crc32fast::Hasher::new();
hasher.update(&(key as u32).to_le_bytes());
for &x in vector {
hasher.update(&x.to_le_bytes());
}
hasher.finalize()
}
pub fn stats(&self) -> WalStats {
WalStats {
writes: self.writes.load(Ordering::Relaxed),
syncs: self.syncs.load(Ordering::Relaxed),
position: self.position,
}
}
}
#[derive(Debug, Clone)]
pub struct WalStats {
pub writes: u64,
pub syncs: u64,
pub position: u64,
}
pub struct MutableSegment {
vectors: HashMap<VectorKey, (u32, Vec<f32>)>,
keys: Vec<VectorKey>,
#[allow(dead_code)]
dim: usize,
capacity: usize,
}
impl MutableSegment {
pub fn new(dim: usize, capacity: usize) -> Self {
Self {
vectors: HashMap::with_capacity(capacity),
keys: Vec::with_capacity(capacity),
dim,
capacity,
}
}
pub fn insert(&mut self, key: VectorKey, vector: Vec<f32>) -> bool {
if self.vectors.len() >= self.capacity {
return false;
}
let index = self.keys.len() as u32;
self.vectors.insert(key, (index, vector));
self.keys.push(key);
true
}
pub fn is_full(&self) -> bool {
self.vectors.len() >= self.capacity
}
pub fn len(&self) -> usize {
self.vectors.len()
}
pub fn is_empty(&self) -> bool {
self.vectors.is_empty()
}
pub fn get(&self, key: VectorKey) -> Option<&[f32]> {
self.vectors.get(&key).map(|(_, v)| v.as_slice())
}
pub fn drain(&mut self) -> Vec<(VectorKey, Vec<f32>)> {
let result: Vec<_> = self
.keys
.drain(..)
.filter_map(|k| self.vectors.remove(&k).map(|(_, v)| (k, v)))
.collect();
result
}
}
pub struct SealedSegment {
pub id: u64,
pub data: Vec<f32>,
pub key_to_index: HashMap<VectorKey, u32>,
pub index_to_key: Vec<VectorKey>,
pub dim: usize,
pub build_time_ns: u64,
}
impl SealedSegment {
pub fn len(&self) -> usize {
self.index_to_key.len()
}
pub fn is_empty(&self) -> bool {
self.index_to_key.is_empty()
}
pub fn get(&self, key: VectorKey) -> Option<&[f32]> {
self.key_to_index.get(&key).map(|&idx| {
let start = idx as usize * self.dim;
&self.data[start..start + self.dim]
})
}
pub fn get_by_index(&self, index: u32) -> Option<&[f32]> {
if (index as usize) < self.index_to_key.len() {
let start = index as usize * self.dim;
Some(&self.data[start..start + self.dim])
} else {
None
}
}
}
struct BuildTask {
segment_id: u64,
vectors: Vec<(VectorKey, Vec<f32>)>,
#[allow(dead_code)]
dim: usize,
}
#[allow(dead_code)]
struct BuildResult {
segment: SealedSegment,
}
pub struct AsyncLsmManager {
config: LsmConfig,
wal: Mutex<WriteAheadLog>,
mutable: RwLock<MutableSegment>,
sealed: RwLock<Vec<Arc<SealedSegment>>>,
pending_builds: Mutex<Vec<BuildTask>>,
workers: Mutex<Vec<JoinHandle<()>>>,
shutdown: Arc<AtomicBool>,
segment_counter: AtomicU64,
stats: LsmStats,
}
impl AsyncLsmManager {
pub fn new(config: LsmConfig) -> std::io::Result<Self> {
std::fs::create_dir_all(&config.wal_path)?;
let wal_file = config.wal_path.join("wal.log");
let wal = WriteAheadLog::open(&wal_file, config.wal_batch_size)?;
let mutable = MutableSegment::new(config.dim, config.mutable_capacity);
let shutdown = Arc::new(AtomicBool::new(false));
Ok(Self {
config,
wal: Mutex::new(wal),
mutable: RwLock::new(mutable),
sealed: RwLock::new(Vec::new()),
pending_builds: Mutex::new(Vec::new()),
workers: Mutex::new(Vec::new()),
shutdown,
segment_counter: AtomicU64::new(0),
stats: LsmStats::default(),
})
}
pub fn insert(&self, key: VectorKey, vector: Vec<f32>) -> Result<(), LsmError> {
{
let mut wal = self.wal.lock().unwrap();
wal.write_insert(key, &vector)?;
}
{
let mut mutable = self.mutable.write().unwrap();
if mutable.is_full() {
drop(mutable);
self.seal_async()?;
mutable = self.mutable.write().unwrap();
}
if !mutable.insert(key, vector) {
return Err(LsmError::SegmentFull);
}
}
self.stats.inserts.fetch_add(1, Ordering::Relaxed);
Ok(())
}
pub fn insert_batch(&self, items: Vec<(VectorKey, Vec<f32>)>) -> Result<(), LsmError> {
{
let mut wal = self.wal.lock().unwrap();
for (key, vector) in &items {
wal.write_insert(*key, vector)?;
}
wal.sync()?;
}
let mut mutable = self.mutable.write().unwrap();
for (key, vector) in items {
if mutable.is_full() {
drop(mutable);
self.seal_async()?;
mutable = self.mutable.write().unwrap();
}
mutable.insert(key, vector);
self.stats.inserts.fetch_add(1, Ordering::Relaxed);
}
Ok(())
}
pub fn seal_async(&self) -> Result<u64, LsmError> {
let segment_id = self.segment_counter.fetch_add(1, Ordering::Relaxed);
{
let mut wal = self.wal.lock().unwrap();
wal.write_seal_start(segment_id)?;
}
let vectors = {
let mut mutable = self.mutable.write().unwrap();
let vectors = mutable.drain();
*mutable = MutableSegment::new(self.config.dim, self.config.mutable_capacity);
vectors
};
if vectors.is_empty() {
return Ok(segment_id);
}
let task = BuildTask {
segment_id,
vectors,
dim: self.config.dim,
};
{
let mut pending = self.pending_builds.lock().unwrap();
pending.push(task);
}
self.ensure_worker_running();
self.stats.seals.fetch_add(1, Ordering::Relaxed);
Ok(segment_id)
}
pub fn seal_blocking(&self) -> Result<Arc<SealedSegment>, LsmError> {
let segment_id = self.segment_counter.fetch_add(1, Ordering::Relaxed);
let vectors = {
let mut mutable = self.mutable.write().unwrap();
let vectors = mutable.drain();
*mutable = MutableSegment::new(self.config.dim, self.config.mutable_capacity);
vectors
};
if vectors.is_empty() {
return Err(LsmError::EmptySegment);
}
let segment = self.build_segment(segment_id, vectors);
let segment = Arc::new(segment);
{
let mut sealed = self.sealed.write().unwrap();
sealed.push(Arc::clone(&segment));
}
{
let mut wal = self.wal.lock().unwrap();
wal.write_seal_complete(segment_id)?;
}
self.stats.seals.fetch_add(1, Ordering::Relaxed);
Ok(segment)
}
pub fn search(&self, query: &[f32], k: usize) -> Vec<(VectorKey, f32)> {
let mut results = Vec::new();
{
let mutable = self.mutable.read().unwrap();
for &key in &mutable.keys {
if let Some(vector) = mutable.get(key) {
let dist = l2_squared(query, vector);
results.push((key, dist));
}
}
}
{
let sealed = self.sealed.read().unwrap();
for segment in sealed.iter() {
for (i, &key) in segment.index_to_key.iter().enumerate() {
if let Some(vector) = segment.get_by_index(i as u32) {
let dist = l2_squared(query, vector);
results.push((key, dist));
}
}
}
}
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
results.truncate(k);
results
}
pub fn get(&self, key: VectorKey) -> Option<Vec<f32>> {
{
let mutable = self.mutable.read().unwrap();
if let Some(v) = mutable.get(key) {
return Some(v.to_vec());
}
}
{
let sealed = self.sealed.read().unwrap();
for segment in sealed.iter().rev() {
if let Some(v) = segment.get(key) {
return Some(v.to_vec());
}
}
}
None
}
pub fn flush(&self) -> Result<(), LsmError> {
loop {
let task = {
let mut pending = self.pending_builds.lock().unwrap();
pending.pop()
};
match task {
Some(task) => {
let segment = self.build_segment(task.segment_id, task.vectors);
let segment = Arc::new(segment);
let mut sealed = self.sealed.write().unwrap();
sealed.push(segment);
let mut wal = self.wal.lock().unwrap();
wal.write_seal_complete(task.segment_id)?;
}
None => break,
}
}
let mut wal = self.wal.lock().unwrap();
wal.sync()?;
Ok(())
}
pub fn stats(&self) -> LsmManagerStats {
let mutable_len = self.mutable.read().unwrap().len();
let sealed_count = self.sealed.read().unwrap().len();
let pending = self.pending_builds.lock().unwrap().len();
LsmManagerStats {
inserts: self.stats.inserts.load(Ordering::Relaxed),
seals: self.stats.seals.load(Ordering::Relaxed),
mutable_vectors: mutable_len,
sealed_segments: sealed_count,
pending_builds: pending,
}
}
fn build_segment(&self, segment_id: u64, vectors: Vec<(VectorKey, Vec<f32>)>) -> SealedSegment {
let start = std::time::Instant::now();
let dim = self.config.dim;
let mut data = Vec::with_capacity(vectors.len() * dim);
let mut key_to_index = HashMap::with_capacity(vectors.len());
let mut index_to_key = Vec::with_capacity(vectors.len());
for (i, (key, vector)) in vectors.into_iter().enumerate() {
data.extend_from_slice(&vector);
key_to_index.insert(key, i as u32);
index_to_key.push(key);
}
SealedSegment {
id: segment_id,
data,
key_to_index,
index_to_key,
dim,
build_time_ns: start.elapsed().as_nanos() as u64,
}
}
fn ensure_worker_running(&self) {
}
}
impl Drop for AsyncLsmManager {
fn drop(&mut self) {
self.shutdown.store(true, Ordering::Release);
let _ = self.flush();
let mut workers = self.workers.lock().unwrap();
for handle in workers.drain(..) {
let _ = handle.join();
}
}
}
#[derive(Default)]
struct LsmStats {
inserts: AtomicU64,
seals: AtomicU64,
}
#[derive(Debug, Clone)]
pub struct LsmManagerStats {
pub inserts: u64,
pub seals: u64,
pub mutable_vectors: usize,
pub sealed_segments: usize,
pub pending_builds: usize,
}
#[derive(Debug)]
pub enum LsmError {
Io(std::io::Error),
SegmentFull,
EmptySegment,
KeyNotFound,
}
impl From<std::io::Error> for LsmError {
fn from(e: std::io::Error) -> Self {
LsmError::Io(e)
}
}
impl std::fmt::Display for LsmError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LsmError::Io(e) => write!(f, "IO error: {}", e),
LsmError::SegmentFull => write!(f, "segment full"),
LsmError::EmptySegment => write!(f, "empty segment"),
LsmError::KeyNotFound => write!(f, "key not found"),
}
}
}
impl std::error::Error for LsmError {}
fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let d = x - y;
d * d
})
.sum()
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_wal_basic() {
let dir = tempdir().unwrap();
let wal_path = dir.path().join("wal.log");
let mut wal = WriteAheadLog::open(&wal_path, 10).unwrap();
let vector = vec![1.0, 2.0, 3.0, 4.0];
wal.write_insert(42, &vector).unwrap();
wal.sync().unwrap();
let stats = wal.stats();
assert_eq!(stats.writes, 1);
assert!(stats.position > 0);
}
#[test]
fn test_mutable_segment() {
let mut segment = MutableSegment::new(4, 10);
segment.insert(1, vec![1.0, 2.0, 3.0, 4.0]);
segment.insert(2, vec![5.0, 6.0, 7.0, 8.0]);
assert_eq!(segment.len(), 2);
assert_eq!(segment.get(1).unwrap(), &[1.0, 2.0, 3.0, 4.0]);
let drained = segment.drain();
assert_eq!(drained.len(), 2);
assert!(segment.is_empty());
}
#[test]
fn test_lsm_manager_basic() {
let dir = tempdir().unwrap();
let config = LsmConfig {
dim: 4,
mutable_capacity: 10,
wal_path: dir.path().to_path_buf(),
..Default::default()
};
let manager = AsyncLsmManager::new(config).unwrap();
manager.insert(1, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
manager.insert(2, vec![5.0, 6.0, 7.0, 8.0]).unwrap();
let v1 = manager.get(1).unwrap();
assert_eq!(v1, vec![1.0, 2.0, 3.0, 4.0]);
let stats = manager.stats();
assert_eq!(stats.inserts, 2);
assert_eq!(stats.mutable_vectors, 2);
}
#[test]
fn test_lsm_seal_blocking() {
let dir = tempdir().unwrap();
let config = LsmConfig {
dim: 4,
mutable_capacity: 10,
wal_path: dir.path().to_path_buf(),
..Default::default()
};
let manager = AsyncLsmManager::new(config).unwrap();
manager.insert(1, vec![1.0, 0.0, 0.0, 0.0]).unwrap();
manager.insert(2, vec![0.0, 1.0, 0.0, 0.0]).unwrap();
let segment = manager.seal_blocking().unwrap();
assert_eq!(segment.len(), 2);
assert!(manager.get(1).is_some());
let stats = manager.stats();
assert_eq!(stats.seals, 1);
assert_eq!(stats.sealed_segments, 1);
assert_eq!(stats.mutable_vectors, 0);
}
#[test]
fn test_lsm_search() {
let dir = tempdir().unwrap();
let config = LsmConfig {
dim: 4,
mutable_capacity: 100,
wal_path: dir.path().to_path_buf(),
..Default::default()
};
let manager = AsyncLsmManager::new(config).unwrap();
manager.insert(1, vec![1.0, 0.0, 0.0, 0.0]).unwrap();
manager.insert(2, vec![0.0, 1.0, 0.0, 0.0]).unwrap();
manager.insert(3, vec![0.5, 0.5, 0.0, 0.0]).unwrap();
let query = vec![1.0, 0.0, 0.0, 0.0];
let results = manager.search(&query, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, 1); assert!(results[0].1 < 0.01); }
#[test]
fn test_lsm_batch_insert() {
let dir = tempdir().unwrap();
let config = LsmConfig {
dim: 4,
mutable_capacity: 100,
wal_path: dir.path().to_path_buf(),
..Default::default()
};
let manager = AsyncLsmManager::new(config).unwrap();
let batch: Vec<_> = (0..10).map(|i| (i as u64, vec![i as f32; 4])).collect();
manager.insert_batch(batch).unwrap();
let stats = manager.stats();
assert_eq!(stats.inserts, 10);
for i in 0..10 {
let v = manager.get(i as u64).unwrap();
assert_eq!(v[0], i as f32);
}
}
}