Skip to main content

cubecl_runtime/memory_management/drop_queue/
queue.rs

1use alloc::vec::Vec;
2use cubecl_common::bytes::Bytes;
3
4use crate::memory_management::{
5    drop_queue::FlushingPolicy, drop_queue::policy::FlushingPolicyState,
6};
7
8/// A synchronization primitive that blocks until the device has finished
9/// processing all commands submitted before the fence was created.
10pub trait Fence {
11    /// Block the current thread until the signals this fence.
12    fn sync(self);
13}
14
15/// Defers the drop of CPU-side [`Bytes`] allocations until the device has
16/// finished reading them.
17///
18/// # How it works
19///
20/// The device uploads are asynchronous: after you copy bytes into a staging buffer
21/// and enqueue an upload command, the CPU memory must remain valid until the
22/// device is done. `PendingDropQueue` manages this lifetime with a two-phase
23/// approach:
24///
25/// 1. **Stage** – call [`push`](Self::push) to hand over bytes that are
26///    in-flight. They land in the `staged` list.
27/// 2. **Flush** – call [`flush`](Self::flush) to rotate the lists. The
28///    previously staged bytes move to `pending`, a new [`Fence`] is created
29///    to mark the end of the current upload batch, and any bytes that were
30///    *already* pending (i.e. the batch before that) are freed after syncing
31///    the previous fence.
32///
33/// This double-buffer scheme means CPU memory is held for at most two flush
34/// cycles, while avoiding any unnecessary stalls on the hot path.
35///
36/// # Flushing policy
37///
38/// Call [`should_flush`](Self::should_flush) to check whether enough bytes
39/// have accumulated to warrant a flush. You may also flush unconditionally
40/// (e.g. at the end of a frame).
41pub struct PendingDropQueue<E: Fence> {
42    /// Fence signalling that the device has consumed everything in `pending`.
43    fence: Option<E>,
44    /// Bytes from the *previous* flush cycle, kept alive until `event` fires.
45    pending: Vec<Bytes>,
46    /// Bytes queued in the *current* cycle, not yet associated with a fence.
47    staged: Vec<Bytes>,
48    /// The configuration of the queue.
49    policy: FlushingPolicy,
50    /// The current state of the policy.
51    policy_state: FlushingPolicyState,
52}
53
54impl<E: Fence> core::fmt::Debug for PendingDropQueue<E> {
55    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
56        f.debug_struct("PendingDropQueue")
57            .field("pending", &self.pending)
58            .field("staged", &self.staged)
59            .field("policy", &self.policy)
60            .field("policy_state", &self.policy_state)
61            .finish()
62    }
63}
64
65impl<E: Fence> Default for PendingDropQueue<E> {
66    fn default() -> Self {
67        Self::new(Default::default())
68    }
69}
70
71impl<F: Fence> PendingDropQueue<F> {
72    /// Creates a new `PendingDropQueue`.
73    pub fn new(policy: FlushingPolicy) -> Self {
74        Self {
75            fence: None,
76            pending: Vec::new(),
77            staged: Vec::new(),
78            policy,
79            policy_state: Default::default(),
80        }
81    }
82    /// Enqueue `bytes` to be dropped once the device has finished reading them.
83    ///
84    /// The bytes are added to the current staged batch and will be freed on
85    /// the flush cycle *after* the next call to [`flush`](Self::flush).
86    pub fn push(&mut self, bytes: Bytes) {
87        self.policy_state.register(&bytes);
88        self.staged.push(bytes);
89    }
90
91    /// Returns `true` when the staged batch is large enough to justify a
92    /// flush.
93    pub fn should_flush(&self) -> bool {
94        self.policy_state.should_flush(&self.policy)
95    }
96
97    /// Rotate the double-buffer and free any memory the device is done with.
98    ///
99    /// `factory` is called to produce a [`Fence`]. It should submit (or
100    /// record) a device signal command so that syncing the fence guarantees all
101    /// preceding device work is complete.
102    pub fn flush<Factory: Fn() -> F>(&mut self, factory: Factory) {
103        // Sync the fence from the previous flush and free the bytes it was
104        // protecting.
105        if let Some(event) = self.fence.take() {
106            event.sync();
107            self.pending.clear();
108        }
109
110        // Safety net: if pending is somehow still populated (no prior fence),
111        // stall immediately rather than freeing memory the GPU might still
112        // be reading.
113        if !self.pending.is_empty() {
114            let event = factory();
115            event.sync();
116            self.pending.clear();
117        }
118
119        // The current staged batch becomes the new pending batch.
120        core::mem::swap(&mut self.pending, &mut self.staged);
121
122        // Record a fence so the *next* flush knows when this batch is safe to
123        // free.
124        self.fence = Some(factory());
125        self.policy_state.reset();
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use alloc::vec;
133    use core::cell::Cell;
134
135    // ---------------------------------------------------------------------------
136    // Test helpers
137    // ---------------------------------------------------------------------------
138
139    #[derive(Clone)]
140    struct MockFence<'a> {
141        sync_count: &'a Cell<u32>,
142    }
143
144    impl Fence for MockFence<'_> {
145        fn sync(self) {
146            self.sync_count.set(self.sync_count.get() + 1);
147        }
148    }
149
150    fn make_queue<'a>(
151        sync_count: &'a Cell<u32>,
152    ) -> (
153        PendingDropQueue<MockFence<'a>>,
154        impl Fn() -> MockFence<'a> + 'a,
155    ) {
156        let queue = PendingDropQueue::new(test_policy());
157        let factory = move || MockFence { sync_count };
158        (queue, factory)
159    }
160
161    fn sample_bytes() -> Bytes {
162        Bytes::from_elems(vec![1u8, 2, 3])
163    }
164
165    fn test_policy() -> FlushingPolicy {
166        FlushingPolicy {
167            max_bytes_count: 2048,
168            max_bytes_size: 8,
169        }
170    }
171
172    // ---------------------------------------------------------------------------
173    // push / should_flush
174    // ---------------------------------------------------------------------------
175
176    #[test]
177    fn push_at_count_threshold_triggers_flush_hint() {
178        let sync_count = Cell::new(0u32);
179        let (mut queue, _factory) = make_queue(&sync_count);
180
181        for _ in 0..test_policy().max_bytes_count {
182            queue.push(sample_bytes());
183        }
184
185        assert!(queue.should_flush());
186    }
187
188    #[test]
189    fn push_large_allocation_triggers_flush_via_size_threshold() {
190        let sync_count = Cell::new(0u32);
191        let (mut queue, _factory) = make_queue(&sync_count);
192        let big = Bytes::from_elems(vec![0u8; test_policy().max_bytes_size as usize + 1]);
193
194        queue.push(big);
195
196        assert!(queue.should_flush());
197    }
198
199    // ---------------------------------------------------------------------------
200    // flush – fence / sync behaviour
201    // ---------------------------------------------------------------------------
202
203    #[test]
204    fn first_flush_creates_fence_without_syncing() {
205        let sync_count = Cell::new(0u32);
206        let (mut queue, factory) = make_queue(&sync_count);
207
208        queue.push(sample_bytes());
209        queue.flush(&factory);
210
211        // The fence is created but must not be synced yet — that happens on
212        // the next flush.
213        assert_eq!(
214            sync_count.get(),
215            0,
216            "fence should not be synced on first flush"
217        );
218    }
219
220    #[test]
221    fn second_flush_syncs_fence_from_first_flush() {
222        let sync_count = Cell::new(0u32);
223        let (mut queue, factory) = make_queue(&sync_count);
224
225        queue.push(sample_bytes());
226        queue.flush(&factory); // flush 1 – creates fence A
227
228        queue.push(sample_bytes());
229        queue.flush(&factory); // flush 2 – syncs fence A, creates fence B
230
231        assert_eq!(sync_count.get(), 1, "exactly one sync after two flushes");
232    }
233
234    #[test]
235    fn each_subsequent_flush_syncs_the_previous_fence() {
236        let sync_count = Cell::new(0u32);
237        let (mut queue, factory) = make_queue(&sync_count);
238
239        for _ in 0..10 {
240            queue.push(sample_bytes());
241            queue.flush(&factory);
242        }
243
244        // Each flush except the first syncs the fence from the previous one.
245        assert_eq!(sync_count.get(), 9);
246    }
247
248    // ---------------------------------------------------------------------------
249    // flush – buffer rotation
250    // ---------------------------------------------------------------------------
251
252    #[test]
253    fn staged_is_empty_after_flush() {
254        let sync_count = Cell::new(0u32);
255        let (mut queue, factory) = make_queue(&sync_count);
256
257        for _ in 0..5 {
258            queue.push(sample_bytes());
259        }
260        queue.flush(&factory);
261
262        assert!(queue.staged.is_empty());
263    }
264
265    #[test]
266    fn pending_holds_previously_staged_bytes_after_flush() {
267        let sync_count = Cell::new(0u32);
268        let (mut queue, factory) = make_queue(&sync_count);
269
270        for _ in 0..5 {
271            queue.push(sample_bytes());
272        }
273        queue.flush(&factory);
274
275        assert_eq!(queue.pending.len(), 5);
276    }
277
278    #[test]
279    fn pending_is_replaced_on_second_flush() {
280        let sync_count = Cell::new(0u32);
281        let (mut queue, factory) = make_queue(&sync_count);
282
283        for _ in 0..5 {
284            queue.push(sample_bytes());
285        }
286        queue.flush(&factory); // pending = 5 items
287
288        queue.push(sample_bytes());
289        queue.flush(&factory); // syncs fence → pending cleared, rotated
290
291        // Only the one item staged between the two flushes should be pending.
292        assert_eq!(queue.pending.len(), 1);
293    }
294
295    // ---------------------------------------------------------------------------
296    // flush – policy state reset
297    // ---------------------------------------------------------------------------
298
299    #[test]
300    fn should_flush_resets_after_flush() {
301        let sync_count = Cell::new(0u32);
302        let (mut queue, factory) = make_queue(&sync_count);
303
304        for _ in 0..test_policy().max_bytes_count {
305            queue.push(sample_bytes());
306        }
307        assert!(queue.should_flush());
308
309        queue.flush(&factory);
310
311        assert!(
312            !queue.should_flush(),
313            "policy state should be reset after flush"
314        );
315    }
316
317    // ---------------------------------------------------------------------------
318    // Edge cases
319    // ---------------------------------------------------------------------------
320
321    #[test]
322    fn flush_on_empty_queue_is_safe() {
323        let sync_count = Cell::new(0u32);
324        let (mut queue, factory) = make_queue(&sync_count);
325
326        // Should not panic regardless of how many times it is called.
327        queue.flush(&factory);
328        queue.flush(&factory);
329        queue.flush(&factory);
330    }
331}