use std::sync::{Mutex, Arc, Weak, OnceLock};
use std::collections::{VecDeque, HashMap};
use std::marker::PhantomData;
use std::cell::{RefCell, UnsafeCell};
use std::any::{TypeId, Any};
use std::sync::atomic::{AtomicU32, Ordering};
use crate::error::{Result, ZiporaError};
const DEFAULT_ROWS: usize = 256;
const DEFAULT_COLS: usize = 256;
pub struct InstanceTls<T, const ROWS: usize = DEFAULT_ROWS, const COLS: usize = DEFAULT_COLS>
where
T: Send + Sync + 'static,
{
id: u32,
_phantom: PhantomData<T>,
_cleanup: Arc<CleanupHandle<T, ROWS, COLS>>,
}
struct TlsMatrix<T, const ROWS: usize, const COLS: usize> {
rows: [Option<Box<[UnsafeCell<Option<T>>; COLS]>>; ROWS],
}
struct GlobalTlsState<T, const ROWS: usize, const COLS: usize>
where
T: Send + Sync + 'static,
{
free_ids: VecDeque<u32>,
next_id: u32,
cleanup_handles: Vec<Weak<CleanupHandle<T, ROWS, COLS>>>,
}
struct CleanupHandle<T, const ROWS: usize, const COLS: usize>
where
T: Send + Sync + 'static,
{
id: u32,
_phantom: PhantomData<T>,
}
thread_local! {
static TLS_MATRICES: RefCell<HashMap<TypeId, Box<dyn Any>>> = RefCell::new(HashMap::new());
}
impl<T, const ROWS: usize, const COLS: usize> InstanceTls<T, ROWS, COLS>
where
T: Send + Sync + Default + Clone + 'static,
{
pub fn new() -> Result<Self> {
if ROWS * COLS == 0 {
return Err(ZiporaError::invalid_parameter("Matrix dimensions cannot be zero"));
}
let id = Self::allocate_id()?;
let cleanup = Arc::new(CleanupHandle {
id,
_phantom: PhantomData,
});
Ok(Self {
id,
_phantom: PhantomData,
_cleanup: cleanup,
})
}
#[inline]
pub fn get(&self) -> T
where
T: Clone,
{
let (row, col) = self.get_indices();
TLS_MATRICES.with(|matrices| {
let mut matrices = matrices.borrow_mut();
let type_id = TypeId::of::<TlsMatrix<T, ROWS, COLS>>();
let matrix = matrices
.entry(type_id)
.or_insert_with(|| Box::new(TlsMatrix::<T, ROWS, COLS>::new()))
.downcast_mut::<TlsMatrix<T, ROWS, COLS>>()
.expect("TLS matrix type mismatch");
unsafe { matrix.get_or_create_value(row, col) }
})
}
#[inline]
pub fn get_value(&self) -> Option<T>
where
T: Clone,
{
let (row, col) = self.get_indices();
TLS_MATRICES.with(|matrices| {
let matrices = matrices.borrow();
let type_id = TypeId::of::<TlsMatrix<T, ROWS, COLS>>();
let matrix = matrices.get(&type_id)?;
let matrix = matrix.downcast_ref::<TlsMatrix<T, ROWS, COLS>>()?;
matrix.get_value(row, col)
})
}
#[inline]
pub fn try_get(&self) -> Option<T>
where
T: Clone,
{
let (row, col) = self.get_indices();
TLS_MATRICES.with(|matrices| {
let matrices = matrices.borrow();
let type_id = TypeId::of::<TlsMatrix<T, ROWS, COLS>>();
let matrix = matrices.get(&type_id)?
.downcast_ref::<TlsMatrix<T, ROWS, COLS>>()?;
matrix.get_value(row, col)
})
}
pub fn set(&self, value: T) {
let (row, col) = self.get_indices();
TLS_MATRICES.with(|matrices| {
let mut matrices = matrices.borrow_mut();
let type_id = TypeId::of::<TlsMatrix<T, ROWS, COLS>>();
let matrix = matrices
.entry(type_id)
.or_insert_with(|| Box::new(TlsMatrix::<T, ROWS, COLS>::new()))
.downcast_mut::<TlsMatrix<T, ROWS, COLS>>()
.expect("TLS matrix type mismatch");
unsafe { matrix.set(row, col, value) }
});
}
pub fn remove(&self) -> Option<T> {
let (row, col) = self.get_indices();
TLS_MATRICES.with(|matrices| {
let mut matrices = matrices.borrow_mut();
let type_id = TypeId::of::<TlsMatrix<T, ROWS, COLS>>();
matrices
.get_mut(&type_id)?
.downcast_mut::<TlsMatrix<T, ROWS, COLS>>()?
.remove(row, col)
})
}
pub fn id(&self) -> u32 {
self.id
}
#[inline]
fn get_indices(&self) -> (usize, usize) {
let id = self.id as usize;
(id / COLS, id % COLS)
}
fn allocate_id() -> Result<u32> {
use std::collections::HashMap;
use std::any::Any;
use std::sync::RwLock;
static GLOBAL_REGISTRY: OnceLock<
RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>
> = OnceLock::new();
let registry = GLOBAL_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
let type_id = TypeId::of::<(T, [(); ROWS], [(); COLS])>();
{
let read_guard = registry.read().map_err(|_| {
ZiporaError::system_error("Failed to acquire TLS registry read lock")
})?;
if let Some(state_box) = read_guard.get(&type_id) {
if let Some(state_mutex) = state_box.downcast_ref::<Mutex<GlobalTlsState<T, ROWS, COLS>>>() {
let mut state = state_mutex.lock().map_err(|_| {
ZiporaError::system_error("Failed to acquire TLS state lock")
})?;
state.cleanup_handles.retain(|handle| handle.strong_count() > 0);
if let Some(id) = state.free_ids.pop_front() {
return Ok(id);
} else {
let id = state.next_id;
if id as usize >= ROWS * COLS {
return Err(ZiporaError::resource_exhausted(&format!(
"Too many TLS instances: max {}",
ROWS * COLS
)));
}
state.next_id += 1;
return Ok(id);
}
}
}
}
let mut write_guard = registry.write().map_err(|_| {
ZiporaError::system_error("Failed to acquire TLS registry write lock")
})?;
let state = write_guard
.entry(type_id)
.or_insert_with(|| {
Box::new(Mutex::new(GlobalTlsState::<T, ROWS, COLS> {
free_ids: VecDeque::new(),
next_id: 0,
cleanup_handles: Vec::new(),
}))
});
let state_mutex = state.downcast_ref::<Mutex<GlobalTlsState<T, ROWS, COLS>>>()
.ok_or_else(|| ZiporaError::system_error("Type downcast failed"))?;
let mut state = state_mutex.lock().map_err(|_| {
ZiporaError::system_error("Failed to acquire TLS state lock")
})?;
if let Some(id) = state.free_ids.pop_front() {
Ok(id)
} else {
let id = state.next_id;
if id as usize >= ROWS * COLS {
return Err(ZiporaError::resource_exhausted(&format!(
"Too many TLS instances: max {}",
ROWS * COLS
)));
}
state.next_id += 1;
Ok(id)
}
}
}
impl<T, const ROWS: usize, const COLS: usize> Clone for InstanceTls<T, ROWS, COLS>
where
T: Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
id: self.id,
_phantom: PhantomData,
_cleanup: Arc::clone(&self._cleanup),
}
}
}
impl<T, const ROWS: usize, const COLS: usize> TlsMatrix<T, ROWS, COLS> {
fn new() -> Self {
Self {
rows: std::array::from_fn(|_| None),
}
}
unsafe fn get_or_create_value(&mut self, row: usize, col: usize) -> T
where
T: Default + Clone,
{
if self.rows[row].is_none() {
let new_row: Box<[UnsafeCell<Option<T>>; COLS]> = Box::new(
std::array::from_fn(|_| UnsafeCell::new(None))
);
self.rows[row] = Some(new_row);
}
let row_data = self.rows[row].as_ref().expect("row is initialized");
let cell = &row_data[col];
let value_ref = unsafe { &mut *cell.get() };
if value_ref.is_none() {
*value_ref = Some(T::default());
}
value_ref.as_ref().expect("value is initialized").clone()
}
fn get_value(&self, row: usize, col: usize) -> Option<T>
where
T: Clone,
{
let row_data = self.rows[row].as_ref()?;
let cell = &row_data[col];
unsafe {
let value_ref = &*cell.get();
value_ref.as_ref().cloned()
}
}
unsafe fn set(&mut self, row: usize, col: usize, value: T) {
if self.rows[row].is_none() {
let new_row: Box<[UnsafeCell<Option<T>>; COLS]> = Box::new(
std::array::from_fn(|_| UnsafeCell::new(None))
);
self.rows[row] = Some(new_row);
}
let row_data = self.rows[row].as_ref().expect("row is initialized");
let cell = &row_data[col];
unsafe { *cell.get() = Some(value) };
}
fn remove(&mut self, row: usize, col: usize) -> Option<T> {
let row_data = self.rows[row].as_ref()?;
let cell = &row_data[col];
unsafe {
let value_ref = &mut *cell.get();
value_ref.take()
}
}
}
unsafe impl<T: Send, const ROWS: usize, const COLS: usize> Send for TlsMatrix<T, ROWS, COLS> {}
impl<T: Send + Sync + 'static, const ROWS: usize, const COLS: usize> Drop for CleanupHandle<T, ROWS, COLS> {
fn drop(&mut self) {
use std::collections::HashMap;
use std::any::Any;
use std::sync::RwLock;
static GLOBAL_REGISTRY: OnceLock<
RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>
> = OnceLock::new();
let registry = GLOBAL_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()));
if let Ok(read_guard) = registry.read() {
let type_id = TypeId::of::<(T, [(); ROWS], [(); COLS])>();
if let Some(state_box) = read_guard.get(&type_id) {
if let Some(state_mutex) = state_box.downcast_ref::<Mutex<GlobalTlsState<T, ROWS, COLS>>>() {
if let Ok(mut state) = state_mutex.lock() {
state.free_ids.push_back(self.id);
}
}
}
}
}
}
pub struct OwnerTls<T, O>
where
T: Send + Sync + 'static,
O: Send + Sync + 'static,
{
instances: HashMap<*const O, InstanceTls<T>>,
_phantom: PhantomData<O>,
}
impl<T, O> OwnerTls<T, O>
where
T: Send + Sync + Default + Clone + 'static,
O: Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
instances: HashMap::new(),
_phantom: PhantomData,
}
}
pub fn get_or_create(&mut self, owner: &O) -> Result<T> {
let owner_ptr = owner as *const O;
if !self.instances.contains_key(&owner_ptr) {
let instance = InstanceTls::new()?;
self.instances.insert(owner_ptr, instance);
}
Ok(self.instances.get(&owner_ptr).expect("owner registered in instances").get())
}
#[inline]
pub fn get(&self, owner: &O) -> Option<T> {
let owner_ptr = owner as *const O;
self.instances.get(&owner_ptr)?.get_value()
}
pub fn remove(&mut self, owner: &O) -> Option<InstanceTls<T>> {
let owner_ptr = owner as *const O;
self.instances.remove(&owner_ptr)
}
#[inline]
pub fn len(&self) -> usize {
self.instances.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.instances.is_empty()
}
}
pub struct TlsPool<T, const POOL_SIZE: usize = 64>
where
T: Send + Sync + 'static,
{
pool: [Option<InstanceTls<T>>; POOL_SIZE],
next_slot: AtomicU32,
}
impl<T, const POOL_SIZE: usize> TlsPool<T, POOL_SIZE>
where
T: Send + Sync + Default + Clone + 'static,
{
pub fn new() -> Result<Self> {
let pool = std::array::from_fn(|_| Some(InstanceTls::new().expect("TLS initialization")));
Ok(Self {
pool,
next_slot: AtomicU32::new(0),
})
}
pub fn get_next(&self) -> T {
let slot = self.next_slot.fetch_add(1, Ordering::Relaxed) as usize % POOL_SIZE;
self.pool[slot].as_ref().expect("pool slot is initialized").get()
}
pub fn get_slot(&self, slot: usize) -> Option<T> {
if slot < POOL_SIZE {
Some(self.pool[slot].as_ref()?.get())
} else {
None
}
}
#[inline]
pub fn len(&self) -> usize {
POOL_SIZE
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
#[derive(Debug, Default, Clone, PartialEq)]
struct TestData {
value: u32,
name: String,
}
#[test]
fn test_instance_tls_basic() {
let tls = InstanceTls::<TestData>::new().unwrap();
assert_eq!(tls.get().value, 0);
assert_eq!(tls.get().name, "");
let test_data = TestData {
value: 42,
name: "test".to_string(),
};
tls.set(test_data);
let retrieved = tls.get();
assert_eq!(retrieved.value, 42);
assert_eq!(retrieved.name, "test");
}
#[test]
fn test_instance_tls_multiple_threads() {
let tls = Arc::new(InstanceTls::<TestData>::new().unwrap());
let handles: Vec<_> = (0..5).map(|i| {
let tls = Arc::clone(&tls);
thread::spawn(move || {
let test_data = TestData {
value: i * 10,
name: format!("thread_{}", i),
};
tls.set(test_data);
thread::sleep(Duration::from_millis(10));
let retrieved = tls.get();
assert_eq!(retrieved.value, i * 10);
assert_eq!(retrieved.name, format!("thread_{}", i));
})
}).collect();
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_instance_tls_set_remove() {
let tls = InstanceTls::<TestData>::new().unwrap();
assert!(tls.try_get().is_none());
let test_data = TestData {
value: 100,
name: "test_set".to_string(),
};
tls.set(test_data);
assert!(tls.try_get().is_some());
let retrieved = tls.try_get().unwrap();
assert_eq!(retrieved.value, 100);
let removed = tls.remove().unwrap();
assert_eq!(removed.value, 100);
assert_eq!(removed.name, "test_set");
assert!(tls.try_get().is_none());
}
#[test]
fn test_owner_tls() {
struct Owner {
id: u32,
}
let owner1 = Owner { id: 1 };
let owner2 = Owner { id: 2 };
let mut owner_tls = OwnerTls::<TestData, Owner>::new();
let mut data1 = owner_tls.get_or_create(&owner1).unwrap();
data1.value = 11;
owner_tls.get_or_create(&owner1).unwrap();
let mut data2 = owner_tls.get_or_create(&owner2).unwrap();
data2.value = 22;
owner_tls.get_or_create(&owner2).unwrap();
let test_data1 = TestData { value: 11, name: "owner1".to_string() };
let test_data2 = TestData { value: 22, name: "owner2".to_string() };
let retrieved1 = owner_tls.get_or_create(&owner1).unwrap();
let retrieved2 = owner_tls.get_or_create(&owner2).unwrap();
assert_eq!(retrieved1.value, 0);
assert_eq!(retrieved2.value, 0);
assert!(owner_tls.remove(&owner1).is_some());
assert!(owner_tls.get(&owner1).is_none());
assert_eq!(owner_tls.get(&owner2).unwrap().value, 0); }
#[test]
fn test_tls_pool() {
let pool = TlsPool::<TestData, 4>::new().unwrap();
for _i in 0..8 {
let _tls = pool.get_next(); }
assert_eq!(pool.get_slot(0).unwrap().value, 0);
assert_eq!(pool.get_slot(1).unwrap().value, 0);
assert_eq!(pool.get_slot(2).unwrap().value, 0);
assert_eq!(pool.get_slot(3).unwrap().value, 0);
}
#[test]
fn test_tls_id_management() {
let tls1 = InstanceTls::<TestData>::new().unwrap();
let tls2 = InstanceTls::<TestData>::new().unwrap();
let id1 = tls1.id();
let id2 = tls2.id();
assert_ne!(id1, id2);
drop(tls1);
let tls3 = InstanceTls::<TestData>::new().unwrap();
let _id3 = tls3.id();
}
#[test]
fn test_matrix_dimensions() {
let tls = InstanceTls::<TestData, 4, 4>::new().unwrap();
let test_data = TestData { value: 99, name: "test".to_string() };
tls.set(test_data);
assert_eq!(tls.get().value, 99);
}
}