Skip to main content

dynamo_runtime/utils/
pool.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::VecDeque;
5use std::ops::{Deref, DerefMut};
6use std::sync::Arc;
7use std::sync::{Condvar, Mutex};
8use tokio::sync::Notify;
9
10/// Trait for items that can be returned to a pool
11pub trait Returnable: Send + Sync + 'static {
12    /// Called when an item is returned to the pool
13    fn on_return(&mut self) {}
14}
15
16pub trait ReturnHandle<T: Returnable>: Send + Sync + 'static {
17    fn return_to_pool(&self, value: PoolValue<T>);
18}
19
20/// Enum to hold either a `Box<T>` or `T` directly
21pub enum PoolValue<T: Returnable> {
22    Boxed(Box<T>),
23    Direct(T),
24}
25
26impl<T: Returnable> PoolValue<T> {
27    /// Create a new PoolValue from a boxed item
28    pub fn from_boxed(value: Box<T>) -> Self {
29        PoolValue::Boxed(value)
30    }
31
32    /// Create a new PoolValue from a direct item
33    pub fn from_direct(value: T) -> Self {
34        PoolValue::Direct(value)
35    }
36
37    /// Get a reference to the underlying item
38    pub fn get(&self) -> &T {
39        match self {
40            PoolValue::Boxed(boxed) => boxed.as_ref(),
41            PoolValue::Direct(direct) => direct,
42        }
43    }
44
45    /// Get a mutable reference to the underlying item
46    pub fn get_mut(&mut self) -> &mut T {
47        match self {
48            PoolValue::Boxed(boxed) => boxed.as_mut(),
49            PoolValue::Direct(direct) => direct,
50        }
51    }
52
53    /// Call on_return on the underlying item
54    pub fn on_return(&mut self) {
55        self.get_mut().on_return();
56    }
57}
58
59impl<T: Returnable> Deref for PoolValue<T> {
60    type Target = T;
61
62    fn deref(&self) -> &Self::Target {
63        self.get()
64    }
65}
66
67impl<T: Returnable> DerefMut for PoolValue<T> {
68    fn deref_mut(&mut self) -> &mut Self::Target {
69        self.get_mut()
70    }
71}
72
73// Private module to restrict access to PoolItem constructor
74mod private {
75    // This type can only be constructed within this module
76    #[derive(Clone, Copy)]
77    pub struct PoolItemToken(());
78
79    impl PoolItemToken {
80        pub(super) fn new() -> Self {
81            PoolItemToken(())
82        }
83    }
84}
85
86/// Core trait defining pool operations
87pub trait PoolExt<T: Returnable>: Send + Sync + 'static {
88    /// Create a new PoolItem (only available to implementors)
89    fn create_pool_item(
90        &self,
91        value: PoolValue<T>,
92        handle: Arc<dyn ReturnHandle<T>>,
93    ) -> PoolItem<T> {
94        PoolItem::new(value, handle)
95    }
96}
97
98/// An item borrowed from a pool
99pub struct PoolItem<T: Returnable> {
100    value: Option<PoolValue<T>>,
101    handle: Arc<dyn ReturnHandle<T>>,
102    _token: private::PoolItemToken,
103}
104
105impl<T: Returnable> PoolItem<T> {
106    /// Create a new PoolItem (only available within this module)
107    fn new(value: PoolValue<T>, handle: Arc<dyn ReturnHandle<T>>) -> Self {
108        Self {
109            value: Some(value),
110            handle,
111            _token: private::PoolItemToken::new(),
112        }
113    }
114
115    /// Convert this unique PoolItem into a shared reference
116    pub fn into_shared(self) -> SharedPoolItem<T> {
117        SharedPoolItem {
118            inner: Arc::new(self),
119        }
120    }
121
122    /// Check if this item still contains a value
123    pub fn has_value(&self) -> bool {
124        self.value.is_some()
125    }
126}
127
128impl<T: Returnable> Deref for PoolItem<T> {
129    type Target = T;
130
131    fn deref(&self) -> &Self::Target {
132        self.value.as_ref().unwrap().get()
133    }
134}
135
136impl<T: Returnable> DerefMut for PoolItem<T> {
137    fn deref_mut(&mut self) -> &mut Self::Target {
138        self.value.as_mut().unwrap().get_mut()
139    }
140}
141
142impl<T: Returnable> Drop for PoolItem<T> {
143    fn drop(&mut self) {
144        if let Some(mut value) = self.value.take() {
145            value.on_return();
146            // Use blocking version for drop
147            self.handle.return_to_pool(value);
148        }
149    }
150}
151
152/// A shared reference to a pooled item
153pub struct SharedPoolItem<T: Returnable> {
154    inner: Arc<PoolItem<T>>,
155}
156
157impl<T: Returnable> Clone for SharedPoolItem<T> {
158    fn clone(&self) -> Self {
159        Self {
160            inner: self.inner.clone(),
161        }
162    }
163}
164
165impl<T: Returnable> SharedPoolItem<T> {
166    /// Get a reference to the underlying item
167    pub fn get(&self) -> &T {
168        self.inner.value.as_ref().unwrap().get()
169    }
170
171    pub fn strong_count(&self) -> usize {
172        Arc::strong_count(&self.inner)
173    }
174}
175
176impl<T: Returnable> Deref for SharedPoolItem<T> {
177    type Target = T;
178
179    fn deref(&self) -> &Self::Target {
180        self.inner.value.as_ref().unwrap().get()
181    }
182}
183
184/// Standard pool implementation
185pub struct Pool<T: Returnable> {
186    state: Arc<PoolState<T>>,
187    capacity: usize,
188}
189
190struct PoolState<T: Returnable> {
191    pool: Arc<Mutex<VecDeque<PoolValue<T>>>>,
192    available: Arc<Notify>,
193}
194
195impl<T: Returnable> ReturnHandle<T> for PoolState<T> {
196    fn return_to_pool(&self, value: PoolValue<T>) {
197        let mut pool = self.pool.lock().unwrap();
198        pool.push_back(value);
199        self.available.notify_one();
200    }
201}
202
203impl<T: Returnable> Pool<T> {
204    /// Create a new pool with the given initial elements
205    pub fn new(initial_elements: Vec<PoolValue<T>>) -> Self {
206        let capacity = initial_elements.len();
207        let pool = initial_elements
208            .into_iter()
209            .collect::<VecDeque<PoolValue<T>>>();
210
211        let state = Arc::new(PoolState {
212            pool: Arc::new(Mutex::new(pool)),
213            available: Arc::new(Notify::new()),
214        });
215
216        Self { state, capacity }
217    }
218
219    /// Create a new pool with initial boxed elements
220    pub fn new_boxed(initial_elements: Vec<Box<T>>) -> Self {
221        let initial_values = initial_elements
222            .into_iter()
223            .map(PoolValue::from_boxed)
224            .collect();
225        Self::new(initial_values)
226    }
227
228    /// Create a new pool with initial direct elements
229    pub fn new_direct(initial_elements: Vec<T>) -> Self {
230        let initial_values = initial_elements
231            .into_iter()
232            .map(PoolValue::from_direct)
233            .collect();
234        Self::new(initial_values)
235    }
236
237    async fn try_acquire(&self) -> Option<PoolItem<T>> {
238        let mut pool = self.state.pool.lock().unwrap();
239        pool.pop_front()
240            .map(|value| PoolItem::new(value, self.state.clone()))
241    }
242
243    async fn acquire(&self) -> PoolItem<T> {
244        loop {
245            if let Some(guard) = self.try_acquire().await {
246                return guard;
247            }
248            self.state.available.notified().await;
249        }
250    }
251
252    fn notify_return(&self) {
253        self.state.available.notify_one();
254    }
255
256    fn capacity(&self) -> usize {
257        self.capacity
258    }
259}
260
261impl<T: Returnable> PoolExt<T> for Pool<T> {}
262
263impl<T: Returnable> Clone for Pool<T> {
264    fn clone(&self) -> Self {
265        Self {
266            state: self.state.clone(),
267            capacity: self.capacity,
268        }
269    }
270}
271
272pub struct SyncPool<T: Returnable> {
273    state: Arc<SyncPoolState<T>>,
274    capacity: usize,
275}
276
277struct SyncPoolState<T: Returnable> {
278    pool: Mutex<VecDeque<PoolValue<T>>>,
279    available: Condvar,
280}
281
282impl<T: Returnable> SyncPool<T> {
283    pub fn new(initial_elements: Vec<PoolValue<T>>) -> Self {
284        let capacity = initial_elements.len();
285        let pool = initial_elements
286            .into_iter()
287            .collect::<VecDeque<PoolValue<T>>>();
288
289        let state = Arc::new(SyncPoolState {
290            pool: Mutex::new(pool),
291            available: Condvar::new(),
292        });
293
294        Self { state, capacity }
295    }
296
297    pub fn new_direct(initial_elements: Vec<T>) -> Self {
298        let initial_values = initial_elements
299            .into_iter()
300            .map(PoolValue::from_direct)
301            .collect();
302        Self::new(initial_values)
303    }
304
305    pub fn try_acquire(&self) -> Option<SyncPoolItem<T>> {
306        let mut pool = self.state.pool.lock().unwrap();
307        pool.pop_front()
308            .map(|value| SyncPoolItem::new(value, self.state.clone()))
309    }
310
311    pub fn acquire_blocking(&self) -> SyncPoolItem<T> {
312        let mut pool = self.state.pool.lock().unwrap();
313
314        while pool.is_empty() {
315            tracing::debug!("SyncPool: waiting for available resource (pool empty)");
316            pool = self.state.available.wait(pool).unwrap();
317            tracing::debug!(
318                "SyncPool: woke up, checking pool again (size: {})",
319                pool.len()
320            );
321        }
322
323        let value = pool.pop_front().unwrap();
324        tracing::debug!("SyncPool: acquired resource, pool size now: {}", pool.len());
325        SyncPoolItem::new(value, self.state.clone())
326    }
327
328    pub fn capacity(&self) -> usize {
329        self.capacity
330    }
331}
332
333impl<T: Returnable> Clone for SyncPool<T> {
334    fn clone(&self) -> Self {
335        Self {
336            state: self.state.clone(),
337            capacity: self.capacity,
338        }
339    }
340}
341
342pub struct SyncPoolItem<T: Returnable> {
343    value: Option<PoolValue<T>>,
344    state: Arc<SyncPoolState<T>>,
345}
346
347impl<T: Returnable> SyncPoolItem<T> {
348    fn new(value: PoolValue<T>, state: Arc<SyncPoolState<T>>) -> Self {
349        Self {
350            value: Some(value),
351            state,
352        }
353    }
354}
355
356impl<T: Returnable> Deref for SyncPoolItem<T> {
357    type Target = T;
358
359    fn deref(&self) -> &Self::Target {
360        self.value.as_ref().unwrap().get()
361    }
362}
363
364impl<T: Returnable> DerefMut for SyncPoolItem<T> {
365    fn deref_mut(&mut self) -> &mut Self::Target {
366        self.value.as_mut().unwrap().get_mut()
367    }
368}
369
370impl<T: Returnable> Drop for SyncPoolItem<T> {
371    fn drop(&mut self) {
372        if let Some(mut value) = self.value.take() {
373            value.on_return();
374
375            let mut pool = self.state.pool.lock().unwrap();
376            pool.push_back(value);
377            tracing::debug!(
378                "SyncPool: returned resource, pool size now: {}, notifying waiters",
379                pool.len()
380            );
381
382            self.state.available.notify_one();
383        }
384    }
385}
386#[cfg(test)]
387mod tests {
388    use super::*;
389    use std::sync::atomic::{AtomicUsize, Ordering};
390    use std::thread;
391    use tokio::time::{Duration, timeout};
392
393    // Implement Returnable for u32 just for testing
394    impl Returnable for u32 {
395        fn on_return(&mut self) {
396            *self = 0;
397            tracing::debug!("Resetting u32 to 0");
398        }
399    }
400
401    #[tokio::test]
402    async fn test_acquire_release() {
403        let initial_elements = vec![
404            PoolValue::Direct(1),
405            PoolValue::Direct(2),
406            PoolValue::Direct(3),
407            PoolValue::Direct(4),
408            PoolValue::Direct(5),
409        ];
410        let pool = Pool::new(initial_elements);
411
412        // Acquire an element from the pool
413        if let Some(mut item) = pool.try_acquire().await {
414            assert_eq!(*item, 1); // It should be the first element we put in
415
416            // Modify the value
417            *item += 10;
418            assert_eq!(*item, 11);
419
420            // The item will be dropped at the end of this scope,
421            // and the value will be returned to the pool
422        }
423
424        // Acquire all remaining elements and the one we returned
425        let mut values = Vec::new();
426        let mut items = Vec::new();
427        while let Some(item) = pool.try_acquire().await {
428            values.push(*item);
429            items.push(item);
430        }
431
432        // The last element in `values` should be the one we returned, and it should be on_return to 0
433        assert_eq!(values, vec![2, 3, 4, 5, 0]);
434
435        // Test the awaitable acquire
436        let pool_clone = pool.clone();
437        let task = tokio::spawn(async move {
438            let first_acquired = pool_clone.acquire().await;
439            assert_eq!(*first_acquired, 0);
440        });
441
442        timeout(Duration::from_secs(1), task)
443            .await
444            .expect_err("Expected timeout");
445
446        // Drop the guards to return the PoolItems to the pool.
447        items.clear();
448
449        let pool_clone = pool.clone();
450        let task = tokio::spawn(async move {
451            let first_acquired = pool_clone.acquire().await;
452            assert_eq!(*first_acquired, 0);
453        });
454
455        // Now the task should be able to finish.
456        timeout(Duration::from_secs(1), task)
457            .await
458            .expect("Task did not complete in time")
459            .unwrap();
460    }
461
462    #[tokio::test]
463    async fn test_shared_items() {
464        let initial_elements = vec![
465            PoolValue::Direct(1),
466            // PoolValue::Direct(2),
467            // PoolValue::Direct(3),
468        ];
469        let pool = Pool::new(initial_elements);
470
471        // Acquire and convert to shared
472        let mut item = pool.acquire().await;
473        *item += 10; // Modify before sharing
474        let shared = item.into_shared();
475        assert_eq!(*shared, 11);
476
477        // Create a clone of the shared item
478        let shared_clone = shared.clone();
479        assert_eq!(*shared_clone, 11);
480
481        // Drop the original shared item
482        drop(shared);
483
484        // Clone should still be valid
485        assert_eq!(*shared_clone, 11);
486
487        // Drop the clone
488        drop(shared_clone);
489
490        // Now we should be able to acquire the item again
491        let item = pool.acquire().await;
492        assert_eq!(*item, 0); // Value should be on_return
493    }
494
495    #[tokio::test]
496    async fn test_boxed_values() {
497        let initial_elements = vec![
498            PoolValue::Boxed(Box::new(1)),
499            // PoolValue::Boxed(Box::new(2)),
500            // PoolValue::Boxed(Box::new(3)),
501        ];
502        let pool = Pool::new(initial_elements);
503
504        // Acquire an element from the pool
505        let mut item = pool.acquire().await;
506        assert_eq!(*item, 1);
507
508        // Modify and return to pool
509        *item += 10;
510        drop(item);
511
512        // Should get on_return value when acquired again
513        let item = pool.acquire().await;
514        assert_eq!(*item, 0);
515    }
516
517    #[tokio::test]
518    async fn test_pool_item_creation() {
519        let pool = Pool::new(vec![PoolValue::Direct(1)]);
520
521        // This works - acquiring from the pool
522        let item = pool.acquire().await;
523        assert_eq!(*item, 1);
524
525        // This would not compile - can't create PoolItem directly
526        // let invalid_item = PoolItem {
527        //     value: Some(PoolValue::Direct(2)),
528        //     pool: pool.clone(),
529        //     _token: /* can't create this */
530        // };
531    }
532
533    #[test]
534    fn test_sync_pool_basic_acquire_release() {
535        let initial_elements = vec![1u32, 2, 3];
536        let pool = SyncPool::new_direct(initial_elements);
537
538        // Try acquire (non-blocking)
539        let item1 = pool.try_acquire().unwrap();
540        assert_eq!(*item1, 1);
541
542        let item2 = pool.try_acquire().unwrap();
543        assert_eq!(*item2, 2);
544
545        // Pool should have one item left
546        let item3 = pool.try_acquire().unwrap();
547        assert_eq!(*item3, 3);
548
549        // Pool should be empty now
550        assert!(pool.try_acquire().is_none());
551
552        // Drop items to return to pool
553        drop(item1); // Returns 0 (after on_return)
554        drop(item2); // Returns 0 (after on_return)
555        drop(item3); // Returns 0 (after on_return)
556
557        // Should be able to acquire again
558        let item = pool.try_acquire().unwrap();
559        assert_eq!(*item, 0); // Value was reset by on_return
560    }
561
562    #[test]
563    fn test_sync_pool_blocking_acquire() {
564        let pool = SyncPool::new_direct(vec![42u32]);
565
566        // Acquire the only item
567        let item = pool.acquire_blocking();
568        assert_eq!(*item, 42);
569
570        let pool_clone = pool.clone();
571        let counter = Arc::new(AtomicUsize::new(0));
572        let counter_clone = counter.clone();
573
574        // Spawn a thread that will wait for the item
575        let handle = thread::spawn(move || {
576            counter_clone.store(1, Ordering::SeqCst); // Mark that we're waiting
577            let waiting_item = pool_clone.acquire_blocking(); // This will block
578            counter_clone.store(2, Ordering::SeqCst); // Mark that we got it
579            assert_eq!(*waiting_item, 0); // Should be reset value
580        });
581
582        // Give the thread time to start waiting
583        thread::sleep(Duration::from_millis(10));
584        assert_eq!(counter.load(Ordering::SeqCst), 1); // Should be waiting
585
586        // Drop the item to trigger condvar notification
587        drop(item);
588
589        // Wait for the other thread to complete
590        handle.join().unwrap();
591        assert_eq!(counter.load(Ordering::SeqCst), 2); // Should have completed
592    }
593
594    #[test]
595    fn test_sync_pool_multiple_waiters() {
596        let pool = SyncPool::new_direct(vec![1u32]);
597
598        // Acquire the only item
599        let item = pool.acquire_blocking();
600
601        let pool_clone1 = pool.clone();
602        let pool_clone2 = pool.clone();
603        let completed = Arc::new(AtomicUsize::new(0));
604        let completed1 = completed.clone();
605        let completed2 = completed.clone();
606
607        // Spawn two threads that will wait
608        let handle1 = thread::spawn(move || {
609            let _item = pool_clone1.acquire_blocking(); // Will block
610            completed1.fetch_add(1, Ordering::SeqCst); // Mark completion
611            // Item drops here, potentially waking thread 2
612        });
613
614        let handle2 = thread::spawn(move || {
615            let _item = pool_clone2.acquire_blocking(); // Will block
616            completed2.fetch_add(1, Ordering::SeqCst); // Mark completion
617            // Item drops here
618        });
619
620        // Give threads time to start waiting
621        thread::sleep(Duration::from_millis(50));
622        assert_eq!(completed.load(Ordering::SeqCst), 0); // Both should be waiting
623
624        // Drop the item - should wake exactly one thread
625        drop(item);
626
627        // Wait for both threads to complete
628        handle1.join().unwrap();
629        handle2.join().unwrap();
630
631        // Both threads should have completed eventually
632        assert_eq!(completed.load(Ordering::SeqCst), 2);
633    }
634
635    #[test]
636    fn test_sync_vs_async_pool_compatibility() {
637        // Test that both pool types work with the same Returnable type
638        let async_pool = Pool::new_direct(vec![1u32, 2u32]);
639        let sync_pool = SyncPool::new_direct(vec![3u32, 4u32]);
640
641        // Both should work
642        let async_rt = tokio::runtime::Runtime::new().unwrap();
643        let async_item = async_rt.block_on(async { async_pool.acquire().await });
644        assert_eq!(*async_item, 1);
645
646        let sync_item = sync_pool.acquire_blocking();
647        assert_eq!(*sync_item, 3);
648
649        // Both use the same Returnable trait
650        drop(async_item); // Should reset to 0
651        drop(sync_item); // Should reset to 0
652    }
653
654    #[test]
655    fn test_sync_pool_condvar_performance() {
656        let pool = SyncPool::new_direct((0..10).collect::<Vec<u32>>());
657        let start = std::time::Instant::now();
658
659        // Rapid acquire/release cycles
660        for _ in 0..1000 {
661            let item = pool.acquire_blocking();
662            // Simulate minimal work
663            let _ = *item + 1;
664            drop(item); // Return to pool
665        }
666
667        let duration = start.elapsed();
668        println!("1000 sync pool operations took {:?}", duration);
669
670        // Should be fast (< 10ms on most systems)
671        // Update(grahamk): Takes 144ms on my box which is much faster than CI, so something
672        // is odd about claim above.
673        assert!(duration < Duration::from_millis(200));
674    }
675}