use core::ops::{Deref, DerefMut};
use std::sync::{
atomic::{AtomicU32, AtomicU64, AtomicUsize, Ordering},
Arc, RwLock, RwLockReadGuard, RwLockWriteGuard,
};
use crate::errors::MemoryError;
use crate::memory::header_store::HeaderStore;
use crate::memory::manager::MemoryManager;
use crate::memory::MemoryClass;
use crate::message::payload::Payload;
use crate::message::{Message, MessageHeader};
use crate::prelude::ScopedManager;
use crate::types::MessageToken;
const EMPTY_INDEX: u32 = u32::MAX;
#[inline]
fn pack_head(tag: u32, idx: u32) -> u64 {
((tag as u64) << 32) | (idx as u64)
}
#[inline]
fn unpack_head(v: u64) -> (u32, u32) {
((v >> 32) as u32, (v & 0xffff_ffff) as u32)
}
struct ConcurrentSlotState<P: Payload> {
message: Option<Message<P>>,
}
impl<P: Payload> ConcurrentSlotState<P> {
fn new() -> Self {
Self { message: None }
}
}
struct ConcurrentSlot<P: Payload> {
state: RwLock<ConcurrentSlotState<P>>,
}
impl<P: Payload> ConcurrentSlot<P> {
fn new() -> Self {
Self {
state: RwLock::new(ConcurrentSlotState::new()),
}
}
}
struct ConcurrentMemoryManagerShared<P: Payload> {
slots: Vec<ConcurrentSlot<P>>,
next_free: Vec<AtomicU32>,
free_head: AtomicU64,
available_count: AtomicUsize,
mem_class: MemoryClass,
}
impl<P: Payload> ConcurrentMemoryManagerShared<P> {
fn new(mem_class: MemoryClass, capacity: usize) -> Self {
assert!(capacity <= u32::MAX as usize);
let mut slots = Vec::with_capacity(capacity);
for _ in 0..capacity {
slots.push(ConcurrentSlot::new());
}
let mut next_free = Vec::with_capacity(capacity);
for i in 0..capacity {
let next = if i + 1 < capacity {
(i + 1) as u32
} else {
EMPTY_INDEX
};
next_free.push(AtomicU32::new(next));
}
let head_index = if capacity > 0 { 0 } else { EMPTY_INDEX };
let free_head = AtomicU64::new(pack_head(0, head_index));
let available_count = AtomicUsize::new(capacity);
Self {
slots,
next_free,
free_head,
available_count,
mem_class,
}
}
fn pop_free(&self) -> Option<usize> {
let mut spins = 0u32;
loop {
let head = self.free_head.load(Ordering::Acquire);
let (tag, idx) = unpack_head(head);
if idx == EMPTY_INDEX {
return None;
}
let next = self.next_free[idx as usize].load(Ordering::Acquire);
let new_tag = tag.wrapping_add(1);
let new_head = pack_head(new_tag, next);
if self
.free_head
.compare_exchange(head, new_head, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
self.available_count.fetch_sub(1, Ordering::AcqRel);
return Some(idx as usize);
}
spins = spins.wrapping_add(1);
if spins & 0xFF == 0 {
std::thread::yield_now();
}
}
}
fn push_free(&self, idx: usize) {
let mut spins = 0u32;
loop {
let head = self.free_head.load(Ordering::Acquire);
let (tag, head_idx) = unpack_head(head);
self.next_free[idx].store(head_idx, Ordering::Release);
let new_tag = tag.wrapping_add(1);
let new_head = pack_head(new_tag, idx as u32);
if self
.free_head
.compare_exchange(head, new_head, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
self.available_count.fetch_add(1, Ordering::AcqRel);
return;
}
spins = spins.wrapping_add(1);
if spins & 0xFF == 0 {
std::thread::yield_now();
}
}
}
}
pub struct ConcurrentMemoryManager<P: Payload> {
shared: Arc<ConcurrentMemoryManagerShared<P>>,
}
impl<P: Payload> Clone for ConcurrentMemoryManager<P> {
fn clone(&self) -> Self {
Self {
shared: Arc::clone(&self.shared),
}
}
}
impl<P: Payload> ConcurrentMemoryManager<P> {
pub fn new(capacity: usize) -> Self {
Self::with_memory_class(capacity, MemoryClass::Host)
}
pub fn with_memory_class(capacity: usize, mem_class: MemoryClass) -> Self {
let shared = ConcurrentMemoryManagerShared::new(mem_class, capacity);
Self {
shared: Arc::new(shared),
}
}
pub fn store_shared(&self, value: Message<P>) -> Result<MessageToken, MemoryError> {
let idx = match self.shared.pop_free() {
None => return Err(MemoryError::NoFreeSlots),
Some(i) => i,
};
let slot = &self.shared.slots[idx];
let mut guard = slot.state.write().map_err(|_| MemoryError::Poisoned)?;
guard.message = Some(value);
drop(guard);
Ok(MessageToken::new(idx as u32))
}
pub fn read_shared(
&self,
token: MessageToken,
) -> Result<ConcurrentReadGuard<'_, P>, MemoryError> {
let idx = token.index();
if idx >= self.shared.slots.len() {
return Err(MemoryError::BadToken);
}
let slot = &self.shared.slots[idx];
let guard = slot.state.read().map_err(|_| MemoryError::Poisoned)?;
if guard.message.is_none() {
return Err(MemoryError::NotAllocated);
}
Ok(ConcurrentReadGuard { guard })
}
pub fn read_mut_shared(
&self,
token: MessageToken,
) -> Result<ConcurrentWriteGuard<'_, P>, MemoryError> {
let idx = token.index();
if idx >= self.shared.slots.len() {
return Err(MemoryError::BadToken);
}
let slot = &self.shared.slots[idx];
let guard = slot.state.write().map_err(|_| MemoryError::Poisoned)?;
if guard.message.is_none() {
return Err(MemoryError::NotAllocated);
}
Ok(ConcurrentWriteGuard { guard })
}
pub fn free_shared(&self, token: MessageToken) -> Result<(), MemoryError> {
let idx = token.index();
if idx >= self.shared.slots.len() {
return Err(MemoryError::BadToken);
}
let slot = &self.shared.slots[idx];
let mut guard = slot.state.write().map_err(|_| MemoryError::Poisoned)?;
if guard.message.is_none() {
return Err(MemoryError::NotAllocated);
}
guard.message = None;
drop(guard);
self.shared.push_free(idx);
Ok(())
}
pub fn available(&self) -> usize {
self.shared.available_count.load(Ordering::Relaxed)
}
pub fn capacity(&self) -> usize {
self.shared.slots.len()
}
pub fn memory_class(&self) -> MemoryClass {
self.shared.mem_class
}
}
pub struct ConcurrentHeaderGuard<'a, P: Payload> {
guard: RwLockReadGuard<'a, ConcurrentSlotState<P>>,
}
impl<'a, P: Payload> Deref for ConcurrentHeaderGuard<'a, P> {
type Target = MessageHeader;
fn deref(&self) -> &Self::Target {
self.guard
.message
.as_ref()
.expect("header guard constructed only when Some")
.header()
}
}
pub struct ConcurrentReadGuard<'a, P: Payload> {
guard: RwLockReadGuard<'a, ConcurrentSlotState<P>>,
}
impl<'a, P: Payload> Deref for ConcurrentReadGuard<'a, P> {
type Target = Message<P>;
fn deref(&self) -> &Self::Target {
self.guard
.message
.as_ref()
.expect("read guard constructed only when Some")
}
}
pub struct ConcurrentWriteGuard<'a, P: Payload> {
guard: RwLockWriteGuard<'a, ConcurrentSlotState<P>>,
}
impl<'a, P: Payload> Deref for ConcurrentWriteGuard<'a, P> {
type Target = Message<P>;
fn deref(&self) -> &Self::Target {
self.guard
.message
.as_ref()
.expect("write guard constructed only when Some")
}
}
impl<'a, P: Payload> DerefMut for ConcurrentWriteGuard<'a, P> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.guard
.message
.as_mut()
.expect("write guard constructed only when Some")
}
}
impl<P: Payload> HeaderStore for ConcurrentMemoryManager<P> {
type HeaderGuard<'a>
= ConcurrentHeaderGuard<'a, P>
where
Self: 'a;
fn peek_header(&self, token: MessageToken) -> Result<Self::HeaderGuard<'_>, MemoryError> {
let idx = token.index();
if idx >= self.shared.slots.len() {
return Err(MemoryError::BadToken);
}
let slot = &self.shared.slots[idx];
let guard = slot.state.read().map_err(|_| MemoryError::Poisoned)?;
if guard.message.is_none() {
return Err(MemoryError::NotAllocated);
}
Ok(ConcurrentHeaderGuard { guard })
}
}
impl<P: Payload> MemoryManager<P> for ConcurrentMemoryManager<P> {
type ReadGuard<'a>
= ConcurrentReadGuard<'a, P>
where
Self: 'a;
type WriteGuard<'a>
= ConcurrentWriteGuard<'a, P>
where
Self: 'a;
fn store(&mut self, value: Message<P>) -> Result<MessageToken, MemoryError> {
self.store_shared(value)
}
fn read(&self, token: MessageToken) -> Result<Self::ReadGuard<'_>, MemoryError> {
self.read_shared(token)
}
fn read_mut(&mut self, token: MessageToken) -> Result<Self::WriteGuard<'_>, MemoryError> {
self.read_mut_shared(token)
}
fn free(&mut self, token: MessageToken) -> Result<(), MemoryError> {
self.free_shared(token)
}
fn available(&self) -> usize {
self.available()
}
fn capacity(&self) -> usize {
self.capacity()
}
fn memory_class(&self) -> MemoryClass {
self.memory_class()
}
}
impl<P: Payload + Send + Sync> ScopedManager<P> for ConcurrentMemoryManager<P> {
type Handle<'a>
= ConcurrentMemoryManager<P>
where
Self: 'a;
fn scoped_handle<'a>(&'a self) -> Self::Handle<'a>
where
Self: 'a,
{
self.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::MessageHeader;
use crate::prelude::{create_test_tensor_filled_with, TestTensor, TEST_TENSOR_BYTE_COUNT};
use std::sync::Arc;
use std::thread;
use std::time::Duration;
fn make_msg(val: u32) -> Message<TestTensor> {
Message::new(MessageHeader::empty(), create_test_tensor_filled_with(val))
}
#[test]
fn basic_store_read_free() {
let mgr: ConcurrentMemoryManager<TestTensor> = ConcurrentMemoryManager::new(4);
let t = mgr.store_shared(make_msg(10)).unwrap();
assert_eq!(mgr.available(), 3);
{
let g = mgr.read_shared(t).unwrap();
assert_eq!(*g.payload(), create_test_tensor_filled_with(10));
}
mgr.free_shared(t).unwrap();
assert_eq!(mgr.available(), 4);
}
#[test]
fn concurrent_reads_same_slot() {
let mgr = Arc::new(ConcurrentMemoryManager::<TestTensor>::new(4));
let t = mgr.store_shared(make_msg(5)).unwrap();
let m1 = mgr.clone();
let th1 = thread::spawn(move || {
let g = m1.read_shared(t).unwrap();
assert_eq!(*g.payload(), create_test_tensor_filled_with(5));
});
let m2 = mgr.clone();
let th2 = thread::spawn(move || {
let g = m2.read_shared(t).unwrap();
assert_eq!(*g.payload(), create_test_tensor_filled_with(5));
});
th1.join().unwrap();
th2.join().unwrap();
mgr.free_shared(t).unwrap();
}
#[test]
fn write_excludes_read() {
use std::sync::Barrier;
let mgr = Arc::new(ConcurrentMemoryManager::<TestTensor>::new(4));
let t = mgr.store_shared(make_msg(7)).unwrap();
let barrier = Arc::new(Barrier::new(2));
let mwriter = mgr.clone();
let bwriter = barrier.clone();
let writer = thread::spawn(move || {
let mut w = mwriter.read_mut_shared(t).unwrap();
*w.payload_mut() = create_test_tensor_filled_with(42);
bwriter.wait();
std::thread::sleep(Duration::from_millis(50));
});
barrier.wait();
let g = mgr.read_shared(t).unwrap();
assert_eq!(*g.payload(), create_test_tensor_filled_with(42));
writer.join().unwrap();
}
#[test]
fn allocate_exhaustion_and_reuse() {
let mgr = ConcurrentMemoryManager::<TestTensor>::new(2);
let t0 = mgr.store_shared(make_msg(1)).unwrap();
let t1 = mgr.store_shared(make_msg(2)).unwrap();
assert_eq!(mgr.available(), 0);
assert!(matches!(
mgr.store_shared(make_msg(3)),
Err(MemoryError::NoFreeSlots)
));
mgr.free_shared(t0).unwrap();
assert_eq!(mgr.available(), 1);
let t2 = mgr.store_shared(make_msg(4)).unwrap();
assert_eq!(t2.index(), t0.index());
mgr.free_shared(t1).unwrap();
mgr.free_shared(t2).unwrap();
}
#[test]
fn store_read_free_cycle() {
let mgr: ConcurrentMemoryManager<TestTensor> = ConcurrentMemoryManager::new(4);
assert_eq!(mgr.available(), 4);
assert_eq!(mgr.capacity(), 4);
let token = mgr.store_shared(make_msg(42)).unwrap();
assert_eq!(mgr.available(), 3);
{
let msg = mgr.read_shared(token).unwrap();
assert_eq!(*msg.payload(), create_test_tensor_filled_with(42));
}
mgr.free_shared(token).unwrap();
assert_eq!(mgr.available(), 4);
}
#[test]
fn read_mut_works() {
let mgr: ConcurrentMemoryManager<TestTensor> = ConcurrentMemoryManager::new(4);
let token = mgr.store_shared(make_msg(10)).unwrap();
{
let mut msg = mgr.read_mut_shared(token).unwrap();
*msg.payload_mut() = create_test_tensor_filled_with(99);
}
{
let msg = mgr.read_shared(token).unwrap();
assert_eq!(*msg.payload(), create_test_tensor_filled_with(99));
}
mgr.free_shared(token).unwrap();
}
#[test]
fn peek_header_works() {
let mgr: ConcurrentMemoryManager<TestTensor> = ConcurrentMemoryManager::new(4);
let token = mgr.store_shared(make_msg(7)).unwrap();
{
let header = mgr.peek_header(token).unwrap();
assert_eq!(*header.payload_size_bytes(), TEST_TENSOR_BYTE_COUNT);
}
mgr.free_shared(token).unwrap();
}
#[test]
fn capacity_exhaustion() {
let mgr: ConcurrentMemoryManager<TestTensor> = ConcurrentMemoryManager::new(2);
let _t0 = mgr.store_shared(make_msg(1)).unwrap();
let _t1 = mgr.store_shared(make_msg(2)).unwrap();
assert_eq!(mgr.available(), 0);
let err = mgr.store_shared(make_msg(3));
assert_eq!(err, Err(MemoryError::NoFreeSlots));
}
#[test]
fn double_free_detected() {
let mgr: ConcurrentMemoryManager<TestTensor> = ConcurrentMemoryManager::new(4);
let token = mgr.store_shared(make_msg(1)).unwrap();
mgr.free_shared(token).unwrap();
let err = mgr.free_shared(token);
assert_eq!(err, Err(MemoryError::NotAllocated));
}
#[test]
fn bad_token_detected() {
let mgr: ConcurrentMemoryManager<TestTensor> = ConcurrentMemoryManager::new(4);
let bad = MessageToken::new(99);
assert!(matches!(mgr.read_shared(bad), Err(MemoryError::BadToken)));
assert!(matches!(mgr.peek_header(bad), Err(MemoryError::BadToken)));
}
#[test]
fn read_freed_slot_is_bad_token() {
let mgr: ConcurrentMemoryManager<TestTensor> = ConcurrentMemoryManager::new(4);
let token = mgr.store_shared(make_msg(1)).unwrap();
mgr.free_shared(token).unwrap();
assert!(matches!(
mgr.read_shared(token),
Err(MemoryError::NotAllocated)
));
assert!(matches!(
mgr.peek_header(token),
Err(MemoryError::NotAllocated)
));
}
#[test]
fn slot_reuse_after_free() {
let mgr: ConcurrentMemoryManager<TestTensor> = ConcurrentMemoryManager::new(1);
let t0 = mgr.store_shared(make_msg(10)).unwrap();
mgr.free_shared(t0).unwrap();
let t1 = mgr.store_shared(make_msg(20)).unwrap();
assert_eq!(t1.index(), 0);
assert_eq!(
*mgr.read_shared(t1).unwrap().payload(),
create_test_tensor_filled_with(20)
);
}
#[test]
fn memory_class_configurable() {
let mgr: ConcurrentMemoryManager<TestTensor> =
ConcurrentMemoryManager::with_memory_class(4, MemoryClass::Device(0));
assert_eq!(mgr.memory_class(), MemoryClass::Device(0));
}
#[test]
fn default_memory_class_is_host() {
let mgr: ConcurrentMemoryManager<TestTensor> = ConcurrentMemoryManager::new(4);
assert_eq!(mgr.memory_class(), MemoryClass::Host);
}
}