use std::collections::HashMap;
use std::time::{Duration, Instant};
use anyhow::Result;
use tokio::sync::{mpsc, oneshot};
use tokio::time::interval;
use tracing::warn;
use uuid::Uuid;
pub mod cuda_event;
pub mod nixl_events;
pub mod nixl_status;
pub mod notification;
pub use cuda_event::CudaEventChecker;
pub use nixl_events::{RegisterNixlNotification, process_nixl_notification_events};
pub use nixl_status::NixlStatusChecker;
pub use notification::TransferCompleteNotification;
pub trait CompletionChecker: Send {
fn is_complete(&self) -> Result<bool>;
}
pub struct RegisterPollingNotification<C: CompletionChecker> {
pub uuid: Uuid,
pub checker: C,
pub done: oneshot::Sender<Result<()>>,
}
struct OutstandingPollingTransfer<C: CompletionChecker> {
checker: C,
done: oneshot::Sender<Result<()>>,
arrived_at: Instant,
last_warned_at: Option<Instant>,
}
fn check_and_warn_slow_transfer(
uuid: &Uuid,
arrived_at: Instant,
last_warned_at: Option<Instant>,
) -> Option<Instant> {
let elapsed = arrived_at.elapsed();
if elapsed > Duration::from_secs(60) {
let should_warn = last_warned_at
.map(|last| last.elapsed() > Duration::from_secs(30))
.unwrap_or(true);
if should_warn {
warn!(
uuid = %uuid,
elapsed_secs = elapsed.as_secs(),
"Transfer has been pending for over 1 minute"
);
return Some(Instant::now());
}
}
last_warned_at
}
pub async fn process_polling_notifications<C: CompletionChecker>(
mut rx: mpsc::Receiver<RegisterPollingNotification<C>>,
) {
let mut outstanding: HashMap<Uuid, OutstandingPollingTransfer<C>> = HashMap::new();
let mut check_interval = interval(Duration::from_millis(1));
loop {
tokio::select! {
notification = rx.recv() => {
match notification {
Some(notif) => {
outstanding.insert(notif.uuid, OutstandingPollingTransfer {
checker: notif.checker,
done: notif.done,
arrived_at: Instant::now(),
last_warned_at: None,
});
}
None => {
break;
}
}
}
_ = check_interval.tick(), if !outstanding.is_empty() => {
let mut completed = Vec::new();
for (uuid, transfer) in outstanding.iter_mut() {
match transfer.checker.is_complete() {
Ok(true) => {
completed.push((*uuid, Ok(())));
}
Ok(false) => {
transfer.last_warned_at = check_and_warn_slow_transfer(
uuid,
transfer.arrived_at,
transfer.last_warned_at,
);
}
Err(e) => {
warn!(
uuid = %uuid,
error = %e,
"Transfer status check failed"
);
completed.push((*uuid, Err(e)));
}
}
}
for (uuid, result) in completed {
if let Some(transfer) = outstanding.remove(&uuid) {
let _ = transfer.done.send(result);
}
}
}
}
}
while !outstanding.is_empty() {
check_interval.tick().await;
let mut completed = Vec::new();
for (uuid, transfer) in outstanding.iter() {
match transfer.checker.is_complete() {
Ok(true) => {
completed.push((*uuid, Ok(())));
}
Ok(false) => {
}
Err(e) => {
warn!(
uuid = %uuid,
error = %e,
"Transfer status check failed during shutdown"
);
completed.push((*uuid, Err(e)));
}
}
}
for (uuid, result) in completed {
if let Some(transfer) = outstanding.remove(&uuid) {
let _ = transfer.done.send(result);
}
}
}
}