use crate::error::DbError;
use crate::types::Row;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
type SeriesWriteBuffer = Arc<Mutex<Vec<Row>>>;
#[derive(Debug, Default)]
pub struct WriteBuffer {
buffers: HashMap<String, SeriesWriteBuffer>,
}
impl WriteBuffer {
pub(crate) fn stage(&mut self, series: &str, row: Row) -> Result<(), DbError> {
let buffer_arc = self
.buffers
.entry(series.to_string()) .or_insert_with(|| Arc::new(Mutex::new(Vec::new())));
let mut buffer_guard = buffer_arc.lock()?; buffer_guard.push(row);
Ok(())
}
pub(crate) fn drain_all_buffers(&mut self) -> Result<HashMap<String, Vec<Row>>, DbError> {
let mut drained_data = HashMap::new();
for (series_name, buffer_arc) in self.buffers.iter() {
let mut buffer_guard = buffer_arc.lock().map_err(|e| {
DbError::LockError(format!(
"Write buffer for series '{}' poisoned during drain: {}",
series_name, e
))
})?;
if !buffer_guard.is_empty() {
let points = std::mem::take(&mut *buffer_guard);
drained_data.insert(series_name.clone(), points);
}
}
Ok(drained_data)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{Row, TagSet, Timestamp, Value};
use std::thread;
use std::time::{SystemTime, UNIX_EPOCH};
fn create_row(seq: u64, ts: Timestamp, val: Value, tags: TagSet) -> Row {
Row {
seq,
timestamp: ts,
value: val,
tags,
}
}
fn get_current_timestamp() -> Timestamp {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos() as u64
}
fn create_tags(pairs: &[(&str, &str)]) -> TagSet {
pairs
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect()
}
#[test]
fn test_stage_single_point() {
let mut buffer = WriteBuffer::default();
let series = "test_series";
let ts = get_current_timestamp();
let tags = create_tags(&[("host", "server1")]);
let row = create_row(1, ts, 42.0, tags);
buffer.stage(series, row.clone()).unwrap();
let drained = buffer.drain_all_buffers().unwrap();
assert_eq!(drained.len(), 1, "Should have one series");
assert!(drained.contains_key(series), "Should contain our series");
let points = &drained[series];
assert_eq!(points.len(), 1, "Should have one point");
assert_eq!(points[0].timestamp, ts);
assert_eq!(points[0].value, 42.0);
}
#[test]
fn test_stage_multiple_points_same_series() {
let mut buffer = WriteBuffer::default();
let series = "multi_point_series";
let ts1 = get_current_timestamp();
thread::sleep(std::time::Duration::from_nanos(1));
let ts2 = get_current_timestamp();
thread::sleep(std::time::Duration::from_nanos(1));
let ts3 = get_current_timestamp();
let tags = create_tags(&[("host", "server1")]);
buffer
.stage(series, create_row(1, ts1, 1.0, tags.clone()))
.unwrap();
buffer
.stage(series, create_row(2, ts2, 2.0, tags.clone()))
.unwrap();
buffer
.stage(series, create_row(3, ts3, 3.0, tags.clone()))
.unwrap();
let drained = buffer.drain_all_buffers().unwrap();
assert_eq!(drained.len(), 1, "Should have one series");
let points = &drained[series];
assert_eq!(points.len(), 3, "Should have three points");
assert_eq!(points[0].timestamp, ts1);
assert_eq!(points[0].value, 1.0);
assert_eq!(points[1].timestamp, ts2);
assert_eq!(points[1].value, 2.0);
assert_eq!(points[2].timestamp, ts3);
assert_eq!(points[2].value, 3.0);
}
#[test]
fn test_stage_multiple_series() {
let mut buffer = WriteBuffer::default();
let series1 = "series1";
let series2 = "series2";
let ts1 = get_current_timestamp();
let ts2 = get_current_timestamp() + 100;
let tags1 = create_tags(&[("region", "us-east")]);
let tags2 = create_tags(&[("region", "us-west")]);
buffer
.stage(series1, create_row(1, ts1, 1.0, tags1.clone()))
.unwrap();
buffer
.stage(series2, create_row(2, ts2, 2.0, tags2.clone()))
.unwrap();
let drained = buffer.drain_all_buffers().unwrap();
assert_eq!(drained.len(), 2, "Should have two series");
assert!(drained.contains_key(series1), "Should contain series1");
assert!(drained.contains_key(series2), "Should contain series2");
let points1 = &drained[series1];
assert_eq!(points1.len(), 1, "Series1 should have one point");
assert_eq!(points1[0].timestamp, ts1);
assert_eq!(points1[0].value, 1.0);
let points2 = &drained[series2];
assert_eq!(points2.len(), 1, "Series2 should have one point");
assert_eq!(points2[0].timestamp, ts2);
assert_eq!(points2[0].value, 2.0);
}
#[test]
fn test_drain_empty_buffer() {
let mut buffer = WriteBuffer::default();
let drained = buffer.drain_all_buffers().unwrap();
assert_eq!(drained.len(), 0, "Drained data should be empty");
}
#[test]
fn test_drain_leaves_buffers_empty() {
let mut buffer = WriteBuffer::default();
let series = "test_series";
let ts = get_current_timestamp();
let tags = create_tags(&[("host", "server1")]);
buffer.stage(series, create_row(1, ts, 1.0, tags)).unwrap();
let first_drain = buffer.drain_all_buffers().unwrap();
assert_eq!(
first_drain.len(),
1,
"First drain should contain our series"
);
let second_drain = buffer.drain_all_buffers().unwrap();
assert_eq!(second_drain.len(), 0, "Second drain should be empty");
}
#[test]
fn test_multithreaded_stage() {
use std::sync::{Arc, Mutex};
use std::thread;
let buffer = Arc::new(Mutex::new(WriteBuffer::default()));
let series = "multithreaded_series";
let num_threads = 4;
let points_per_thread = 25;
let mut handles = vec![];
for thread_id in 0..num_threads {
let buffer_clone = Arc::clone(&buffer);
let series_name = series.to_string();
let handle = thread::spawn(move || {
for i in 0..points_per_thread {
let ts = get_current_timestamp() + (thread_id * 1000 + i) as u64;
let value = (thread_id * 100 + i) as f64;
let tags = create_tags(&[
("thread_id", &thread_id.to_string()),
("point_id", &i.to_string()),
]);
let row = create_row((thread_id * 1000 + i) as u64, ts, value, tags);
let mut buffer_guard = buffer_clone.lock().unwrap();
buffer_guard.stage(&series_name, row).unwrap();
thread::sleep(std::time::Duration::from_nanos(1));
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let mut buffer_guard = buffer.lock().unwrap();
let drained = buffer_guard.drain_all_buffers().unwrap();
assert_eq!(drained.len(), 1, "Should have one series");
let points = &drained[series];
assert_eq!(
points.len(),
num_threads * points_per_thread,
"Should have {} points",
num_threads * points_per_thread
);
for thread_id in 0..num_threads {
for i in 0..points_per_thread {
let found = points.iter().any(|p| {
p.tags.get("thread_id") == Some(&thread_id.to_string())
&& p.tags.get("point_id") == Some(&i.to_string())
});
assert!(
found,
"Point with thread_id={}, point_id={} not found",
thread_id, i
);
}
}
}
#[test]
fn test_stage_with_different_tag_combinations() {
let mut buffer = WriteBuffer::default();
let series = "tag_test_series";
let ts_base = get_current_timestamp();
let no_tags = TagSet::new();
buffer
.stage(series, create_row(1, ts_base, 1.0, no_tags))
.unwrap();
let one_tag = create_tags(&[("region", "us-east")]);
buffer
.stage(series, create_row(2, ts_base + 1, 2.0, one_tag))
.unwrap();
let multi_tags = create_tags(&[
("region", "eu-west"),
("host", "server2"),
("service", "api"),
("version", "1.0"),
]);
buffer
.stage(series, create_row(3, ts_base + 2, 3.0, multi_tags.clone()))
.unwrap();
let drained = buffer.drain_all_buffers().unwrap();
let points = &drained[series];
assert_eq!(points.len(), 3, "Should have three points");
assert_eq!(points[0].tags.len(), 0, "First point should have no tags");
assert_eq!(points[1].tags.len(), 1, "Second point should have one tag");
assert_eq!(points[1].tags.get("region"), Some(&"us-east".to_string()));
assert_eq!(points[2].tags.len(), 4, "Third point should have four tags");
assert_eq!(points[2].tags, multi_tags);
}
}