use crossbeam_queue::ArrayQueue;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
const SMALL_BUF_CAPACITY: usize = 64 * 1024;
const LARGE_BUF_CAPACITY: usize = 8 * 1024 * 1024;
const HUGE_BUF_CAPACITY: usize = 64 * 1024 * 1024;
const GIANT_BUF_CAPACITY: usize = 256 * 1024 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PoolProfile {
Training,
Inference,
}
#[derive(Debug, Clone, Copy)]
pub struct TierConfig {
pub count: usize,
pub capacity: usize,
}
pub struct PoolBuilder {
small: TierConfig,
large: TierConfig,
huge: TierConfig,
giant: TierConfig,
}
impl PoolBuilder {
pub fn new() -> Self {
Self {
small: TierConfig {
count: 1,
capacity: SMALL_BUF_CAPACITY,
},
large: TierConfig {
count: 1,
capacity: LARGE_BUF_CAPACITY,
},
huge: TierConfig {
count: 1,
capacity: HUGE_BUF_CAPACITY,
},
giant: TierConfig {
count: 0,
capacity: GIANT_BUF_CAPACITY,
},
}
}
pub fn small(mut self, config: TierConfig) -> Self {
self.small = config;
self
}
pub fn large(mut self, config: TierConfig) -> Self {
self.large = config;
self
}
pub fn huge(mut self, config: TierConfig) -> Self {
self.huge = config;
self
}
pub fn giant(mut self, config: TierConfig) -> Self {
self.giant = config;
self
}
pub fn build(self) -> Arc<BufferPool> {
BufferPool::from_tier_configs(self.small, self.large, self.huge, self.giant)
}
}
impl Default for PoolBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct BufferPool {
small: ArrayQueue<Vec<u8>>,
large: ArrayQueue<Vec<u8>>,
huge: ArrayQueue<Vec<u8>>,
giant: ArrayQueue<Vec<u8>>,
small_cap: usize,
large_cap: usize,
huge_cap: usize,
giant_cap: usize,
}
impl BufferPool {
pub fn new() -> Arc<Self> {
Self::with_profile(PoolProfile::Training)
}
pub fn with_profile(profile: PoolProfile) -> Arc<Self> {
let (small_count, large_count, huge_count, giant_count) = match profile {
PoolProfile::Training => (256, 32, 16, 4),
PoolProfile::Inference => (256, 16, 4, 0),
};
Self::from_tier_configs(
TierConfig {
count: small_count,
capacity: SMALL_BUF_CAPACITY,
},
TierConfig {
count: large_count,
capacity: LARGE_BUF_CAPACITY,
},
TierConfig {
count: huge_count,
capacity: HUGE_BUF_CAPACITY,
},
TierConfig {
count: giant_count,
capacity: GIANT_BUF_CAPACITY,
},
)
}
pub fn with_config(small_pool_size: usize, small_buf_cap: usize) -> Arc<Self> {
Self::from_tier_configs(
TierConfig {
count: small_pool_size,
capacity: small_buf_cap,
},
TierConfig {
count: 4,
capacity: LARGE_BUF_CAPACITY,
},
TierConfig {
count: 2,
capacity: HUGE_BUF_CAPACITY,
},
TierConfig {
count: 1,
capacity: GIANT_BUF_CAPACITY,
},
)
}
fn from_tier_configs(
small: TierConfig,
large: TierConfig,
huge: TierConfig,
giant: TierConfig,
) -> Arc<Self> {
Arc::new(Self {
small: ArrayQueue::new(small.count.max(1)),
large: ArrayQueue::new(large.count.max(1)),
huge: ArrayQueue::new(huge.count.max(1)),
giant: ArrayQueue::new(giant.count.max(1)),
small_cap: small.capacity,
large_cap: large.capacity,
huge_cap: huge.capacity,
giant_cap: giant.capacity,
})
}
pub fn checkout(self: &Arc<Self>, len: usize) -> PooledBuf {
let (queue, tier, capacity) = self.tier_for_size(len);
let mut buf = match queue {
Some(q) => q.pop().unwrap_or_else(|| Vec::with_capacity(capacity)),
None => Vec::with_capacity(len),
};
buf.resize(len, 0);
PooledBuf {
buf: Some(buf),
pool: Arc::clone(self),
tier,
}
}
fn tier_for_size(&self, len: usize) -> (Option<&ArrayQueue<Vec<u8>>>, PoolTier, usize) {
if len <= self.small_cap {
(Some(&self.small), PoolTier::Small, self.small_cap)
} else if len <= self.large_cap {
(Some(&self.large), PoolTier::Large, self.large_cap)
} else if len <= self.huge_cap {
(Some(&self.huge), PoolTier::Huge, self.huge_cap)
} else if len <= self.giant_cap {
(Some(&self.giant), PoolTier::Giant, self.giant_cap)
} else {
(None, PoolTier::Unpooled, len)
}
}
fn return_buf(&self, mut buf: Vec<u8>, tier: PoolTier) {
let (queue, max_cap) = match tier {
PoolTier::Small => (Some(&self.small), self.small_cap * 4),
PoolTier::Large => (Some(&self.large), self.large_cap * 4),
PoolTier::Huge => (Some(&self.huge), self.huge_cap * 4),
PoolTier::Giant => (Some(&self.giant), self.giant_cap * 4),
PoolTier::Unpooled => (None, 0),
};
if let Some(q) = queue
&& buf.capacity() <= max_cap
{
buf.clear();
let _ = q.push(buf);
}
}
}
#[derive(Debug, Clone, Copy)]
enum PoolTier {
Small,
Large,
Huge,
Giant,
Unpooled,
}
pub struct PooledBuf {
buf: Option<Vec<u8>>,
pool: Arc<BufferPool>,
tier: PoolTier,
}
impl PooledBuf {
pub fn from_vec(v: Vec<u8>, pool: Arc<BufferPool>) -> Self {
let len = v.len();
let (_, tier, _) = pool.tier_for_size(len);
Self {
buf: Some(v),
pool,
tier,
}
}
}
impl Deref for PooledBuf {
type Target = [u8];
fn deref(&self) -> &[u8] {
self.buf.as_ref().expect("PooledBuf used after drop")
}
}
impl DerefMut for PooledBuf {
fn deref_mut(&mut self) -> &mut [u8] {
self.buf.as_mut().expect("PooledBuf used after drop")
}
}
impl Drop for PooledBuf {
fn drop(&mut self) {
if let Some(buf) = self.buf.take() {
self.pool.return_buf(buf, self.tier);
}
}
}
impl AsRef<[u8]> for PooledBuf {
fn as_ref(&self) -> &[u8] {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_checkout_zeroed_and_deref_mut() {
let pool = BufferPool::with_config(4, 1024);
let mut buf = pool.checkout(100);
assert_eq!(buf.len(), 100);
assert!(buf.iter().all(|&b| b == 0));
buf[0] = 0xAA;
assert_eq!(buf[0], 0xAA);
}
#[test]
fn test_exhaustion_fallback_and_return() {
let pool = BufferPool::with_config(1, 64);
let b1 = pool.checkout(10);
let b2 = pool.checkout(10); assert_eq!(b2.len(), 10);
drop(b1); drop(b2); }
#[test]
fn test_tier_selection() {
let pool = BufferPool::new();
let m = 1024 * 1024;
for &size in &[100, m, 32 * m, 128 * m, 512 * m] {
let buf = pool.checkout(size);
assert_eq!(buf.len(), size);
}
}
fn assert_checkout(pool: &Arc<BufferPool>, sizes: &[usize]) {
let bufs: Vec<_> = sizes
.iter()
.map(|&s| {
let b = pool.checkout(s);
assert_eq!(b.len(), s);
b
})
.collect();
drop(bufs);
}
#[test]
fn test_workload_profiles() {
let m = 1024 * 1024;
assert_checkout(
&BufferPool::new(),
&[14 * m, 14 * m, 14 * m, 208 * m, 48 * m],
);
assert_checkout(
&BufferPool::with_profile(PoolProfile::Inference),
&[16 * m, 16 * m, 64 * m, m, m, m, m, 128 * m],
);
}
#[test]
fn test_lazy_allocation() {
let pool = BufferPool::with_profile(PoolProfile::Training);
let buf = pool.checkout(1024);
assert_eq!(buf.len(), 1024);
drop(buf);
let buf2 = pool.checkout(512);
assert_eq!(buf2.len(), 512);
}
#[test]
fn test_pool_builder() {
let pool = PoolBuilder::new()
.small(TierConfig {
count: 4,
capacity: 1024,
})
.large(TierConfig {
count: 2,
capacity: 1024 * 1024,
})
.build();
let buf = pool.checkout(500);
assert_eq!(buf.len(), 500);
let pool2 = PoolBuilder::default().build();
let buf2 = pool2.checkout(100);
assert_eq!(buf2.len(), 100);
}
}