mountpoint_s3_crt/s3/
pool.rs

1//! Bridge custom memory pool implementations to the CRT S3 Client interface.
2
3use std::marker::PhantomPinned;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::{fmt::Debug, ptr::NonNull};
7
8use mountpoint_s3_crt_sys::{
9    aws_allocator, aws_byte_buf, aws_byte_buf_from_empty_array, aws_byte_cursor, aws_future_s3_buffer_ticket,
10    aws_future_s3_buffer_ticket_acquire, aws_future_s3_buffer_ticket_new, aws_future_s3_buffer_ticket_release,
11    aws_future_s3_buffer_ticket_set_result_by_move, aws_ref_count_init, aws_s3_buffer_pool, aws_s3_buffer_pool_config,
12    aws_s3_buffer_pool_factory_fn, aws_s3_buffer_pool_reserve_meta, aws_s3_buffer_pool_vtable, aws_s3_buffer_ticket,
13    aws_s3_buffer_ticket_vtable,
14};
15
16use crate::ToAwsByteCursor as _;
17use crate::common::allocator::Allocator;
18use crate::s3::client::MetaRequestType;
19
20/// A custom memory pool.
21///
22/// **WARNING:** The API for this trait is still experimental and will likely change
23/// in future releases.
24pub trait MemoryPool: Clone + Send + Sync {
25    /// Associated buffer type.
26    type Buffer: AsMut<[u8]>;
27
28    /// Get a buffer of at least the requested size.
29    fn get_buffer(&self, size: usize, meta_request_type: MetaRequestType) -> Self::Buffer;
30
31    /// Trim the pool.
32    ///
33    /// Return `true` if the pool freed any memory.
34    fn trim(&self) -> bool;
35}
36
37/// Factory for a custom memory pool.
38pub trait MemoryPoolFactory: Send + Sync {
39    /// The [MemoryPool] implementation created by this factory.
40    type Pool: MemoryPool;
41
42    /// Create a memory pool instance.
43    fn create(&self, options: MemoryPoolFactoryOptions) -> Self::Pool;
44}
45
46impl<F, P> MemoryPoolFactory for F
47where
48    F: Fn(MemoryPoolFactoryOptions) -> P + Send + Sync,
49    P: MemoryPool,
50{
51    type Pool = P;
52
53    fn create(&self, options: MemoryPoolFactoryOptions) -> Self::Pool {
54        self(options)
55    }
56}
57
58/// Options to create a [MemoryPool].
59#[derive(Debug)]
60pub struct MemoryPoolFactoryOptions {
61    part_size: usize,
62    max_part_size: usize,
63    memory_limit: usize,
64}
65
66impl MemoryPoolFactoryOptions {
67    /// The default part size set on the client.
68    pub fn part_size(&self) -> usize {
69        self.part_size
70    }
71    /// The max part size for the client.
72    pub fn max_part_size(&self) -> usize {
73        self.max_part_size
74    }
75    /// The memory limit set on the client.
76    pub fn memory_limit(&self) -> usize {
77        self.memory_limit
78    }
79}
80
81/// Factory used by [Client](`super::client::Client`) to create CRT wrappers for [MemoryPool] implementations.
82#[derive(Debug, Clone)]
83pub struct CrtBufferPoolFactory(Arc<CrtBufferPoolFactoryInner>);
84
85#[derive(Debug)]
86struct CrtBufferPoolFactoryInner {
87    factory_ptr: NonNull<libc::c_void>,
88    factory_fn: aws_s3_buffer_pool_factory_fn,
89    drop_fn: fn(*mut ::libc::c_void),
90}
91
92// SAFETY: `CrtBufferPoolFactoryInner` is safe to transfer across threads because it wraps a [MemoryPoolFactory] implementation that is [Send].
93unsafe impl Send for CrtBufferPoolFactoryInner {}
94// SAFETY: `CrtBufferPoolFactoryInner` is safe to share across threads because it wraps a [MemoryPoolFactory] implementation that is [Sync].
95unsafe impl Sync for CrtBufferPoolFactoryInner {}
96
97impl Drop for CrtBufferPoolFactoryInner {
98    fn drop(&mut self) {
99        (self.drop_fn)(self.factory_ptr.as_ptr());
100    }
101}
102
103impl CrtBufferPoolFactory {
104    /// Builds a factory for the given pool.
105    pub fn new<PoolFactory: MemoryPoolFactory>(pool_factory: PoolFactory) -> Self {
106        let factory = Box::pin(pool_factory);
107        // SAFETY: The pointer to the factory will only be used in `buffer_pool_factory` and
108        // `drop_pool_factory`, which will treat it as pinned.
109        let leaked = Box::leak(unsafe { Pin::into_inner_unchecked(factory) });
110        // SAFETY: `leaked` is not null.
111        let factory_ptr = unsafe { NonNull::new_unchecked(leaked as *mut PoolFactory as *mut libc::c_void) };
112        Self(Arc::new(CrtBufferPoolFactoryInner {
113            factory_ptr,
114            factory_fn: Some(buffer_pool_factory::<PoolFactory>),
115            drop_fn: drop_pool_factory::<PoolFactory>,
116        }))
117    }
118
119    /// Returns the factory callback and user_data pointer to pass to the CRT.
120    pub(crate) fn as_inner(&self) -> (aws_s3_buffer_pool_factory_fn, *mut ::libc::c_void) {
121        (self.0.factory_fn, self.0.factory_ptr.as_ptr())
122    }
123}
124
125unsafe extern "C" fn buffer_pool_factory<PoolFactory: MemoryPoolFactory>(
126    allocator: *mut aws_allocator,
127    config: aws_s3_buffer_pool_config,
128    user_data: *mut libc::c_void,
129) -> *mut aws_s3_buffer_pool {
130    // SAFETY: `user_data` references a pinned box owned by the `CrtBufferPoolFactory` instance.
131    let pool_factory = unsafe { &*(user_data as *mut PoolFactory) };
132
133    // SAFETY: `allocator` is a non-null pointer to a `aws_allocator` instance.
134    let allocator = unsafe { NonNull::new_unchecked(allocator).into() };
135
136    let options = MemoryPoolFactoryOptions {
137        part_size: config.part_size,
138        max_part_size: config.max_part_size,
139        memory_limit: config.memory_limit,
140    };
141    let pool = pool_factory.create(options);
142
143    let crt_pool = CrtBufferPool::new(pool.clone(), allocator);
144
145    // SAFETY: the CRT will only use the pool through its vtable and refcount.
146    unsafe { crt_pool.leak() }
147}
148
149fn drop_pool_factory<PoolFactory: MemoryPoolFactory>(factory_ptr: *mut libc::c_void) {
150    // SAFETY: `factory_ptr` was leaked in `CrtBufferPoolFactory::new`.
151    _ = unsafe { Pin::new_unchecked(Box::from_raw(factory_ptr as *mut PoolFactory)) };
152}
153
154/// Internal wrapper to bridge the [MemoryPool] implementation to
155/// the `aws_s3_buffer_pool` to provide to the CRT.
156///
157/// Instances of this type also hold the vtables to set up both
158/// the `aws_s3_buffer_pool` itself and the `aws_s3_buffer_ticket`s
159/// it returns. Notably, all the functions in the vtables are generic
160/// in the same [MemoryPool] implementation as [CrtBufferPool], so
161/// that the CRT can handle different implementations and there is
162/// no need for dynamic dispatch on the Rust side.
163struct CrtBufferPool<Pool: MemoryPool> {
164    /// Inner struct to pass to CRT functions.
165    inner: aws_s3_buffer_pool,
166    /// [MemoryPool] implementation.
167    pool: Pool,
168    /// Holds the vtable to point to in `inner`.
169    pool_vtable: aws_s3_buffer_pool_vtable,
170    /// Holds the vtable for the `aws_s3_buffer_ticket` instances.
171    ticket_vtable: aws_s3_buffer_ticket_vtable,
172    /// CRT allocator.
173    allocator: Allocator,
174    /// Pin this struct because inner.impl_ will be a pointer to this object.
175    _pinned: PhantomPinned,
176}
177
178impl<Pool: MemoryPool> CrtBufferPool<Pool> {
179    fn new(pool: Pool, allocator: Allocator) -> Pin<Box<Self>> {
180        // `inner` will be initialized after pinning because its fields require pinned addresses.
181        let mut crt_pool = Box::pin(CrtBufferPool {
182            inner: Default::default(),
183            pool,
184            pool_vtable: aws_s3_buffer_pool_vtable {
185                reserve: Some(pool_reserve::<Pool>),
186                trim: Some(pool_trim::<Pool>),
187                acquire: None,
188                release: None,
189            },
190            ticket_vtable: aws_s3_buffer_ticket_vtable {
191                claim: Some(ticket_claim::<Pool::Buffer>),
192                acquire: None,
193                release: None,
194            },
195            allocator,
196            _pinned: Default::default(),
197        });
198
199        // Set up the vtable and `impl_` to the pinned addresses (self-referential) and initialize ref-counting.
200        // SAFETY: We're setting up the struct to be self-referential, and we're not moving out
201        // of the struct, so the unchecked deref of the pinned pointer is okay.
202        unsafe {
203            let pool_ref = Pin::get_unchecked_mut(Pin::as_mut(&mut crt_pool));
204            pool_ref.inner.vtable = &raw mut pool_ref.pool_vtable;
205            pool_ref.inner.impl_ = pool_ref as *mut CrtBufferPool<Pool> as *mut libc::c_void;
206            aws_ref_count_init(
207                &mut pool_ref.inner.ref_count,
208                &mut pool_ref.inner as *mut aws_s3_buffer_pool as *mut libc::c_void,
209                Some(pool_destroy::<Pool>),
210            );
211        }
212
213        crt_pool
214    }
215
216    /// Leak a pinned instance and returns a raw pointer.
217    ///
218    /// # Safety
219    /// The returned pointer must eventually be passed to [from_raw] and can
220    /// additionally only used in [ref_from_raw].
221    unsafe fn leak(self: Pin<Box<Self>>) -> *mut aws_s3_buffer_pool {
222        // SAFETY: the resulting pointer will be only used in `pool_reserve`, `pool_trim`, and `pool_destroy`.
223        let pool = Box::leak(unsafe { Pin::into_inner_unchecked(self) });
224        &raw mut pool.inner
225    }
226
227    /// Returns a reference to original instance from a raw pointer.
228    ///
229    /// # Safety
230    /// The raw pointer must have been obtained through [leak()].
231    unsafe fn ref_from_raw(pool: &*mut aws_s3_buffer_pool) -> &Self {
232        // SAFETY: `pool` points to the `inner` field of a pinned instance.
233        unsafe {
234            let impl_ptr = (**pool).impl_;
235            &*(impl_ptr as *mut Self)
236        }
237    }
238
239    /// Re-constructs the original pinned instance from a raw pointer.
240    ///
241    /// # Safety
242    /// The raw pointer must have been obtained through [leak()].
243    unsafe fn from_raw(pool: *mut aws_s3_buffer_pool) -> Pin<Box<Self>> {
244        // SAFETY: `pool` points to the `inner` field of a pinned instance.
245        unsafe { Pin::new_unchecked(Box::from_raw((*pool).impl_ as *mut Self)) }
246    }
247
248    fn trim(&self) {
249        self.pool.trim();
250    }
251
252    fn reserve(&self, size: usize, meta_request_type: MetaRequestType) -> CrtTicketFuture {
253        let future = CrtTicketFuture::new(&self.allocator);
254
255        // Get a buffer from the pool, build its ticket, and immediately fullfil the future.
256        // This will likely change later, when we make the method on the pool async.
257        let buffer = self.pool.get_buffer(size, meta_request_type);
258        let ticket = self.make_ticket(buffer);
259        future.set(ticket);
260
261        future
262    }
263
264    fn make_ticket(&self, buffer: Pool::Buffer) -> Pin<Box<CrtTicket<Pool::Buffer>>> {
265        // `inner` will be initialized after pinning because its fields require pinned addresses.
266        let mut ticket = Box::pin(CrtTicket {
267            inner: Default::default(),
268            ticket_vtable: self.ticket_vtable,
269            buffer,
270            _pinned: Default::default(),
271        });
272
273        // Set up the vtable and `impl_` to the pinned addresses (self-referential) and initialize ref-counting.
274        // SAFETY: We're setting up the struct to be self-referential, and we're not moving out
275        // of the struct, so the unchecked deref of the pinned pointer is okay.
276        unsafe {
277            let ticket_ref = Pin::get_unchecked_mut(Pin::as_mut(&mut ticket));
278            ticket_ref.inner.vtable = &raw mut ticket_ref.ticket_vtable;
279            ticket_ref.inner.impl_ = ticket_ref as *mut CrtTicket<Pool::Buffer> as *mut libc::c_void;
280            aws_ref_count_init(
281                &mut ticket_ref.inner.ref_count,
282                &mut ticket_ref.inner as *mut aws_s3_buffer_ticket as *mut libc::c_void,
283                Some(ticket_destroy::<Pool::Buffer>),
284            );
285        }
286
287        ticket
288    }
289}
290
291unsafe extern "C" fn pool_reserve<Pool: MemoryPool>(
292    pool: *mut aws_s3_buffer_pool,
293    meta: aws_s3_buffer_pool_reserve_meta,
294) -> *mut aws_future_s3_buffer_ticket {
295    // SAFETY: `pool` was obtained through `CrtMemoryPool::leak`.
296    let crt_pool = unsafe { CrtBufferPool::<Pool>::ref_from_raw(&pool) };
297
298    // SAFETY: `meta.meta_request` is a pointer to a valid `aws_s3_meta_request`.
299    let request_type = unsafe { (*meta.meta_request).type_ };
300
301    let future = crt_pool.reserve(meta.size, request_type.into());
302
303    // SAFETY: the CRT will take ownership of the future.
304    unsafe { future.into_inner_ptr() }
305}
306
307unsafe extern "C" fn pool_trim<Pool: MemoryPool>(pool: *mut aws_s3_buffer_pool) {
308    // SAFETY: `pool` was obtained through `CrtMemoryPool::leak`.
309    let crt_pool = unsafe { CrtBufferPool::<Pool>::ref_from_raw(&pool) };
310    crt_pool.trim();
311}
312
313unsafe extern "C" fn pool_destroy<Pool: MemoryPool>(data: *mut libc::c_void) {
314    let pool = data as *mut aws_s3_buffer_pool;
315
316    // SAFETY: `pool` was obtained through `CrtMemoryPool::leak`.
317    _ = unsafe { CrtBufferPool::<Pool>::from_raw(pool) };
318}
319
320/// Wrapper for [aws_s3_buffer_ticket].
321struct CrtTicket<Buffer: AsMut<[u8]>> {
322    /// Inner struct to pass to CRT functions.
323    inner: aws_s3_buffer_ticket,
324    /// Holds the vtable to point to in `inner`.
325    ticket_vtable: aws_s3_buffer_ticket_vtable,
326    /// Buffer implementing [AsMut<\[u8\]>].
327    buffer: Buffer,
328    /// Pin this struct because inner.impl_ will be a pointer to this object.
329    _pinned: PhantomPinned,
330}
331
332impl<Buffer: AsMut<[u8]>> CrtTicket<Buffer> {
333    /// Leak a pinned instance and returns a raw pointer.
334    ///
335    /// # Safety
336    /// The returned pointer must eventually be passed to [from_raw] and can
337    /// additionally only be used in [ref_mut_from_raw].
338    unsafe fn leak(self: Pin<Box<Self>>) -> *mut aws_s3_buffer_ticket {
339        // SAFETY: the resulting pointer will be only used in `ticket_claim` and `ticket_destroy`.
340        let boxed = unsafe { Pin::into_inner_unchecked(self) };
341        let pool = Box::leak(boxed);
342        &raw mut pool.inner
343    }
344
345    /// Returns a reference to original instance from a raw pointer.
346    ///
347    /// # Safety
348    /// The raw pointer must have been obtained through [leak()].
349    unsafe fn ref_mut_from_raw(ticket: &mut *mut aws_s3_buffer_ticket) -> &mut Self {
350        // SAFETY: `ticket` points to the `inner` field of a pinned instance.
351        unsafe {
352            let impl_ptr = (**ticket).impl_;
353            &mut *(impl_ptr as *mut Self)
354        }
355    }
356
357    /// Re-constructs the original pinned instance from a raw pointer.
358    ///
359    /// # Safety
360    /// The raw pointer must have been obtained through [leak()].
361    unsafe fn from_raw(ticket: *mut aws_s3_buffer_ticket) -> Pin<Box<Self>> {
362        // SAFETY: `ticket` points to the `inner` field of a pinned instance.
363        unsafe { Pin::new_unchecked(Box::from_raw((*ticket).impl_ as *mut Self)) }
364    }
365}
366
367/// Return the buffer associated with a ticket.
368unsafe extern "C" fn ticket_claim<Buffer: AsMut<[u8]>>(mut ticket: *mut aws_s3_buffer_ticket) -> aws_byte_buf {
369    // SAFETY: `ticket` was obtained through `Ticket::leak`.
370    let ticket = unsafe { CrtTicket::<Buffer>::ref_mut_from_raw(&mut ticket) };
371
372    // SAFETY: the CRT guarantees to only use the returned buffer while holding the ticket.
373    let aws_byte_cursor { len, ptr } = unsafe { ticket.buffer.as_mut().as_aws_byte_cursor() };
374
375    // Use `aws_byte_buf_from_empty_array` to build an `aws_byte_buf` with 0 length and `len` capacity.
376    // SAFETY: `ptr` is a valid buffer with capacity >= `len`.
377    unsafe { aws_byte_buf_from_empty_array(ptr as *mut libc::c_void, len) }
378}
379
380unsafe extern "C" fn ticket_destroy<Buffer: AsMut<[u8]>>(data: *mut libc::c_void) {
381    let ticket = data as *mut aws_s3_buffer_ticket;
382    // SAFETY: `ticket` was obtained through `Ticket::leak`.
383    _ = unsafe { CrtTicket::<Buffer>::from_raw(ticket) };
384}
385
386/// Wrapper for [aws_future_s3_buffer_ticket].
387#[derive(Debug)]
388struct CrtTicketFuture {
389    inner: *mut aws_future_s3_buffer_ticket,
390}
391
392// SAFETY: `aws_future_s3_buffer_ticket` is reference counted and its methods are thread-safe.
393unsafe impl Send for CrtTicketFuture {}
394
395// SAFETY: `aws_future_s3_buffer_ticket` is reference counted and its methods are thread-safe.
396unsafe impl Sync for CrtTicketFuture {}
397
398impl CrtTicketFuture {
399    fn new(allocator: &Allocator) -> Self {
400        // SAFETY: aws_future_s3_buffer_ticket_new return a non-null pointer to a new aws_future_s3_buffer_ticket with a reference count of 1.
401        let inner = unsafe { aws_future_s3_buffer_ticket_new(allocator.inner.as_ptr()) };
402        Self { inner }
403    }
404
405    fn set<Buffer: AsMut<[u8]>>(&self, ticket: Pin<Box<CrtTicket<Buffer>>>) {
406        // SAFETY: `ticket` will be passed to the CRT which will only use it through its vtable and refcount.
407        let mut ticket = unsafe { ticket.leak() };
408        // SAFETY: `self.inner` is a valid future and we are setting it to `ticket`.
409        unsafe {
410            aws_future_s3_buffer_ticket_set_result_by_move(self.inner, &mut ticket);
411        }
412    }
413
414    /// Return the pointer to the inner `aws_future_s3_buffer_ticket` instance.
415    ///
416    /// # Safety
417    /// The returned pointer follows ref-counting rules and must be eventually released.
418    unsafe fn into_inner_ptr(mut self) -> *mut aws_future_s3_buffer_ticket {
419        // Swap the pointer to return with null, so drop will be a no-op.
420        std::mem::replace(&mut self.inner, std::ptr::null_mut())
421    }
422}
423
424impl Clone for CrtTicketFuture {
425    fn clone(&self) -> Self {
426        // SAFETY: `self.inner` is a valid `aws_future_s3_buffer_ticket`, and we increment its
427        // reference count on Clone and decrement it on Drop.
428        let inner = unsafe { aws_future_s3_buffer_ticket_acquire(self.inner) };
429        Self { inner }
430    }
431}
432
433impl Drop for CrtTicketFuture {
434    fn drop(&mut self) {
435        // SAFETY: `self.inner` is a valid `aws_future_s3_buffer_ticket` (or null), and on Drop
436        // it's safe to decrement the reference count since this is balancing the `acquire` in `new`.
437        unsafe { aws_future_s3_buffer_ticket_release(self.inner) };
438    }
439}