use bytes::BytesMut;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
static BUFFER_POOL: once_cell::sync::Lazy<Arc<MemoryPool>> =
once_cell::sync::Lazy::new(|| Arc::new(MemoryPool::new()));
static STRING_INTERNER: once_cell::sync::Lazy<Arc<StringInterner>> =
once_cell::sync::Lazy::new(|| Arc::new(StringInterner::new()));
#[derive(Debug, Clone)]
pub struct MemoryPoolConfig {
pub max_buffers_per_size: usize,
pub size_classes: Vec<usize>,
pub max_interned_strings: usize,
}
impl Default for MemoryPoolConfig {
fn default() -> Self {
Self {
max_buffers_per_size: 1000,
size_classes: vec![
64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536,
],
max_interned_strings: 10000,
}
}
}
pub struct MemoryPool {
pools: Mutex<HashMap<usize, Vec<BytesMut>>>,
config: MemoryPoolConfig,
}
impl Default for MemoryPool {
fn default() -> Self {
Self::new()
}
}
impl MemoryPool {
pub fn new() -> Self {
Self::with_config(MemoryPoolConfig::default())
}
pub fn with_config(config: MemoryPoolConfig) -> Self {
Self {
pools: Mutex::new(HashMap::new()),
config,
}
}
pub fn global() -> &'static Arc<MemoryPool> {
&BUFFER_POOL
}
pub async fn get_buffer(&self, capacity: usize) -> BytesMut {
let size_class = self.find_size_class(capacity);
let mut pools = self.pools.lock().await;
if let Some(buffers) = pools.get_mut(&size_class)
&& let Some(buffer) = buffers.pop()
{
let mut buffer = buffer;
buffer.clear();
if buffer.capacity() < capacity {
buffer.reserve(capacity - buffer.capacity());
}
return buffer;
}
BytesMut::with_capacity(size_class)
}
pub async fn return_buffer(&self, mut buffer: BytesMut) {
let size_class = self.find_size_class(buffer.capacity());
let mut pools = self.pools.lock().await;
let pool = pools.entry(size_class).or_insert_with(Vec::new);
if pool.len() < self.config.max_buffers_per_size {
buffer.clear();
pool.push(buffer);
}
}
fn find_size_class(&self, capacity: usize) -> usize {
for &size in &self.config.size_classes {
if size >= capacity {
return size;
}
}
capacity.next_power_of_two().max(64)
}
pub async fn stats(&self) -> MemoryPoolStats {
let pools = self.pools.lock().await;
let mut total_buffers = 0;
let mut total_memory = 0;
for (size_class, buffers) in &*pools {
total_buffers += buffers.len();
total_memory += buffers.len() * size_class;
}
MemoryPoolStats {
total_buffers,
total_memory_bytes: total_memory,
pools_count: pools.len(),
}
}
}
#[derive(Debug, Clone)]
pub struct MemoryPoolStats {
pub total_buffers: usize,
pub total_memory_bytes: usize,
pub pools_count: usize,
}
pub struct StringInterner {
strings: Mutex<HashMap<String, Arc<str>>>,
config: MemoryPoolConfig,
}
impl Default for StringInterner {
fn default() -> Self {
Self::new()
}
}
impl StringInterner {
pub fn new() -> Self {
Self::with_config(MemoryPoolConfig::default())
}
pub fn with_config(config: MemoryPoolConfig) -> Self {
Self {
strings: Mutex::new(HashMap::new()),
config,
}
}
pub fn global() -> &'static Arc<StringInterner> {
&STRING_INTERNER
}
pub async fn intern(&self, s: &str) -> Arc<str> {
let mut strings = self.strings.lock().await;
if let Some(interned) = strings.get(s) {
return Arc::clone(interned);
}
if strings.len() >= self.config.max_interned_strings {
let to_remove: Vec<String> = strings
.keys()
.take(strings.len() - self.config.max_interned_strings + 100) .cloned()
.collect();
for key in to_remove {
strings.remove(&key);
}
}
let interned: Arc<str> = Arc::from(s);
strings.insert(s.to_string(), Arc::clone(&interned));
interned
}
pub fn intern_sync(&self, s: &str) -> Arc<str> {
futures::executor::block_on(self.intern(s))
}
pub async fn stats(&self) -> StringInternerStats {
let strings = self.strings.lock().await;
StringInternerStats {
total_strings: strings.len(),
total_memory_bytes: strings.values().map(|s| s.len()).sum(),
}
}
}
#[derive(Debug, Clone)]
pub struct StringInternerStats {
pub total_strings: usize,
pub total_memory_bytes: usize,
}
pub mod global {
use super::*;
pub async fn get_buffer(capacity: usize) -> BytesMut {
MemoryPool::global().get_buffer(capacity).await
}
pub async fn return_buffer(buffer: BytesMut) {
MemoryPool::global().return_buffer(buffer).await
}
pub async fn intern_string(s: &str) -> Arc<str> {
StringInterner::global().intern(s).await
}
pub fn intern_string_sync(s: &str) -> Arc<str> {
StringInterner::global().intern_sync(s)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_buffer_pooling() {
let pool = MemoryPool::new();
let mut buffer = pool.get_buffer(100).await;
assert!(buffer.capacity() >= 100);
buffer.extend_from_slice(b"hello world");
assert_eq!(&buffer[..], b"hello world");
pool.return_buffer(buffer).await;
let buffer2 = pool.get_buffer(100).await;
assert!(buffer2.capacity() >= 100);
let stats = pool.stats().await;
assert_eq!(stats.total_buffers, 0); }
#[tokio::test]
async fn test_string_interning() {
let interner = StringInterner::new();
let s1 = interner.intern("hello").await;
let s2 = interner.intern("hello").await;
let s3 = interner.intern("world").await;
assert!(Arc::ptr_eq(&s1, &s2));
assert!(!Arc::ptr_eq(&s1, &s3));
assert_eq!(&*s1, "hello");
assert_eq!(&*s3, "world");
let stats = interner.stats().await;
assert_eq!(stats.total_strings, 2);
assert!(stats.total_memory_bytes >= 10); }
#[tokio::test]
async fn test_global_access() {
let buffer = global::get_buffer(64).await;
assert!(buffer.capacity() >= 64);
let s1 = global::intern_string("test").await;
let s2 = global::intern_string("test").await;
assert!(Arc::ptr_eq(&s1, &s2));
}
#[tokio::test]
async fn test_size_classes() {
let pool = MemoryPool::new();
let buffer64 = pool.get_buffer(32).await; let buffer128 = pool.get_buffer(100).await;
assert!(buffer64.capacity() >= 64);
assert!(buffer128.capacity() >= 128);
}
#[tokio::test]
async fn test_pool_limits() {
let config = MemoryPoolConfig {
max_buffers_per_size: 2,
..Default::default()
};
let pool = MemoryPool::with_config(config);
let b1 = pool.get_buffer(64).await;
let b2 = pool.get_buffer(64).await;
let b3 = pool.get_buffer(64).await;
pool.return_buffer(b1).await;
pool.return_buffer(b2).await;
pool.return_buffer(b3).await;
let stats = pool.stats().await;
assert_eq!(stats.total_buffers, 2); }
}