use crate::traits::SizedDimension;
use ndarray::{Dimension, StrideShape};
use std::collections::HashMap;
use std::sync::RwLock;
use uuid::Uuid;
pub struct PipeState {
pub write_ptr: usize,
pub read_ptr: usize,
}
pub struct MetadataManager<M: Clone> {
metadata: RwLock<Option<M>>,
}
impl<M: Clone> MetadataManager<M> {
pub fn new() -> Self {
Self {
metadata: RwLock::new(None),
}
}
pub fn get(&self) -> Option<M> {
self.metadata.read().unwrap().clone()
}
pub fn set(&self, metadata: M) {
*self.metadata.write().unwrap() = Some(metadata);
}
}
pub struct ShapeManager<D: SizedDimension + Dimension> {
shape: StrideShape<D>,
shape_tuple: D::CurrentSize,
}
impl<D: SizedDimension + Dimension> ShapeManager<D> {
pub fn new<Sh: Into<StrideShape<D>>>(shape_input: Sh) -> Self {
let shape: StrideShape<D> = shape_input.into();
let shape_tuple = D::from_array_view(shape.raw_dim().as_array_view());
Self { shape, shape_tuple }
}
pub fn element_size(&self) -> usize {
self.shape.size()
}
pub fn get_larger_array_size(&self, n_elements: usize) -> D::LargerSize
where
D::CurrentSize: Clone,
{
D::get_larger_array_size(n_elements, self.shape_tuple.clone())
}
pub fn total_scalars(&self, n_elements: usize) -> usize {
n_elements * self.element_size()
}
}
pub struct ReaderManager {
read_ptrs: RwLock<HashMap<Uuid, usize>>,
}
impl ReaderManager {
pub fn new() -> Self {
Self {
read_ptrs: RwLock::new(HashMap::new()),
}
}
pub fn register_reader(&self, reader_id: Uuid, start_position: usize) {
self.read_ptrs
.write()
.unwrap()
.insert(reader_id, start_position);
}
pub fn get_reader_position(&self, reader_id: Uuid) -> Option<usize> {
self.read_ptrs.read().unwrap().get(&reader_id).copied()
}
pub fn advance_reader(&self, reader_id: Uuid, n_to_consume: usize) {
self.read_ptrs
.write()
.unwrap()
.entry(reader_id)
.and_modify(|ptr| *ptr += n_to_consume);
}
pub fn unregister_reader(&self, reader_id: Uuid) {
self.read_ptrs.write().unwrap().remove(&reader_id);
}
pub fn get_min_distance_from(&self, write_ptr: usize) -> Option<usize> {
self.read_ptrs
.read()
.unwrap()
.values()
.map(|read_ptr| write_ptr - *read_ptr)
.min()
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Ix0;
fn get_min_position(rdr: &ReaderManager) -> Option<usize> {
rdr.read_ptrs.read().unwrap().values().min().copied()
}
#[test]
fn test_metadata_manager() {
let manager = MetadataManager::<String>::new();
assert_eq!(manager.get(), None);
manager.set("test metadata".to_string());
assert_eq!(manager.get(), Some("test metadata".to_string()));
}
#[test]
fn test_shape_manager_ix0() {
let manager = ShapeManager::<Ix0>::new([]);
assert_eq!(manager.element_size(), 1);
assert_eq!(manager.total_scalars(10), 10);
assert_eq!(manager.get_larger_array_size(5), 5);
}
#[test]
fn test_reader_manager() {
let manager = ReaderManager::new();
let reader1 = Uuid::new_v4();
let reader2 = Uuid::new_v4();
manager.register_reader(reader1, 0);
manager.register_reader(reader2, 10);
assert_eq!(manager.get_reader_position(reader1), Some(0));
assert_eq!(manager.get_reader_position(reader2), Some(10));
manager.advance_reader(reader1, 5);
assert_eq!(manager.get_reader_position(reader1), Some(5));
assert_eq!(get_min_position(&manager), Some(5));
manager.unregister_reader(reader1);
assert_eq!(manager.get_reader_position(reader1), None);
assert_eq!(get_min_position(&manager), Some(10));
}
}