Skip to main content

nexar/collective/
handle.rs

1use crate::error::Result;
2use std::future::Future;
3use tokio::task::JoinHandle;
4
5/// A handle to a non-blocking collective operation.
6///
7/// The collective runs asynchronously in a spawned task. Call `wait()` to
8/// block until it completes, or check `is_finished()` to poll.
9///
10/// If dropped without calling `wait()`, the background task is aborted to
11/// prevent writes to potentially-freed memory.
12pub struct CollectiveHandle {
13    inner: Option<JoinHandle<Result<()>>>,
14}
15
16impl CollectiveHandle {
17    /// Spawn a future as a non-blocking collective and return a handle.
18    pub(crate) fn spawn(fut: impl Future<Output = Result<()>> + Send + 'static) -> Self {
19        Self {
20            inner: Some(tokio::spawn(fut)),
21        }
22    }
23
24    /// Wait for the collective to complete and propagate any error.
25    pub async fn wait(mut self) -> Result<()> {
26        let handle = self
27            .inner
28            .take()
29            .expect("CollectiveHandle already consumed");
30        handle.await.map_err(|e| {
31            crate::error::NexarError::transport(format!("collective task panicked: {e}"))
32        })?
33    }
34
35    /// Check if the collective has finished (non-blocking).
36    pub fn is_finished(&self) -> bool {
37        self.inner.as_ref().is_none_or(|h| h.is_finished())
38    }
39}
40
41impl Drop for CollectiveHandle {
42    fn drop(&mut self) {
43        if let Some(handle) = &self.inner {
44            handle.abort();
45        }
46    }
47}
48
49/// A group of non-blocking collectives that can be waited on together.
50pub struct CollectiveGroup {
51    handles: Vec<CollectiveHandle>,
52}
53
54impl CollectiveGroup {
55    /// Create an empty group.
56    pub fn new() -> Self {
57        Self {
58            handles: Vec::new(),
59        }
60    }
61
62    /// Add a handle to the group.
63    pub fn push(&mut self, h: CollectiveHandle) {
64        self.handles.push(h);
65    }
66
67    /// Wait for all collectives in the group to complete.
68    ///
69    /// Returns the first error encountered, if any. All tasks are awaited
70    /// regardless of errors.
71    pub async fn wait_all(self) -> Result<()> {
72        let mut first_err = None;
73        for h in self.handles {
74            if let Err(e) = h.wait().await
75                && first_err.is_none()
76            {
77                first_err = Some(e);
78            }
79        }
80        match first_err {
81            Some(e) => Err(e),
82            None => Ok(()),
83        }
84    }
85}
86
87impl Default for CollectiveGroup {
88    fn default() -> Self {
89        Self::new()
90    }
91}