use super::*;
use super::{block::Block, config::NixlOptions};
use cudarc::driver::CudaStream;
use std::sync::Arc;
pub struct TransferContext {
nixl_agent: Option<NixlAgent>,
stream: Arc<CudaStream>,
}
impl TransferContext {
pub fn new(nixl_agent: Option<NixlAgent>, stream: Arc<CudaStream>) -> Self {
Self { nixl_agent, stream }
}
pub fn nixl_agent(&self) -> Option<&NixlAgent> {
self.nixl_agent.as_ref()
}
pub fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
}
#[allow(dead_code)]
pub struct KvBlockManagerState<Metadata: BlockMetadata> {
worker_id: WorkerID,
cancellation_token: CancellationToken,
nixl_agent: Option<NixlAgent>,
nixl_backends: HashMap<String, Arc<nixl_sys::Backend>>,
host_pool: Option<BlockPool<PinnedStorage, Metadata>>,
device_pool: Option<BlockPool<DeviceStorage, Metadata>>,
local_block_set: NixlBlockSet,
remote_block_sets: RwLock<HashMap<WorkerID, HashMap<usize, RemoteBlocks>>>,
}
impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
pub fn new(config: KvBlockManagerConfig) -> Result<Arc<Self>> {
config
.runtime
.validate()
.context("Validating runtime config")?;
config.model.validate().context("Validating model config")?;
let worker_id = config.runtime.worker_id;
let cancellation_token = config.runtime.cancellation_token;
let mut nixl_backends: HashMap<String, Arc<nixl_sys::Backend>> = HashMap::new();
let nixl_agent = match config.runtime.nixl {
NixlOptions::Enabled => {
tracing::debug!("Creating NIXL agent");
let agent = NixlAgent::new(&worker_id.to_string())?;
tracing::debug!("Creating NIXL backends");
let (_ucx_mem_list1, ucx_params) = agent.get_plugin_params("UCX")?;
let backend = agent.create_backend("UCX", &ucx_params)?;
nixl_backends.insert("UCX".to_string(), Arc::new(backend));
Some(agent)
}
NixlOptions::EnabledWithAgent(agent) => Some(agent),
NixlOptions::Disabled => None,
};
let model = &config.model;
let mut layout_builder = LayoutConfig::builder();
layout_builder
.num_layers(model.num_layers)
.page_size(model.page_size)
.inner_dim(model.inner_dim)
.dtype(model.dtype);
let mut next_block_set_idx = 0;
let mut local_block_set = block::nixl::NixlBlockSet::new(worker_id);
let (host_pool, host_blocks) = if let Some(config) = config.host_layout {
next_block_set_idx += 1;
tracing::debug!("Constructing host pool.");
let layout = create_layout(layout_builder.clone(), config, nixl_agent.as_ref())?;
local_block_set.add_block_set(next_block_set_idx, layout.serialize()?);
let (pool, blocks) = create_block_pool::<_, Metadata>(
layout,
next_block_set_idx,
cancellation_token.clone(),
worker_id,
)?;
(Some(pool), Some(blocks))
} else {
tracing::debug!("No host layout provided; will not allocate host blocks.");
(None, None)
};
let (device_pool, device_blocks) = if let Some(config) = config.device_layout {
next_block_set_idx += 1;
tracing::debug!("Constructing device pool.");
let layout = create_layout(layout_builder.clone(), config, nixl_agent.as_ref())?;
local_block_set.add_block_set(next_block_set_idx, layout.serialize()?);
let (pool, blocks) = create_block_pool::<_, Metadata>(
layout,
next_block_set_idx,
cancellation_token.clone(),
worker_id,
)?;
(Some(pool), Some(blocks))
} else {
tracing::debug!("No device layout provided; will not allocate device blocks.");
(None, None)
};
if let Some(nixl_agent) = &nixl_agent {
tracing::debug!("Finalize NixlBlockSet: adding NIXL metadata.");
local_block_set.set_nixl_metadata(nixl_agent.get_local_md()?);
}
let state = Arc::new(Self {
worker_id,
cancellation_token,
nixl_agent,
nixl_backends,
host_pool,
device_pool,
local_block_set,
remote_block_sets: RwLock::new(HashMap::new()),
});
if let Some(mut blocks) = host_blocks {
blocks.iter_mut().for_each(|block| {
block.set_manager(state.clone());
});
state
.host_pool
.as_ref()
.unwrap()
.add_blocks_blocking(blocks)?;
}
if let Some(mut blocks) = device_blocks {
blocks.iter_mut().for_each(|block| {
block.set_manager(state.clone());
});
state
.device_pool
.as_ref()
.unwrap()
.add_blocks_blocking(blocks)?;
}
Ok(state)
}
pub fn export_local_blockset(&self) -> Result<SerializedNixlBlockSet> {
SerializedNixlBlockSet::try_from(&self.local_block_set)
.context("Failed to serialize local blockset")
}
pub fn import_remote_blockset(
&self,
serialized_blockset: SerializedNixlBlockSet,
) -> Result<()> {
let remote = NixlBlockSet::try_from(serialized_blockset)
.context("Failed to deserialize remote blockset")?;
let (block_sets, metadata, worker_id) = remote.dissolve();
tracing::debug!("Importing remote blockset from worker {}", worker_id);
assert_ne!(
worker_id, self.worker_id,
"Cannot import blockset from self"
);
let agent = self
.nixl_agent
.as_ref()
.ok_or_else(|| anyhow::anyhow!("NIXL agent not initialized"))?;
let mut remote_block_sets = self.remote_block_sets.write().unwrap();
if remote_block_sets.contains_key(&worker_id) {
anyhow::bail!(
"Worker ID {} already exists; cannot update remote blockset",
worker_id
);
}
let mut inner_map = HashMap::new();
for (block_set_idx, block_set_layout) in block_sets {
let remote_blocks =
RemoteBlocks::from_serialized(block_set_layout.clone(), block_set_idx, worker_id)?;
let layout = remote_blocks.layout();
let storage = layout.storage();
let storage = storage
.first()
.ok_or_else(|| anyhow::anyhow!("No storage found in remote blockset"))?;
match storage.mem_type() {
MemType::Dram => {
tracing::trace!(block_set_idx, "Detected Host/DRAM remote descriptor");
}
MemType::Vram => {
tracing::trace!(block_set_idx, "Detected GPU/Device/VRAM remote descriptor");
}
_ => {
tracing::warn!(
block_set_idx,
"Detected unknown remote descriptor; skipping blockset..."
);
continue;
}
}
inner_map.insert(block_set_idx, remote_blocks);
}
let agent_id = agent
.load_remote_md(&metadata)
.context("Loading remote metadata")?;
let agent_id: WorkerID =
agent_id .parse() .context("Failed to parse agent ID string into WorkerID (u64)")?;
assert_eq!(agent_id, worker_id, "Mismatch with remote worker ID");
remote_block_sets.insert(worker_id, inner_map);
Ok(())
}
pub fn get_remote_blocks_immutable(
&self,
bds: &BlockDescriptorList,
) -> Result<Vec<RemoteBlock<IsImmutable>>> {
self.get_remote_blocks::<IsImmutable>(bds)
}
pub fn get_remote_blocks_mutable(
&self,
bds: &BlockDescriptorList,
) -> Result<Vec<RemoteBlock<IsMutable>>> {
if bds.mutability() == BlockMutability::Mutable {
self.get_remote_blocks::<IsMutable>(bds)
} else {
anyhow::bail!("Cannot get mutable remote blocks for immutable block descriptor set");
}
}
fn get_remote_blocks<M: MutabilityKind>(
&self,
bds: &BlockDescriptorList,
) -> Result<Vec<RemoteBlock<M>>> {
let remote_block_sets = self.remote_block_sets.read().unwrap();
let remote_blocks = remote_block_sets
.get(&bds.worker_id())
.and_then(|map| map.get(&bds.block_set_idx()))
.ok_or_else(|| {
anyhow::anyhow!(
"No remote blockset found for worker {} and block_set_idx {}",
bds.worker_id(),
bds.block_set_idx()
)
})?;
let blocks: Vec<block::nixl::RemoteBlock<M>> = bds
.block_indices()
.iter()
.map(|block_idx| remote_blocks.block(*block_idx))
.collect::<Result<Vec<_>, _>>()?;
Ok(blocks)
}
pub fn host(&self) -> Option<&BlockPool<PinnedStorage, Metadata>> {
self.host_pool.as_ref()
}
pub fn device(&self) -> Option<&BlockPool<DeviceStorage, Metadata>> {
self.device_pool.as_ref()
}
pub fn worker_id(&self) -> WorkerID {
self.worker_id
}
}
impl<Metadata: BlockMetadata> std::fmt::Debug for KvBlockManagerState<Metadata> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "KvBlockManagerState")
}
}
fn create_layout<S: Storage + NixlRegisterableStorage>(
mut builder: LayoutConfigBuilder,
config: KvManagerLayoutConfig<S>,
nixl_agent: Option<&NixlAgent>,
) -> Result<Arc<dyn NixlLayout<StorageType = S>>> {
let layout = builder.num_blocks(config.num_blocks).build()?;
if let Some(storage) = config.storage {
let mut layout = layout.create_layout(config.layout_type, storage)?;
if let Some(nixl_agent) = nixl_agent {
layout.nixl_register(nixl_agent, None)?;
}
return Ok(Arc::new(layout));
}
if let Some(allocator) = config.allocator {
let mut layout = layout.allocate_layout(config.layout_type, allocator)?;
if let Some(nixl_agent) = nixl_agent {
layout.nixl_register(nixl_agent, None)?;
}
return Ok(Arc::new(layout));
}
anyhow::bail!("failed to create layout");
}
#[expect(clippy::type_complexity)]
fn create_block_pool<S: Storage + NixlRegisterableStorage, M: BlockMetadata>(
layout: Arc<dyn NixlLayout<StorageType = S>>,
block_set_idx: usize,
cancellation_token: CancellationToken,
worker_id: WorkerID,
) -> Result<(BlockPool<S, M>, Vec<Block<S, M>>)> {
let blocks = block::layout_to_blocks::<_, M>(layout, block_set_idx, worker_id)?;
let pool = BlockPool::<S, M>::builder()
.cancel_token(cancellation_token)
.build()?;
Ok((pool, blocks))
}