use crate::error::Result;
use crate::index::VectorIndex;
use crate::store::RecordStore;
use crate::types::{MemoryRecord, RecordId};
use crate::wal::WalWriter;
use parking_lot::RwLock;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Instant;
#[derive(Debug, Clone)]
pub enum BufferedOp {
Insert(MemoryRecord),
UpdateStats { id: RecordId, outcome: f64 },
Delete(RecordId),
}
#[derive(Debug, Clone)]
pub struct WriteBufferConfig {
pub max_ops: usize,
pub max_bytes: usize,
pub max_age_ms: u64,
pub use_wal: bool,
}
impl Default for WriteBufferConfig {
fn default() -> Self {
Self {
max_ops: 1000,
max_bytes: 64 * 1024 * 1024, max_age_ms: 5000, use_wal: true,
}
}
}
impl WriteBufferConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn with_max_ops(mut self, max: usize) -> Self {
self.max_ops = max;
self
}
#[must_use]
pub const fn with_max_bytes(mut self, max: usize) -> Self {
self.max_bytes = max;
self
}
#[must_use]
pub const fn with_max_age_ms(mut self, max: u64) -> Self {
self.max_age_ms = max;
self
}
#[must_use]
pub const fn without_wal(mut self) -> Self {
self.use_wal = false;
self
}
}
#[derive(Debug, Clone, Default)]
pub struct BufferStats {
pub buffered_ops: usize,
pub buffered_bytes: usize,
pub total_inserts: u64,
pub total_updates: u64,
pub total_deletes: u64,
pub flush_count: u64,
}
pub struct WriteBuffer {
config: WriteBufferConfig,
wal: Option<Arc<WalWriter>>,
ops: RwLock<VecDeque<BufferedOp>>,
size_bytes: AtomicUsize,
last_flush: RwLock<Instant>,
total_inserts: AtomicU64,
total_updates: AtomicU64,
total_deletes: AtomicU64,
flush_count: AtomicU64,
}
impl std::fmt::Debug for WriteBuffer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WriteBuffer")
.field("config", &self.config)
.field("ops_count", &self.ops.read().len())
.field("size_bytes", &self.size_bytes.load(Ordering::Relaxed))
.finish()
}
}
impl WriteBuffer {
#[must_use]
pub fn new(config: WriteBufferConfig, wal: Arc<WalWriter>) -> Self {
Self {
config,
wal: Some(wal),
ops: RwLock::new(VecDeque::new()),
size_bytes: AtomicUsize::new(0),
last_flush: RwLock::new(Instant::now()),
total_inserts: AtomicU64::new(0),
total_updates: AtomicU64::new(0),
total_deletes: AtomicU64::new(0),
flush_count: AtomicU64::new(0),
}
}
#[must_use]
pub fn without_wal(config: WriteBufferConfig) -> Self {
Self {
config,
wal: None,
ops: RwLock::new(VecDeque::new()),
size_bytes: AtomicUsize::new(0),
last_flush: RwLock::new(Instant::now()),
total_inserts: AtomicU64::new(0),
total_updates: AtomicU64::new(0),
total_deletes: AtomicU64::new(0),
flush_count: AtomicU64::new(0),
}
}
fn estimate_op_size(op: &BufferedOp) -> usize {
match op {
BufferedOp::Insert(record) => {
std::mem::size_of::<MemoryRecord>()
+ record.embedding.len() * 4
+ record.context.len()
+ record.id.len()
}
BufferedOp::UpdateStats { .. } => 32,
BufferedOp::Delete(_) => 32,
}
}
fn should_flush(&self) -> bool {
let ops = self.ops.read();
let size = self.size_bytes.load(Ordering::Relaxed);
let last_flush = self.last_flush.read();
if ops.len() >= self.config.max_ops {
return true;
}
if size >= self.config.max_bytes {
return true;
}
if self.config.max_age_ms > 0 {
let age = last_flush.elapsed().as_millis() as u64;
if age >= self.config.max_age_ms && !ops.is_empty() {
return true;
}
}
false
}
pub fn insert(&self, record: MemoryRecord) -> Result<()> {
if let Some(wal) = &self.wal {
wal.log_insert(&record)?;
}
let op = BufferedOp::Insert(record);
let size = Self::estimate_op_size(&op);
{
let mut ops = self.ops.write();
ops.push_back(op);
}
self.size_bytes.fetch_add(size, Ordering::Relaxed);
self.total_inserts.fetch_add(1, Ordering::Relaxed);
Ok(())
}
pub fn update_stats(&self, id: &RecordId, outcome: f64) -> Result<()> {
if let Some(wal) = &self.wal {
wal.log_update_stats(id, outcome)?;
}
let op = BufferedOp::UpdateStats {
id: id.clone(),
outcome,
};
let size = Self::estimate_op_size(&op);
{
let mut ops = self.ops.write();
ops.push_back(op);
}
self.size_bytes.fetch_add(size, Ordering::Relaxed);
self.total_updates.fetch_add(1, Ordering::Relaxed);
Ok(())
}
pub fn delete(&self, id: &RecordId) -> Result<()> {
if let Some(wal) = &self.wal {
wal.log_delete(id)?;
}
let op = BufferedOp::Delete(id.clone());
let size = Self::estimate_op_size(&op);
{
let mut ops = self.ops.write();
ops.push_back(op);
}
self.size_bytes.fetch_add(size, Ordering::Relaxed);
self.total_deletes.fetch_add(1, Ordering::Relaxed);
Ok(())
}
pub fn flush<S: RecordStore, I: VectorIndex>(
&self,
store: &mut S,
index: &mut I,
) -> Result<FlushResult> {
let ops: Vec<BufferedOp> = {
let mut ops_guard = self.ops.write();
std::mem::take(&mut *ops_guard).into()
};
if ops.is_empty() {
return Ok(FlushResult::default());
}
let mut result = FlushResult::default();
for op in ops {
match op {
BufferedOp::Insert(record) => {
index.add(record.id.to_string(), &record.embedding)?;
store.insert(record)?;
result.inserts += 1;
}
BufferedOp::UpdateStats { id, outcome } => {
store.update_stats(&id, outcome)?;
result.updates += 1;
}
BufferedOp::Delete(id) => {
index.remove(id.as_str())?;
store.remove(&id)?;
result.deletes += 1;
}
}
}
self.size_bytes.store(0, Ordering::SeqCst);
*self.last_flush.write() = Instant::now();
self.flush_count.fetch_add(1, Ordering::Relaxed);
if let Some(wal) = &self.wal {
wal.log_checkpoint()?;
}
Ok(result)
}
pub fn flush_to_store<S: RecordStore>(&self, store: &mut S) -> Result<FlushResult> {
let ops: Vec<BufferedOp> = {
let mut ops_guard = self.ops.write();
std::mem::take(&mut *ops_guard).into()
};
if ops.is_empty() {
return Ok(FlushResult::default());
}
let mut result = FlushResult::default();
for op in ops {
match op {
BufferedOp::Insert(record) => {
store.insert(record)?;
result.inserts += 1;
}
BufferedOp::UpdateStats { id, outcome } => {
store.update_stats(&id, outcome)?;
result.updates += 1;
}
BufferedOp::Delete(id) => {
store.remove(&id)?;
result.deletes += 1;
}
}
}
self.size_bytes.store(0, Ordering::SeqCst);
*self.last_flush.write() = Instant::now();
self.flush_count.fetch_add(1, Ordering::Relaxed);
Ok(result)
}
pub fn maybe_flush<S: RecordStore, I: VectorIndex>(
&self,
store: &mut S,
index: &mut I,
) -> Result<bool> {
if self.should_flush() {
self.flush(store, index)?;
Ok(true)
} else {
Ok(false)
}
}
#[must_use]
pub fn stats(&self) -> BufferStats {
BufferStats {
buffered_ops: self.ops.read().len(),
buffered_bytes: self.size_bytes.load(Ordering::Relaxed),
total_inserts: self.total_inserts.load(Ordering::Relaxed),
total_updates: self.total_updates.load(Ordering::Relaxed),
total_deletes: self.total_deletes.load(Ordering::Relaxed),
flush_count: self.flush_count.load(Ordering::Relaxed),
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.ops.read().is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.ops.read().len()
}
}
#[derive(Debug, Clone, Default)]
pub struct FlushResult {
pub inserts: usize,
pub updates: usize,
pub deletes: usize,
}
impl FlushResult {
#[must_use]
pub fn total(&self) -> usize {
self.inserts + self.updates + self.deletes
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::index::{FlatIndex, IndexConfig};
use crate::stats::OutcomeStats;
use crate::store::InMemoryStore;
use crate::types::RecordStatus;
fn create_test_record(id: &str) -> MemoryRecord {
MemoryRecord {
id: id.into(),
embedding: vec![1.0, 2.0, 3.0],
context: format!("Context for {id}"),
outcome: 0.5,
metadata: Default::default(),
created_at: 1234567890,
status: RecordStatus::Active,
stats: OutcomeStats::new(1),
}
}
#[test]
fn test_buffer_insert_and_flush() {
let config = WriteBufferConfig::new().without_wal();
let buffer = WriteBuffer::without_wal(config);
buffer.insert(create_test_record("rec-1")).unwrap();
buffer.insert(create_test_record("rec-2")).unwrap();
assert_eq!(buffer.len(), 2);
let mut store = InMemoryStore::new();
let mut index = FlatIndex::new(IndexConfig::new(3));
let result = buffer.flush(&mut store, &mut index).unwrap();
assert_eq!(result.inserts, 2);
assert_eq!(store.len(), 2);
assert_eq!(index.len(), 2);
assert!(buffer.is_empty());
}
#[test]
fn test_buffer_update_stats() {
let config = WriteBufferConfig::new().without_wal();
let buffer = WriteBuffer::without_wal(config);
buffer.insert(create_test_record("rec-1")).unwrap();
buffer.update_stats(&"rec-1".into(), 0.8).unwrap();
buffer.update_stats(&"rec-1".into(), 0.9).unwrap();
let mut store = InMemoryStore::new();
let mut index = FlatIndex::new(IndexConfig::new(3));
let result = buffer.flush(&mut store, &mut index).unwrap();
assert_eq!(result.inserts, 1);
assert_eq!(result.updates, 2);
let record = store.get(&"rec-1".into()).unwrap();
assert_eq!(record.stats.count(), 2);
}
#[test]
fn test_buffer_delete() {
let config = WriteBufferConfig::new().without_wal();
let buffer = WriteBuffer::without_wal(config);
buffer.insert(create_test_record("rec-1")).unwrap();
buffer.insert(create_test_record("rec-2")).unwrap();
let mut store = InMemoryStore::new();
let mut index = FlatIndex::new(IndexConfig::new(3));
buffer.flush(&mut store, &mut index).unwrap();
buffer.delete(&"rec-1".into()).unwrap();
let result = buffer.flush(&mut store, &mut index).unwrap();
assert_eq!(result.deletes, 1);
assert_eq!(store.len(), 1);
assert_eq!(index.len(), 1);
}
#[test]
fn test_auto_flush_by_ops() {
let config = WriteBufferConfig::new()
.without_wal()
.with_max_ops(5);
let buffer = WriteBuffer::without_wal(config);
let mut store = InMemoryStore::new();
let mut index = FlatIndex::new(IndexConfig::new(3));
for i in 0..4 {
buffer.insert(create_test_record(&format!("rec-{i}"))).unwrap();
buffer.maybe_flush(&mut store, &mut index).unwrap();
}
assert!(!buffer.is_empty());
buffer.insert(create_test_record("rec-4")).unwrap();
let flushed = buffer.maybe_flush(&mut store, &mut index).unwrap();
assert!(flushed);
assert!(buffer.is_empty());
assert_eq!(store.len(), 5);
}
#[test]
fn test_buffer_stats() {
let config = WriteBufferConfig::new().without_wal();
let buffer = WriteBuffer::without_wal(config);
buffer.insert(create_test_record("rec-1")).unwrap();
buffer.insert(create_test_record("rec-2")).unwrap();
buffer.update_stats(&"rec-1".into(), 0.8).unwrap();
buffer.delete(&"rec-2".into()).unwrap();
let stats = buffer.stats();
assert_eq!(stats.buffered_ops, 4);
assert!(stats.buffered_bytes > 0);
assert_eq!(stats.total_inserts, 2);
assert_eq!(stats.total_updates, 1);
assert_eq!(stats.total_deletes, 1);
let mut store = InMemoryStore::new();
let mut index = FlatIndex::new(IndexConfig::new(3));
buffer.flush(&mut store, &mut index).unwrap();
let stats_after = buffer.stats();
assert_eq!(stats_after.buffered_ops, 0);
assert_eq!(stats_after.flush_count, 1);
}
#[test]
fn test_flush_to_store_only() {
let config = WriteBufferConfig::new().without_wal();
let buffer = WriteBuffer::without_wal(config);
buffer.insert(create_test_record("rec-1")).unwrap();
let mut store = InMemoryStore::new();
let result = buffer.flush_to_store(&mut store).unwrap();
assert_eq!(result.inserts, 1);
assert_eq!(store.len(), 1);
}
#[test]
fn test_empty_flush() {
let config = WriteBufferConfig::new().without_wal();
let buffer = WriteBuffer::without_wal(config);
let mut store = InMemoryStore::new();
let mut index = FlatIndex::new(IndexConfig::new(3));
let result = buffer.flush(&mut store, &mut index).unwrap();
assert_eq!(result.total(), 0);
}
}