use crate::error::DbError;
use crate::types::DataPoint;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
type SeriesWriteBuffer = Arc<Mutex<Vec<DataPoint>>>;
#[derive(Debug, Default)]
pub struct WriteBuffer {
buffers: HashMap<String, SeriesWriteBuffer>,
}
impl WriteBuffer {
pub fn stage(&mut self, series: &str, point: DataPoint) -> Result<(), DbError> {
let buffer_arc = self
.buffers
.entry(series.to_string()) .or_insert_with(|| Arc::new(Mutex::new(Vec::new())));
let mut buffer_guard = buffer_arc.lock()?; buffer_guard.push(point);
Ok(())
}
pub fn drain_all_buffers(&mut self) -> HashMap<String, Vec<DataPoint>> {
let mut drained_data = HashMap::new();
for (series_name, buffer_arc) in self.buffers.iter() {
if let Ok(mut buffer_guard) = buffer_arc.lock() {
if !buffer_guard.is_empty() {
let points = std::mem::take(&mut *buffer_guard);
drained_data.insert(series_name.clone(), points);
}
} else {
eprintln!("Warning: Buffer for series {} is poisoned, skipping drain.", series_name);
}
}
drained_data
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{TagSet, Timestamp, Value};
use std::time::{SystemTime, UNIX_EPOCH};
use std::thread;
use std::sync::Arc;
fn create_point(ts: Timestamp, val: Value, tags: TagSet) -> DataPoint {
DataPoint {
timestamp: ts,
value: val,
tags,
}
}
fn get_current_timestamp() -> Timestamp {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos() as u64
}
fn create_tags(pairs: &[(&str, &str)]) -> TagSet {
pairs
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect()
}
#[test]
fn test_stage_single_point() {
let mut buffer = WriteBuffer::default();
let series = "test_series";
let ts = get_current_timestamp();
let tags = create_tags(&[("host", "server1")]);
let point = create_point(ts, 42.0, tags);
buffer.stage(series, point.clone()).unwrap();
let drained = buffer.drain_all_buffers();
assert_eq!(drained.len(), 1, "Should have one series");
assert!(drained.contains_key(series), "Should contain our series");
let points = &drained[series];
assert_eq!(points.len(), 1, "Should have one point");
assert_eq!(points[0].timestamp, ts);
assert_eq!(points[0].value, 42.0);
}
#[test]
fn test_stage_multiple_points_same_series() {
let mut buffer = WriteBuffer::default();
let series = "multi_point_series";
let ts1 = get_current_timestamp();
thread::sleep(std::time::Duration::from_nanos(1));
let ts2 = get_current_timestamp();
thread::sleep(std::time::Duration::from_nanos(1));
let ts3 = get_current_timestamp();
let tags = create_tags(&[("host", "server1")]);
buffer.stage(series, create_point(ts1, 1.0, tags.clone())).unwrap();
buffer.stage(series, create_point(ts2, 2.0, tags.clone())).unwrap();
buffer.stage(series, create_point(ts3, 3.0, tags.clone())).unwrap();
let drained = buffer.drain_all_buffers();
assert_eq!(drained.len(), 1, "Should have one series");
let points = &drained[series];
assert_eq!(points.len(), 3, "Should have three points");
assert_eq!(points[0].timestamp, ts1);
assert_eq!(points[0].value, 1.0);
assert_eq!(points[1].timestamp, ts2);
assert_eq!(points[1].value, 2.0);
assert_eq!(points[2].timestamp, ts3);
assert_eq!(points[2].value, 3.0);
}
#[test]
fn test_stage_multiple_series() {
let mut buffer = WriteBuffer::default();
let series1 = "series1";
let series2 = "series2";
let ts1 = get_current_timestamp();
let ts2 = get_current_timestamp() + 100;
let tags1 = create_tags(&[("region", "us-east")]);
let tags2 = create_tags(&[("region", "us-west")]);
buffer.stage(series1, create_point(ts1, 1.0, tags1.clone())).unwrap();
buffer.stage(series2, create_point(ts2, 2.0, tags2.clone())).unwrap();
let drained = buffer.drain_all_buffers();
assert_eq!(drained.len(), 2, "Should have two series");
assert!(drained.contains_key(series1), "Should contain series1");
assert!(drained.contains_key(series2), "Should contain series2");
let points1 = &drained[series1];
assert_eq!(points1.len(), 1, "Series1 should have one point");
assert_eq!(points1[0].timestamp, ts1);
assert_eq!(points1[0].value, 1.0);
let points2 = &drained[series2];
assert_eq!(points2.len(), 1, "Series2 should have one point");
assert_eq!(points2[0].timestamp, ts2);
assert_eq!(points2[0].value, 2.0);
}
#[test]
fn test_drain_empty_buffer() {
let mut buffer = WriteBuffer::default();
let drained = buffer.drain_all_buffers();
assert_eq!(drained.len(), 0, "Drained data should be empty");
}
#[test]
fn test_drain_leaves_buffers_empty() {
let mut buffer = WriteBuffer::default();
let series = "test_series";
let ts = get_current_timestamp();
let tags = create_tags(&[("host", "server1")]);
buffer.stage(series, create_point(ts, 1.0, tags)).unwrap();
let first_drain = buffer.drain_all_buffers();
assert_eq!(first_drain.len(), 1, "First drain should contain our series");
let second_drain = buffer.drain_all_buffers();
assert_eq!(second_drain.len(), 0, "Second drain should be empty");
}
#[test]
fn test_multithreaded_stage() {
use std::sync::{Arc, Mutex};
use std::thread;
let buffer = Arc::new(Mutex::new(WriteBuffer::default()));
let series = "multithreaded_series";
let num_threads = 4;
let points_per_thread = 25;
let mut handles = vec![];
for thread_id in 0..num_threads {
let buffer_clone = Arc::clone(&buffer);
let series_name = series.to_string();
let handle = thread::spawn(move || {
for i in 0..points_per_thread {
let ts = get_current_timestamp() + (thread_id * 1000 + i) as u64;
let value = (thread_id * 100 + i) as f64;
let tags = create_tags(&[
("thread_id", &thread_id.to_string()),
("point_id", &i.to_string())
]);
let point = create_point(ts, value, tags);
let mut buffer_guard = buffer_clone.lock().unwrap();
buffer_guard.stage(&series_name, point).unwrap();
thread::sleep(std::time::Duration::from_nanos(1));
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let mut buffer_guard = buffer.lock().unwrap();
let drained = buffer_guard.drain_all_buffers();
assert_eq!(drained.len(), 1, "Should have one series");
let points = &drained[series];
assert_eq!(
points.len(),
num_threads * points_per_thread,
"Should have {} points",
num_threads * points_per_thread
);
for thread_id in 0..num_threads {
for i in 0..points_per_thread {
let found = points.iter().any(|p| {
p.tags.get("thread_id") == Some(&thread_id.to_string()) &&
p.tags.get("point_id") == Some(&i.to_string())
});
assert!(
found,
"Point with thread_id={}, point_id={} not found",
thread_id,
i
);
}
}
}
#[test]
fn test_stage_with_different_tag_combinations() {
let mut buffer = WriteBuffer::default();
let series = "tag_test_series";
let ts_base = get_current_timestamp();
let no_tags = TagSet::new();
buffer.stage(series, create_point(ts_base, 1.0, no_tags)).unwrap();
let one_tag = create_tags(&[("region", "us-east")]);
buffer.stage(series, create_point(ts_base + 1, 2.0, one_tag)).unwrap();
let multi_tags = create_tags(&[
("region", "eu-west"),
("host", "server2"),
("service", "api"),
("version", "1.0")
]);
buffer.stage(series, create_point(ts_base + 2, 3.0, multi_tags.clone())).unwrap();
let drained = buffer.drain_all_buffers();
let points = &drained[series];
assert_eq!(points.len(), 3, "Should have three points");
assert_eq!(points[0].tags.len(), 0, "First point should have no tags");
assert_eq!(points[1].tags.len(), 1, "Second point should have one tag");
assert_eq!(points[1].tags.get("region"), Some(&"us-east".to_string()));
assert_eq!(points[2].tags.len(), 4, "Third point should have four tags");
assert_eq!(points[2].tags, multi_tags);
}
}