use super::{
BlockChecksum, FillPattern, PhysicalLayout, StorageKind, compute_block_checksums,
fill_blocks, transfer_blocks,
};
use super::context::TransferContext;
use anyhow::{Result, anyhow};
use std::collections::HashMap;
#[derive(Debug)]
pub struct RoundTripTestResult {
pub source_checksums: HashMap<usize, BlockChecksum>,
pub dest_checksums: HashMap<usize, BlockChecksum>,
pub block_mapping: Vec<(usize, usize)>,
pub success: bool,
pub mismatches: Vec<(usize, usize)>, }
impl RoundTripTestResult {
pub fn is_success(&self) -> bool {
self.success
}
pub fn num_blocks(&self) -> usize {
self.block_mapping.len()
}
pub fn report(&self) -> String {
if self.success {
format!(
"Round-trip test PASSED: {}/{} blocks verified successfully",
self.num_blocks(),
self.num_blocks()
)
} else {
format!(
"Round-trip test FAILED: {}/{} blocks mismatched\nMismatches: {:?}",
self.mismatches.len(),
self.num_blocks(),
self.mismatches
)
}
}
}
pub struct RoundTripTest {
source: PhysicalLayout,
intermediate: PhysicalLayout,
destination: PhysicalLayout,
block_mapping: Vec<(usize, usize, usize)>,
fill_pattern: FillPattern,
}
impl RoundTripTest {
pub fn new(
source: PhysicalLayout,
intermediate: PhysicalLayout,
destination: PhysicalLayout,
) -> Result<Self> {
if source.is_remote() {
return Err(anyhow!("Source layout must be local"));
}
if destination.is_remote() {
return Err(anyhow!("Destination layout must be local"));
}
Ok(Self {
source,
intermediate,
destination,
block_mapping: Vec::new(),
fill_pattern: FillPattern::Sequential,
})
}
pub fn with_fill_pattern(mut self, pattern: FillPattern) -> Self {
self.fill_pattern = pattern;
self
}
pub fn add_block_mapping(
mut self,
src_id: usize,
intermediate_id: usize,
dst_id: usize,
) -> Self {
self.block_mapping.push((src_id, intermediate_id, dst_id));
self
}
pub fn with_block_mappings(mut self, mappings: &[(usize, usize, usize)]) -> Self {
self.block_mapping.extend_from_slice(mappings);
self
}
pub async fn run(self, ctx: &TransferContext) -> Result<RoundTripTestResult> {
if self.block_mapping.is_empty() {
return Err(anyhow!("No block mappings specified"));
}
let src_ids: Vec<usize> = self.block_mapping.iter().map(|(src, _, _)| *src).collect();
fill_blocks(&self.source, &src_ids, self.fill_pattern)?;
let source_checksums = compute_block_checksums(&self.source, &src_ids)?;
let src_ids_intermediate: Vec<usize> =
self.block_mapping.iter().map(|(src, _, _)| *src).collect();
let inter_ids_from_src: Vec<usize> = self
.block_mapping
.iter()
.map(|(_, inter, _)| *inter)
.collect();
let notification = transfer_blocks(
&self.source,
&self.intermediate,
&src_ids_intermediate,
&inter_ids_from_src,
ctx,
)?;
notification.await?;
let inter_ids_to_dst: Vec<usize> = self
.block_mapping
.iter()
.map(|(_, inter, _)| *inter)
.collect();
let dst_ids_from_inter: Vec<usize> =
self.block_mapping.iter().map(|(_, _, dst)| *dst).collect();
let notification = transfer_blocks(
&self.intermediate,
&self.destination,
&inter_ids_to_dst,
&dst_ids_from_inter,
ctx,
)?;
notification.await?;
let dst_ids: Vec<usize> = self.block_mapping.iter().map(|(_, _, dst)| *dst).collect();
let dest_checksums = compute_block_checksums(&self.destination, &dst_ids)?;
let mut mismatches = Vec::new();
for (src_id, _, dst_id) in &self.block_mapping {
let src_checksum = &source_checksums[src_id];
let dst_checksum = &dest_checksums[dst_id];
if src_checksum != dst_checksum {
mismatches.push((*src_id, *dst_id));
}
}
let success = mismatches.is_empty();
let block_mapping: Vec<(usize, usize)> = self
.block_mapping
.iter()
.map(|(src, _, dst)| (*src, *dst))
.collect();
Ok(RoundTripTestResult {
source_checksums,
dest_checksums,
block_mapping,
success,
mismatches,
})
}
}
#[cfg(test, features = "testing-cuda")]
mod tests {
use super::*;
use crate::block_manager::v2::layout::{
FullyContiguousLayout, Layout, LayoutConfig, MemoryRegion, OwnedMemoryRegion,
};
use std::sync::Arc;
fn create_test_context() -> TransferContext {
todo!("Create test context - requires CUDA/NIXL setup")
}
#[tokio::test]
async fn test_round_trip_host_to_host() {
let (src_layout, _src_mem) = create_test_layout(4);
let (inter_layout, _inter_mem) = create_test_layout(4);
let (dst_layout, _dst_mem) = create_test_layout(4);
let source = PhysicalLayout::new_local(src_layout, StorageKind::System);
let intermediate = PhysicalLayout::new_local(inter_layout, StorageKind::Pinned);
let destination = PhysicalLayout::new_local(dst_layout, StorageKind::System);
let test = RoundTripTest::new(source, intermediate, destination)
.unwrap()
.with_fill_pattern(FillPattern::Sequential)
.add_block_mapping(0, 0, 0)
.add_block_mapping(1, 1, 1)
.add_block_mapping(2, 2, 2)
.add_block_mapping(3, 3, 3);
let ctx = create_test_context();
let result = test.run(&ctx).await.unwrap();
assert!(result.is_success(), "{}", result.report());
assert_eq!(result.num_blocks(), 4);
}
#[tokio::test]
async fn test_round_trip_different_block_ids() {
let (src_layout, _src_mem) = create_test_layout(8);
let (inter_layout, _inter_mem) = create_test_layout(8);
let (dst_layout, _dst_mem) = create_test_layout(8);
let source = PhysicalLayout::new_local(src_layout, StorageKind::System);
let intermediate = PhysicalLayout::new_local(inter_layout, StorageKind::Pinned);
let destination = PhysicalLayout::new_local(dst_layout, StorageKind::System);
let test = RoundTripTest::new(source, intermediate, destination)
.unwrap()
.with_fill_pattern(FillPattern::BlockBased)
.with_block_mappings(&[(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)]);
let ctx = create_test_context();
let result = test.run(&ctx).await.unwrap();
assert!(result.is_success(), "{}", result.report());
assert_eq!(result.num_blocks(), 4);
}
#[test]
fn test_round_trip_builder() {
let (src_layout, _) = create_test_layout(4);
let (inter_layout, _) = create_test_layout(4);
let (dst_layout, _) = create_test_layout(4);
let source = PhysicalLayout::new_local(src_layout, StorageKind::System);
let intermediate = PhysicalLayout::new_local(inter_layout, StorageKind::Pinned);
let destination = PhysicalLayout::new_local(dst_layout, StorageKind::System);
let test = RoundTripTest::new(source, intermediate, destination)
.unwrap()
.with_fill_pattern(FillPattern::Constant(42))
.add_block_mapping(0, 0, 1)
.add_block_mapping(1, 1, 2);
assert_eq!(test.block_mapping.len(), 2);
}
#[test]
fn test_round_trip_requires_local_source() {
let (src_layout, _) = create_test_layout(1);
let (inter_layout, _) = create_test_layout(1);
let (dst_layout, _) = create_test_layout(1);
let source =
PhysicalLayout::new_remote(src_layout, StorageKind::System, "remote".to_string());
let intermediate = PhysicalLayout::new_local(inter_layout, StorageKind::Pinned);
let destination = PhysicalLayout::new_local(dst_layout, StorageKind::System);
let result = RoundTripTest::new(source, intermediate, destination);
assert!(result.is_err());
}
#[test]
fn test_round_trip_requires_local_destination() {
let (src_layout, _) = create_test_layout(1);
let (inter_layout, _) = create_test_layout(1);
let (dst_layout, _) = create_test_layout(1);
let source = PhysicalLayout::new_local(src_layout, StorageKind::System);
let intermediate = PhysicalLayout::new_local(inter_layout, StorageKind::Pinned);
let destination =
PhysicalLayout::new_remote(dst_layout, StorageKind::System, "remote".to_string());
let result = RoundTripTest::new(source, intermediate, destination);
assert!(result.is_err());
}
}