use super::*;
use futures::future::try_join_all;
use nixl_sys::NixlDescriptor;
use utils::*;
use zmq::*;
use BlockTransferPool::*;
use crate::block_manager::{
BasicMetadata, Storage,
block::{
Block, BlockDataProvider, BlockDataProviderMut, ReadableBlock, WritableBlock,
data::local::LocalBlockData,
locality,
transfer::{TransferContext, WriteTo, WriteToStrategy},
},
connector::scheduler::{SchedulingDecision, TransferSchedulerClient},
offload::MAX_TRANSFER_BATCH_SIZE,
storage::{DeviceStorage, DiskStorage, Local, PinnedStorage},
};
use anyhow::Result;
use async_trait::async_trait;
use std::{any::Any, sync::Arc};
type LocalBlock<S, M> = Block<S, locality::Local, M>;
type LocalBlockDataList<S> = Vec<LocalBlockData<S>>;
#[derive(Clone, Debug)]
pub struct ConnectorTransferBatcher {
max_batch_size: usize,
}
impl ConnectorTransferBatcher {
pub fn new() -> Self {
Self {
max_batch_size: MAX_TRANSFER_BATCH_SIZE,
}
}
pub async fn execute_batched_transfer(
&self,
handler: &BlockTransferHandler,
request: BlockTransferRequest,
) -> Result<()> {
let blocks = request.blocks();
let num_blocks = blocks.len();
if num_blocks <= self.max_batch_size {
return handler.execute_transfer_direct(request).await;
}
let batches = blocks.chunks(self.max_batch_size);
let batch_futures: Vec<_> = batches
.map(|batch| {
let batch_request = BlockTransferRequest {
from_pool: *request.from_pool(),
to_pool: *request.to_pool(),
blocks: batch.to_vec(),
connector_req: None,
};
handler.execute_transfer_direct(batch_request)
})
.collect();
tracing::debug!("Executing {} batches concurrently", batch_futures.len());
match try_join_all(batch_futures).await {
Ok(_) => Ok(()),
Err(e) => {
tracing::error!("Batched connector transfer failed: {}", e);
Err(e)
}
}
}
}
#[derive(Clone)]
pub struct BlockTransferHandler {
device: Option<LocalBlockDataList<DeviceStorage>>,
host: Option<LocalBlockDataList<PinnedStorage>>,
disk: Option<LocalBlockDataList<DiskStorage>>,
context: Arc<TransferContext>,
scheduler_client: Option<TransferSchedulerClient>,
batcher: ConnectorTransferBatcher,
}
impl BlockTransferHandler {
pub fn new(
device_blocks: Option<Vec<LocalBlock<DeviceStorage, BasicMetadata>>>,
host_blocks: Option<Vec<LocalBlock<PinnedStorage, BasicMetadata>>>,
disk_blocks: Option<Vec<LocalBlock<DiskStorage, BasicMetadata>>>,
context: Arc<TransferContext>,
scheduler_client: Option<TransferSchedulerClient>,
) -> Result<Self> {
Ok(Self {
device: Self::get_local_data(device_blocks),
host: Self::get_local_data(host_blocks),
disk: Self::get_local_data(disk_blocks),
context,
scheduler_client,
batcher: ConnectorTransferBatcher::new(),
})
}
fn get_local_data<S: Storage>(
blocks: Option<Vec<LocalBlock<S, BasicMetadata>>>,
) -> Option<LocalBlockDataList<S>> {
blocks.map(|blocks| {
blocks
.into_iter()
.map(|b| {
let block_data = b.block_data() as &dyn Any;
block_data
.downcast_ref::<LocalBlockData<S>>()
.unwrap()
.clone()
})
.collect()
})
}
async fn begin_transfer<Source, Target>(
&self,
source_pool_list: &Option<LocalBlockDataList<Source>>,
target_pool_list: &Option<LocalBlockDataList<Target>>,
request: BlockTransferRequest,
) -> Result<tokio::sync::oneshot::Receiver<()>>
where
Source: Storage + NixlDescriptor,
Target: Storage + NixlDescriptor,
LocalBlockData<Source>:
ReadableBlock<StorageType = Source> + Local + WriteToStrategy<LocalBlockData<Target>>,
LocalBlockData<Target>: WritableBlock<StorageType = Target>,
LocalBlockData<Source>: BlockDataProvider<Locality = locality::Local>,
LocalBlockData<Target>: BlockDataProviderMut<Locality = locality::Local>,
{
let Some(source_pool_list) = source_pool_list else {
return Err(anyhow::anyhow!("Source pool manager not initialized"));
};
let Some(target_pool_list) = target_pool_list else {
return Err(anyhow::anyhow!("Target pool manager not initialized"));
};
let source_idxs = request.blocks().iter().map(|(from, _)| *from);
let target_idxs = request.blocks().iter().map(|(_, to)| *to);
let sources: Vec<LocalBlockData<Source>> = source_idxs
.map(|idx| source_pool_list[idx].clone())
.collect();
let mut targets: Vec<LocalBlockData<Target>> = target_idxs
.map(|idx| target_pool_list[idx].clone())
.collect();
match sources.write_to(&mut targets, self.context.clone()) {
Ok(channel) => Ok(channel),
Err(e) => {
tracing::error!("Failed to write to blocks: {:?}", e);
Err(e.into())
}
}
}
pub async fn execute_transfer(&self, request: BlockTransferRequest) -> Result<()> {
self.batcher.execute_batched_transfer(self, request).await
}
pub async fn execute_transfer_direct(&self, request: BlockTransferRequest) -> Result<()> {
tracing::debug!(
"Performing transfer of {} blocks from {:?} to {:?}",
request.blocks().len(),
request.from_pool(),
request.to_pool()
);
tracing::debug!("request: {request:#?}");
let notify = match (request.from_pool(), request.to_pool()) {
(Device, Host) => self.begin_transfer(&self.device, &self.host, request).await,
(Device, Disk) => self.begin_transfer(&self.device, &self.disk, request).await,
(Host, Device) => self.begin_transfer(&self.host, &self.device, request).await,
(Host, Disk) => self.begin_transfer(&self.host, &self.disk, request).await,
(Disk, Device) => self.begin_transfer(&self.disk, &self.device, request).await,
_ => {
return Err(anyhow::anyhow!("Invalid transfer type."));
}
}?;
notify.await?;
Ok(())
}
}
#[async_trait]
impl Handler for BlockTransferHandler {
async fn handle(&self, mut message: MessageHandle) -> Result<()> {
if message.data.len() != 1 {
return Err(anyhow::anyhow!(
"Block transfer request must have exactly one data element"
));
}
let mut request: BlockTransferRequest = serde_json::from_slice(&message.data[0])?;
let result = if let Some(req) = request.connector_req.take() {
let operation_id = req.uuid;
tracing::debug!(
request_id = %req.request_id,
operation_id = %operation_id,
"scheduling transfer"
);
let client = self
.scheduler_client
.as_ref()
.expect("scheduler client is required")
.clone();
let handle = client.schedule_transfer(req).await?;
assert_eq!(handle.scheduler_decision(), SchedulingDecision::Execute);
match self.execute_transfer(request).await {
Ok(_) => {
handle.mark_complete(Ok(())).await;
Ok(())
}
Err(e) => {
handle.mark_complete(Err(anyhow::anyhow!("{}", e))).await;
Err(e)
}
}
} else {
self.execute_transfer(request).await
};
message.ack().await?;
result
}
}