use grafeo_common::memory::buffer::{MemoryConsumer, MemoryRegion, SpillError, priorities};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
pub(crate) struct OperatorSpillState {
name: String,
usage: AtomicUsize,
eviction_requested: AtomicBool,
eviction_target: AtomicUsize,
}
impl OperatorSpillState {
pub(crate) fn new(name: String) -> Self {
Self {
name,
usage: AtomicUsize::new(0),
eviction_requested: AtomicBool::new(false),
eviction_target: AtomicUsize::new(0),
}
}
pub(crate) fn set_usage(&self, bytes: usize) {
self.usage.store(bytes, Ordering::Relaxed);
}
pub(crate) fn usage(&self) -> usize {
self.usage.load(Ordering::Relaxed)
}
pub(crate) fn request_eviction(&self, target_bytes: usize) {
self.eviction_target.store(target_bytes, Ordering::Relaxed);
self.eviction_requested.store(true, Ordering::Release);
}
pub(crate) fn take_eviction_request(&self) -> Option<usize> {
if self.eviction_requested.swap(false, Ordering::AcqRel) {
Some(self.eviction_target.load(Ordering::Relaxed))
} else {
None
}
}
pub(crate) fn name(&self) -> &str {
&self.name
}
}
pub(crate) struct OperatorConsumerAdapter {
state: Arc<OperatorSpillState>,
}
impl OperatorConsumerAdapter {
pub(crate) fn new(state: Arc<OperatorSpillState>) -> Self {
Self { state }
}
}
impl MemoryConsumer for OperatorConsumerAdapter {
fn name(&self) -> &str {
self.state.name()
}
fn memory_usage(&self) -> usize {
self.state.usage()
}
fn eviction_priority(&self) -> u8 {
priorities::EXECUTION_BUFFERS
}
fn region(&self) -> MemoryRegion {
MemoryRegion::ExecutionBuffers
}
fn evict(&self, target_bytes: usize) -> usize {
self.state.request_eviction(target_bytes);
0
}
fn can_spill(&self) -> bool {
true
}
fn spill(&self, target_bytes: usize) -> Result<usize, SpillError> {
self.state.request_eviction(target_bytes);
Ok(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_spill_state_usage_tracking() {
let state = OperatorSpillState::new("test_sort".to_string());
assert_eq!(state.usage(), 0);
state.set_usage(1024);
assert_eq!(state.usage(), 1024);
state.set_usage(0);
assert_eq!(state.usage(), 0);
}
#[test]
fn test_spill_state_eviction_request() {
let state = OperatorSpillState::new("test_agg".to_string());
assert!(state.take_eviction_request().is_none());
state.request_eviction(4096);
assert_eq!(state.take_eviction_request(), Some(4096));
assert!(state.take_eviction_request().is_none());
}
#[test]
fn test_spill_state_multiple_requests_last_wins() {
let state = OperatorSpillState::new("test".to_string());
state.request_eviction(1000);
state.request_eviction(2000);
let target = state.take_eviction_request();
assert!(target.is_some());
assert_eq!(target.unwrap(), 2000);
}
#[test]
fn test_consumer_adapter_reports_correct_metadata() {
let state = Arc::new(OperatorSpillState::new("sort_op_1".to_string()));
state.set_usage(8192);
let adapter = OperatorConsumerAdapter::new(Arc::clone(&state));
assert_eq!(adapter.name(), "sort_op_1");
assert_eq!(adapter.memory_usage(), 8192);
assert_eq!(adapter.eviction_priority(), priorities::EXECUTION_BUFFERS);
assert_eq!(adapter.region(), MemoryRegion::ExecutionBuffers);
assert!(adapter.can_spill());
}
#[test]
fn test_consumer_adapter_evict_sets_flag() {
let state = Arc::new(OperatorSpillState::new("agg_op".to_string()));
let adapter = OperatorConsumerAdapter::new(Arc::clone(&state));
let freed = adapter.evict(4096);
assert_eq!(freed, 0);
assert_eq!(state.take_eviction_request(), Some(4096));
}
#[test]
fn test_consumer_adapter_spill_sets_flag() {
let state = Arc::new(OperatorSpillState::new("sort_op".to_string()));
let adapter = OperatorConsumerAdapter::new(Arc::clone(&state));
let result = adapter.spill(2048);
assert!(result.is_ok());
assert_eq!(result.unwrap(), 0);
assert_eq!(state.take_eviction_request(), Some(2048));
}
}