Skip to main content

coreml_native/
async_bridge.rs

1//! Runtime-agnostic bridge from Apple completion handlers to Rust futures.
2//!
3//! Provides [`CompletionFuture<T>`] -- a [`Future`] that resolves when an
4//! Objective-C completion handler fires. Works with any async runtime
5//! (tokio, async-std, smol) or can be blocked on synchronously.
6
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::{Arc, Mutex};
10use std::task::{Context, Poll, Waker};
11
12use crate::error::Result;
13
14/// Shared state between the completion handler and the future.
15struct Shared<T> {
16    value: Option<Result<T>>,
17    waker: Option<Waker>,
18}
19
20/// A future that resolves when an Apple completion handler fires.
21///
22/// Created by [`completion_channel`]. The sender half is passed into
23/// the Objective-C block; the future half is returned to the caller.
24pub struct CompletionFuture<T> {
25    shared: Arc<Mutex<Shared<T>>>,
26}
27
28// Safety: The Arc<Mutex<>> provides thread-safe interior mutability.
29// T must be Send because the completion handler fires on a GCD queue
30// (different thread) and sends T to the future's polling thread.
31unsafe impl<T: Send> Send for CompletionFuture<T> {}
32
33impl<T: Send> Future for CompletionFuture<T> {
34    type Output = Result<T>;
35
36    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
37        let mut shared = self.shared.lock().unwrap();
38        if let Some(value) = shared.value.take() {
39            Poll::Ready(value)
40        } else {
41            shared.waker = Some(cx.waker().clone());
42            Poll::Pending
43        }
44    }
45}
46
47impl<T: Send> CompletionFuture<T> {
48    /// Block the current thread until the completion handler fires.
49    ///
50    /// This is a convenience for callers who don't have an async runtime.
51    /// Uses a condvar internally -- no external dependencies required.
52    pub fn block_on(self) -> Result<T> {
53        // Fast path: value is already available.
54        {
55            let mut shared = self.shared.lock().unwrap();
56            if let Some(value) = shared.value.take() {
57                return value;
58            }
59        }
60
61        // Slow path: wait on a condvar.
62        let pair = Arc::new((std::sync::Mutex::new(false), std::sync::Condvar::new()));
63        let pair_for_waker = pair.clone();
64
65        {
66            let mut shared = self.shared.lock().unwrap();
67            // Re-check after acquiring lock.
68            if let Some(value) = shared.value.take() {
69                return value;
70            }
71            let waker = condvar_waker(pair_for_waker);
72            shared.waker = Some(waker);
73        }
74
75        // Wait for the waker to fire.
76        let (lock, cvar) = &*pair;
77        let mut ready = lock.lock().unwrap();
78        while !*ready {
79            ready = cvar.wait(ready).unwrap();
80        }
81
82        let mut shared = self.shared.lock().unwrap();
83        shared.value.take().expect("waker fired but no value was set")
84    }
85}
86
87/// Creates a channel pair: a [`CompletionSender`] and a [`CompletionFuture`].
88///
89/// The sender is designed to be called exactly once from inside an
90/// Objective-C completion handler block.
91pub(crate) fn completion_channel<T: Send>() -> (CompletionSender<T>, CompletionFuture<T>) {
92    let shared = Arc::new(Mutex::new(Shared {
93        value: None,
94        waker: None,
95    }));
96
97    let sender = CompletionSender {
98        shared: shared.clone(),
99    };
100
101    let future = CompletionFuture { shared };
102
103    (sender, future)
104}
105
106/// Sender half of the completion channel.
107///
108/// Call [`send()`](CompletionSender::send) from within the ObjC completion
109/// handler block. Consumes self to enforce exactly-once semantics.
110pub(crate) struct CompletionSender<T> {
111    shared: Arc<Mutex<Shared<T>>>,
112}
113
114// Safety: CompletionSender is designed to be moved into a block2 closure
115// and called on a GCD dispatch queue (different thread). The Arc<Mutex<>>
116// provides thread safety.
117unsafe impl<T: Send> Send for CompletionSender<T> {}
118unsafe impl<T: Send> Sync for CompletionSender<T> {}
119
120impl<T: Send> CompletionSender<T> {
121    /// Send the completion result, waking the future if it's being polled.
122    pub fn send(self, value: Result<T>) {
123        let mut shared = self.shared.lock().unwrap();
124        shared.value = Some(value);
125        if let Some(waker) = shared.waker.take() {
126            waker.wake();
127        }
128    }
129}
130
131/// Create a [`Waker`] that signals a condvar when woken.
132///
133/// The waker holds an `Arc` reference to the (Mutex<bool>, Condvar) pair.
134/// When woken, it sets the bool to `true` and calls `notify_one()`.
135fn condvar_waker(
136    pair: Arc<(std::sync::Mutex<bool>, std::sync::Condvar)>,
137) -> Waker {
138    use std::task::{RawWaker, RawWakerVTable};
139
140    type CondvarPair = (std::sync::Mutex<bool>, std::sync::Condvar);
141
142    unsafe fn clone_fn(data: *const ()) -> RawWaker {
143        let arc = Arc::from_raw(data as *const CondvarPair);
144        let cloned = arc.clone();
145        // Don't drop the original -- we borrowed it via from_raw.
146        std::mem::forget(arc);
147        RawWaker::new(Arc::into_raw(cloned) as *const (), &VTABLE)
148    }
149
150    unsafe fn wake_fn(data: *const ()) {
151        // Takes ownership (consumes the Arc).
152        let arc = Arc::from_raw(data as *const CondvarPair);
153        let (lock, cvar) = &*arc;
154        let mut ready = lock.lock().unwrap();
155        *ready = true;
156        cvar.notify_one();
157        // arc drops here, decrementing refcount.
158    }
159
160    unsafe fn wake_by_ref_fn(data: *const ()) {
161        // Borrows -- must not drop the Arc.
162        let arc = Arc::from_raw(data as *const CondvarPair);
163        {
164            let (lock, cvar) = &*arc;
165            let mut ready = lock.lock().unwrap();
166            *ready = true;
167            cvar.notify_one();
168            drop(ready);
169        }
170        std::mem::forget(arc);
171    }
172
173    unsafe fn drop_fn(data: *const ()) {
174        // Drop the Arc, decrementing refcount.
175        drop(Arc::from_raw(data as *const CondvarPair));
176    }
177
178    static VTABLE: RawWakerVTable =
179        RawWakerVTable::new(clone_fn, wake_fn, wake_by_ref_fn, drop_fn);
180
181    let data = Arc::into_raw(pair) as *const ();
182    // Safety: The RawWaker vtable correctly manages the Arc refcount.
183    // clone increments, wake/drop decrement, wake_by_ref is neutral.
184    unsafe { Waker::from_raw(RawWaker::new(data, &VTABLE)) }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use crate::error::{Error, ErrorKind};
191
192    #[test]
193    fn send_then_block_on() {
194        let (sender, future) = completion_channel::<String>();
195
196        std::thread::spawn(move || {
197            std::thread::sleep(std::time::Duration::from_millis(10));
198            sender.send(Ok("hello".to_string()));
199        });
200
201        let result = future.block_on().unwrap();
202        assert_eq!(result, "hello");
203    }
204
205    #[test]
206    fn error_propagation() {
207        let (sender, future) = completion_channel::<String>();
208
209        std::thread::spawn(move || {
210            sender.send(Err(Error::new(ErrorKind::ModelLoad, "test error")));
211        });
212
213        let err = future.block_on().unwrap_err();
214        assert_eq!(err.kind(), &ErrorKind::ModelLoad);
215    }
216
217    #[test]
218    fn immediate_value() {
219        let (sender, future) = completion_channel::<i32>();
220        // Value is set before block_on -- exercises the fast path.
221        sender.send(Ok(42));
222        assert_eq!(future.block_on().unwrap(), 42);
223    }
224
225    #[test]
226    fn poll_via_future_trait() {
227        use std::task::{RawWaker, RawWakerVTable};
228
229        // Minimal noop waker for manual polling.
230        fn noop_waker() -> Waker {
231            unsafe fn clone(_: *const ()) -> RawWaker {
232                RawWaker::new(std::ptr::null(), &NOOP_VTABLE)
233            }
234            unsafe fn noop(_: *const ()) {}
235            static NOOP_VTABLE: RawWakerVTable =
236                RawWakerVTable::new(clone, noop, noop, noop);
237            unsafe {
238                Waker::from_raw(RawWaker::new(std::ptr::null(), &NOOP_VTABLE))
239            }
240        }
241
242        let (sender, mut future) = completion_channel::<u64>();
243        let waker = noop_waker();
244        let mut cx = Context::from_waker(&waker);
245
246        // First poll: pending.
247        let pinned = Pin::new(&mut future);
248        assert!(pinned.poll(&mut cx).is_pending());
249
250        // Send value.
251        sender.send(Ok(99));
252
253        // Second poll: ready.
254        let pinned = Pin::new(&mut future);
255        match pinned.poll(&mut cx) {
256            Poll::Ready(Ok(v)) => assert_eq!(v, 99),
257            other => panic!("expected Ready(Ok(99)), got {other:?}"),
258        }
259    }
260
261    #[test]
262    fn concurrent_stress() {
263        // Spawn many channels concurrently to test for races.
264        let handles: Vec<_> = (0..50)
265            .map(|i| {
266                let (sender, future) = completion_channel::<i32>();
267                let h = std::thread::spawn(move || {
268                    std::thread::sleep(std::time::Duration::from_micros(i * 10));
269                    sender.send(Ok(i as i32));
270                });
271                (h, future)
272            })
273            .collect();
274
275        for (i, (handle, future)) in handles.into_iter().enumerate() {
276            let val = future.block_on().unwrap();
277            assert_eq!(val, i as i32);
278            handle.join().unwrap();
279        }
280    }
281}