Skip to main content

compio_driver/
buffer_pool.rs

1use std::{
2    cell::UnsafeCell,
3    fmt::Debug,
4    io,
5    mem::{self, MaybeUninit},
6    ops::{Deref, DerefMut},
7    ptr::{self, NonNull},
8    rc::{Rc, Weak},
9    slice,
10};
11
12use compio_buf::{IoBuf, IoBufMut, SetLen};
13
14use crate::sys::BufControl;
15
16/// Trait used to allocate buffers for compio-driver's buffer pool.
17///
18/// Default implementation is [`BoxAllocator`], which uses [`Box`] to allocate
19/// and deallocate each buffer.
20pub trait BufferAllocator {
21    /// Allocate a chunk of memory with `len`.
22    fn allocate(len: u32) -> NonNull<MaybeUninit<u8>>;
23
24    /// Deallocate a chunk of memory.
25    ///
26    /// # Safety
27    ///
28    /// The pointer passed in must be previously allocated by this allocator.
29    unsafe fn deallocate(ptr: NonNull<MaybeUninit<u8>>, len: u32);
30}
31
32/// Default implementation of [`BufferAllocator`].
33pub struct BoxAllocator;
34
35// Default implementation of [`BufferAllocator`]
36impl BufferAllocator for BoxAllocator {
37    fn allocate(len: u32) -> NonNull<MaybeUninit<u8>> {
38        let ptr = Box::into_raw(Box::<[u8]>::new_uninit_slice(len as usize)).cast();
39        // SAFETY: Creating `NonNull` from `Box`
40        unsafe { NonNull::new_unchecked(ptr) }
41    }
42
43    unsafe fn deallocate(ptr: NonNull<MaybeUninit<u8>>, len: u32) {
44        let ptr = ptr::slice_from_raw_parts_mut(ptr.as_ptr(), len as usize);
45        // SAFETY: Caller guarantees the pointer was allocated by us with `len`
46        _ = unsafe { Box::from_raw(ptr) };
47    }
48}
49
50#[derive(Debug, Clone, Copy)]
51pub(crate) struct BufferAlloc {
52    allocate: fn(len: u32) -> NonNull<MaybeUninit<u8>>,
53    deallocate: unsafe fn(ptr: NonNull<MaybeUninit<u8>>, len: u32),
54}
55
56impl BufferAlloc {
57    pub fn new<A: BufferAllocator>() -> Self {
58        Self {
59            allocate: A::allocate,
60            deallocate: A::deallocate,
61        }
62    }
63}
64
65/// A buffer pointer without length part.
66pub(crate) type BufPtr = NonNull<MaybeUninit<u8>>;
67/// A buffer slot. It's always 1-pointer sized thanks to niche optimization.
68pub(crate) type Slot = Option<BufPtr>;
69
70const _: () = assert!(size_of::<Slot>() == size_of::<usize>());
71
72/// A buffer pool.
73///
74/// This type by itself does nothing, and should only be used by `*Managed` ops.
75#[derive(Clone)]
76pub struct BufferPool {
77    shared: Weak<Shared>,
78}
79
80#[repr(transparent)]
81#[derive(Debug)]
82pub(crate) struct BufferPoolRoot {
83    shared: Rc<Shared>,
84}
85
86/// A unique reference to a buffer within the buffer pool.
87///
88/// Dropping this type will reset the buffer back to the pool instead of
89/// releasing buffer's memory.
90#[derive(Debug)]
91pub struct BufferRef {
92    /// Allocator to deallocate the buffer in case the driver is dropped.
93    alloc: BufferAlloc,
94    /// Initialized length of the buffer, set with [`SetLen`]
95    len: u32,
96    /// User-set capacity, default to `full_cap`
97    cap: u32,
98    /// Full capacity of the buffer, used to release memory if driver (buffer
99    /// pool) is dropped
100    full_cap: u32,
101    /// Weak handle of the buffer pool
102    shared: Weak<Shared>,
103    /// Pointer of the buffer
104    ptr: BufPtr,
105    /// Buffer id (index within the Vec)
106    buffer_id: u16,
107}
108
109#[repr(transparent)]
110struct Shared {
111    inner: UnsafeCell<Inner>,
112}
113
114struct Inner {
115    /// Allocator of the buffers
116    alloc: BufferAlloc,
117
118    /// Control block corresponds to each driver
119    ctrl: BufControl,
120
121    /// Size of each buffer
122    size: u32,
123
124    /// Buffer pointers
125    bufs: Vec<Slot>,
126}
127
128impl BufferPoolRoot {
129    pub(crate) fn new(
130        driver: &mut crate::Driver,
131        alloc: BufferAlloc,
132        num_of_bufs: u16,
133        buffer_size: usize,
134        flags: u16,
135    ) -> io::Result<Self> {
136        let size: u32 = buffer_size.try_into().map_err(|_| {
137            io::Error::new(
138                io::ErrorKind::InvalidInput,
139                "Buffer size too large. Should be able to fit into u32.",
140            )
141        })?;
142        let bufs = (0..num_of_bufs.next_power_of_two())
143            .map(|_| Some((alloc.allocate)(size)))
144            .collect::<Vec<_>>();
145        let ctrl = unsafe { BufControl::new(driver, &bufs, size, flags) }?;
146
147        Ok(Self {
148            shared: Shared {
149                inner: Inner {
150                    alloc,
151                    ctrl,
152                    size,
153                    bufs,
154                }
155                .into(),
156            }
157            .into(),
158        })
159    }
160
161    /// Release the buffer pool and deallocate all buffers.
162    ///
163    /// If the buffer pool root is dropped without calling this function,
164    /// everything will be leaked and there will be no chance to recover them
165    /// back, except those have been taken by [`BufferRef`], which will be
166    /// released when they're dropped.
167    ///
168    /// If the control block failed to release, this function will return an io
169    /// Error without deallocating buffers, and it's possible to retry.
170    ///
171    /// # Safety
172    ///
173    /// [`BufferPoolRoot`] must not be used after `release` is called and
174    /// returned successfully. Only thing that's safe to do afterwards is to
175    /// drop it.
176    pub(crate) unsafe fn release(&mut self, driver: &mut crate::Driver) -> io::Result<()> {
177        unsafe {
178            self.shared.with(|inner| {
179                inner.ctrl.release(driver)?;
180                for buf in mem::take(&mut inner.bufs).into_iter().flatten() {
181                    // Control is successfully released, now deallocate buffers
182                    (inner.alloc.deallocate)(buf, inner.size)
183                }
184                io::Result::Ok(())
185            })
186        }?;
187
188        Ok(())
189    }
190
191    pub(crate) fn get_pool(&self) -> BufferPool {
192        BufferPool {
193            shared: Rc::downgrade(&self.shared),
194        }
195    }
196
197    pub(crate) fn is_unique(&self) -> bool {
198        Rc::strong_count(&self.shared) == 1
199    }
200}
201
202impl Debug for BufferPool {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        if let Some(shared) = self.shared.upgrade() {
205            f.debug_struct("BufferPool")
206                .field("shared", &shared)
207                .finish()
208        } else {
209            f.debug_struct("BufferPool")
210                .field("shared", &"<dropped>")
211                .finish()
212        }
213    }
214}
215
216impl Debug for Shared {
217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218        struct Buf {
219            ptr: BufPtr,
220        }
221
222        impl Debug for Buf {
223            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
224                write!(f, "Buf<{:p}>", self.ptr)
225            }
226        }
227
228        struct BuffersDebug<'a> {
229            buffers: &'a [Slot],
230        }
231
232        impl Debug for BuffersDebug<'_> {
233            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234                f.debug_list()
235                    .entries(self.buffers.iter().map(|buf| buf.map(|ptr| Buf { ptr })))
236                    .finish()
237            }
238        }
239
240        unsafe {
241            self.with(|inner| {
242                let buffers = BuffersDebug {
243                    buffers: &inner.bufs,
244                };
245                f.debug_struct("Shared")
246                    .field("control", &inner.ctrl)
247                    .field("size", &inner.size)
248                    .field("buffers", &buffers)
249                    .finish()
250            })
251        }
252    }
253}
254
255impl BufferPool {
256    /// Pop an available buffer from the pool with given capacity.
257    ///
258    /// This operation is not supported on io-uring driver and will always
259    /// return [`Unsupported`].
260    ///
261    /// [`Unsupported`]: io::ErrorKind::Unsupported
262    pub fn pop(&self) -> io::Result<BufferRef> {
263        let buffer_id = unsafe { self.with(|inner| inner.ctrl.pop()) }??;
264
265        Ok(self.take(buffer_id)?.expect("Buffer should be available"))
266    }
267
268    /// Take the indicated buffer from the pool.
269    ///
270    /// Returns `None` if the buffer is not reset back yet or does not exist.
271    pub fn take(&self, buffer_id: u16) -> io::Result<Option<BufferRef>> {
272        let shared = self.shared()?;
273        let Some(ptr) = shared.take(buffer_id) else {
274            return Ok(None);
275        };
276        let cap = shared.len();
277
278        Ok(Some(BufferRef {
279            alloc: shared.alloc(),
280            len: 0,
281            cap,
282            full_cap: cap,
283            shared: Rc::downgrade(&shared),
284            ptr,
285            buffer_id,
286        }))
287    }
288
289    /// Reset the `buffer_id` so that it's available for kernel to use, return
290    /// whether a buffer has been reset.
291    ///
292    /// This is the same as `take(buffer_id)` and immediately drop it.
293    pub fn reset(&self, buffer_id: u16) -> io::Result<bool> {
294        let shared = self.shared()?;
295        let Some(buf) = shared.take(buffer_id) else {
296            return Ok(false);
297        };
298        shared.reset(buffer_id, buf);
299        Ok(true)
300    }
301
302    fn shared(&self) -> io::Result<Rc<Shared>> {
303        self.shared
304            .upgrade()
305            .ok_or_else(|| io::Error::other("The driver has been dropped"))
306    }
307
308    /// # Safety
309    ///
310    /// `f` must not access `self` reentrantly
311    unsafe fn with<F, R>(&self, f: F) -> io::Result<R>
312    where
313        F: FnOnce(&mut Inner) -> R,
314    {
315        Ok(unsafe { self.shared()?.with(f) })
316    }
317
318    /// Get the group id of this buffer pool.
319    #[cfg(io_uring)]
320    pub(crate) fn buffer_group(&self) -> io::Result<u16> {
321        unsafe { self.with(|i| i.ctrl.buffer_group()) }
322    }
323
324    /// Test if the buffer pool is an io_uring one.
325    #[cfg(fusion)]
326    pub fn is_io_uring(&self) -> io::Result<bool> {
327        unsafe { self.with(|inner| inner.ctrl.is_io_uring()) }
328    }
329}
330
331impl Shared {
332    /// # Safety
333    ///
334    /// `f` must not access [`Self::inner`] reentrantly
335    #[inline(always)]
336    unsafe fn with<F, R>(&self, f: F) -> R
337    where
338        F: FnOnce(&mut Inner) -> R,
339    {
340        f(unsafe { &mut *self.inner.get() })
341    }
342
343    fn alloc(&self) -> BufferAlloc {
344        unsafe { self.with(|inner| inner.alloc) }
345    }
346
347    fn take(&self, buffer_id: u16) -> Option<BufPtr> {
348        unsafe { self.with(|inner| inner.bufs.get_mut(buffer_id as usize)?.take()) }
349    }
350
351    fn reset(&self, buffer_id: u16, ptr: BufPtr) {
352        unsafe {
353            self.with(|inner| {
354                inner.bufs[buffer_id as usize] = Some(ptr);
355                inner.ctrl.reset(buffer_id, ptr, inner.size);
356            })
357        }
358    }
359
360    fn len(&self) -> u32 {
361        unsafe { self.with(|inner| inner.size) }
362    }
363}
364
365impl BufferRef {
366    /// Set the capacity of this buffer.
367    ///
368    /// This does nothing if `cap` is greater than underlying buffer's
369    /// length.
370    pub fn with_capacity(mut self, cap: usize) -> Self {
371        self.set_capacity(cap);
372        self
373    }
374
375    /// Set the capacity of this buffer.
376    ///
377    /// This does nothing if `cap` is greater than underlying buffer's
378    /// length or equals 0.
379    pub fn set_capacity(&mut self, cap: usize) {
380        if cap == 0 {
381            return;
382        }
383        self.cap = (cap as u32).min(self.full_cap);
384        self.len = self.len.min(self.cap);
385    }
386}
387
388impl Deref for BufferRef {
389    type Target = [u8];
390
391    fn deref(&self) -> &Self::Target {
392        // SAFETY: `SetLen` guarantees the range is initialized
393        unsafe { slice::from_raw_parts(self.ptr.as_ptr().cast(), self.len as usize) }
394    }
395}
396
397impl DerefMut for BufferRef {
398    fn deref_mut(&mut self) -> &mut Self::Target {
399        // SAFETY: `SetLen` guarantees the range is initialized
400        unsafe { slice::from_raw_parts_mut(self.ptr.as_ptr() as _, self.len as usize) }
401    }
402}
403
404impl IoBuf for BufferRef {
405    fn as_init(&self) -> &[u8] {
406        self
407    }
408}
409
410impl SetLen for BufferRef {
411    unsafe fn set_len(&mut self, len: usize) {
412        debug_assert!(len <= u32::MAX as usize);
413        self.len = (len as u32).min(self.cap);
414    }
415}
416
417impl IoBufMut for BufferRef {
418    fn as_uninit(&mut self) -> &mut [MaybeUninit<u8>] {
419        // SAFETY: Cap is initialized as the buffer length, and setting it is
420        // is capped at full_cap, so it will never exceed buffer length. Pointer is
421        // not deallocated.
422        unsafe { slice::from_raw_parts_mut(self.ptr.as_ptr(), self.cap as usize) }
423    }
424}
425
426impl Drop for BufferRef {
427    fn drop(&mut self) {
428        if let Some(shared) = self.shared.upgrade() {
429            // If the buffer pool is alive, set the pointer back
430            shared.reset(self.buffer_id, self.ptr);
431        } else {
432            unsafe { (self.alloc.deallocate)(self.ptr, self.full_cap) }
433        }
434    }
435}