mod handle;
mod local;
mod metadata;
mod remote;
pub use handle::LayoutHandle;
pub use metadata::{SerializedLayout, WorkerAddress};
pub(crate) use local::LocalLayout;
pub(crate) use metadata::LocalLayoutDescriptor;
pub(crate) use remote::RemoteLayout;
use crate::block_manager::v2::memory::StorageKind;
use crate::block_manager::v2::physical::layout::PhysicalLayout;
use crate::block_manager::v2::physical::transfer::TransferContext;
use crate::block_manager::v2::physical::transfer::context::TransferCompleteNotification;
use crate::block_manager::v2::physical::transfer::nixl_agent::NixlAgent;
use crate::block_manager::v2::physical::transfer::options::TransferOptions;
use anyhow::{Result, anyhow, bail};
use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicU16, Ordering};
use std::sync::{Arc, RwLock};
#[derive(Clone)]
pub struct TransportManager {
registry: Arc<RwLock<LayoutRegistry>>,
context: Arc<TransferContext>,
}
impl TransportManager {
pub fn builder() -> crate::block_manager::v2::physical::transfer::context::TransferConfigBuilder
{
TransferContext::builder()
}
pub(crate) fn from_context(context: TransferContext) -> Self {
let worker_id = context.worker_id();
let nixl_agent = context.nixl_agent().clone();
let registry = Arc::new(RwLock::new(LayoutRegistry::new(nixl_agent, worker_id)));
Self {
registry,
context: Arc::new(context),
}
}
pub fn register_layout(&self, layout: PhysicalLayout) -> Result<LayoutHandle> {
self.registry.write().unwrap().register_local(layout)
}
pub fn export_metadata(&self) -> Result<SerializedLayout> {
self.registry.read().unwrap().export_metadata()
}
pub fn import_metadata(&self, metadata: SerializedLayout) -> Result<Vec<LayoutHandle>> {
self.registry.write().unwrap().import_metadata(metadata)
}
pub fn execute_transfer(
&self,
src_handle: LayoutHandle,
src_blocks: &[usize],
dst_handle: LayoutHandle,
dst_blocks: &[usize],
options: TransferOptions,
) -> Result<TransferCompleteNotification> {
let (src_layout, dst_layout) = {
let registry = self.registry.read().unwrap();
let src = registry
.get_layout(src_handle)
.ok_or_else(|| anyhow!("invalid source handle: {}", src_handle))?
.clone(); let dst = registry
.get_layout(dst_handle)
.ok_or_else(|| anyhow!("invalid destination handle: {}", dst_handle))?
.clone();
(src, dst)
};
super::transfer::executor::execute_transfer(
&src_layout,
&dst_layout,
src_blocks,
dst_blocks,
options,
&self.context,
)
}
pub fn worker_id(&self) -> u64 {
self.context.worker_id()
}
pub fn get_local_handles(&self) -> Vec<LayoutHandle> {
self.registry.read().unwrap().local_handles()
}
pub fn get_remote_handles(&self) -> Vec<LayoutHandle> {
self.registry.read().unwrap().remote_handles()
}
pub fn context(&self) -> &Arc<TransferContext> {
&self.context
}
#[cfg(all(test, feature = "testing-cuda"))]
pub(crate) fn h2d_stream(&self) -> &std::sync::Arc<cudarc::driver::CudaStream> {
self.context.h2d_stream()
}
#[cfg(all(test, feature = "testing-cuda"))]
#[allow(dead_code)]
pub(crate) fn d2h_stream(&self) -> &std::sync::Arc<cudarc::driver::CudaStream> {
self.context.d2h_stream()
}
#[cfg(all(test, feature = "testing-cuda"))]
pub(crate) fn cuda_context(&self) -> &std::sync::Arc<cudarc::driver::CudaContext> {
self.context.cuda_context()
}
#[cfg(all(test, feature = "testing-cuda"))]
pub(crate) fn register_cuda_event(
&self,
event: cudarc::driver::CudaEvent,
) -> TransferCompleteNotification {
self.context.register_cuda_event(event)
}
}
#[derive(Debug)]
pub(crate) struct LayoutRegistry {
nixl_agent: NixlAgent,
worker_id: u64,
next_layout_id: AtomicU16,
local_layouts: HashMap<LayoutHandle, LocalLayout>,
remote_layouts: HashMap<LayoutHandle, RemoteLayout>,
loaded_remotes: HashSet<(String, u64)>,
}
#[expect(dead_code)]
impl LayoutRegistry {
pub(crate) fn new(nixl_agent: NixlAgent, worker_id: u64) -> Self {
Self {
nixl_agent,
worker_id,
next_layout_id: AtomicU16::new(0),
local_layouts: HashMap::new(),
remote_layouts: HashMap::new(),
loaded_remotes: HashSet::new(),
}
}
pub(crate) fn register_local(&mut self, layout: PhysicalLayout) -> Result<LayoutHandle> {
let layout_id = self.next_layout_id.fetch_add(1, Ordering::SeqCst);
if layout_id == u16::MAX {
bail!("Layout ID overflow: maximum number of layouts (65535) reached");
}
let handle = LayoutHandle::new(self.worker_id, layout_id);
let local_layout = LocalLayout::new(handle, layout);
self.local_layouts.insert(handle, local_layout);
Ok(handle)
}
pub(crate) fn export_metadata(&self) -> Result<SerializedLayout> {
let nixl_metadata = self
.nixl_agent
.get_local_md()
.map_err(|e| anyhow!("failed to get NIXL local metadata: {:?}", e))?;
let worker_address = WorkerAddress::new(self.worker_id, self.nixl_agent.name().to_string());
let mut serialized_layouts = Vec::new();
for (handle, local_layout) in &self.local_layouts {
let location = local_layout.layout().location();
if matches!(
location,
StorageKind::System | StorageKind::Device(_) | StorageKind::Pinned
) {
let serialized = local_layout
.layout()
.to_descriptor()
.map_err(|e| anyhow!("failed to serialize layout {}: {}", handle, e))?;
serialized_layouts.push(LocalLayoutDescriptor::new(*handle, serialized));
}
}
SerializedLayout::pack(worker_address, nixl_metadata, serialized_layouts)
}
pub(crate) fn import_metadata(
&mut self,
metadata: SerializedLayout,
) -> Result<Vec<LayoutHandle>> {
let inner = metadata.unpack()?;
let remote_key = (
inner.worker_address.nixl_agent_name.clone(),
inner.worker_address.worker_id,
);
if self.loaded_remotes.contains(&remote_key) {
bail!(
"Remote worker already loaded: {} (worker_id={})",
remote_key.0,
remote_key.1
);
}
let returned_agent_name = self
.nixl_agent
.load_remote_md(&inner.nixl_metadata)
.map_err(|e| anyhow!("failed to load remote NIXL metadata: {:?}", e))?;
if returned_agent_name != inner.worker_address.nixl_agent_name {
bail!(
"Agent name mismatch: expected '{}', got '{}'",
inner.worker_address.nixl_agent_name,
returned_agent_name
);
}
let mut imported_handles = Vec::new();
for serialized_with_handle in inner.layouts {
let handle = serialized_with_handle.handle;
let layout = PhysicalLayout::from_descriptor(serialized_with_handle.layout)
.map_err(|e| anyhow!("failed to reconstruct layout {}: {}", handle, e))?;
let remote_layout = RemoteLayout::new(handle, layout);
self.remote_layouts.insert(handle, remote_layout);
imported_handles.push(handle);
}
self.loaded_remotes.insert(remote_key);
Ok(imported_handles)
}
pub(crate) fn get_local(&self, handle: LayoutHandle) -> Option<&LocalLayout> {
self.local_layouts.get(&handle)
}
pub(crate) fn get_remote(&self, handle: LayoutHandle) -> Option<&RemoteLayout> {
self.remote_layouts.get(&handle)
}
pub(crate) fn get_layout(&self, handle: LayoutHandle) -> Option<&PhysicalLayout> {
self.local_layouts
.get(&handle)
.map(|l| l.layout())
.or_else(|| self.remote_layouts.get(&handle).map(|r| r.layout()))
}
pub(crate) fn is_local(&self, handle: LayoutHandle) -> bool {
self.local_layouts.contains_key(&handle)
}
pub(crate) fn is_remote(&self, handle: LayoutHandle) -> bool {
self.remote_layouts.contains_key(&handle)
}
pub(crate) fn local_count(&self) -> usize {
self.local_layouts.len()
}
pub(crate) fn remote_count(&self) -> usize {
self.remote_layouts.len()
}
pub(crate) fn worker_id(&self) -> u64 {
self.worker_id
}
pub(crate) fn local_handles(&self) -> Vec<LayoutHandle> {
self.local_layouts.keys().copied().collect()
}
pub(crate) fn remote_handles(&self) -> Vec<LayoutHandle> {
self.remote_layouts.keys().copied().collect()
}
}
#[cfg(all(test, feature = "testing-nixl"))]
mod tests {
use super::*;
use crate::block_manager::v2::physical::layout::LayoutConfig;
use crate::block_manager::v2::physical::transfer::nixl_agent::NixlAgent;
fn make_test_agent(name: &str) -> NixlAgent {
NixlAgent::require_backends(name, &[]).expect("failed to create wrapped agent")
}
fn make_test_layout(agent: &NixlAgent) -> PhysicalLayout {
let config = LayoutConfig::builder()
.num_blocks(2)
.num_layers(2)
.outer_dim(2)
.page_size(4)
.inner_dim(8)
.dtype_width_bytes(2)
.build()
.unwrap();
PhysicalLayout::builder(agent.clone())
.with_config(config)
.fully_contiguous()
.allocate_system()
.build()
.unwrap()
}
#[test]
fn test_manager_creation() {
let agent = make_test_agent("test-manager");
let manager = LayoutRegistry::new(agent, 42);
assert_eq!(manager.worker_id(), 42);
assert_eq!(manager.local_count(), 0);
assert_eq!(manager.remote_count(), 0);
}
#[test]
fn test_register_local() {
let agent = make_test_agent("test-register");
let mut manager = LayoutRegistry::new(agent.clone(), 100);
let layout = make_test_layout(&agent);
let handle = manager.register_local(layout).unwrap();
assert_eq!(handle.worker_id(), 100);
assert_eq!(handle.layout_id(), 0);
assert_eq!(manager.local_count(), 1);
assert!(manager.is_local(handle));
assert!(!manager.is_remote(handle));
}
#[test]
fn test_register_multiple_locals() {
let agent = make_test_agent("test-multiple");
let mut manager = LayoutRegistry::new(agent.clone(), 1);
let handle1 = manager.register_local(make_test_layout(&agent)).unwrap();
let handle2 = manager.register_local(make_test_layout(&agent)).unwrap();
let handle3 = manager.register_local(make_test_layout(&agent)).unwrap();
assert_eq!(handle1.layout_id(), 0);
assert_eq!(handle2.layout_id(), 1);
assert_eq!(handle3.layout_id(), 2);
assert_eq!(manager.local_count(), 3);
}
#[test]
#[ignore] fn test_export_import_roundtrip() {
let source_agent = make_test_agent("source");
let mut source_manager = LayoutRegistry::new(source_agent.clone(), 1);
let handle1 = source_manager
.register_local(make_test_layout(&source_agent))
.unwrap();
let handle2 = source_manager
.register_local(make_test_layout(&source_agent))
.unwrap();
let metadata = source_manager.export_metadata().unwrap();
assert!(!metadata.is_empty());
let dest_agent = make_test_agent("dest");
let mut dest_manager = LayoutRegistry::new(dest_agent, 2);
let imported_handles = dest_manager.import_metadata(metadata).unwrap();
assert_eq!(imported_handles.len(), 2);
assert_eq!(dest_manager.remote_count(), 2);
assert!(dest_manager.is_remote(handle1));
assert!(dest_manager.is_remote(handle2));
assert!(dest_manager.get_remote(handle1).is_some());
assert!(dest_manager.get_remote(handle2).is_some());
assert!(dest_manager.get_layout(handle1).is_some());
}
#[test]
#[ignore] fn test_import_duplicate_remote_fails() {
let source_agent = make_test_agent("source2");
let mut source_manager = LayoutRegistry::new(source_agent.clone(), 10);
source_manager
.register_local(make_test_layout(&source_agent))
.unwrap();
let metadata = source_manager.export_metadata().unwrap();
let dest_agent = make_test_agent("dest2");
let mut dest_manager = LayoutRegistry::new(dest_agent, 20);
let metadata_clone = SerializedLayout::from_bytes(metadata.as_bytes().clone());
dest_manager.import_metadata(metadata).unwrap();
let result = dest_manager.import_metadata(metadata_clone);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("already loaded"));
}
#[test]
fn test_get_layout_handles() {
let agent = make_test_agent("test-handles");
let mut manager = LayoutRegistry::new(agent.clone(), 5);
let h1 = manager.register_local(make_test_layout(&agent)).unwrap();
let h2 = manager.register_local(make_test_layout(&agent)).unwrap();
let handles = manager.local_handles();
assert_eq!(handles.len(), 2);
assert!(handles.contains(&h1));
assert!(handles.contains(&h2));
}
}