pub mod factory;
pub mod locality;
pub mod data;
pub mod registry;
pub mod state;
pub mod transfer;
pub use data::{BlockData, BlockDataExt, BlockDataProvider, BlockDataProviderMut, view};
pub use locality::LocalityProvider;
pub use crate::tokens::TokenBlockError;
pub use anyhow::Result;
pub use registry::{GlobalRegistry, RegistrationHandle};
pub use state::{BlockState, BlockStateInvalid};
use crate::block_manager::{
state::KvBlockManagerState as BlockManager,
storage::{Local, Remote, Storage, StorageTypeProvider},
};
use crate::tokens::{SaltHash, SequenceHash, Token, TokenBlock, Tokens};
use super::{
WorkerID,
events::PublishHandle,
layout::{BlockLayout, LayoutError, LayoutType},
storage::StorageType,
};
use derive_getters::Getters;
use std::{
fmt::Debug,
ops::{Deref, DerefMut},
sync::Arc,
};
use thiserror::Error;
pub mod private {
#[derive(Clone, Copy)]
pub struct PrivateToken;
}
pub type BlockId = usize;
pub type BlockSetId = usize;
pub type BlockResult<T> = std::result::Result<T, BlockError>;
#[derive(Debug, Error)]
pub enum BlockError {
#[error(transparent)]
Layout(#[from] LayoutError),
#[error("Invalid state: {0}")]
InvalidState(String),
#[error("Invalid block ID: {0}")]
InvalidBlockID(BlockId),
#[error("Misconfigured block data parallelism: {0}")]
MisconfiguredBlockDataParallelism(String),
#[error("Incompatible storage type: {0}")]
IncompatibleStorageType(String),
#[error("Views are not available on logical blocks")]
ViewsNotAvailableOnLogicalBlocks,
#[error(transparent)]
Other(#[from] anyhow::Error),
#[error("Immutable block already has a duplicate")]
IncompatibleImmutableBlock,
}
pub trait BlockMetadata: Default + std::fmt::Debug + Clone + Ord + Send + Sync + 'static {
fn on_acquired(&mut self, tick: u64);
fn on_returned(&mut self, tick: u64);
fn reset_metadata(&mut self);
fn offload_priority(&self) -> Option<u64>;
fn with_priority(&self, priority: u32) -> Self;
}
pub trait MaybeReturnableBlock<S: Storage, L: LocalityProvider, M: BlockMetadata> {
fn is_returnable(&self) -> bool;
fn try_take_block(self, token: private::PrivateToken) -> Option<Vec<Block<S, L, M>>>;
}
pub trait WritableBlock: BlockDataProviderMut {}
pub trait ReadableBlock: BlockDataProvider {}
pub trait ReadableBlocks {}
impl<T: ReadableBlock> ReadableBlocks for Vec<T> {}
impl<T: ReadableBlock> ReadableBlocks for [T] {}
impl<T: ReadableBlock> ReadableBlocks for &[T] {}
pub trait WritableBlocks {}
impl<T: WritableBlock> WritableBlocks for Vec<T> {}
impl<T: WritableBlock> WritableBlocks for [T] {}
impl<T: WritableBlock> WritableBlocks for &[T] {}
pub trait AsBlockSlice<'a, B: 'a> {
fn as_block_slice(&'a self) -> &'a [B];
}
pub trait AsBlockMutSlice<'a, B: 'a> {
fn as_block_mut_slice(&'a mut self) -> &'a mut [B];
}
pub trait IntoWritableBlocks<Locality: LocalityProvider, M: BlockMetadata> {
type Output: WritableBlocks;
fn into_writable_blocks(self, manager: &BlockManager<Locality, M>)
-> BlockResult<Self::Output>;
}
impl<T: WritableBlocks, Locality: LocalityProvider, M: BlockMetadata>
IntoWritableBlocks<Locality, M> for T
{
type Output = T;
fn into_writable_blocks(
self,
_manager: &BlockManager<Locality, M>,
) -> BlockResult<Self::Output> {
Ok(self)
}
}
pub trait IntoReadableBlocks<Locality: LocalityProvider, M: BlockMetadata> {
type Output: ReadableBlocks;
fn into_readable_blocks(self, manager: &BlockManager<Locality, M>)
-> BlockResult<Self::Output>;
}
impl<T: ReadableBlocks, Locality: LocalityProvider, M: BlockMetadata>
IntoReadableBlocks<Locality, M> for T
{
type Output = T;
fn into_readable_blocks(
self,
_manager: &BlockManager<Locality, M>,
) -> BlockResult<Self::Output> {
Ok(self)
}
}
#[derive(Debug)]
pub struct Block<S: Storage, L: LocalityProvider, M: BlockMetadata> {
data: L::BlockData<S>,
metadata: M,
state: BlockState,
manager: Option<Arc<BlockManager<L, M>>>,
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Block<S, L, M> {
pub fn new(data: L::BlockData<S>, metadata: M) -> BlockResult<Self> {
Ok(Self {
data,
metadata,
state: BlockState::Reset,
manager: None,
})
}
pub fn sequence_hash(&self) -> Result<SequenceHash, BlockError> {
match self.state() {
BlockState::Complete(state) => Ok(state.token_block().sequence_hash()),
BlockState::Registered(state, _) => Ok(state.sequence_hash()),
_ => Err(BlockError::InvalidState(
"Block is not complete nor registered.".to_string(),
)),
}
}
pub fn parent_sequence_hash(&self) -> Result<Option<SequenceHash>, BlockError> {
match self.state() {
BlockState::Complete(state) => Ok(state.token_block().parent_sequence_hash()),
BlockState::Registered(state, _) => Ok(state.parent_sequence_hash()),
_ => Err(BlockError::InvalidState(
"Block is not complete nor registered.".to_string(),
)),
}
}
pub fn reset(&mut self) {
self.state = BlockState::Reset;
self.metadata.reset_metadata();
}
pub fn init_sequence(&mut self, salt_hash: SaltHash) -> Result<()> {
Ok(self
.state
.initialize_sequence(self.page_size(), salt_hash)?)
}
pub fn add_token(&mut self, token: Token) -> Result<()> {
self.state.add_token(token)
}
pub fn add_tokens(&mut self, tokens: Tokens) -> Result<Tokens> {
self.state.add_tokens(tokens)
}
pub fn pop_token(&mut self) -> Result<()> {
self.state.pop_token()
}
pub fn pop_tokens(&mut self, count: usize) -> Result<()> {
self.state.pop_tokens(count)
}
pub fn commit(&mut self) -> Result<()> {
self.state.commit()
}
pub fn apply_token_block(&mut self, token_block: TokenBlock) -> Result<()> {
if self.page_size() != token_block.tokens().len() {
return Err(BlockStateInvalid(format!(
"TokenBlock size ({}) does not match Block page size ({})",
token_block.tokens().len(),
self.page_size()
))
.into());
}
self.state.apply_token_block(token_block)
}
pub fn len(&self) -> usize {
match self.state.len() {
Some(len) => len,
None => self.page_size(),
}
}
pub fn remaining(&self) -> usize {
self.state.remaining()
}
pub fn is_empty(&self) -> bool {
self.state.is_empty()
}
pub fn is_full(&self) -> bool {
self.len() == self.page_size()
}
pub fn tokens(&self) -> Option<&Tokens> {
self.state.tokens()
}
pub(crate) fn set_manager(&mut self, manager: Arc<BlockManager<L, M>>) {
self.manager = Some(manager);
}
pub(crate) fn manager(&self) -> Option<&Arc<BlockManager<L, M>>> {
self.manager.as_ref()
}
pub fn metadata(&self) -> &M {
&self.metadata
}
pub fn update_metadata(&mut self, metadata: M) {
self.metadata = metadata;
}
#[allow(dead_code)]
pub(crate) fn update_state(&mut self, state: BlockState) {
self.state = state;
}
pub fn state(&self) -> &BlockState {
&self.state
}
pub fn state_mut(&mut self) -> &mut BlockState {
&mut self.state
}
pub fn num_blocks(&self) -> usize {
1
}
pub fn block_id(&self) -> BlockId {
self.data.block_id()
}
pub fn num_layers(&self) -> usize {
self.data.num_layers()
}
pub fn page_size(&self) -> usize {
self.data.page_size()
}
pub fn inner_dim(&self) -> usize {
self.data.num_inner_dims()
}
pub fn num_outer_dims(&self) -> usize {
self.data.num_outer_dims()
}
pub(crate) fn metadata_on_acquired(&mut self, tick: u64) {
self.metadata.on_acquired(tick);
}
pub(crate) fn metadata_on_returned(&mut self, tick: u64) {
self.metadata.on_returned(tick);
}
}
pub(crate) trait PrivateBlockExt {
fn register(
&mut self,
registry: &mut registry::BlockRegistry,
) -> Result<Option<PublishHandle>, registry::BlockRegistrationError>;
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> PrivateBlockExt for Block<S, L, M> {
fn register(
&mut self,
registry: &mut registry::BlockRegistry,
) -> Result<Option<PublishHandle>, registry::BlockRegistrationError> {
registry.register_block(&mut self.state)
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Local for Block<S, L, M> {}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> StorageTypeProvider for Block<S, L, M> {
type StorageType = S;
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> BlockDataProvider for Block<S, L, M> {
type Locality = L;
fn block_data(&self) -> &impl BlockDataExt<S> {
&self.data
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> BlockDataProviderMut for Block<S, L, M> {
type Locality = L;
fn block_data_mut(&mut self) -> &mut impl BlockDataExt<S> {
&mut self.data
}
}
pub trait BlockExt {
fn reset(&mut self);
fn init_sequence(&mut self, salt_hash: SaltHash) -> Result<()>;
fn add_token(&mut self, token: Token) -> Result<()>;
fn add_tokens(&mut self, tokens: Tokens) -> Result<Tokens>;
fn pop_token(&mut self) -> Result<()>;
fn pop_tokens(&mut self, count: usize) -> Result<()>;
fn commit(&mut self) -> Result<()>;
fn apply_token_block(&mut self, token_block: TokenBlock) -> Result<()>;
fn len(&self) -> usize;
fn remaining(&self) -> usize;
fn is_empty(&self) -> bool;
fn is_full(&self) -> bool;
fn tokens(&self) -> Option<&Tokens>;
}
#[derive(Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Getters)]
pub struct BasicMetadata {
#[getter(copy)]
priority: u32,
#[getter(copy)]
returned_tick: u64,
#[getter(copy)]
acquired_tick: u64,
}
impl BasicMetadata {
pub fn update_priority(&self, priority: u32) -> Self {
BasicMetadata {
priority,
returned_tick: self.returned_tick,
acquired_tick: self.acquired_tick,
}
}
}
impl BlockMetadata for BasicMetadata {
fn on_acquired(&mut self, tick: u64) {
self.acquired_tick = tick;
}
fn on_returned(&mut self, tick: u64) {
self.returned_tick = tick;
}
fn reset_metadata(&mut self) {
self.priority = 0;
}
fn offload_priority(&self) -> Option<u64> {
Some(self.priority as u64)
}
fn with_priority(&self, priority: u32) -> Self {
self.update_priority(priority)
}
}
#[cfg(test)]
mod basic_metadata_tests {
use super::*;
#[test]
fn test_basic_metadata_with_priority() {
let metadata = BasicMetadata::default();
let updated = metadata.with_priority(75);
assert_eq!(updated.offload_priority(), Some(75));
}
#[test]
fn test_basic_metadata_with_priority_preserves_ticks() {
let mut metadata = BasicMetadata::default();
metadata.on_acquired(100);
metadata.on_returned(200);
let updated = metadata.with_priority(50);
assert_eq!(updated.priority(), 50);
assert_eq!(updated.acquired_tick(), 100);
assert_eq!(updated.returned_tick(), 200);
}
}
#[derive(Debug)]
pub struct Blocks<L: BlockLayout, M: BlockMetadata> {
layout: Box<L>,
metadata: std::marker::PhantomData<M>,
block_set_idx: usize,
worker_id: WorkerID,
}
impl<L: BlockLayout + 'static, M: BlockMetadata> Blocks<L, M> {
pub fn new(layout: L, block_set_idx: usize, worker_id: WorkerID) -> BlockResult<Self> {
let layout = Box::new(layout);
Ok(Self {
layout,
metadata: std::marker::PhantomData,
block_set_idx,
worker_id,
})
}
pub fn into_blocks(self) -> BlockResult<Vec<Block<L::StorageType, locality::Local, M>>> {
let layout: Arc<dyn BlockLayout<StorageType = L::StorageType>> = Arc::new(*self.layout);
layout_to_blocks(layout, self.block_set_idx, self.worker_id)
}
}
pub(crate) fn layout_to_blocks<S: Storage, M: BlockMetadata>(
layout: Arc<dyn BlockLayout<StorageType = S>>,
block_set_idx: usize,
worker_id: WorkerID,
) -> BlockResult<Vec<Block<S, locality::Local, M>>> {
(0..layout.num_blocks())
.map(|idx| {
let data = BlockData::new(layout.clone(), idx, block_set_idx, worker_id);
let data = data;
Block::new(data, M::default())
})
.collect()
}
pub struct MutableBlock<S: Storage, L: LocalityProvider, M: BlockMetadata> {
block: Option<Block<S, L, M>>,
return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, L, M>>,
parent: Option<Arc<MutableBlock<S, L, M>>>,
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> StorageTypeProvider
for MutableBlock<S, L, M>
{
type StorageType = S;
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> BlockDataProvider
for MutableBlock<S, L, M>
{
type Locality = L;
fn block_data(&self) -> &impl BlockDataExt<S> {
&self.block.as_ref().expect("block was dropped").data
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> BlockDataProviderMut
for MutableBlock<S, L, M>
{
type Locality = L;
fn block_data_mut(&mut self) -> &mut impl BlockDataExt<S> {
&mut self.block.as_mut().expect("block was dropped").data
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Local for MutableBlock<S, L, M> {}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> MutableBlock<S, L, M> {
pub(crate) fn new(
block: Block<S, L, M>,
return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, L, M>>,
) -> Self {
Self {
block: Some(block),
return_tx,
parent: None,
}
}
pub fn set_parent(&mut self, parent: Arc<MutableBlock<S, L, M>>) {
self.parent = Some(parent);
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> std::fmt::Debug for MutableBlock<S, L, M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.block {
Some(block) => {
write!(
f,
"MutableBlock(storage_type: {:?}, block_id: {}, sequence_hash: {:?})",
block.block_data().storage_type(),
block.block_id(),
block.sequence_hash().ok()
)
}
None => write!(f, "MutableBlock(block: None)"),
}
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Drop for MutableBlock<S, L, M> {
fn drop(&mut self) {
tracing::debug!("drop: {:?}", self);
if let Some(block) = self.block.take()
&& self.return_tx.send(block).is_err()
{
tracing::warn!("block pool shutdown before block was returned");
}
let mut current_parent = self.parent.take();
while let Some(arc_parent) = current_parent {
match Arc::try_unwrap(arc_parent) {
Ok(mut parent) => {
current_parent = parent.parent.take();
}
Err(_) => {
break;
}
}
}
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Deref for MutableBlock<S, L, M> {
type Target = Block<S, L, M>;
fn deref(&self) -> &Self::Target {
self.block.as_ref().expect("block was dropped")
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> DerefMut for MutableBlock<S, L, M> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.block.as_mut().expect("block was dropped")
}
}
impl<'a, S: Storage + 'a, L: LocalityProvider + 'a, M: BlockMetadata>
AsBlockSlice<'a, MutableBlock<S, L, M>> for [MutableBlock<S, L, M>]
{
fn as_block_slice(&'a self) -> &'a [MutableBlock<S, L, M>] {
self
}
}
impl<'a, S: Storage + 'a, L: LocalityProvider + 'a, M: BlockMetadata>
AsBlockSlice<'a, MutableBlock<S, L, M>> for Vec<MutableBlock<S, L, M>>
{
fn as_block_slice(&'a self) -> &'a [MutableBlock<S, L, M>] {
self.as_slice()
}
}
impl<'a, S: Storage + 'a, L: LocalityProvider + 'a, M: BlockMetadata>
AsBlockMutSlice<'a, MutableBlock<S, L, M>> for [MutableBlock<S, L, M>]
{
fn as_block_mut_slice(&'a mut self) -> &'a mut [MutableBlock<S, L, M>] {
self
}
}
impl<'a, S: Storage + 'a, L: LocalityProvider + 'a, M: BlockMetadata>
AsBlockMutSlice<'a, MutableBlock<S, L, M>> for Vec<MutableBlock<S, L, M>>
{
fn as_block_mut_slice(&'a mut self) -> &'a mut [MutableBlock<S, L, M>] {
self.as_mut_slice()
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> IntoWritableBlocks<L, M>
for MutableBlock<S, L, M>
{
type Output = Vec<MutableBlock<S, L, M>>;
fn into_writable_blocks(self, _manager: &BlockManager<L, M>) -> BlockResult<Self::Output> {
Ok(vec![self])
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> IntoReadableBlocks<L, M>
for MutableBlock<S, L, M>
{
type Output = Vec<MutableBlock<S, L, M>>;
fn into_readable_blocks(self, _manager: &BlockManager<L, M>) -> BlockResult<Self::Output> {
Ok(vec![self])
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> MaybeReturnableBlock<S, L, M>
for MutableBlock<S, L, M>
{
fn is_returnable(&self) -> bool {
self.block.is_some()
}
fn try_take_block(mut self, _: private::PrivateToken) -> Option<Vec<Block<S, L, M>>> {
self.block.take().map(|block| vec![block])
}
}
pub struct ImmutableBlock<S: Storage, L: LocalityProvider, M: BlockMetadata> {
block: Arc<MutableBlock<S, L, M>>,
sequence_hash: SequenceHash,
duplicate: Option<Arc<MutableBlock<S, L, M>>>,
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> std::fmt::Debug
for ImmutableBlock<S, L, M>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"ImmutableBlock(storage: {:?}, block_id: {}, sequence_hash: {})",
self.block
.block
.as_ref()
.expect("block was dropped")
.block_data()
.storage_type(),
self.block_id(),
self.sequence_hash
)
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Clone for ImmutableBlock<S, L, M> {
fn clone(&self) -> Self {
Self {
block: self.block.clone(),
sequence_hash: self.sequence_hash,
duplicate: self.duplicate.clone(),
}
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> ImmutableBlock<S, L, M> {
pub(crate) fn new(block: Arc<MutableBlock<S, L, M>>) -> Self {
let sequence_hash = block.sequence_hash().expect("block is in the wrong state");
Self {
block,
sequence_hash,
duplicate: None,
}
}
pub(crate) fn with_duplicate(
self,
duplicate: Arc<MutableBlock<S, L, M>>,
) -> Result<Self, BlockError> {
if self.duplicate.is_some() {
return Err(BlockError::IncompatibleImmutableBlock);
}
Ok(Self {
duplicate: Some(duplicate),
..self
})
}
pub(crate) fn mutable_block(&self) -> &Arc<MutableBlock<S, L, M>> {
&self.block
}
pub fn sequence_hash(&self) -> SequenceHash {
self.sequence_hash
}
pub fn block_id(&self) -> BlockId {
self.duplicate
.as_ref()
.map_or(self.block.block_id(), |duplicate| duplicate.block_id())
}
#[allow(unused)]
pub(crate) fn is_duplicate(&self) -> bool {
self.duplicate.is_some()
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> StorageTypeProvider
for ImmutableBlock<S, L, M>
{
type StorageType = S;
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> BlockDataProvider
for ImmutableBlock<S, L, M>
{
type Locality = L;
fn block_data(&self) -> &impl BlockDataExt<S> {
&self.block.block.as_ref().expect("block was dropped").data
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Local for ImmutableBlock<S, L, M> {}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> Deref for ImmutableBlock<S, L, M> {
type Target = Block<S, L, M>;
fn deref(&self) -> &Self::Target {
self.block
.as_ref()
.block
.as_ref()
.expect("block was dropped")
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> IntoReadableBlocks<L, M>
for ImmutableBlock<S, L, M>
{
type Output = Vec<ImmutableBlock<S, L, M>>;
fn into_readable_blocks(self, _manager: &BlockManager<L, M>) -> BlockResult<Self::Output> {
Ok(vec![self])
}
}
impl<'a, S: Storage + 'a, L: LocalityProvider + 'a, M: BlockMetadata>
AsBlockSlice<'a, ImmutableBlock<S, L, M>> for [ImmutableBlock<S, L, M>]
{
fn as_block_slice(&'a self) -> &'a [ImmutableBlock<S, L, M>] {
self
}
}
impl<'a, S: Storage + 'a, L: LocalityProvider + 'a, M: BlockMetadata>
AsBlockSlice<'a, ImmutableBlock<S, L, M>> for Vec<ImmutableBlock<S, L, M>>
{
fn as_block_slice(&'a self) -> &'a [ImmutableBlock<S, L, M>] {
self.as_slice()
}
}
impl<S: Storage + 'static, L: LocalityProvider, M: BlockMetadata> ImmutableBlock<S, L, M> {
pub async fn enqueue_offload(&self, priority: u64) -> Result<()> {
if let Some(manager) = self.manager() {
manager.enqueue_offload_block(self, priority).await?;
} else {
tracing::warn!("Block is not managed. Unable to enqueue offload.");
}
Ok(())
}
}
impl<S: Storage, L: LocalityProvider, M: BlockMetadata> MaybeReturnableBlock<S, L, M>
for ImmutableBlock<S, L, M>
{
fn is_returnable(&self) -> bool {
match &self.duplicate {
Some(duplicate) => Arc::strong_count(duplicate) == 1,
None => Arc::strong_count(&self.block) == 1,
}
}
fn try_take_block(mut self, token: private::PrivateToken) -> Option<Vec<Block<S, L, M>>> {
let blocks = [
Arc::try_unwrap(self.block).ok(),
self.duplicate
.take()
.and_then(|duplicate| Arc::try_unwrap(duplicate).ok()),
];
let blocks = blocks
.into_iter()
.flatten()
.filter_map(|block| block.try_take_block(token))
.flatten()
.collect::<Vec<_>>();
if blocks.is_empty() {
None
} else {
Some(blocks)
}
}
}
impl<B: BlockDataProvider> ReadableBlock for B {}
impl<B: BlockDataProviderMut> WritableBlock for B {}
pub mod nixl {
use super::*;
use super::view::{BlockKind, Kind, LayerKind};
use super::super::{
WorkerID,
layout::nixl::{NixlLayout, SerializedNixlBlockLayout},
storage::nixl::{MemType, NixlRegisterableStorage, NixlStorage},
};
use derive_getters::{Dissolve, Getters};
use nixl_sys::{Agent as NixlAgent, MemoryRegion, NixlDescriptor, OptArgs};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub trait MutabilityKind: Debug + Clone + Copy + Send + Sync + 'static {}
#[derive(Debug, Clone, Copy)]
pub struct IsMutable;
impl MutabilityKind for IsMutable {}
#[derive(Debug, Clone, Copy)]
pub struct IsImmutable;
impl MutabilityKind for IsImmutable {}
impl<L: NixlLayout, M: BlockMetadata> Blocks<L, M>
where
L::StorageType: NixlRegisterableStorage,
{
pub fn nixl_register(
&mut self,
agent: &NixlAgent,
opt_args: Option<&OptArgs>,
) -> anyhow::Result<()> {
self.layout.nixl_register(agent, opt_args)
}
}
#[derive(Copy, Clone)] pub struct NixlMemoryDescriptor<'a, K: Kind, M: MutabilityKind> {
addr: u64,
size: usize,
mem_type: MemType,
device_id: u64,
_lifetime: std::marker::PhantomData<&'a ()>, _kind: std::marker::PhantomData<K>, _mutability: std::marker::PhantomData<M>, }
pub(crate) fn short_type_name<T>() -> &'static str {
let name = core::any::type_name::<T>();
name.split("::").last().unwrap_or(name)
}
impl<K: Kind, M: MutabilityKind> Debug for NixlMemoryDescriptor<'_, K, M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NixlMemoryDescriptor")
.field("addr", &self.addr)
.field("size", &self.size)
.field("mem_type", &self.mem_type)
.field("device_id", &self.device_id)
.field("kind", &short_type_name::<K>()) .field("mutability", &short_type_name::<M>())
.finish()
}
}
impl<K: Kind, M: MutabilityKind> NixlMemoryDescriptor<'_, K, M> {
#[inline]
pub(crate) fn new(addr: u64, size: usize, mem_type: MemType, device_id: u64) -> Self {
Self {
addr,
size,
mem_type,
device_id,
_lifetime: std::marker::PhantomData,
_kind: std::marker::PhantomData,
_mutability: std::marker::PhantomData,
}
}
}
impl<K: Kind, M: MutabilityKind> MemoryRegion for NixlMemoryDescriptor<'_, K, M> {
unsafe fn as_ptr(&self) -> *const u8 {
self.addr as *const u8
}
fn size(&self) -> usize {
self.size
}
}
impl<K: Kind, M: MutabilityKind> NixlDescriptor for NixlMemoryDescriptor<'_, K, M> {
fn mem_type(&self) -> MemType {
self.mem_type
}
fn device_id(&self) -> u64 {
self.device_id
}
}
pub trait NixlBlockDataImmutable<S: Storage + NixlDescriptor>: BlockDataExt<S> {
fn as_block_descriptor(
&self,
) -> BlockResult<NixlMemoryDescriptor<'_, BlockKind, IsImmutable>>;
fn as_layer_descriptor(
&self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsImmutable>>;
}
impl<S: Storage + NixlDescriptor> NixlBlockDataImmutable<S> for BlockData<S> {
fn as_block_descriptor(
&self,
) -> BlockResult<NixlMemoryDescriptor<'_, BlockKind, IsImmutable>> {
Ok(self.block_view()?.as_nixl_descriptor())
}
fn as_layer_descriptor(
&self,
layer_idx: usize,
outer_idx: usize,
) -> BlockResult<NixlMemoryDescriptor<'_, LayerKind, IsImmutable>> {
Ok(self.layer_view(layer_idx, outer_idx)?.as_nixl_descriptor())
}
}
#[derive(Debug, Error)]
pub enum NixlSerializationError {
#[error("Serialization failed: {0}")]
Serialize(#[from] serde_json::Error),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializedNixlBlockSet(Vec<u8>);
impl TryFrom<&NixlBlockSet> for SerializedNixlBlockSet {
type Error = NixlSerializationError;
fn try_from(value: &NixlBlockSet) -> Result<Self, Self::Error> {
let bytes = serde_json::to_vec(value)?;
Ok(SerializedNixlBlockSet(bytes))
}
}
impl TryFrom<NixlBlockSet> for SerializedNixlBlockSet {
type Error = NixlSerializationError;
fn try_from(value: NixlBlockSet) -> Result<Self, Self::Error> {
let bytes = serde_json::to_vec(&value)?;
Ok(SerializedNixlBlockSet(bytes))
}
}
impl TryFrom<&SerializedNixlBlockSet> for NixlBlockSet {
type Error = NixlSerializationError;
fn try_from(value: &SerializedNixlBlockSet) -> Result<Self, Self::Error> {
let block_set = serde_json::from_slice(&value.0)?;
Ok(block_set)
}
}
impl TryFrom<SerializedNixlBlockSet> for NixlBlockSet {
type Error = NixlSerializationError;
fn try_from(value: SerializedNixlBlockSet) -> Result<Self, Self::Error> {
let block_set = serde_json::from_slice(&value.0)?;
Ok(block_set)
}
}
#[derive(Clone, serde::Serialize, serde::Deserialize, Dissolve)]
pub struct NixlBlockSet {
block_sets: HashMap<usize, SerializedNixlBlockLayout>,
nixl_metadata: Vec<u8>,
worker_id: u64,
}
impl NixlBlockSet {
pub fn new(worker_id: u64) -> Self {
Self {
block_sets: HashMap::new(),
nixl_metadata: Vec::new(),
worker_id,
}
}
pub fn worker_id(&self) -> u64 {
self.worker_id
}
pub fn block_sets(&self) -> &HashMap<usize, SerializedNixlBlockLayout> {
&self.block_sets
}
pub fn add_block_set(
&mut self,
block_set_idx: usize,
serialized_layout: SerializedNixlBlockLayout,
) {
self.block_sets.insert(block_set_idx, serialized_layout);
}
pub fn get_nixl_metadata(&self) -> &Vec<u8> {
&self.nixl_metadata
}
pub fn set_nixl_metadata(&mut self, nixl_metadata: Vec<u8>) {
self.nixl_metadata = nixl_metadata;
}
}
#[derive(Debug, Clone)]
pub struct RemoteBlocks {
layout: Arc<dyn BlockLayout<StorageType = NixlStorage>>,
block_set_idx: usize,
worker_id: WorkerID,
}
impl RemoteBlocks {
pub fn new(
layout: Arc<dyn BlockLayout<StorageType = NixlStorage>>,
block_set_idx: usize,
worker_id: WorkerID,
) -> Self {
Self {
layout,
block_set_idx,
worker_id,
}
}
pub fn from_serialized(
serialized: SerializedNixlBlockLayout,
block_set_idx: usize,
worker_id: WorkerID,
) -> BlockResult<Self> {
let layout = serialized.deserialize()?;
Ok(Self::new(layout, block_set_idx, worker_id))
}
pub fn block<M: MutabilityKind>(&self, block_idx: usize) -> BlockResult<RemoteBlock<M>> {
if block_idx >= self.layout.num_blocks() {
return Err(BlockError::InvalidState(format!(
"block index out of bounds: {} >= {}",
block_idx,
self.layout.num_blocks()
)));
}
Ok(RemoteBlock::new(
self.layout.clone(),
block_idx,
self.block_set_idx,
self.worker_id,
))
}
pub fn layout(&self) -> &dyn BlockLayout<StorageType = NixlStorage> {
self.layout.as_ref()
}
}
pub type ImmutableRemoteBlock = RemoteBlock<IsImmutable>;
pub type MutableRemoteBlock = RemoteBlock<IsMutable>;
pub struct RemoteBlock<M: MutabilityKind> {
data: BlockData<NixlStorage>,
_mutability: std::marker::PhantomData<M>,
}
impl<M: MutabilityKind> Remote for RemoteBlock<M> {}
impl<M: MutabilityKind> RemoteBlock<M> {
pub fn new(
layout: Arc<dyn BlockLayout<StorageType = NixlStorage>>,
block_idx: usize,
block_set_idx: usize,
worker_id: WorkerID,
) -> Self {
let data = BlockData::new(layout, block_idx, block_set_idx, worker_id);
Self {
data,
_mutability: std::marker::PhantomData,
}
}
}
impl<M: MutabilityKind> StorageTypeProvider for RemoteBlock<M> {
type StorageType = NixlStorage;
}
impl<M: MutabilityKind> BlockDataProvider for RemoteBlock<M> {
type Locality = locality::Local;
fn block_data(&self) -> &impl BlockDataExt<NixlStorage> {
&self.data
}
}
impl BlockDataProviderMut for RemoteBlock<IsMutable> {
type Locality = locality::Local;
fn block_data_mut(&mut self) -> &mut impl BlockDataExt<NixlStorage> {
&mut self.data
}
}
impl<'a, M: MutabilityKind> AsBlockSlice<'a, RemoteBlock<M>> for [RemoteBlock<M>] {
fn as_block_slice(&'a self) -> &'a [RemoteBlock<M>] {
self
}
}
impl<'a, M: MutabilityKind> AsBlockSlice<'a, RemoteBlock<M>> for Vec<RemoteBlock<M>> {
fn as_block_slice(&'a self) -> &'a [RemoteBlock<M>] {
self.as_slice()
}
}
impl<'a> AsBlockMutSlice<'a, RemoteBlock<IsMutable>> for [RemoteBlock<IsMutable>] {
fn as_block_mut_slice(&'a mut self) -> &'a mut [RemoteBlock<IsMutable>] {
self
}
}
impl<'a> AsBlockMutSlice<'a, RemoteBlock<IsMutable>> for Vec<RemoteBlock<IsMutable>> {
fn as_block_mut_slice(&'a mut self) -> &'a mut [RemoteBlock<IsMutable>] {
self.as_mut_slice()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BlockMutability {
Immutable,
Mutable,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct BlockDescriptor {
pub worker_id: WorkerID,
pub block_set_idx: usize,
pub block_idx: usize,
pub mutability: BlockMutability,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Getters)]
pub struct BlockDescriptorList {
#[getter(copy)]
worker_id: WorkerID,
#[getter(copy)]
block_set_idx: usize,
#[getter(copy)]
mutability: BlockMutability,
block_indices: Vec<usize>,
}
#[derive(Debug, Error)]
pub enum BlockDescriptorSetError {
#[error("Input block list cannot be empty")]
EmptyInput,
#[error("Blocks in the input list are not homogeneous (worker_id, block_set_idx mismatch)")]
NotHomogeneous,
#[error("Serialization failed: {0}")]
SerializationError(#[from] serde_json::Error),
#[error(
"An invalid block handle was encountered (block may have been dropped prematurely)"
)]
InvalidBlockHandle,
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::layout::tests::setup_layout;
use crate::tokens::{TokenBlockSequence, Tokens};
const BLOCK_SIZE: u32 = 4;
const SALT_HASH: SaltHash = 12345;
fn create_reset_block() -> Block<impl Storage, locality::Local, BasicMetadata> {
let layout = setup_layout(None).unwrap();
let data = BlockData::new(Arc::new(layout), 0, 42, 0);
Block::new(data, BasicMetadata::default()).unwrap()
}
fn create_full_token_block() -> TokenBlock {
let tokens = Tokens::from(vec![1, 2, 3, 4]);
let salt_hash = SALT_HASH;
let block_size = BLOCK_SIZE;
let (mut blocks, _) =
TokenBlockSequence::split_tokens(tokens.as_ref(), block_size, salt_hash);
blocks.pop().unwrap()
}
#[test]
fn test_block_state_transitions_and_ops() {
let mut block = create_reset_block();
assert!(matches!(block.state(), BlockState::Reset));
assert!(block.add_token(1).is_err(), "Append on Reset should fail");
assert!(
block.add_tokens(Tokens::from(vec![1])).is_err(),
"Extend on Reset should fail"
);
assert!(block.commit().is_err(), "Commit on Reset should fail");
assert!(block.pop_token().is_err(), "Pop on Reset should fail");
assert!(
block.pop_tokens(1).is_err(),
"Pop tokens on Reset should fail"
);
assert!(block.init_sequence(SALT_HASH).is_ok());
assert!(matches!(block.state(), BlockState::Partial(_)));
let invalid_block = create_full_token_block();
assert!(
block.apply_token_block(invalid_block).is_err(),
"Apply block on Partial should fail"
);
assert!(block.add_token(1).is_ok()); assert!(block.add_token(2).is_ok()); assert!(block.add_tokens(Tokens::from(vec![3])).is_ok()); assert_eq!(block.len(), 3);
let new_tokens = Tokens::from(vec![4, 5]);
assert_eq!(block.add_tokens(new_tokens.clone()).unwrap().as_ref(), &[5]);
assert!(block.add_tokens(Tokens::from(vec![4])).is_ok()); assert_eq!(block.len(), BLOCK_SIZE as usize);
assert!(block.add_token(5).is_err(), "Append on full Partial block");
assert!(block.pop_token().is_ok()); assert_eq!(block.len(), 3);
assert!(block.pop_tokens(2).is_ok()); assert_eq!(block.len(), 1);
assert!(block.pop_tokens(2).is_err(), "Pop too many tokens");
assert_eq!(block.len(), 1);
assert!(block.pop_token().is_ok()); assert_eq!(block.len(), 0);
assert!(block.is_empty());
assert!(block.add_tokens(Tokens::from(vec![1, 2, 3, 4])).is_ok());
assert_eq!(block.len(), BLOCK_SIZE as usize);
assert!(block.commit().is_ok());
assert!(matches!(block.state(), BlockState::Complete(_)));
assert_eq!(block.tokens().unwrap().as_ref(), &[1, 2, 3, 4]);
assert!(
block.init_sequence(SALT_HASH).is_err(),
"Init sequence on Complete should fail"
);
assert!(
block.add_token(5).is_err(),
"Append on Complete should fail"
);
assert!(
block.add_tokens(Tokens::from(vec![5])).is_err(),
"Extend on Complete should fail"
);
assert!(block.commit().is_err(), "Commit on Complete should fail");
assert!(block.pop_token().is_err(), "Pop on Complete should fail");
assert!(
block.pop_tokens(1).is_err(),
"Pop tokens on Complete should fail"
);
let invalid_block = create_full_token_block();
assert!(
block.apply_token_block(invalid_block).is_err(),
"Apply block on Complete should fail"
);
block.reset();
assert!(matches!(block.state(), BlockState::Reset));
let full_block = create_full_token_block();
assert!(block.apply_token_block(full_block.clone()).is_ok());
assert!(matches!(block.state(), BlockState::Complete(_)));
let applied_tokens = block.tokens().unwrap();
assert_eq!(applied_tokens, full_block.tokens());
let mut non_reset_block = create_reset_block();
non_reset_block.init_sequence(SALT_HASH).unwrap(); assert!(
non_reset_block.apply_token_block(full_block).is_err(),
"Apply block to non-reset state"
);
}
#[test]
fn test_block_state_incomplete_commit() {
let mut partial_block = create_reset_block();
partial_block.init_sequence(SALT_HASH).unwrap();
partial_block.add_token(1).unwrap();
partial_block.add_tokens(Tokens::from(vec![2, 3])).unwrap();
assert_eq!(partial_block.len(), 3);
assert!(
partial_block.commit().is_err(),
"Commit on incomplete Partial block"
);
}
#[test]
fn test_error_types() {
let mut block = create_reset_block();
block.init_sequence(SALT_HASH).unwrap();
block.add_tokens(Tokens::from(vec![1, 2, 3, 4])).unwrap();
let append_err = block.add_token(5).unwrap_err();
assert!(append_err.is::<TokenBlockError>());
assert_eq!(
*append_err.downcast_ref::<TokenBlockError>().unwrap(),
TokenBlockError::Full
);
let new_tokens = Tokens::from(vec![5]);
let ret_tokens = block.add_tokens(new_tokens.clone()).unwrap();
assert_eq!(new_tokens, ret_tokens);
block.commit().unwrap();
let commit_err = block.commit().unwrap_err();
assert!(commit_err.is::<BlockStateInvalid>());
block.reset();
block.init_sequence(SALT_HASH).unwrap();
let pop_err = block.pop_token().unwrap_err();
assert!(pop_err.is::<TokenBlockError>());
assert_eq!(
*pop_err.downcast_ref::<TokenBlockError>().unwrap(),
TokenBlockError::Empty
);
let pop_tokens_err = block.pop_tokens(1).unwrap_err();
assert!(pop_tokens_err.is::<TokenBlockError>());
assert_eq!(
*pop_tokens_err.downcast_ref::<TokenBlockError>().unwrap(),
TokenBlockError::InsufficientTokens
);
block.add_token(1).unwrap();
let commit_incomplete_err = block.commit().unwrap_err();
assert!(commit_incomplete_err.is::<TokenBlockError>());
assert_eq!(
*commit_incomplete_err
.downcast_ref::<TokenBlockError>()
.unwrap(),
TokenBlockError::Incomplete
);
}
}