cubecl_runtime/memory_management/drop_queue/queue.rs
1use alloc::vec::Vec;
2use cubecl_common::bytes::Bytes;
3
4use crate::memory_management::{
5 drop_queue::FlushingPolicy, drop_queue::policy::FlushingPolicyState,
6};
7
8/// A synchronization primitive that blocks until the device has finished
9/// processing all commands submitted before the fence was created.
10pub trait Fence {
11 /// Block the current thread until the signals this fence.
12 fn sync(self);
13}
14
15/// Defers the drop of CPU-side [`Bytes`] allocations until the device has
16/// finished reading them.
17///
18/// # How it works
19///
20/// The device uploads are asynchronous: after you copy bytes into a staging buffer
21/// and enqueue an upload command, the CPU memory must remain valid until the
22/// device is done. `PendingDropQueue` manages this lifetime with a two-phase
23/// approach:
24///
25/// 1. **Stage** – call [`push`](Self::push) to hand over bytes that are
26/// in-flight. They land in the `staged` list.
27/// 2. **Flush** – call [`flush`](Self::flush) to rotate the lists. The
28/// previously staged bytes move to `pending`, a new [`Fence`] is created
29/// to mark the end of the current upload batch, and any bytes that were
30/// *already* pending (i.e. the batch before that) are freed after syncing
31/// the previous fence.
32///
33/// This double-buffer scheme means CPU memory is held for at most two flush
34/// cycles, while avoiding any unnecessary stalls on the hot path.
35///
36/// # Flushing policy
37///
38/// Call [`should_flush`](Self::should_flush) to check whether enough bytes
39/// have accumulated to warrant a flush. You may also flush unconditionally
40/// (e.g. at the end of a frame).
41pub struct PendingDropQueue<E: Fence> {
42 /// Fence signalling that the device has consumed everything in `pending`.
43 fence: Option<E>,
44 /// Bytes from the *previous* flush cycle, kept alive until `event` fires.
45 pending: Vec<Bytes>,
46 /// Bytes queued in the *current* cycle, not yet associated with a fence.
47 staged: Vec<Bytes>,
48 /// The configuration of the queue.
49 policy: FlushingPolicy,
50 /// The current state of the policy.
51 policy_state: FlushingPolicyState,
52}
53
54impl<E: Fence> core::fmt::Debug for PendingDropQueue<E> {
55 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
56 f.debug_struct("PendingDropQueue")
57 .field("pending", &self.pending)
58 .field("staged", &self.staged)
59 .field("policy", &self.policy)
60 .field("policy_state", &self.policy_state)
61 .finish()
62 }
63}
64
65impl<E: Fence> Default for PendingDropQueue<E> {
66 fn default() -> Self {
67 Self::new(Default::default())
68 }
69}
70
71impl<F: Fence> PendingDropQueue<F> {
72 /// Creates a new `PendingDropQueue`.
73 pub fn new(policy: FlushingPolicy) -> Self {
74 Self {
75 fence: None,
76 pending: Vec::new(),
77 staged: Vec::new(),
78 policy,
79 policy_state: Default::default(),
80 }
81 }
82 /// Enqueue `bytes` to be dropped once the device has finished reading them.
83 ///
84 /// The bytes are added to the current staged batch and will be freed on
85 /// the flush cycle *after* the next call to [`flush`](Self::flush).
86 pub fn push(&mut self, bytes: Bytes) {
87 self.policy_state.register(&bytes);
88 self.staged.push(bytes);
89 }
90
91 /// Returns `true` when the staged batch is large enough to justify a
92 /// flush.
93 pub fn should_flush(&self) -> bool {
94 self.policy_state.should_flush(&self.policy)
95 }
96
97 /// Rotate the double-buffer and free any memory the device is done with.
98 ///
99 /// `factory` is called to produce a [`Fence`]. It should submit (or
100 /// record) a device signal command so that syncing the fence guarantees all
101 /// preceding device work is complete.
102 pub fn flush<Factory: Fn() -> F>(&mut self, factory: Factory) {
103 // Sync the fence from the previous flush and free the bytes it was
104 // protecting.
105 if let Some(event) = self.fence.take() {
106 event.sync();
107 self.pending.clear();
108 }
109
110 // Safety net: if pending is somehow still populated (no prior fence),
111 // stall immediately rather than freeing memory the GPU might still
112 // be reading.
113 if !self.pending.is_empty() {
114 let event = factory();
115 event.sync();
116 self.pending.clear();
117 }
118
119 // The current staged batch becomes the new pending batch.
120 core::mem::swap(&mut self.pending, &mut self.staged);
121
122 // Record a fence so the *next* flush knows when this batch is safe to
123 // free.
124 self.fence = Some(factory());
125 self.policy_state.reset();
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use alloc::vec;
133 use core::cell::Cell;
134
135 // ---------------------------------------------------------------------------
136 // Test helpers
137 // ---------------------------------------------------------------------------
138
139 #[derive(Clone)]
140 struct MockFence<'a> {
141 sync_count: &'a Cell<u32>,
142 }
143
144 impl Fence for MockFence<'_> {
145 fn sync(self) {
146 self.sync_count.set(self.sync_count.get() + 1);
147 }
148 }
149
150 fn make_queue<'a>(
151 sync_count: &'a Cell<u32>,
152 ) -> (
153 PendingDropQueue<MockFence<'a>>,
154 impl Fn() -> MockFence<'a> + 'a,
155 ) {
156 let queue = PendingDropQueue::new(test_policy());
157 let factory = move || MockFence { sync_count };
158 (queue, factory)
159 }
160
161 fn sample_bytes() -> Bytes {
162 Bytes::from_elems(vec![1u8, 2, 3])
163 }
164
165 fn test_policy() -> FlushingPolicy {
166 FlushingPolicy {
167 max_bytes_count: 2048,
168 max_bytes_size: 8,
169 }
170 }
171
172 // ---------------------------------------------------------------------------
173 // push / should_flush
174 // ---------------------------------------------------------------------------
175
176 #[test]
177 fn push_at_count_threshold_triggers_flush_hint() {
178 let sync_count = Cell::new(0u32);
179 let (mut queue, _factory) = make_queue(&sync_count);
180
181 for _ in 0..test_policy().max_bytes_count {
182 queue.push(sample_bytes());
183 }
184
185 assert!(queue.should_flush());
186 }
187
188 #[test]
189 fn push_large_allocation_triggers_flush_via_size_threshold() {
190 let sync_count = Cell::new(0u32);
191 let (mut queue, _factory) = make_queue(&sync_count);
192 let big = Bytes::from_elems(vec![0u8; test_policy().max_bytes_size as usize + 1]);
193
194 queue.push(big);
195
196 assert!(queue.should_flush());
197 }
198
199 // ---------------------------------------------------------------------------
200 // flush – fence / sync behaviour
201 // ---------------------------------------------------------------------------
202
203 #[test]
204 fn first_flush_creates_fence_without_syncing() {
205 let sync_count = Cell::new(0u32);
206 let (mut queue, factory) = make_queue(&sync_count);
207
208 queue.push(sample_bytes());
209 queue.flush(&factory);
210
211 // The fence is created but must not be synced yet — that happens on
212 // the next flush.
213 assert_eq!(
214 sync_count.get(),
215 0,
216 "fence should not be synced on first flush"
217 );
218 }
219
220 #[test]
221 fn second_flush_syncs_fence_from_first_flush() {
222 let sync_count = Cell::new(0u32);
223 let (mut queue, factory) = make_queue(&sync_count);
224
225 queue.push(sample_bytes());
226 queue.flush(&factory); // flush 1 – creates fence A
227
228 queue.push(sample_bytes());
229 queue.flush(&factory); // flush 2 – syncs fence A, creates fence B
230
231 assert_eq!(sync_count.get(), 1, "exactly one sync after two flushes");
232 }
233
234 #[test]
235 fn each_subsequent_flush_syncs_the_previous_fence() {
236 let sync_count = Cell::new(0u32);
237 let (mut queue, factory) = make_queue(&sync_count);
238
239 for _ in 0..10 {
240 queue.push(sample_bytes());
241 queue.flush(&factory);
242 }
243
244 // Each flush except the first syncs the fence from the previous one.
245 assert_eq!(sync_count.get(), 9);
246 }
247
248 // ---------------------------------------------------------------------------
249 // flush – buffer rotation
250 // ---------------------------------------------------------------------------
251
252 #[test]
253 fn staged_is_empty_after_flush() {
254 let sync_count = Cell::new(0u32);
255 let (mut queue, factory) = make_queue(&sync_count);
256
257 for _ in 0..5 {
258 queue.push(sample_bytes());
259 }
260 queue.flush(&factory);
261
262 assert!(queue.staged.is_empty());
263 }
264
265 #[test]
266 fn pending_holds_previously_staged_bytes_after_flush() {
267 let sync_count = Cell::new(0u32);
268 let (mut queue, factory) = make_queue(&sync_count);
269
270 for _ in 0..5 {
271 queue.push(sample_bytes());
272 }
273 queue.flush(&factory);
274
275 assert_eq!(queue.pending.len(), 5);
276 }
277
278 #[test]
279 fn pending_is_replaced_on_second_flush() {
280 let sync_count = Cell::new(0u32);
281 let (mut queue, factory) = make_queue(&sync_count);
282
283 for _ in 0..5 {
284 queue.push(sample_bytes());
285 }
286 queue.flush(&factory); // pending = 5 items
287
288 queue.push(sample_bytes());
289 queue.flush(&factory); // syncs fence → pending cleared, rotated
290
291 // Only the one item staged between the two flushes should be pending.
292 assert_eq!(queue.pending.len(), 1);
293 }
294
295 // ---------------------------------------------------------------------------
296 // flush – policy state reset
297 // ---------------------------------------------------------------------------
298
299 #[test]
300 fn should_flush_resets_after_flush() {
301 let sync_count = Cell::new(0u32);
302 let (mut queue, factory) = make_queue(&sync_count);
303
304 for _ in 0..test_policy().max_bytes_count {
305 queue.push(sample_bytes());
306 }
307 assert!(queue.should_flush());
308
309 queue.flush(&factory);
310
311 assert!(
312 !queue.should_flush(),
313 "policy state should be reset after flush"
314 );
315 }
316
317 // ---------------------------------------------------------------------------
318 // Edge cases
319 // ---------------------------------------------------------------------------
320
321 #[test]
322 fn flush_on_empty_queue_is_safe() {
323 let sync_count = Cell::new(0u32);
324 let (mut queue, factory) = make_queue(&sync_count);
325
326 // Should not panic regardless of how many times it is called.
327 queue.flush(&factory);
328 queue.flush(&factory);
329 queue.flush(&factory);
330 }
331}