use super::*;
use anyhow::Result;
use nixl_sys::{MemoryRegion, NixlDescriptor, OptArgs, XferDescList, XferOp};
use std::ops::Range;
pub fn write_block_to<'a, Source, Destination>(
src: &'a Source,
dst: &'a mut Destination,
ctx: &TransferContext,
notify: Option<String>,
) -> Result<()>
where
Source: BlockDataProvider,
Destination: BlockDataProviderMut,
{
let src_data = src.block_data(private::PrivateToken);
let dst_data = dst.block_data_mut(private::PrivateToken);
if src_data.is_fully_contiguous() && dst_data.is_fully_contiguous() {
let nixl_agent = ctx.nixl_agent().expect("NIXL agent not found");
let remote_worker_id = dst_data.worker_id.to_string();
let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?;
let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?;
let src_desc = src_data.block_view()?.as_nixl_descriptor();
let dst_desc = dst_data.block_view_mut()?.as_nixl_descriptor_mut();
unsafe {
src_dl.add_desc(
src_desc.as_ptr() as usize,
src_desc.size(),
src_desc.device_id(),
)?;
dst_dl.add_desc(
dst_desc.as_ptr() as usize,
dst_desc.size(),
dst_desc.device_id(),
)?;
}
let xfer_req =
nixl_agent.create_xfer_req(XferOp::Write, &src_dl, &dst_dl, &remote_worker_id, None)?;
let mut xfer_args = OptArgs::new()?;
if let Some(notify) = notify {
xfer_args.set_has_notification(true)?;
xfer_args.set_notification_message(notify.as_bytes())?;
}
let mut status = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;
tracing::span!(tracing::Level::DEBUG, "Waiting for transfer to complete").in_scope(|| {
while status {
status = nixl_agent.get_xfer_status(&xfer_req).unwrap();
}
});
} else {
assert_eq!(src_data.num_layers(), dst_data.num_layers());
write_layers_to(0..src_data.num_layers(), src, dst, ctx, notify)?;
}
Ok(())
}
pub fn write_layers_to<'a, Source, Destination>(
layer_range: Range<usize>,
src: &'a Source,
dst: &'a mut Destination,
ctx: &TransferContext,
notify: Option<String>,
) -> Result<()>
where
Source: BlockDataProvider,
Destination: BlockDataProviderMut,
{
let src_data = src.block_data(private::PrivateToken);
let dst_data = dst.block_data_mut(private::PrivateToken);
let nixl_agent = ctx.nixl_agent().expect("NIXL agent not found");
let remote_worker_id = dst_data.worker_id.to_string();
let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?;
let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?;
for layer_idx in layer_range {
let src_view = src_data.layer_view(layer_idx)?;
let mut dst_view = dst_data.layer_view_mut(layer_idx)?;
debug_assert_eq!(src_view.size(), dst_view.size());
let src_desc = src_view.as_nixl_descriptor();
let dst_desc = dst_view.as_nixl_descriptor_mut();
unsafe {
src_dl.add_desc(
src_desc.as_ptr() as usize,
src_desc.size(),
src_desc.device_id(),
)?;
dst_dl.add_desc(
dst_desc.as_ptr() as usize,
dst_desc.size(),
dst_desc.device_id(),
)?;
}
}
let mut xfer_args = OptArgs::new()?;
if let Some(notify) = notify {
xfer_args.set_has_notification(true)?;
xfer_args.set_notification_message(notify.as_bytes())?;
}
let xfer_req = nixl_agent.create_xfer_req(
XferOp::Write,
&src_dl,
&dst_dl,
&remote_worker_id,
Some(&xfer_args),
)?;
let mut status = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;
tracing::span!(tracing::Level::DEBUG, "Waiting for transfer to complete").in_scope(|| {
while status {
status = nixl_agent.get_xfer_status(&xfer_req).unwrap();
}
});
Ok(())
}