use std::collections::HashMap;
use std::time::{Duration, Instant};
use anyhow::Result;
use nixl_sys::{Agent as NixlAgent, NotificationMap, XferRequest};
use tokio::sync::{mpsc, oneshot};
use tokio::time::interval;
use tracing::warn;
use uuid::Uuid;
pub struct RegisterNixlNotification {
pub uuid: Uuid,
pub xfer_req: XferRequest,
pub done: oneshot::Sender<Result<()>>,
}
struct OutstandingTransfer {
#[allow(dead_code)] xfer_req: XferRequest,
done: oneshot::Sender<Result<()>>,
arrived_at: Instant,
last_warned_at: Option<Instant>,
}
fn check_and_warn_slow_transfer(
uuid: &Uuid,
arrived_at: Instant,
last_warned_at: Option<Instant>,
) -> Option<Instant> {
let elapsed = arrived_at.elapsed();
if elapsed > Duration::from_secs(60) {
let should_warn = last_warned_at
.map(|last| last.elapsed() > Duration::from_secs(30))
.unwrap_or(true);
if should_warn {
warn!(
uuid = %uuid,
elapsed_secs = elapsed.as_secs(),
"Transfer has been pending for over 1 minute"
);
return Some(Instant::now());
}
}
last_warned_at
}
pub async fn process_nixl_notification_events(
agent: NixlAgent,
mut rx: mpsc::Receiver<RegisterNixlNotification>,
) {
let mut outstanding: HashMap<Uuid, OutstandingTransfer> = HashMap::new();
let mut check_interval = interval(Duration::from_millis(1));
loop {
tokio::select! {
notification = rx.recv() => {
match notification {
Some(notif) => {
outstanding.insert(notif.uuid, OutstandingTransfer {
xfer_req: notif.xfer_req,
done: notif.done,
arrived_at: Instant::now(),
last_warned_at: None,
});
}
None => {
break;
}
}
}
_ = check_interval.tick(), if !outstanding.is_empty() => {
let mut notif_map = match NotificationMap::new() {
Ok(map) => map,
Err(e) => {
warn!(error = %e, "Failed to create notification map");
continue;
}
};
if let Err(e) = agent.get_notifications(&mut notif_map, None) {
warn!(error = %e, "Failed to fetch NIXL notifications");
continue;
}
let notifications = match notif_map.take_notifs() {
Ok(notifs) => notifs,
Err(e) => {
warn!(error = %e, "Failed to extract notifications from map");
continue;
}
};
let mut completed = Vec::new();
for (_agent_name, notif_strings) in notifications {
for notif_str in notif_strings {
if let Ok(notif_uuid) = Uuid::parse_str(¬if_str) {
if outstanding.contains_key(¬if_uuid) {
completed.push(notif_uuid);
} else {
warn!(
uuid = %notif_uuid,
"Received notification for transfer not in outstanding map (early arrival)"
);
}
}
}
}
for (uuid, transfer) in outstanding.iter_mut() {
if !completed.contains(uuid) {
transfer.last_warned_at = check_and_warn_slow_transfer(
uuid,
transfer.arrived_at,
transfer.last_warned_at,
);
}
}
for uuid in completed {
if let Some(transfer) = outstanding.remove(&uuid) {
let _ = transfer.done.send(Ok(()));
}
}
}
}
}
while !outstanding.is_empty() {
check_interval.tick().await;
let mut notif_map = match NotificationMap::new() {
Ok(map) => map,
Err(_) => continue,
};
if let Ok(()) = agent.get_notifications(&mut notif_map, None)
&& let Ok(notifications) = notif_map.take_notifs()
{
let mut completed = Vec::new();
for (_agent_name, notif_strings) in notifications {
for notif_str in notif_strings {
if let Ok(notif_uuid) = Uuid::parse_str(¬if_str)
&& outstanding.contains_key(¬if_uuid)
{
completed.push(notif_uuid);
}
}
}
for uuid in completed {
if let Some(transfer) = outstanding.remove(&uuid) {
let _ = transfer.done.send(Ok(()));
}
}
}
}
}