ragc_core/
memory_bounded_queue.rs1use std::collections::BinaryHeap;
5use std::sync::{Arc, Condvar, Mutex};
6
7pub struct MemoryBoundedQueue<T: Ord> {
17 inner: Arc<Mutex<QueueInner<T>>>,
18 capacity_bytes: usize,
19 not_full: Arc<Condvar>,
20 not_empty: Arc<Condvar>,
21}
22
23#[derive(Debug)]
25struct PriorityItem<T: Ord> {
26 item: T,
27 size: usize,
28}
29
30impl<T: Ord> Ord for PriorityItem<T> {
31 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
32 self.item.cmp(&other.item)
33 }
34}
35
36impl<T: Ord> PartialOrd for PriorityItem<T> {
37 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
38 Some(self.cmp(other))
39 }
40}
41
42impl<T: Ord> PartialEq for PriorityItem<T> {
43 fn eq(&self, other: &Self) -> bool {
44 self.item == other.item
45 }
46}
47
48impl<T: Ord> Eq for PriorityItem<T> {}
49
50struct QueueInner<T: Ord> {
51 items: BinaryHeap<PriorityItem<T>>, current_size: usize, closed: bool, }
55
56impl<T: Ord> MemoryBoundedQueue<T> {
57 pub fn new(capacity_bytes: usize) -> Self {
69 Self {
70 inner: Arc::new(Mutex::new(QueueInner {
71 items: BinaryHeap::new(),
72 current_size: 0,
73 closed: false,
74 })),
75 capacity_bytes,
76 not_full: Arc::new(Condvar::new()),
77 not_empty: Arc::new(Condvar::new()),
78 }
79 }
80
81 pub fn push(&self, item: T, size_bytes: usize) -> Result<(), PushError> {
98 let mut inner = self.inner.lock().unwrap();
99
100 while inner.current_size + size_bytes > self.capacity_bytes && !inner.closed {
102 inner = self.not_full.wait(inner).unwrap();
103 }
104
105 if inner.closed {
107 return Err(PushError::Closed);
108 }
109
110 inner.items.push(PriorityItem {
112 item,
113 size: size_bytes,
114 });
115 inner.current_size += size_bytes;
116
117 self.not_empty.notify_one();
119
120 Ok(())
121 }
122
123 pub fn try_push(&self, item: T, size_bytes: usize) -> Result<(), TryPushError> {
127 let mut inner = self.inner.lock().unwrap();
128
129 if inner.closed {
130 return Err(TryPushError::Closed);
131 }
132
133 if inner.current_size + size_bytes > self.capacity_bytes {
134 return Err(TryPushError::WouldBlock);
135 }
136
137 inner.items.push(PriorityItem {
139 item,
140 size: size_bytes,
141 });
142 inner.current_size += size_bytes;
143
144 self.not_empty.notify_one();
146
147 Ok(())
148 }
149
150 pub fn pull(&self) -> Option<T> {
167 let mut inner = self.inner.lock().unwrap();
168
169 while inner.items.is_empty() && !inner.closed {
171 inner = self.not_empty.wait(inner).unwrap();
172 }
173
174 if inner.items.is_empty() {
176 return None;
177 }
178
179 let priority_item = inner.items.pop().unwrap();
181 inner.current_size -= priority_item.size;
182
183 self.not_full.notify_one();
185
186 Some(priority_item.item)
187 }
188
189 pub fn try_pull(&self) -> Option<T> {
193 let mut inner = self.inner.lock().unwrap();
194
195 if inner.items.is_empty() {
196 return None;
197 }
198
199 let priority_item = inner.items.pop().unwrap();
201 inner.current_size -= priority_item.size;
202
203 self.not_full.notify_one();
205
206 Some(priority_item.item)
207 }
208
209 pub fn close(&self) {
216 let mut inner = self.inner.lock().unwrap();
217 inner.closed = true;
218
219 self.not_full.notify_all();
221 self.not_empty.notify_all();
222 }
223
224 pub fn is_closed(&self) -> bool {
226 self.inner.lock().unwrap().closed
227 }
228
229 pub fn current_size(&self) -> usize {
231 self.inner.lock().unwrap().current_size
232 }
233
234 pub fn len(&self) -> usize {
236 self.inner.lock().unwrap().items.len()
237 }
238
239 pub fn is_empty(&self) -> bool {
241 self.inner.lock().unwrap().items.is_empty()
242 }
243
244 pub fn capacity(&self) -> usize {
246 self.capacity_bytes
247 }
248}
249
250impl<T: Ord> Clone for MemoryBoundedQueue<T> {
252 fn clone(&self) -> Self {
253 Self {
254 inner: Arc::clone(&self.inner),
255 capacity_bytes: self.capacity_bytes,
256 not_full: Arc::clone(&self.not_full),
257 not_empty: Arc::clone(&self.not_empty),
258 }
259 }
260}
261
262#[derive(Debug, Clone, Copy, PartialEq, Eq)]
263pub enum PushError {
264 Closed,
265}
266
267impl std::fmt::Display for PushError {
268 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269 match self {
270 PushError::Closed => write!(f, "Queue is closed"),
271 }
272 }
273}
274
275impl std::error::Error for PushError {}
276
277#[derive(Debug, Clone, Copy, PartialEq, Eq)]
278pub enum TryPushError {
279 Closed,
280 WouldBlock,
281}
282
283impl std::fmt::Display for TryPushError {
284 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285 match self {
286 TryPushError::Closed => write!(f, "Queue is closed"),
287 TryPushError::WouldBlock => write!(f, "Queue is full - would block"),
288 }
289 }
290}
291
292impl std::error::Error for TryPushError {}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use std::sync::atomic::{AtomicBool, Ordering};
298 use std::thread;
299 use std::time::Duration;
300
301 #[test]
302 fn test_basic_push_pull() {
303 let queue: MemoryBoundedQueue<Vec<u8>> = MemoryBoundedQueue::new(1024);
304
305 let data = vec![0u8; 100];
307 queue.push(data.clone(), 100).unwrap();
308
309 let pulled = queue.pull().unwrap();
311 assert_eq!(pulled, data);
312 }
313
314 #[test]
315 fn test_backpressure() {
316 let queue: MemoryBoundedQueue<Vec<u8>> = MemoryBoundedQueue::new(1024);
317
318 queue.push(vec![0u8; 512], 512).unwrap();
320 queue.push(vec![0u8; 512], 512).unwrap();
321
322 let blocked = Arc::new(AtomicBool::new(false));
324 let blocked_clone = Arc::clone(&blocked);
325 let queue_clone = queue.clone();
326
327 let handle = thread::spawn(move || {
328 blocked_clone.store(true, Ordering::SeqCst);
329 queue_clone.push(vec![0u8; 100], 100).unwrap();
330 blocked_clone.store(false, Ordering::SeqCst);
331 });
332
333 thread::sleep(Duration::from_millis(100));
335 assert!(blocked.load(Ordering::SeqCst), "Push should be blocked!");
336
337 queue.pull().unwrap();
339
340 handle.join().unwrap();
342 assert!(
343 !blocked.load(Ordering::SeqCst),
344 "Push should have completed!"
345 );
346 }
347
348 #[test]
349 fn test_close_queue() {
350 let queue: MemoryBoundedQueue<Vec<u8>> = MemoryBoundedQueue::new(1024);
351
352 queue.push(vec![0u8; 100], 100).unwrap();
354 queue.push(vec![0u8; 100], 100).unwrap();
355
356 queue.close();
358
359 assert!(queue.push(vec![0u8; 100], 100).is_err());
361
362 assert!(queue.pull().is_some());
364 assert!(queue.pull().is_some());
365
366 assert!(queue.pull().is_none());
368 }
369
370 #[test]
371 fn test_try_operations() {
372 let queue: MemoryBoundedQueue<Vec<u8>> = MemoryBoundedQueue::new(100);
373
374 assert!(queue.try_push(vec![0u8; 50], 50).is_ok());
376
377 assert_eq!(
379 queue.try_push(vec![0u8; 60], 60),
380 Err(TryPushError::WouldBlock)
381 );
382
383 assert!(queue.try_pull().is_some());
385
386 assert!(queue.try_pull().is_none());
388 }
389
390 #[test]
391 fn test_multiple_producers_consumers() {
392 let queue: MemoryBoundedQueue<usize> = MemoryBoundedQueue::new(1000);
393
394 let mut producers = vec![];
396 for i in 0..3 {
397 let q = queue.clone();
398 producers.push(thread::spawn(move || {
399 for j in 0..100 {
400 q.push(i * 100 + j, 10).unwrap();
401 }
402 }));
403 }
404
405 let mut consumers = vec![];
407 for _ in 0..2 {
408 let q = queue.clone();
409 consumers.push(thread::spawn(move || {
410 let mut count = 0;
411 while let Some(_) = q.pull() {
412 count += 1;
413 if count == 150 {
414 break; }
416 }
417 count
418 }));
419 }
420
421 for p in producers {
423 p.join().unwrap();
424 }
425
426 queue.close();
428
429 let mut total = 0;
431 for c in consumers {
432 total += c.join().unwrap();
433 }
434
435 assert_eq!(total, 300);
437 }
438
439 #[test]
440 fn test_size_tracking() {
441 let queue: MemoryBoundedQueue<Vec<u8>> = MemoryBoundedQueue::new(1024);
442
443 assert_eq!(queue.current_size(), 0);
444
445 queue.push(vec![0u8; 100], 100).unwrap();
446 assert_eq!(queue.current_size(), 100);
447
448 queue.push(vec![0u8; 200], 200).unwrap();
449 assert_eq!(queue.current_size(), 300);
450
451 queue.pull().unwrap();
454 assert_eq!(queue.current_size(), 100);
455
456 queue.pull().unwrap();
457 assert_eq!(queue.current_size(), 0);
458 }
459}