use crate::error::{Result, ZiporaError};
use std::cell::RefCell;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex, OnceLock};
const DEFAULT_CAPACITY: usize = 16 << 20;
const DEFAULT_MAX_BUFFERS: usize = 32;
const MIN_REUSE_SIZE: usize = 1024;
const MAX_REUSE_SIZE: usize = 1 << 20;
#[derive(Debug, Clone)]
pub struct EntropyContextConfig {
pub capacity: usize,
pub max_buffers: usize,
pub thread_local: bool,
pub zero_copy: bool,
}
impl Default for EntropyContextConfig {
fn default() -> Self {
Self {
capacity: DEFAULT_CAPACITY,
max_buffers: DEFAULT_MAX_BUFFERS,
thread_local: true,
zero_copy: true,
}
}
}
pub struct ContextBuffer {
data: Vec<u8>,
context: Option<Arc<Mutex<BufferPool>>>,
}
impl ContextBuffer {
pub fn new(capacity: usize) -> Self {
Self {
data: Vec::with_capacity(capacity),
context: None,
}
}
fn with_context(mut data: Vec<u8>, context: Arc<Mutex<BufferPool>>) -> Self {
data.clear();
Self {
data,
context: Some(context),
}
}
pub fn as_mut_vec(&mut self) -> &mut Vec<u8> {
&mut self.data
}
pub fn as_vec(&self) -> &Vec<u8> {
&self.data
}
#[inline]
pub fn as_slice(&self) -> &[u8] {
&self.data
}
#[inline]
pub fn len(&self) -> usize {
self.data.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[inline]
pub fn capacity(&self) -> usize {
self.data.capacity()
}
pub fn resize(&mut self, new_len: usize, value: u8) {
self.data.resize(new_len, value);
}
pub fn ensure_capacity(&mut self, capacity: usize) {
if self.data.capacity() < capacity {
self.data.reserve(capacity - self.data.capacity());
}
}
pub fn clear(&mut self) {
self.data.clear();
}
pub fn into_vec(mut self) -> Vec<u8> {
self.context = None; std::mem::take(&mut self.data)
}
}
impl Drop for ContextBuffer {
fn drop(&mut self) {
if let Some(context) = self.context.take() {
let capacity = self.data.capacity();
if capacity >= MIN_REUSE_SIZE && capacity <= MAX_REUSE_SIZE {
if let Ok(mut pool) = context.lock() {
pool.return_buffer(std::mem::take(&mut self.data));
}
}
}
}
}
struct BufferPool {
buffers: VecDeque<Vec<u8>>,
total_capacity: usize,
max_capacity: usize,
max_buffers: usize,
}
impl BufferPool {
fn new(config: &EntropyContextConfig) -> Self {
Self {
buffers: VecDeque::new(),
total_capacity: 0,
max_capacity: config.capacity,
max_buffers: config.max_buffers,
}
}
fn get_buffer(&mut self, min_capacity: usize) -> Vec<u8> {
let mut best_idx = None;
let mut best_capacity = usize::MAX;
for (idx, buffer) in self.buffers.iter().enumerate() {
let capacity = buffer.capacity();
if capacity >= min_capacity && capacity < best_capacity {
best_capacity = capacity;
best_idx = Some(idx);
if capacity == min_capacity {
break;
}
}
}
if let Some(idx) = best_idx {
let buffer = self.buffers.remove(idx).expect("buffer index valid from position()");
self.total_capacity -= buffer.capacity();
buffer
} else {
Vec::with_capacity(min_capacity)
}
}
fn return_buffer(&mut self, buffer: Vec<u8>) {
let capacity = buffer.capacity();
if self.buffers.len() < self.max_buffers
&& self.total_capacity + capacity <= self.max_capacity {
self.total_capacity += capacity;
self.buffers.push_back(buffer);
}
}
}
struct ThreadLocalContext {
pool: RefCell<BufferPool>,
config: EntropyContextConfig,
}
thread_local! {
static TLS_CONTEXT: OnceLock<ThreadLocalContext> = const { OnceLock::new() };
}
static GLOBAL_CONTEXT: OnceLock<Arc<Mutex<BufferPool>>> = OnceLock::new();
static GLOBAL_CONFIG: OnceLock<EntropyContextConfig> = OnceLock::new();
pub struct EntropyContext {
pool: Arc<Mutex<BufferPool>>,
config: EntropyContextConfig,
}
impl EntropyContext {
pub fn new() -> Self {
Self::with_config(EntropyContextConfig::default())
}
pub fn with_config(config: EntropyContextConfig) -> Self {
let pool = Arc::new(Mutex::new(BufferPool::new(&config)));
Self { pool, config }
}
pub fn global() -> Self {
let config = GLOBAL_CONFIG.get_or_init(EntropyContextConfig::default);
let pool = GLOBAL_CONTEXT.get_or_init(|| {
Arc::new(Mutex::new(BufferPool::new(config)))
});
Self {
pool: pool.clone(),
config: config.clone(),
}
}
pub fn thread_local() -> Self {
TLS_CONTEXT.with(|ctx| {
let tls_ctx = ctx.get_or_init(|| {
let config = GLOBAL_CONFIG.get_or_init(EntropyContextConfig::default).clone();
ThreadLocalContext {
pool: RefCell::new(BufferPool::new(&config)),
config: config.clone(),
}
});
Self::with_config(tls_ctx.config.clone())
})
}
pub fn alloc(&self, capacity: usize) -> Result<ContextBuffer> {
let buffer = if self.config.thread_local && capacity <= MAX_REUSE_SIZE {
TLS_CONTEXT.with(|ctx| {
if let Some(tls_ctx) = ctx.get() {
if let Ok(mut pool) = tls_ctx.pool.try_borrow_mut() {
return pool.get_buffer(capacity);
}
}
Vec::with_capacity(capacity)
})
} else {
match self.pool.lock() {
Ok(mut pool) => pool.get_buffer(capacity),
Err(_) => return Err(ZiporaError::resource_exhausted("Context pool lock failed")),
}
};
Ok(ContextBuffer::with_context(buffer, self.pool.clone()))
}
pub fn alloc_zeroed(&self, size: usize) -> Result<ContextBuffer> {
let mut buffer = self.alloc(size)?;
buffer.resize(size, 0);
Ok(buffer)
}
pub fn config(&self) -> &EntropyContextConfig {
&self.config
}
pub fn stats(&self) -> Result<ContextStats> {
match self.pool.lock() {
Ok(pool) => Ok(ContextStats {
cached_buffers: pool.buffers.len(),
total_capacity: pool.total_capacity,
max_capacity: pool.max_capacity,
max_buffers: pool.max_buffers,
}),
Err(_) => Err(ZiporaError::resource_exhausted("Context pool lock failed")),
}
}
pub fn get_buffer(&mut self, size: usize) -> Result<Vec<u8>> {
let buffer = self.alloc(size)?;
let mut vec = buffer.into_vec();
vec.resize(size, 0);
Ok(vec)
}
pub fn get_temp_buffer(&mut self, size: usize) -> Result<Vec<u8>> {
self.get_buffer(size)
}
}
impl Default for EntropyContext {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ContextStats {
pub cached_buffers: usize,
pub total_capacity: usize,
pub max_capacity: usize,
pub max_buffers: usize,
}
pub struct EntropyResult {
pub data: Vec<u8>,
pub buffer: ContextBuffer,
}
impl EntropyResult {
pub fn new(data: Vec<u8>, buffer: ContextBuffer) -> Self {
Self { data, buffer }
}
pub fn into_data(self) -> Vec<u8> {
self.data
}
#[inline]
pub fn data(&self) -> &[u8] {
&self.data
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_buffer_basic() {
let mut buffer = ContextBuffer::new(1024);
assert_eq!(buffer.capacity(), 1024);
assert_eq!(buffer.len(), 0);
assert!(buffer.is_empty());
buffer.resize(100, 42);
assert_eq!(buffer.len(), 100);
assert!(!buffer.is_empty());
assert_eq!(buffer.as_slice()[0], 42);
}
#[test]
fn test_entropy_context_allocation() {
let context = EntropyContext::new();
let buffer1 = context.alloc(1024).unwrap();
assert!(buffer1.capacity() >= 1024);
let buffer2 = context.alloc_zeroed(512).unwrap();
assert_eq!(buffer2.len(), 512);
assert!(buffer2.as_slice().iter().all(|&b| b == 0));
}
#[test]
fn test_buffer_reuse() {
let context = EntropyContext::new();
{
let _buffer = context.alloc(2048).unwrap();
}
let stats = context.stats().unwrap();
println!("Stats after drop: cached_buffers={}, total_capacity={}",
stats.cached_buffers, stats.total_capacity);
let buffer2 = context.alloc(1024).unwrap();
assert!(buffer2.capacity() >= 1024);
}
#[test]
fn test_thread_local_context() {
let _ctx1 = EntropyContext::thread_local();
let _ctx2 = EntropyContext::thread_local();
let buffer = _ctx1.alloc(1024).unwrap();
assert!(buffer.capacity() >= 1024);
}
#[test]
fn test_global_context() {
let ctx1 = EntropyContext::global();
let ctx2 = EntropyContext::global();
let _buffer1 = ctx1.alloc(1024).unwrap();
let _buffer2 = ctx2.alloc(1024).unwrap();
}
#[test]
fn test_entropy_result() {
let context = EntropyContext::new();
let buffer = context.alloc(1024).unwrap();
let data = vec![1, 2, 3, 4, 5];
let result = EntropyResult::new(data.clone(), buffer);
assert_eq!(result.data(), &data);
let extracted_data = result.into_data();
assert_eq!(extracted_data, data);
}
#[test]
fn test_context_config() {
let config = EntropyContextConfig {
capacity: 1024,
max_buffers: 4,
thread_local: false,
zero_copy: true,
};
let context = EntropyContext::with_config(config.clone());
assert_eq!(context.config().capacity, 1024);
assert_eq!(context.config().max_buffers, 4);
assert!(!context.config().thread_local);
assert!(context.config().zero_copy);
}
}