use crate::error::Result;
use std::future::Future;
use tokio::task::JoinHandle;
pub struct CollectiveHandle {
inner: Option<JoinHandle<Result<()>>>,
}
impl CollectiveHandle {
pub(crate) fn spawn(fut: impl Future<Output = Result<()>> + Send + 'static) -> Self {
Self {
inner: Some(tokio::spawn(fut)),
}
}
pub async fn wait(mut self) -> Result<()> {
let handle = self
.inner
.take()
.expect("CollectiveHandle already consumed");
handle.await.map_err(|e| {
crate::error::NexarError::transport(format!("collective task panicked: {e}"))
})?
}
pub fn is_finished(&self) -> bool {
self.inner.as_ref().is_none_or(|h| h.is_finished())
}
}
impl Drop for CollectiveHandle {
fn drop(&mut self) {
if let Some(handle) = &self.inner {
handle.abort();
}
}
}
pub struct CollectiveGroup {
handles: Vec<CollectiveHandle>,
}
impl CollectiveGroup {
pub fn new() -> Self {
Self {
handles: Vec::new(),
}
}
pub fn push(&mut self, h: CollectiveHandle) {
self.handles.push(h);
}
pub async fn wait_all(self) -> Result<()> {
let mut first_err = None;
for h in self.handles {
if let Err(e) = h.wait().await
&& first_err.is_none()
{
first_err = Some(e);
}
}
match first_err {
Some(e) => Err(e),
None => Ok(()),
}
}
}
impl Default for CollectiveGroup {
fn default() -> Self {
Self::new()
}
}