use std::ffi::{c_void, CStr};
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Condvar, Mutex};
use std::task::{Context, Poll, Waker};
struct SyncCompletionState<T> {
completed: bool,
result: Option<Result<T, String>>,
}
pub struct SyncCompletion<T> {
inner: Arc<(Mutex<SyncCompletionState<T>>, Condvar)>,
}
pub type SyncCompletionPtr = *mut c_void;
impl<T> SyncCompletion<T> {
#[must_use]
pub fn new() -> (Self, SyncCompletionPtr) {
let inner = Arc::new((
Mutex::new(SyncCompletionState {
completed: false,
result: None,
}),
Condvar::new(),
));
let raw = Arc::into_raw(Arc::clone(&inner));
(Self { inner }, raw as SyncCompletionPtr)
}
pub fn wait(self) -> Result<T, String> {
let (lock, cvar) = &*self.inner;
let mut state = lock.lock().unwrap();
while !state.completed {
state = cvar.wait(state).unwrap();
}
state
.result
.take()
.unwrap_or_else(|| Err("Completion signaled without result".to_string()))
}
pub unsafe fn complete_ok(context: SyncCompletionPtr, value: T) {
Self::complete_with_result(context, Ok(value));
}
pub unsafe fn complete_err(context: SyncCompletionPtr, error: String) {
Self::complete_with_result(context, Err(error));
}
pub unsafe fn complete_with_result(context: SyncCompletionPtr, result: Result<T, String>) {
if context.is_null() {
return;
}
let inner = Arc::from_raw(context.cast::<(Mutex<SyncCompletionState<T>>, Condvar)>());
let (lock, cvar) = &*inner;
{
let mut state = lock.lock().unwrap();
state.completed = true;
state.result = Some(result);
}
cvar.notify_one();
}
}
impl<T> Default for SyncCompletion<T> {
fn default() -> Self {
Self::new().0
}
}
struct AsyncCompletionState<T> {
result: Option<Result<T, String>>,
waker: Option<Waker>,
}
pub struct AsyncCompletion<T> {
_marker: std::marker::PhantomData<T>,
}
pub struct AsyncCompletionFuture<T> {
inner: Arc<Mutex<AsyncCompletionState<T>>>,
}
impl<T> AsyncCompletion<T> {
#[must_use]
pub fn create() -> (AsyncCompletionFuture<T>, SyncCompletionPtr) {
let inner = Arc::new(Mutex::new(AsyncCompletionState {
result: None,
waker: None,
}));
let raw = Arc::into_raw(Arc::clone(&inner));
(AsyncCompletionFuture { inner }, raw as SyncCompletionPtr)
}
pub unsafe fn complete_ok(context: SyncCompletionPtr, value: T) {
Self::complete_with_result(context, Ok(value));
}
pub unsafe fn complete_err(context: SyncCompletionPtr, error: String) {
Self::complete_with_result(context, Err(error));
}
pub unsafe fn complete_with_result(context: SyncCompletionPtr, result: Result<T, String>) {
if context.is_null() {
return;
}
let inner = Arc::from_raw(context.cast::<Mutex<AsyncCompletionState<T>>>());
let waker = {
let mut state = inner.lock().unwrap();
state.result = Some(result);
state.waker.take()
};
if let Some(w) = waker {
w.wake();
}
}
}
impl<T> Future for AsyncCompletionFuture<T> {
type Output = Result<T, String>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut state = self.inner.lock().unwrap();
state.result.take().map_or_else(
|| {
state.waker = Some(cx.waker().clone());
Poll::Pending
},
Poll::Ready,
)
}
}
#[must_use]
pub unsafe fn error_from_cstr(msg: *const i8) -> String {
if msg.is_null() {
"Unknown error".to_string()
} else {
CStr::from_ptr(msg)
.to_str()
.map_or_else(|_| "Unknown error".to_string(), String::from)
}
}
pub type UnitCompletion = SyncCompletion<()>;
impl UnitCompletion {
#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub extern "C" fn callback(context: *mut c_void, success: bool, msg: *const i8) {
if success {
unsafe { Self::complete_ok(context, ()) };
} else {
let error = unsafe { error_from_cstr(msg) };
unsafe { Self::complete_err(context, error) };
}
}
}