use super::*;
use crate::block_manager::block::transfer::{
TransferContext, TransferError, WriteToStrategy, handle_local_transfer,
};
use crate::block_manager::storage::{self, nixl::NixlDescriptor};
use std::any::Any;
use tokio::sync::oneshot;
pub trait LocalityProvider: Send + Sync + 'static + std::fmt::Debug {
type BlockData<S: Storage>: BlockDataExt<S>;
fn handle_transfer<RB, WB>(
_sources: &[RB],
_targets: &mut [WB],
_ctx: Arc<TransferContext>,
) -> Result<oneshot::Receiver<()>, TransferError>
where
RB: ReadableBlock + WriteToStrategy<WB> + storage::Local,
<RB as StorageTypeProvider>::StorageType: NixlDescriptor,
<WB as StorageTypeProvider>::StorageType: NixlDescriptor,
RB: BlockDataProvider<Locality = Self>,
WB: WritableBlock + BlockDataProviderMut<Locality = Self>;
}
#[derive(Debug)]
pub struct Local;
impl LocalityProvider for Local {
type BlockData<S: Storage> = BlockData<S>;
fn handle_transfer<RB, WB>(
sources: &[RB],
targets: &mut [WB],
ctx: Arc<TransferContext>,
) -> Result<oneshot::Receiver<()>, TransferError>
where
RB: ReadableBlock + WriteToStrategy<WB> + storage::Local,
<RB as StorageTypeProvider>::StorageType: NixlDescriptor,
<WB as StorageTypeProvider>::StorageType: NixlDescriptor,
RB: BlockDataProvider<Locality = Self>,
WB: WritableBlock + BlockDataProviderMut<Locality = Self>,
{
handle_local_transfer(sources, targets, ctx)
}
}
pub use crate::block_manager::block::data::logical::{LogicalBlockData, LogicalResources};
#[derive(Debug)]
pub struct Logical<R: LogicalResources> {
_resources: std::marker::PhantomData<R>,
}
impl<R: LogicalResources> Logical<R> {
fn load_resources<B: BlockDataProvider<Locality = Logical<R>>>(blocks: &[B]) -> Vec<Arc<R>> {
blocks
.iter()
.map(|block| {
let any_block = block.block_data() as &dyn Any;
let logical_block = any_block
.downcast_ref::<LogicalBlockData<<B as StorageTypeProvider>::StorageType, R>>()
.unwrap();
logical_block.resources()
})
.collect()
}
fn load_resources_mut<B: BlockDataProviderMut<Locality = Logical<R>>>(
blocks: &mut [B],
) -> Vec<Arc<R>> {
blocks
.iter_mut()
.map(|block| {
let any_block = block.block_data_mut() as &mut dyn Any;
let logical_block = any_block
.downcast_mut::<LogicalBlockData<<B as StorageTypeProvider>::StorageType, R>>()
.unwrap();
logical_block.resources()
})
.collect()
}
}
impl<R: LogicalResources> LocalityProvider for Logical<R> {
type BlockData<S: Storage> = LogicalBlockData<S, R>;
fn handle_transfer<RB, WB>(
sources: &[RB],
targets: &mut [WB],
ctx: Arc<TransferContext>,
) -> Result<oneshot::Receiver<()>, TransferError>
where
RB: ReadableBlock + WriteToStrategy<WB> + storage::Local,
<RB as StorageTypeProvider>::StorageType: NixlDescriptor,
<WB as StorageTypeProvider>::StorageType: NixlDescriptor,
RB: BlockDataProvider<Locality = Self>,
WB: WritableBlock + BlockDataProviderMut<Locality = Self>,
{
if sources.is_empty() && targets.is_empty() {
tracing::warn!(
"Logical::handle_transfer called with both sources and targets empty, skipping transfer"
);
let (tx, rx) = oneshot::channel();
tx.send(()).unwrap();
return Ok(rx);
}
if sources.len() != targets.len() {
return Err(TransferError::CountMismatch(sources.len(), targets.len()));
}
let source_resources = Self::load_resources(sources);
let target_resources = Self::load_resources_mut(targets);
let all_resources = source_resources
.into_iter()
.chain(target_resources)
.collect::<Vec<_>>();
if !all_resources
.iter()
.all(|r| Arc::ptr_eq(r, &all_resources[0]))
{
return Err(anyhow::anyhow!("Resources used in a transfer must be the same!").into());
}
let common_resource = all_resources[0].clone();
common_resource.handle_transfer(sources, targets, ctx)
}
}