use alloc::vec::Vec;
use cubecl_common::bytes::Bytes;
use crate::memory_management::{
drop_queue::FlushingPolicy, drop_queue::policy::FlushingPolicyState,
};
pub trait Fence {
fn sync(self);
}
pub struct PendingDropQueue<E: Fence> {
fence: Option<E>,
pending: Vec<Bytes>,
staged: Vec<Bytes>,
policy: FlushingPolicy,
policy_state: FlushingPolicyState,
}
impl<E: Fence> core::fmt::Debug for PendingDropQueue<E> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("PendingDropQueue")
.field("pending", &self.pending)
.field("staged", &self.staged)
.field("policy", &self.policy)
.field("policy_state", &self.policy_state)
.finish()
}
}
impl<E: Fence> Default for PendingDropQueue<E> {
fn default() -> Self {
Self::new(Default::default())
}
}
impl<F: Fence> PendingDropQueue<F> {
pub fn new(policy: FlushingPolicy) -> Self {
Self {
fence: None,
pending: Vec::new(),
staged: Vec::new(),
policy,
policy_state: Default::default(),
}
}
pub fn push(&mut self, bytes: Bytes) {
self.policy_state.register(&bytes);
self.staged.push(bytes);
}
pub fn should_flush(&self) -> bool {
self.policy_state.should_flush(&self.policy)
}
pub fn flush<Factory: Fn() -> F>(&mut self, factory: Factory) {
if let Some(event) = self.fence.take() {
event.sync();
self.pending.clear();
}
if !self.pending.is_empty() {
let event = factory();
event.sync();
self.pending.clear();
}
core::mem::swap(&mut self.pending, &mut self.staged);
self.fence = Some(factory());
self.policy_state.reset();
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
use core::cell::Cell;
#[derive(Clone)]
struct MockFence<'a> {
sync_count: &'a Cell<u32>,
}
impl Fence for MockFence<'_> {
fn sync(self) {
self.sync_count.set(self.sync_count.get() + 1);
}
}
fn make_queue<'a>(
sync_count: &'a Cell<u32>,
) -> (
PendingDropQueue<MockFence<'a>>,
impl Fn() -> MockFence<'a> + 'a,
) {
let queue = PendingDropQueue::new(test_policy());
let factory = move || MockFence { sync_count };
(queue, factory)
}
fn sample_bytes() -> Bytes {
Bytes::from_elems(vec![1u8, 2, 3])
}
fn test_policy() -> FlushingPolicy {
FlushingPolicy {
max_bytes_count: 2048,
max_bytes_size: 8,
}
}
#[test]
fn push_at_count_threshold_triggers_flush_hint() {
let sync_count = Cell::new(0u32);
let (mut queue, _factory) = make_queue(&sync_count);
for _ in 0..test_policy().max_bytes_count {
queue.push(sample_bytes());
}
assert!(queue.should_flush());
}
#[test]
fn push_large_allocation_triggers_flush_via_size_threshold() {
let sync_count = Cell::new(0u32);
let (mut queue, _factory) = make_queue(&sync_count);
let big = Bytes::from_elems(vec![0u8; test_policy().max_bytes_size as usize + 1]);
queue.push(big);
assert!(queue.should_flush());
}
#[test]
fn first_flush_creates_fence_without_syncing() {
let sync_count = Cell::new(0u32);
let (mut queue, factory) = make_queue(&sync_count);
queue.push(sample_bytes());
queue.flush(&factory);
assert_eq!(
sync_count.get(),
0,
"fence should not be synced on first flush"
);
}
#[test]
fn second_flush_syncs_fence_from_first_flush() {
let sync_count = Cell::new(0u32);
let (mut queue, factory) = make_queue(&sync_count);
queue.push(sample_bytes());
queue.flush(&factory);
queue.push(sample_bytes());
queue.flush(&factory);
assert_eq!(sync_count.get(), 1, "exactly one sync after two flushes");
}
#[test]
fn each_subsequent_flush_syncs_the_previous_fence() {
let sync_count = Cell::new(0u32);
let (mut queue, factory) = make_queue(&sync_count);
for _ in 0..10 {
queue.push(sample_bytes());
queue.flush(&factory);
}
assert_eq!(sync_count.get(), 9);
}
#[test]
fn staged_is_empty_after_flush() {
let sync_count = Cell::new(0u32);
let (mut queue, factory) = make_queue(&sync_count);
for _ in 0..5 {
queue.push(sample_bytes());
}
queue.flush(&factory);
assert!(queue.staged.is_empty());
}
#[test]
fn pending_holds_previously_staged_bytes_after_flush() {
let sync_count = Cell::new(0u32);
let (mut queue, factory) = make_queue(&sync_count);
for _ in 0..5 {
queue.push(sample_bytes());
}
queue.flush(&factory);
assert_eq!(queue.pending.len(), 5);
}
#[test]
fn pending_is_replaced_on_second_flush() {
let sync_count = Cell::new(0u32);
let (mut queue, factory) = make_queue(&sync_count);
for _ in 0..5 {
queue.push(sample_bytes());
}
queue.flush(&factory);
queue.push(sample_bytes());
queue.flush(&factory);
assert_eq!(queue.pending.len(), 1);
}
#[test]
fn should_flush_resets_after_flush() {
let sync_count = Cell::new(0u32);
let (mut queue, factory) = make_queue(&sync_count);
for _ in 0..test_policy().max_bytes_count {
queue.push(sample_bytes());
}
assert!(queue.should_flush());
queue.flush(&factory);
assert!(
!queue.should_flush(),
"policy state should be reset after flush"
);
}
#[test]
fn flush_on_empty_queue_is_safe() {
let sync_count = Cell::new(0u32);
let (mut queue, factory) = make_queue(&sync_count);
queue.flush(&factory);
queue.flush(&factory);
queue.flush(&factory);
}
}