use datafusion_common::{Result, internal_datafusion_err};
use std::hash::{Hash, Hasher};
use std::{cmp::Ordering, sync::Arc, sync::atomic};
mod pool;
#[cfg(feature = "arrow_buffer_pool")]
pub mod arrow;
pub mod proxy {
pub use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt};
}
pub use datafusion_common::{
human_readable_count, human_readable_duration, human_readable_size, units,
};
pub use pool::*;
pub trait MemoryPool: Send + Sync + std::fmt::Debug {
fn register(&self, _consumer: &MemoryConsumer) {}
fn unregister(&self, _consumer: &MemoryConsumer) {}
fn grow(&self, reservation: &MemoryReservation, additional: usize);
fn shrink(&self, reservation: &MemoryReservation, shrink: usize);
fn try_grow(&self, reservation: &MemoryReservation, additional: usize) -> Result<()>;
fn reserved(&self) -> usize;
fn memory_limit(&self) -> MemoryLimit {
MemoryLimit::Unknown
}
}
pub enum MemoryLimit {
Infinite,
Finite(usize),
Unknown,
}
#[derive(Debug)]
pub struct MemoryConsumer {
name: String,
can_spill: bool,
id: usize,
}
impl PartialEq for MemoryConsumer {
fn eq(&self, other: &Self) -> bool {
let is_same_id = self.id == other.id;
#[cfg(debug_assertions)]
if is_same_id {
assert_eq!(self.name, other.name);
assert_eq!(self.can_spill, other.can_spill);
}
is_same_id
}
}
impl Eq for MemoryConsumer {}
impl Hash for MemoryConsumer {
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state);
self.name.hash(state);
self.can_spill.hash(state);
}
}
impl MemoryConsumer {
fn new_unique_id() -> usize {
static ID: atomic::AtomicUsize = atomic::AtomicUsize::new(0);
ID.fetch_add(1, atomic::Ordering::Relaxed)
}
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
can_spill: false,
id: Self::new_unique_id(),
}
}
pub fn clone_with_new_id(&self) -> Self {
Self {
name: self.name.clone(),
can_spill: self.can_spill,
id: Self::new_unique_id(),
}
}
pub fn id(&self) -> usize {
self.id
}
pub fn with_can_spill(self, can_spill: bool) -> Self {
Self { can_spill, ..self }
}
pub fn can_spill(&self) -> bool {
self.can_spill
}
pub fn name(&self) -> &str {
&self.name
}
pub fn register(self, pool: &Arc<dyn MemoryPool>) -> MemoryReservation {
pool.register(&self);
MemoryReservation {
registration: Arc::new(SharedRegistration {
pool: Arc::clone(pool),
consumer: self,
}),
size: atomic::AtomicUsize::new(0),
}
}
}
#[derive(Debug)]
struct SharedRegistration {
pool: Arc<dyn MemoryPool>,
consumer: MemoryConsumer,
}
impl Drop for SharedRegistration {
fn drop(&mut self) {
self.pool.unregister(&self.consumer);
}
}
#[derive(Debug)]
pub struct MemoryReservation {
registration: Arc<SharedRegistration>,
size: atomic::AtomicUsize,
}
impl MemoryReservation {
pub fn size(&self) -> usize {
self.size.load(atomic::Ordering::Relaxed)
}
pub fn consumer(&self) -> &MemoryConsumer {
&self.registration.consumer
}
pub fn free(&self) -> usize {
let size = self.size.swap(0, atomic::Ordering::Relaxed);
if size != 0 {
self.registration.pool.shrink(self, size);
}
size
}
pub fn shrink(&self, capacity: usize) {
self.size
.fetch_update(
atomic::Ordering::Relaxed,
atomic::Ordering::Relaxed,
|prev| prev.checked_sub(capacity),
)
.unwrap_or_else(|prev| {
panic!("Cannot free the capacity {capacity} out of allocated size {prev}")
});
self.registration.pool.shrink(self, capacity);
}
pub fn try_shrink(&self, capacity: usize) -> Result<usize> {
let prev = self
.size
.fetch_update(
atomic::Ordering::Relaxed,
atomic::Ordering::Relaxed,
|prev| prev.checked_sub(capacity),
)
.map_err(|prev| {
internal_datafusion_err!(
"Cannot free the capacity {capacity} out of allocated size {prev}"
)
})?;
self.registration.pool.shrink(self, capacity);
Ok(prev - capacity)
}
pub fn resize(&self, capacity: usize) {
let size = self.size.load(atomic::Ordering::Relaxed);
match capacity.cmp(&size) {
Ordering::Greater => self.grow(capacity - size),
Ordering::Less => self.shrink(size - capacity),
_ => {}
}
}
pub fn try_resize(&self, capacity: usize) -> Result<()> {
let size = self.size.load(atomic::Ordering::Relaxed);
match capacity.cmp(&size) {
Ordering::Greater => self.try_grow(capacity - size)?,
Ordering::Less => {
self.try_shrink(size - capacity)?;
}
_ => {}
};
Ok(())
}
pub fn grow(&self, capacity: usize) {
self.registration.pool.grow(self, capacity);
self.size.fetch_add(capacity, atomic::Ordering::Relaxed);
}
pub fn try_grow(&self, capacity: usize) -> Result<()> {
self.registration.pool.try_grow(self, capacity)?;
self.size.fetch_add(capacity, atomic::Ordering::Relaxed);
Ok(())
}
pub fn split(&self, capacity: usize) -> MemoryReservation {
self.size
.fetch_update(
atomic::Ordering::Relaxed,
atomic::Ordering::Relaxed,
|prev| prev.checked_sub(capacity),
)
.unwrap();
Self {
size: atomic::AtomicUsize::new(capacity),
registration: Arc::clone(&self.registration),
}
}
pub fn new_empty(&self) -> Self {
Self {
size: atomic::AtomicUsize::new(0),
registration: Arc::clone(&self.registration),
}
}
pub fn take(&mut self) -> MemoryReservation {
self.split(self.size.load(atomic::Ordering::Relaxed))
}
}
impl Drop for MemoryReservation {
fn drop(&mut self) {
self.free();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_id_uniqueness() {
let mut ids = std::collections::HashSet::new();
for _ in 0..100 {
let consumer = MemoryConsumer::new("test");
assert!(ids.insert(consumer.id())); }
}
#[test]
fn test_memory_pool_underflow() {
let pool = Arc::new(GreedyMemoryPool::new(50)) as _;
let a1 = MemoryConsumer::new("a1").register(&pool);
assert_eq!(pool.reserved(), 0);
a1.grow(100);
assert_eq!(pool.reserved(), 100);
assert_eq!(a1.free(), 100);
assert_eq!(pool.reserved(), 0);
a1.try_grow(100).unwrap_err();
assert_eq!(pool.reserved(), 0);
a1.try_grow(30).unwrap();
assert_eq!(pool.reserved(), 30);
let a2 = MemoryConsumer::new("a2").register(&pool);
a2.try_grow(25).unwrap_err();
assert_eq!(pool.reserved(), 30);
drop(a1);
assert_eq!(pool.reserved(), 0);
a2.try_grow(25).unwrap();
assert_eq!(pool.reserved(), 25);
}
#[test]
fn test_split() {
let pool = Arc::new(GreedyMemoryPool::new(50)) as _;
let r1 = MemoryConsumer::new("r1").register(&pool);
r1.try_grow(20).unwrap();
assert_eq!(r1.size(), 20);
assert_eq!(pool.reserved(), 20);
let r2 = r1.split(5);
assert_eq!(r1.size(), 15);
assert_eq!(r2.size(), 5);
assert_eq!(pool.reserved(), 20);
drop(r1);
assert_eq!(r2.size(), 5);
assert_eq!(pool.reserved(), 5);
}
#[test]
fn test_new_empty() {
let pool = Arc::new(GreedyMemoryPool::new(50)) as _;
let r1 = MemoryConsumer::new("r1").register(&pool);
r1.try_grow(20).unwrap();
let r2 = r1.new_empty();
r2.try_grow(5).unwrap();
assert_eq!(r1.size(), 20);
assert_eq!(r2.size(), 5);
assert_eq!(pool.reserved(), 25);
}
#[test]
fn test_take() {
let pool = Arc::new(GreedyMemoryPool::new(50)) as _;
let mut r1 = MemoryConsumer::new("r1").register(&pool);
r1.try_grow(20).unwrap();
let r2 = r1.take();
r2.try_grow(5).unwrap();
assert_eq!(r1.size(), 0);
assert_eq!(r2.size(), 25);
assert_eq!(pool.reserved(), 25);
r1.try_grow(3).unwrap();
assert_eq!(r1.size(), 3);
assert_eq!(r2.size(), 25);
assert_eq!(pool.reserved(), 28);
}
#[test]
fn test_try_shrink() {
let pool = Arc::new(GreedyMemoryPool::new(100)) as _;
let r1 = MemoryConsumer::new("r1").register(&pool);
r1.try_grow(50).unwrap();
assert_eq!(r1.size(), 50);
assert_eq!(pool.reserved(), 50);
let new_size = r1.try_shrink(30).unwrap();
assert_eq!(new_size, 20);
assert_eq!(r1.size(), 20);
assert_eq!(pool.reserved(), 20);
let r2 = MemoryConsumer::new("r2").register(&pool);
r2.try_grow(80).unwrap();
assert_eq!(pool.reserved(), 100);
let err = r1.try_shrink(25);
assert!(err.is_err());
assert_eq!(r1.size(), 20);
assert_eq!(pool.reserved(), 100);
let new_size = r1.try_shrink(20).unwrap();
assert_eq!(new_size, 0);
assert_eq!(r1.size(), 0);
assert_eq!(pool.reserved(), 80);
}
}