coreml_native/
async_bridge.rs1use std::future::Future;
8use std::pin::Pin;
9use std::sync::{Arc, Mutex};
10use std::task::{Context, Poll, Waker};
11
12use crate::error::Result;
13
14struct Shared<T> {
16 value: Option<Result<T>>,
17 waker: Option<Waker>,
18}
19
20pub struct CompletionFuture<T> {
25 shared: Arc<Mutex<Shared<T>>>,
26}
27
28unsafe impl<T: Send> Send for CompletionFuture<T> {}
32
33impl<T: Send> Future for CompletionFuture<T> {
34 type Output = Result<T>;
35
36 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
37 let mut shared = self.shared.lock().unwrap();
38 if let Some(value) = shared.value.take() {
39 Poll::Ready(value)
40 } else {
41 shared.waker = Some(cx.waker().clone());
42 Poll::Pending
43 }
44 }
45}
46
47impl<T: Send> CompletionFuture<T> {
48 pub fn block_on(self) -> Result<T> {
53 {
55 let mut shared = self.shared.lock().unwrap();
56 if let Some(value) = shared.value.take() {
57 return value;
58 }
59 }
60
61 let pair = Arc::new((std::sync::Mutex::new(false), std::sync::Condvar::new()));
63 let pair_for_waker = pair.clone();
64
65 {
66 let mut shared = self.shared.lock().unwrap();
67 if let Some(value) = shared.value.take() {
69 return value;
70 }
71 let waker = condvar_waker(pair_for_waker);
72 shared.waker = Some(waker);
73 }
74
75 let (lock, cvar) = &*pair;
77 let mut ready = lock.lock().unwrap();
78 while !*ready {
79 ready = cvar.wait(ready).unwrap();
80 }
81
82 let mut shared = self.shared.lock().unwrap();
83 shared.value.take().expect("waker fired but no value was set")
84 }
85}
86
87pub(crate) fn completion_channel<T: Send>() -> (CompletionSender<T>, CompletionFuture<T>) {
92 let shared = Arc::new(Mutex::new(Shared {
93 value: None,
94 waker: None,
95 }));
96
97 let sender = CompletionSender {
98 shared: shared.clone(),
99 };
100
101 let future = CompletionFuture { shared };
102
103 (sender, future)
104}
105
106pub(crate) struct CompletionSender<T> {
111 shared: Arc<Mutex<Shared<T>>>,
112}
113
114unsafe impl<T: Send> Send for CompletionSender<T> {}
118unsafe impl<T: Send> Sync for CompletionSender<T> {}
119
120impl<T: Send> CompletionSender<T> {
121 pub fn send(self, value: Result<T>) {
123 let mut shared = self.shared.lock().unwrap();
124 shared.value = Some(value);
125 if let Some(waker) = shared.waker.take() {
126 waker.wake();
127 }
128 }
129}
130
131fn condvar_waker(
136 pair: Arc<(std::sync::Mutex<bool>, std::sync::Condvar)>,
137) -> Waker {
138 use std::task::{RawWaker, RawWakerVTable};
139
140 type CondvarPair = (std::sync::Mutex<bool>, std::sync::Condvar);
141
142 unsafe fn clone_fn(data: *const ()) -> RawWaker {
143 let arc = Arc::from_raw(data as *const CondvarPair);
144 let cloned = arc.clone();
145 std::mem::forget(arc);
147 RawWaker::new(Arc::into_raw(cloned) as *const (), &VTABLE)
148 }
149
150 unsafe fn wake_fn(data: *const ()) {
151 let arc = Arc::from_raw(data as *const CondvarPair);
153 let (lock, cvar) = &*arc;
154 let mut ready = lock.lock().unwrap();
155 *ready = true;
156 cvar.notify_one();
157 }
159
160 unsafe fn wake_by_ref_fn(data: *const ()) {
161 let arc = Arc::from_raw(data as *const CondvarPair);
163 {
164 let (lock, cvar) = &*arc;
165 let mut ready = lock.lock().unwrap();
166 *ready = true;
167 cvar.notify_one();
168 drop(ready);
169 }
170 std::mem::forget(arc);
171 }
172
173 unsafe fn drop_fn(data: *const ()) {
174 drop(Arc::from_raw(data as *const CondvarPair));
176 }
177
178 static VTABLE: RawWakerVTable =
179 RawWakerVTable::new(clone_fn, wake_fn, wake_by_ref_fn, drop_fn);
180
181 let data = Arc::into_raw(pair) as *const ();
182 unsafe { Waker::from_raw(RawWaker::new(data, &VTABLE)) }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use crate::error::{Error, ErrorKind};
191
192 #[test]
193 fn send_then_block_on() {
194 let (sender, future) = completion_channel::<String>();
195
196 std::thread::spawn(move || {
197 std::thread::sleep(std::time::Duration::from_millis(10));
198 sender.send(Ok("hello".to_string()));
199 });
200
201 let result = future.block_on().unwrap();
202 assert_eq!(result, "hello");
203 }
204
205 #[test]
206 fn error_propagation() {
207 let (sender, future) = completion_channel::<String>();
208
209 std::thread::spawn(move || {
210 sender.send(Err(Error::new(ErrorKind::ModelLoad, "test error")));
211 });
212
213 let err = future.block_on().unwrap_err();
214 assert_eq!(err.kind(), &ErrorKind::ModelLoad);
215 }
216
217 #[test]
218 fn immediate_value() {
219 let (sender, future) = completion_channel::<i32>();
220 sender.send(Ok(42));
222 assert_eq!(future.block_on().unwrap(), 42);
223 }
224
225 #[test]
226 fn poll_via_future_trait() {
227 use std::task::{RawWaker, RawWakerVTable};
228
229 fn noop_waker() -> Waker {
231 unsafe fn clone(_: *const ()) -> RawWaker {
232 RawWaker::new(std::ptr::null(), &NOOP_VTABLE)
233 }
234 unsafe fn noop(_: *const ()) {}
235 static NOOP_VTABLE: RawWakerVTable =
236 RawWakerVTable::new(clone, noop, noop, noop);
237 unsafe {
238 Waker::from_raw(RawWaker::new(std::ptr::null(), &NOOP_VTABLE))
239 }
240 }
241
242 let (sender, mut future) = completion_channel::<u64>();
243 let waker = noop_waker();
244 let mut cx = Context::from_waker(&waker);
245
246 let pinned = Pin::new(&mut future);
248 assert!(pinned.poll(&mut cx).is_pending());
249
250 sender.send(Ok(99));
252
253 let pinned = Pin::new(&mut future);
255 match pinned.poll(&mut cx) {
256 Poll::Ready(Ok(v)) => assert_eq!(v, 99),
257 other => panic!("expected Ready(Ok(99)), got {other:?}"),
258 }
259 }
260
261 #[test]
262 fn concurrent_stress() {
263 let handles: Vec<_> = (0..50)
265 .map(|i| {
266 let (sender, future) = completion_channel::<i32>();
267 let h = std::thread::spawn(move || {
268 std::thread::sleep(std::time::Duration::from_micros(i * 10));
269 sender.send(Ok(i as i32));
270 });
271 (h, future)
272 })
273 .collect();
274
275 for (i, (handle, future)) in handles.into_iter().enumerate() {
276 let val = future.block_on().unwrap();
277 assert_eq!(val, i as i32);
278 handle.join().unwrap();
279 }
280 }
281}