use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use tokio::sync::{
Notify,
watch::{Receiver, Sender, channel, error::SendError},
};
use super::status::MappingUid;
#[derive(Debug, Clone, thiserror::Error)]
#[error("completion handle dropped before signaling")]
pub struct CompletionDroppedError;
#[derive(Debug, Clone)]
pub struct CompletionHandle(Sender<Option<MappingUid>>);
impl CompletionHandle {
pub fn complete(&self, mapping_uid: MappingUid) -> Result<(), SendError<Option<MappingUid>>> {
self.0.send(Some(mapping_uid))
}
}
#[derive(Debug, Clone)]
pub struct WaitHandle(Receiver<Option<MappingUid>>);
impl WaitHandle {
pub async fn wait(&mut self) -> Result<MappingUid, CompletionDroppedError> {
let result = self
.0
.wait_for(Option::is_some)
.await
.map_err(|_| CompletionDroppedError)?;
result.ok_or(CompletionDroppedError)
}
#[must_use]
pub fn try_get(&self) -> Option<MappingUid> {
*self.0.borrow()
}
}
#[must_use]
pub fn completion_pair() -> (CompletionHandle, WaitHandle) {
let (sender, receiver) = channel(None);
(CompletionHandle(sender), WaitHandle(receiver))
}
#[derive(Debug, Clone)]
pub struct DeletionBarrier {
inner: Arc<DeletionBarrierInner>,
}
#[derive(Debug)]
struct DeletionBarrierInner {
count: AtomicUsize,
notify: Notify,
}
impl DeletionBarrier {
#[must_use]
pub fn new() -> Self {
Self {
inner: Arc::new(DeletionBarrierInner {
count: AtomicUsize::new(0),
notify: Notify::new(),
}),
}
}
#[must_use]
pub fn completion_handle(&self) -> DeletionCompletionHandle {
self.inner.count.fetch_add(1, Ordering::SeqCst);
DeletionCompletionHandle {
inner: Arc::clone(&self.inner),
}
}
#[must_use]
pub fn wait_handle(&self) -> DeletionWaitHandle {
DeletionWaitHandle {
inner: Arc::clone(&self.inner),
}
}
}
impl Default for DeletionBarrier {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct DeletionCompletionHandle {
inner: Arc<DeletionBarrierInner>,
}
impl Clone for DeletionCompletionHandle {
fn clone(&self) -> Self {
self.inner.count.fetch_add(1, Ordering::SeqCst);
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl Drop for DeletionCompletionHandle {
fn drop(&mut self) {
let prev = self.inner.count.fetch_sub(1, Ordering::SeqCst);
if prev == 1 {
self.inner.notify.notify_waiters();
}
}
}
#[derive(Debug, Clone)]
pub struct DeletionWaitHandle {
inner: Arc<DeletionBarrierInner>,
}
impl DeletionWaitHandle {
pub async fn wait(&self) {
loop {
if self.inner.count.load(Ordering::SeqCst) == 0 {
return;
}
self.inner.notify.notified().await;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_completion_handle_wakes_waiters() {
let (completion, wait) = completion_pair();
let mut wait_clone = wait.clone();
let task = tokio::spawn(async move { wait_clone.wait().await });
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
completion.complete(MappingUid::new_for_test(0)).unwrap();
let result = task.await.unwrap().unwrap();
assert_eq!(result, MappingUid::new_for_test(0));
}
#[tokio::test]
async fn test_try_get_returns_none_before_completion() {
let (_completion, wait) = completion_pair();
assert_eq!(wait.try_get(), None);
}
#[tokio::test]
async fn test_wait_handle_is_cloneable() {
let (completion, wait) = completion_pair();
let wait1 = wait.clone();
let wait2 = wait.clone();
completion.complete(MappingUid::new_for_test(42)).unwrap();
assert_eq!(wait1.try_get(), Some(MappingUid::new_for_test(42)));
assert_eq!(wait2.try_get(), Some(MappingUid::new_for_test(42)));
}
}