use super::Aborter;
use commonware_utils::sync::Mutex;
use std::{
mem,
sync::{Arc, Weak},
};
pub(crate) struct Tree {
inner: Mutex<TreeInner>,
}
struct TreeInner {
_parent: Option<Arc<Tree>>,
children: Vec<Weak<Tree>>,
task: Option<Aborter>,
aborted: bool,
}
impl TreeInner {
const fn new(parent: Option<Arc<Tree>>, aborted: bool) -> Self {
Self {
_parent: parent,
children: Vec::new(),
task: None,
aborted,
}
}
fn child(&mut self, child: &Arc<Tree>) {
self.children.retain(|weak| weak.strong_count() > 0);
self.children.push(Arc::downgrade(child));
}
fn register(&mut self, aborter: Aborter) -> Result<(), Aborter> {
if self.aborted {
return Err(aborter);
}
assert!(self.task.is_none(), "task already registered");
self.task = Some(aborter);
Ok(())
}
fn abort(&mut self) -> Option<(Option<Aborter>, Vec<Weak<Tree>>)> {
if self.aborted {
return None;
}
self.aborted = true;
let task = self.task.take();
let children = mem::take(&mut self.children);
Some((task, children))
}
}
impl Tree {
pub(crate) fn root() -> Arc<Self> {
Arc::new(Self {
inner: Mutex::new(TreeInner::new(None, false)),
})
}
pub(crate) fn child(parent: &Arc<Self>) -> (Arc<Self>, bool) {
let mut parent_inner = parent.inner.lock();
let aborted = parent_inner.aborted;
let child = Arc::new(Self {
inner: Mutex::new(TreeInner::new(Some(parent.clone()), aborted)),
});
if !aborted {
parent_inner.child(&child);
}
drop(parent_inner);
(child, aborted)
}
pub(crate) fn register(self: &Arc<Self>, aborter: Aborter) {
let result = {
let mut inner = self.inner.lock();
inner.register(aborter)
};
if let Err(aborter) = result {
aborter.abort();
}
}
pub(crate) fn abort(self: &Arc<Self>) {
let result = {
let mut inner = self.inner.lock();
inner.abort()
};
let Some((task, children)) = result else {
return;
};
if let Some(aborter) = task {
aborter.abort();
}
for child in children {
if let Some(child) = child.upgrade() {
child.abort();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::MetricHandle;
use futures::future::{pending, AbortHandle, Abortable};
use prometheus_client::metrics::gauge::Gauge;
fn aborter() -> (Aborter, Abortable<futures::future::Pending<()>>) {
let gauge = Gauge::default();
let metric = MetricHandle::new(gauge);
let (handle, registration) = AbortHandle::new_pair();
let aborter = Aborter::new(handle, metric);
(aborter, Abortable::new(pending::<()>(), registration))
}
#[test]
fn abort_cascades_to_children() {
let root = Tree::root();
let (parent, aborted) = Tree::child(&root);
assert!(!aborted, "parent node unexpectedly aborted");
let (parent_aborter, parent_future) = aborter();
parent.register(parent_aborter);
let (child, aborted) = Tree::child(&parent);
assert!(!aborted, "child node unexpectedly aborted");
let (child_aborter, child_future) = aborter();
child.register(child_aborter);
parent.abort();
assert!(parent_future.is_aborted(), "parent was not aborted");
assert!(child_future.is_aborted(), "child was not aborted");
}
#[test]
fn idle_child_survives_descendant_abort() {
let root = Tree::root();
let (parent, aborted) = Tree::child(&root);
assert!(!aborted, "parent node unexpectedly aborted");
let (child1, aborted) = Tree::child(&parent);
assert!(!aborted, "child1 node unexpectedly aborted");
let (child2, aborted) = Tree::child(&parent);
assert!(!aborted, "child2 node unexpectedly aborted");
let (child1_aborter, child1_future) = aborter();
child1.register(child1_aborter);
let (child2_aborter, child2_future) = aborter();
child2.register(child2_aborter);
child2.abort();
assert!(child2_future.is_aborted(), "child2 was not aborted");
assert!(
!child1_future.is_aborted(),
"child1 was aborted by descendant task"
);
}
}