use std::sync::Arc;
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation};
use crate::engine::EngineId;
use crate::governor::MemoryGovernor;
use crate::pressure::PressureLevel;
#[derive(Debug)]
pub struct GovernedMemoryPool {
governor: Arc<MemoryGovernor>,
engine: EngineId,
}
impl GovernedMemoryPool {
pub fn new(governor: Arc<MemoryGovernor>, engine: EngineId) -> Self {
Self { governor, engine }
}
pub fn for_queries(governor: Arc<MemoryGovernor>) -> Self {
Self::new(governor, EngineId::Query)
}
pub fn pressure(&self) -> PressureLevel {
self.governor.engine_pressure(self.engine)
}
}
impl MemoryPool for GovernedMemoryPool {
fn grow(&self, _reservation: &MemoryReservation, additional: usize) {
if self.governor.try_reserve(self.engine, additional).is_err() {
tracing::warn!(
engine = %self.engine,
additional,
allocated = self.governor.budget(self.engine)
.map(|b| b.allocated()).unwrap_or(0),
"infallible grow exceeded budget — subsequent try_grow will fail"
);
}
}
fn shrink(&self, _reservation: &MemoryReservation, shrink: usize) {
self.governor.release(self.engine, shrink);
}
fn try_grow(
&self,
_reservation: &MemoryReservation,
additional: usize,
) -> datafusion_common::Result<()> {
self.governor
.try_reserve(self.engine, additional)
.map_err(|e| {
datafusion_common::DataFusionError::ResourcesExhausted(format!(
"query memory budget exhausted: {e}"
))
})
}
fn reserved(&self) -> usize {
self.governor
.budget(self.engine)
.map(|b| b.allocated())
.unwrap_or(0)
}
fn register(&self, _consumer: &MemoryConsumer) {
tracing::debug!(
consumer = _consumer.name(),
engine = %self.engine,
"DataFusion consumer registered"
);
}
fn unregister(&self, _consumer: &MemoryConsumer) {
tracing::debug!(
consumer = _consumer.name(),
engine = %self.engine,
"DataFusion consumer unregistered"
);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::governor::GovernorConfig;
use std::collections::HashMap;
fn test_governor(query_budget: usize) -> Arc<MemoryGovernor> {
let mut engine_limits = HashMap::new();
engine_limits.insert(EngineId::Query, query_budget);
engine_limits.insert(EngineId::Vector, 1024 * 1024);
let config = GovernorConfig {
global_ceiling: query_budget + 1024 * 1024,
engine_limits,
};
Arc::new(MemoryGovernor::new(config).unwrap())
}
fn make_pool(gov: Arc<MemoryGovernor>) -> Arc<dyn MemoryPool> {
Arc::new(GovernedMemoryPool::for_queries(gov))
}
#[test]
fn try_grow_within_budget_succeeds() {
let gov = test_governor(1024 * 1024);
let pool = make_pool(gov);
let mut reservation = MemoryConsumer::new("test_sort").register(&pool);
assert!(reservation.try_grow(512 * 1024).is_ok());
assert_eq!(reservation.size(), 512 * 1024);
}
#[test]
fn try_grow_exceeding_budget_fails() {
let gov = test_governor(1024);
let pool = make_pool(gov);
let mut reservation = MemoryConsumer::new("big_aggregate").register(&pool);
assert!(reservation.try_grow(512).is_ok());
let err = reservation.try_grow(1024);
assert!(err.is_err());
let msg = err.unwrap_err().to_string();
assert!(msg.contains("budget exhausted"), "got: {msg}");
}
#[test]
fn shrink_frees_budget() {
let gov = test_governor(1024);
let pool = make_pool(gov);
let mut reservation = MemoryConsumer::new("test").register(&pool);
reservation.try_grow(1024).unwrap();
assert!(reservation.try_grow(1).is_err());
reservation.shrink(512);
assert!(reservation.try_grow(512).is_ok());
}
#[test]
fn reserved_tracks_allocations() {
let gov = test_governor(4096);
let pool = make_pool(gov);
assert_eq!(pool.reserved(), 0);
let mut reservation = MemoryConsumer::new("test").register(&pool);
reservation.try_grow(1000).unwrap();
assert_eq!(pool.reserved(), 1000);
reservation.shrink(600);
assert_eq!(pool.reserved(), 400);
}
#[test]
fn pressure_reflects_utilization() {
let gov = test_governor(1000);
let pool = GovernedMemoryPool::for_queries(Arc::clone(&gov));
assert_eq!(pool.pressure(), PressureLevel::Normal);
gov.try_reserve(EngineId::Query, 850).unwrap();
assert_eq!(pool.pressure(), PressureLevel::Critical);
gov.try_reserve(EngineId::Query, 110).unwrap();
assert_eq!(pool.pressure(), PressureLevel::Emergency);
}
}