async_cffi/
lib.rs

1use core::panic;
2use std::{ffi::c_void, future::Future, pin::Pin, ptr::NonNull, sync::Arc, task::Waker};
3
4use futures::task::{self, ArcWake};
5
6pub type CWaker = extern "C" fn();
7// () -> nullable result pointer
8pub type CffiPollFuncT = extern "C" fn() -> *const c_void;
9
10#[derive(Debug, Clone)]
11pub struct CWakerWrapper {
12    pub waker: CWaker,
13}
14
15impl ArcWake for CWakerWrapper {
16    fn wake_by_ref(arc_self: &Arc<Self>) {
17        (arc_self.waker)();
18    }
19}
20
21pub struct RustWakerWrapper {
22    waker: Box<dyn Fn() + Send + Sync>,
23}
24
25impl ArcWake for RustWakerWrapper {
26    fn wake_by_ref(arc_self: &Arc<Self>) {
27        (arc_self.waker)();
28    }
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub struct SafePtr(pub *const c_void);
33
34// SAFETY: Ensure the pointer is safe to share across threads
35unsafe impl Send for SafePtr {}
36unsafe impl Sync for SafePtr {}
37
38pub struct CffiFuture {
39    // Locking just so Rust knows this is thread-safe.
40    // Since we should never be polling this from multiple threads at once,
41    // just an unsafe cell would be enough.
42    // Using a std Mutex so tokio doesn't complain about blocking the thread
43    // if we are using this in a tokio context.
44    waker: std::sync::Mutex<Option<Waker>>,
45    // CffiFuture_ptr (for waker) -> result pointer
46    poll_fn: Box<dyn FnMut() -> SafePtr + Send>,
47    pub debug: bool,
48}
49
50unsafe impl Send for CffiFuture {}
51unsafe impl Sync for CffiFuture {}
52
53// C FFI Future that can be polled from Rust, but uses a C callback to poll.
54// Usage:
55// let fut = CffiFuture::new(poll_fn);
56// let result = fut.await; // This will call the C callback to poll the future.
57// // From C
58// void* poll_fn() {}
59// // Call fut.wake() to wake the future from C when it is ready.
60impl CffiFuture {
61    pub fn new<F>(poll_fn: F) -> Pin<Box<Self>>
62    where
63        F: FnMut() -> SafePtr + Send + Sync + 'static,
64    {
65        Box::pin(CffiFuture {
66            waker: std::sync::Mutex::new(None),
67            poll_fn: Box::new(poll_fn),
68            debug: false,
69        })
70    }
71
72    // Wrap a Rust Future to create a CffiFuture.
73    // The output of the Rust Future must be a non-null pointer.
74    // Use `from_rust_future_boxed` if the output may be null.
75    pub fn from_rust_nonnull_future<F>(fut: F) -> Pin<Box<CffiFuture>>
76    where
77        F: Future<Output = *const c_void> + Send + 'static,
78    {
79        let mut box_fut = Box::pin(fut);
80        let mut cffi_fut = Box::pin(CffiFuture {
81            waker: std::sync::Mutex::new(None),
82            poll_fn: Box::new(move || {
83                SafePtr(std::ptr::null()) // Placeholder for the poll function
84            }),
85            debug: false,
86        });
87
88        // Poll fn is self referential, so must be constructed after the CffiFuture is created.
89        // TODO: Add test that uses this waker
90        let cffi_fut_ptr = SafePtr(cffi_fut.as_mut().get_mut() as *mut CffiFuture as *const c_void);
91        cffi_fut.poll_fn = Box::new(move || {
92            let waker = Arc::new(RustWakerWrapper {
93                waker: Box::new(move || {
94                    let cffi_fut_ptr = cffi_fut_ptr;
95                    let fut = Pin::new(unsafe {
96                        (cffi_fut_ptr.0 as *mut CffiFuture)
97                            .as_mut()
98                            .expect("CffiFuture cannot be null")
99                    });
100                    fut.as_ref().wake();
101                }),
102            });
103            let waker = task::waker(waker);
104
105            let mut ctx = std::task::Context::from_waker(&waker);
106            match box_fut.as_mut().poll(&mut ctx) {
107                std::task::Poll::Ready(result) => SafePtr(result),
108                std::task::Poll::Pending => {
109                    // If the future is pending, we return a null pointer.
110                    SafePtr(std::ptr::null())
111                }
112            }
113        });
114
115        cffi_fut
116    }
117
118    // Wrap a Rust Future to create a CffiFuture.
119    // The rust future can output any type.
120    // If the output is guaranteed to be a non-null pointer,
121    // `from_rust_nonnull_future` can be used directly.
122    pub fn from_rust_future_boxed<T: 'static>(
123        fut: impl Future<Output = T> + Send + 'static,
124    ) -> Pin<Box<CffiFuture>> {
125        CffiFuture::from_rust_nonnull_future(box_future_output(fut))
126    }
127
128    pub fn into_raw(self: Pin<Box<Self>>) -> *mut c_void {
129        Box::into_raw(Pin::into_inner(self)) as *mut c_void
130    }
131
132    pub fn from_raw(ptr: *mut c_void) -> Pin<&'static mut Self> {
133        unsafe {
134            Pin::new_unchecked(
135                (ptr as *mut CffiFuture)
136                    .as_mut()
137                    .expect("CffiFuture cannot be null"),
138            )
139        }
140    }
141
142    pub fn poll_inner(self: std::pin::Pin<&mut Self>, waker: &Waker) -> SafePtr {
143        if self.debug {
144            dbg!("Rust: called poll_inner on CffiFuture");
145        }
146        let cffi_future = self.get_mut();
147        let result = (cffi_future.poll_fn)();
148        if !result.0.is_null() {
149            if cffi_future.debug {
150                dbg!("Rust: CffiFuture is ready, returning result");
151            }
152            return result;
153        } else {
154            if cffi_future.debug {
155                dbg!("Rust: CffiFuture not ready, registering waker");
156            }
157            // If the result is null, we need to register the waker.
158            let mut waker_lock = cffi_future.waker.lock().unwrap();
159            match waker_lock.as_mut() {
160                Some(existing_waker) => existing_waker.clone_from(waker),
161                None => *waker_lock = Some(waker.clone()),
162            }
163            SafePtr(std::ptr::null()) // Return null if not ready
164        }
165    }
166
167    pub fn wake(self: std::pin::Pin<&Self>) {
168        if let Some(waker) = self.waker.lock().unwrap().take() {
169            if self.debug {
170                dbg!("Rust calling waker.wake()");
171            }
172            waker.wake();
173            if self.debug {
174                dbg!("Rust called waker.wake()");
175            }
176        } else {
177            // Wake called before poll
178            panic!("CffiFuture: No waker to wake up");
179        }
180    }
181}
182
183impl Future for CffiFuture {
184    type Output = *const c_void;
185
186    fn poll(
187        self: std::pin::Pin<&mut Self>,
188        cx: &mut std::task::Context<'_>,
189    ) -> std::task::Poll<Self::Output> {
190        if self.debug {
191            dbg!("Rust: Polling CffiFuture via Future trait");
192        }
193        let waker = cx.waker();
194        let result = self.poll_inner(waker);
195        if !result.0.is_null() {
196            std::task::Poll::Ready(result.0)
197        } else {
198            std::task::Poll::Pending
199        }
200    }
201}
202
203pub async fn box_future_output<'a, T>(fut: impl Future<Output = T> + Send + 'a) -> *const c_void {
204    let result = fut.await;
205    Box::into_raw(Box::new(result)) as *const c_void
206}
207
208pub fn waker_from_wrapper_ptr(wrapper: *mut c_void) -> Waker {
209    let wrapper = unsafe {
210        (wrapper as *mut CWakerWrapper)
211            .as_mut()
212            .expect("Context cannot be null")
213    };
214    waker_from_wrapper(wrapper.clone())
215}
216
217fn waker_from_wrapper(wrapper: CWakerWrapper) -> Waker {
218    let arc_wrapper = Arc::new(wrapper);
219    task::waker(arc_wrapper)
220}
221
222/// `example_dyn_fn_new() -> ptr<Box<dyn Fn()>>`
223pub fn example_dyn_fn_new() -> *mut c_void {
224    let dyn_fn: Box<dyn Fn()> = Box::new(|| {
225        println!("Hello from the dynamic function!");
226    });
227    let boxed_dyn_fn = Box::new(dyn_fn);
228    Box::into_raw(boxed_dyn_fn) as *mut c_void
229}
230
231/// `example_pointer_buffer_new() -> CffiPointerBuffer`
232pub fn example_pointer_buffer_new() -> CffiPointerBuffer {
233    let pointers: Box<[*const c_void]> = Box::new([box_i32(1), box_i32(2), box_i32(3)]);
234
235    CffiPointerBuffer::from_slice(pointers)
236}
237
238/// `call_dyn_fn(dyn_fn: ptr<Box<dyn Fn()>>) -> ()`
239pub fn call_dyn_fn(dyn_fn: *mut c_void) {
240    let dyn_fn = unsafe {
241        (dyn_fn as *mut Box<dyn Fn()>)
242            .as_ref()
243            .expect("Failed to get dyn_fn")
244    };
245    dyn_fn();
246}
247
248/// `blocking_wait(fut: ptr<CffiFuture<T>>) -> ptr<T>`
249pub fn blocking_wait(fut: *mut c_void) -> *const c_void {
250    let fut = unsafe {
251        (fut as *mut CffiFuture)
252            .as_mut()
253            .expect("CffiFuture cannot be null")
254    };
255    let pinned = std::pin::Pin::new(fut);
256
257    let rt = tokio::runtime::Runtime::new().unwrap();
258    rt.block_on(pinned)
259}
260
261/// `waker_wrapper_new(waker: extern fn()) -> ptr<CWakerWrapper>`
262pub fn waker_wrapper_new(waker: CWaker) -> *mut c_void {
263    let wrapper = Box::new(CWakerWrapper { waker });
264    Box::into_raw(wrapper) as *mut c_void
265}
266
267/// `new_cffi_future(poll_fn: extern fn() -> opt_ptr<T>, debug: bool) -> ptr<CffiFuture<T>>`
268pub fn new_cffi_future(poll_fn: CffiPollFuncT, debug: bool) -> *mut c_void {
269    let mut future = CffiFuture::new(move || {
270        let result = poll_fn();
271        SafePtr(result)
272    });
273    future.debug = debug;
274    future.into_raw()
275}
276
277/// `poll_cffi_future(fut: ptr<CffiFuture<T>>, waker: ptr<CWakerWrapper>) -> opt_ptr<T>`
278pub fn poll_cffi_future(fut: *mut c_void, waker: *mut c_void) -> *const c_void {
279    let fut = unsafe {
280        (fut as *mut CffiFuture)
281            .as_mut()
282            .expect("CffiFuture cannot be null")
283    };
284    let fut = Pin::new(fut);
285    let waker = waker_from_wrapper_ptr(waker);
286
287    if fut.debug {
288        dbg!("Rust: Polling CffiFuture from C");
289    }
290
291    fut.poll_inner(&waker).0
292}
293
294/// `wake_cffi_future(fut: ptr<CffiFuture<T>>) -> ()`
295pub fn wake_cffi_future(fut: *mut c_void) {
296    let fut = unsafe {
297        (fut as *mut CffiFuture)
298            .as_mut()
299            .expect("CffiFuture cannot be null")
300    };
301    let pinned = std::pin::Pin::new(fut);
302    pinned.as_ref().wake();
303}
304
305/// `box_i32(value: i32) -> ptr<i32>`
306pub fn box_i32(value: i32) -> *mut c_void {
307    let boxed_value = Box::new(value);
308    Box::into_raw(boxed_value) as *mut c_void
309}
310
311/// `box_u64(value: u64) -> ptr<u64>`
312pub fn box_u64(value: u64) -> *mut c_void {
313    let boxed_value = Box::new(value);
314    Box::into_raw(boxed_value) as *mut c_void
315}
316
317/// `box_ptr(value: opt_ptr<T>) -> ptr<opt_ptr<T>>`
318pub fn box_ptr(value: *const c_void) -> *mut c_void {
319    let boxed_value = Box::new(value);
320    Box::into_raw(boxed_value) as *mut c_void
321}
322
323#[repr(C)]
324#[derive(Debug, Clone)]
325pub struct CffiPointerBuffer {
326    pub pointers: *const *const c_void,
327    pub length: usize,
328}
329
330impl CffiPointerBuffer {
331    pub fn as_slice(&self) -> &[*const c_void] {
332        unsafe { std::slice::from_raw_parts(self.pointers, self.length) }
333    }
334
335    pub fn from_slice(pointers: Box<[*const c_void]>) -> Self {
336        let length = pointers.len();
337        let buffer = Self {
338            pointers: pointers.as_ptr(),
339            length,
340        };
341        let _ = Box::into_raw(pointers);
342
343        buffer
344    }
345
346    pub fn new_empty() -> Self {
347        Self {
348            pointers: NonNull::dangling().as_ptr(),
349            length: 0,
350        }
351    }
352}
353
354unsafe impl Send for CffiPointerBuffer {}
355unsafe impl Sync for CffiPointerBuffer {}