1use std::{
2 cell::UnsafeCell,
3 fmt::Debug,
4 io,
5 mem::{self, MaybeUninit},
6 ops::{Deref, DerefMut},
7 ptr::{self, NonNull},
8 rc::{Rc, Weak},
9 slice,
10};
11
12use compio_buf::{IoBuf, IoBufMut, SetLen};
13
14use crate::sys::BufControl;
15
16pub trait BufferAllocator {
21 fn allocate(len: u32) -> NonNull<MaybeUninit<u8>>;
23
24 unsafe fn deallocate(ptr: NonNull<MaybeUninit<u8>>, len: u32);
30}
31
32pub struct BoxAllocator;
34
35impl BufferAllocator for BoxAllocator {
37 fn allocate(len: u32) -> NonNull<MaybeUninit<u8>> {
38 let ptr = Box::into_raw(Box::<[u8]>::new_uninit_slice(len as usize)).cast();
39 unsafe { NonNull::new_unchecked(ptr) }
41 }
42
43 unsafe fn deallocate(ptr: NonNull<MaybeUninit<u8>>, len: u32) {
44 let ptr = ptr::slice_from_raw_parts_mut(ptr.as_ptr(), len as usize);
45 _ = unsafe { Box::from_raw(ptr) };
47 }
48}
49
50#[derive(Debug, Clone, Copy)]
51pub(crate) struct BufferAlloc {
52 allocate: fn(len: u32) -> NonNull<MaybeUninit<u8>>,
53 deallocate: unsafe fn(ptr: NonNull<MaybeUninit<u8>>, len: u32),
54}
55
56impl BufferAlloc {
57 pub fn new<A: BufferAllocator>() -> Self {
58 Self {
59 allocate: A::allocate,
60 deallocate: A::deallocate,
61 }
62 }
63}
64
65pub(crate) type BufPtr = NonNull<MaybeUninit<u8>>;
67pub(crate) type Slot = Option<BufPtr>;
69
70const _: () = assert!(size_of::<Slot>() == size_of::<usize>());
71
72#[derive(Clone)]
76pub struct BufferPool {
77 shared: Weak<Shared>,
78}
79
80#[repr(transparent)]
81#[derive(Debug)]
82pub(crate) struct BufferPoolRoot {
83 shared: Rc<Shared>,
84}
85
86#[derive(Debug)]
91pub struct BufferRef {
92 alloc: BufferAlloc,
94 len: u32,
96 cap: u32,
98 full_cap: u32,
101 shared: Weak<Shared>,
103 ptr: BufPtr,
105 buffer_id: u16,
107}
108
109#[repr(transparent)]
110struct Shared {
111 inner: UnsafeCell<Inner>,
112}
113
114struct Inner {
115 alloc: BufferAlloc,
117
118 ctrl: BufControl,
120
121 size: u32,
123
124 bufs: Vec<Slot>,
126}
127
128impl BufferPoolRoot {
129 pub(crate) fn new(
130 driver: &mut crate::Driver,
131 alloc: BufferAlloc,
132 num_of_bufs: u16,
133 buffer_size: usize,
134 flags: u16,
135 ) -> io::Result<Self> {
136 let size: u32 = buffer_size.try_into().map_err(|_| {
137 io::Error::new(
138 io::ErrorKind::InvalidInput,
139 "Buffer size too large. Should be able to fit into u32.",
140 )
141 })?;
142 let bufs = (0..num_of_bufs.next_power_of_two())
143 .map(|_| Some((alloc.allocate)(size)))
144 .collect::<Vec<_>>();
145 let ctrl = unsafe { BufControl::new(driver, &bufs, size, flags) }?;
146
147 Ok(Self {
148 shared: Shared {
149 inner: Inner {
150 alloc,
151 ctrl,
152 size,
153 bufs,
154 }
155 .into(),
156 }
157 .into(),
158 })
159 }
160
161 pub(crate) unsafe fn release(&mut self, driver: &mut crate::Driver) -> io::Result<()> {
177 unsafe {
178 self.shared.with(|inner| {
179 inner.ctrl.release(driver)?;
180 for buf in mem::take(&mut inner.bufs).into_iter().flatten() {
181 (inner.alloc.deallocate)(buf, inner.size)
183 }
184 io::Result::Ok(())
185 })
186 }?;
187
188 Ok(())
189 }
190
191 pub(crate) fn get_pool(&self) -> BufferPool {
192 BufferPool {
193 shared: Rc::downgrade(&self.shared),
194 }
195 }
196
197 pub(crate) fn is_unique(&self) -> bool {
198 Rc::strong_count(&self.shared) == 1
199 }
200}
201
202impl Debug for BufferPool {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 if let Some(shared) = self.shared.upgrade() {
205 f.debug_struct("BufferPool")
206 .field("shared", &shared)
207 .finish()
208 } else {
209 f.debug_struct("BufferPool")
210 .field("shared", &"<dropped>")
211 .finish()
212 }
213 }
214}
215
216impl Debug for Shared {
217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218 struct Buf {
219 ptr: BufPtr,
220 }
221
222 impl Debug for Buf {
223 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
224 write!(f, "Buf<{:p}>", self.ptr)
225 }
226 }
227
228 struct BuffersDebug<'a> {
229 buffers: &'a [Slot],
230 }
231
232 impl Debug for BuffersDebug<'_> {
233 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234 f.debug_list()
235 .entries(self.buffers.iter().map(|buf| buf.map(|ptr| Buf { ptr })))
236 .finish()
237 }
238 }
239
240 unsafe {
241 self.with(|inner| {
242 let buffers = BuffersDebug {
243 buffers: &inner.bufs,
244 };
245 f.debug_struct("Shared")
246 .field("control", &inner.ctrl)
247 .field("size", &inner.size)
248 .field("buffers", &buffers)
249 .finish()
250 })
251 }
252 }
253}
254
255impl BufferPool {
256 pub fn pop(&self) -> io::Result<BufferRef> {
263 let buffer_id = unsafe { self.with(|inner| inner.ctrl.pop()) }??;
264
265 Ok(self.take(buffer_id)?.expect("Buffer should be available"))
266 }
267
268 pub fn take(&self, buffer_id: u16) -> io::Result<Option<BufferRef>> {
272 let shared = self.shared()?;
273 let Some(ptr) = shared.take(buffer_id) else {
274 return Ok(None);
275 };
276 let cap = shared.len();
277
278 Ok(Some(BufferRef {
279 alloc: shared.alloc(),
280 len: 0,
281 cap,
282 full_cap: cap,
283 shared: Rc::downgrade(&shared),
284 ptr,
285 buffer_id,
286 }))
287 }
288
289 pub fn reset(&self, buffer_id: u16) -> io::Result<bool> {
294 let shared = self.shared()?;
295 let Some(buf) = shared.take(buffer_id) else {
296 return Ok(false);
297 };
298 shared.reset(buffer_id, buf);
299 Ok(true)
300 }
301
302 fn shared(&self) -> io::Result<Rc<Shared>> {
303 self.shared
304 .upgrade()
305 .ok_or_else(|| io::Error::other("The driver has been dropped"))
306 }
307
308 unsafe fn with<F, R>(&self, f: F) -> io::Result<R>
312 where
313 F: FnOnce(&mut Inner) -> R,
314 {
315 Ok(unsafe { self.shared()?.with(f) })
316 }
317
318 #[cfg(io_uring)]
320 pub(crate) fn buffer_group(&self) -> io::Result<u16> {
321 unsafe { self.with(|i| i.ctrl.buffer_group()) }
322 }
323
324 #[cfg(fusion)]
326 pub fn is_io_uring(&self) -> io::Result<bool> {
327 unsafe { self.with(|inner| inner.ctrl.is_io_uring()) }
328 }
329}
330
331impl Shared {
332 #[inline(always)]
336 unsafe fn with<F, R>(&self, f: F) -> R
337 where
338 F: FnOnce(&mut Inner) -> R,
339 {
340 f(unsafe { &mut *self.inner.get() })
341 }
342
343 fn alloc(&self) -> BufferAlloc {
344 unsafe { self.with(|inner| inner.alloc) }
345 }
346
347 fn take(&self, buffer_id: u16) -> Option<BufPtr> {
348 unsafe { self.with(|inner| inner.bufs.get_mut(buffer_id as usize)?.take()) }
349 }
350
351 fn reset(&self, buffer_id: u16, ptr: BufPtr) {
352 unsafe {
353 self.with(|inner| {
354 inner.bufs[buffer_id as usize] = Some(ptr);
355 inner.ctrl.reset(buffer_id, ptr, inner.size);
356 })
357 }
358 }
359
360 fn len(&self) -> u32 {
361 unsafe { self.with(|inner| inner.size) }
362 }
363}
364
365impl BufferRef {
366 pub fn with_capacity(mut self, cap: usize) -> Self {
371 self.set_capacity(cap);
372 self
373 }
374
375 pub fn set_capacity(&mut self, cap: usize) {
380 if cap == 0 {
381 return;
382 }
383 self.cap = (cap as u32).min(self.full_cap);
384 self.len = self.len.min(self.cap);
385 }
386}
387
388impl Deref for BufferRef {
389 type Target = [u8];
390
391 fn deref(&self) -> &Self::Target {
392 unsafe { slice::from_raw_parts(self.ptr.as_ptr().cast(), self.len as usize) }
394 }
395}
396
397impl DerefMut for BufferRef {
398 fn deref_mut(&mut self) -> &mut Self::Target {
399 unsafe { slice::from_raw_parts_mut(self.ptr.as_ptr() as _, self.len as usize) }
401 }
402}
403
404impl IoBuf for BufferRef {
405 fn as_init(&self) -> &[u8] {
406 self
407 }
408}
409
410impl SetLen for BufferRef {
411 unsafe fn set_len(&mut self, len: usize) {
412 debug_assert!(len <= u32::MAX as usize);
413 self.len = (len as u32).min(self.cap);
414 }
415}
416
417impl IoBufMut for BufferRef {
418 fn as_uninit(&mut self) -> &mut [MaybeUninit<u8>] {
419 unsafe { slice::from_raw_parts_mut(self.ptr.as_ptr(), self.cap as usize) }
423 }
424}
425
426impl Drop for BufferRef {
427 fn drop(&mut self) {
428 if let Some(shared) = self.shared.upgrade() {
429 shared.reset(self.buffer_id, self.ptr);
431 } else {
432 unsafe { (self.alloc.deallocate)(self.ptr, self.full_cap) }
433 }
434 }
435}