use crate::device::{Device, DeviceType};
use std::marker::PhantomData;
pub trait PhantomDevice: 'static + std::fmt::Debug + Send + Sync {
const DEVICE_TYPE: DeviceType;
fn device_type() -> DeviceType {
Self::DEVICE_TYPE
}
fn is_compatible<Other: PhantomDevice>() -> bool {
Self::DEVICE_TYPE == Other::DEVICE_TYPE
}
fn device_name() -> &'static str;
fn requires_gpu() -> bool {
!matches!(Self::DEVICE_TYPE, DeviceType::Cpu)
}
fn supports_p2p() -> bool {
matches!(Self::DEVICE_TYPE, DeviceType::Cuda(_))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PhantomCpu;
impl PhantomDevice for PhantomCpu {
const DEVICE_TYPE: DeviceType = DeviceType::Cpu;
fn device_name() -> &'static str {
"CPU"
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PhantomCuda<const INDEX: usize>;
impl<const INDEX: usize> PhantomDevice for PhantomCuda<INDEX> {
const DEVICE_TYPE: DeviceType = DeviceType::Cuda(INDEX);
fn device_name() -> &'static str {
"CUDA"
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PhantomMetal<const INDEX: usize>;
impl<const INDEX: usize> PhantomDevice for PhantomMetal<INDEX> {
const DEVICE_TYPE: DeviceType = DeviceType::Metal(INDEX);
fn device_name() -> &'static str {
"Metal"
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PhantomWgpu<const INDEX: usize>;
impl<const INDEX: usize> PhantomDevice for PhantomWgpu<INDEX> {
const DEVICE_TYPE: DeviceType = DeviceType::Wgpu(INDEX);
fn device_name() -> &'static str {
"WebGPU"
}
}
#[derive(Debug)]
pub struct DeviceHandle<P: PhantomDevice> {
device: Box<dyn Device>,
_phantom: PhantomData<P>,
}
impl<P: PhantomDevice> DeviceHandle<P> {
pub fn new(device: Box<dyn Device>) -> Result<Self, crate::error::TorshError> {
if device.device_type() != P::DEVICE_TYPE {
return Err(crate::error::TorshError::InvalidArgument(format!(
"Device type mismatch: expected {:?}, got {:?}",
P::DEVICE_TYPE,
device.device_type()
)));
}
Ok(Self {
device,
_phantom: PhantomData,
})
}
pub unsafe fn new_unchecked(device: Box<dyn Device>) -> Self {
Self {
device,
_phantom: PhantomData,
}
}
pub fn device(&self) -> &dyn Device {
self.device.as_ref()
}
pub fn device_mut(&mut self) -> &mut dyn Device {
self.device.as_mut()
}
pub fn phantom_device_type() -> DeviceType {
P::DEVICE_TYPE
}
pub const fn is_cpu() -> bool {
matches!(P::DEVICE_TYPE, DeviceType::Cpu)
}
pub const fn is_gpu() -> bool {
!matches!(P::DEVICE_TYPE, DeviceType::Cpu)
}
pub const fn is_cuda() -> bool {
matches!(P::DEVICE_TYPE, DeviceType::Cuda(_))
}
pub const fn is_metal() -> bool {
matches!(P::DEVICE_TYPE, DeviceType::Metal(_))
}
pub const fn is_wgpu() -> bool {
matches!(P::DEVICE_TYPE, DeviceType::Wgpu(_))
}
pub fn cast<Q: PhantomDevice>(
self,
) -> Result<DeviceHandle<Q>, (Self, crate::error::TorshError)> {
if self.device.device_type() != Q::DEVICE_TYPE {
let error = crate::error::TorshError::InvalidArgument(format!(
"Cannot cast device from {:?} to {:?}",
P::DEVICE_TYPE,
Q::DEVICE_TYPE
));
return Err((self, error));
}
Ok(DeviceHandle {
device: self.device,
_phantom: PhantomData,
})
}
pub unsafe fn cast_unchecked<Q: PhantomDevice>(self) -> DeviceHandle<Q> {
DeviceHandle {
device: self.device,
_phantom: PhantomData,
}
}
pub fn into_device(self) -> Box<dyn Device> {
self.device
}
}
impl<P: PhantomDevice> Clone for DeviceHandle<P> {
fn clone(&self) -> Self {
let cloned_device = self.device.clone_device().expect("Failed to clone device");
Self {
device: cloned_device,
_phantom: PhantomData,
}
}
}
pub trait DeviceCompatible<Other> {
const COMPATIBLE: bool;
fn compatibility_info() -> &'static str;
}
impl<P: PhantomDevice> DeviceCompatible<P> for P {
const COMPATIBLE: bool = true;
fn compatibility_info() -> &'static str {
"Same device type - always compatible"
}
}
pub trait DeviceOperation<P: PhantomDevice> {
type Output;
type Requirements: DeviceRequirements;
fn execute(device: &DeviceHandle<P>) -> Result<Self::Output, crate::error::TorshError>;
const SUPPORTED: bool = Self::Requirements::SATISFIED_BY_DEVICE;
}
pub trait DeviceRequirements {
const SATISFIED_BY_DEVICE: bool;
fn description() -> &'static str;
}
#[derive(Debug, Clone, Copy)]
pub struct RequiresGpu;
impl DeviceRequirements for RequiresGpu {
const SATISFIED_BY_DEVICE: bool = false;
fn description() -> &'static str {
"Requires GPU device"
}
}
#[derive(Debug, Clone, Copy)]
pub struct RequiresCpu;
impl DeviceRequirements for RequiresCpu {
const SATISFIED_BY_DEVICE: bool = false;
fn description() -> &'static str {
"Requires CPU device"
}
}
#[derive(Debug, Clone, Copy)]
pub struct RequiresCuda;
impl DeviceRequirements for RequiresCuda {
const SATISFIED_BY_DEVICE: bool = false;
fn description() -> &'static str {
"Requires CUDA device"
}
}
#[derive(Debug, Clone, Copy)]
pub struct NoRequirements;
impl DeviceRequirements for NoRequirements {
const SATISFIED_BY_DEVICE: bool = true;
fn description() -> &'static str {
"No specific device requirements"
}
}
#[derive(Debug)]
pub struct SameDevice<P1: PhantomDevice, P2: PhantomDevice> {
_phantom: PhantomData<(P1, P2)>,
}
impl<P1: PhantomDevice, P2: PhantomDevice> SameDevice<P1, P2> {
pub fn is_satisfied() -> bool {
match (P1::DEVICE_TYPE, P2::DEVICE_TYPE) {
(DeviceType::Cpu, DeviceType::Cpu) => true,
(DeviceType::Cuda(a), DeviceType::Cuda(b)) => a == b,
(DeviceType::Metal(a), DeviceType::Metal(b)) => a == b,
(DeviceType::Wgpu(a), DeviceType::Wgpu(b)) => a == b,
_ => false,
}
}
}
#[derive(Debug)]
pub struct TransferCompatible<P1: PhantomDevice, P2: PhantomDevice> {
_phantom: PhantomData<(P1, P2)>,
}
impl<P1: PhantomDevice, P2: PhantomDevice> TransferCompatible<P1, P2> {
pub const SUPPORTED: bool = true;
pub fn transfer_cost() -> u32 {
match (P1::DEVICE_TYPE, P2::DEVICE_TYPE) {
(DeviceType::Cpu, DeviceType::Cpu) => 0,
(DeviceType::Cuda(a), DeviceType::Cuda(b)) if a == b => 0,
(DeviceType::Metal(a), DeviceType::Metal(b)) if a == b => 0,
(DeviceType::Wgpu(a), DeviceType::Wgpu(b)) if a == b => 0,
(DeviceType::Cpu, DeviceType::Cuda(_)) => 100,
(DeviceType::Cuda(_), DeviceType::Cpu) => 100,
(DeviceType::Cpu, DeviceType::Metal(_)) => 80,
(DeviceType::Metal(_), DeviceType::Cpu) => 80,
_ => 200, }
}
}
#[derive(Debug)]
pub struct PhantomDeviceManager<P: PhantomDevice> {
handles: Vec<DeviceHandle<P>>,
_phantom: PhantomData<P>,
}
impl<P: PhantomDevice> PhantomDeviceManager<P> {
pub fn new() -> Self {
Self {
handles: Vec::new(),
_phantom: PhantomData,
}
}
pub fn add_device(&mut self, handle: DeviceHandle<P>) {
self.handles.push(handle);
}
pub fn device_count(&self) -> usize {
self.handles.len()
}
pub fn get_device(&self, index: usize) -> Option<&DeviceHandle<P>> {
self.handles.get(index)
}
pub fn get_device_mut(&mut self, index: usize) -> Option<&mut DeviceHandle<P>> {
self.handles.get_mut(index)
}
pub fn remove_device(&mut self, index: usize) -> Option<DeviceHandle<P>> {
if index < self.handles.len() {
Some(self.handles.remove(index))
} else {
None
}
}
pub fn devices(&self) -> &[DeviceHandle<P>] {
&self.handles
}
pub fn clear(&mut self) {
self.handles.clear();
}
pub fn execute_on_all<Op>(
&self,
_operation: Op,
) -> Vec<Result<Op::Output, crate::error::TorshError>>
where
Op: DeviceOperation<P> + Clone,
{
self.handles
.iter()
.map(|handle| Op::execute(handle))
.collect()
}
}
impl<P: PhantomDevice> Default for PhantomDeviceManager<P> {
fn default() -> Self {
Self::new()
}
}
pub mod utils {
use super::*;
pub fn create_phantom_handle<P: PhantomDevice>(
device: Box<dyn Device>,
) -> Result<DeviceHandle<P>, crate::error::TorshError> {
DeviceHandle::<P>::new(device)
}
pub fn check_phantom_compatibility<P1: PhantomDevice, P2: PhantomDevice>() -> bool {
P1::DEVICE_TYPE == P2::DEVICE_TYPE
}
pub fn phantom_transfer_cost<P1: PhantomDevice, P2: PhantomDevice>() -> u32 {
TransferCompatible::<P1, P2>::transfer_cost()
}
pub fn create_phantom_manager<P: PhantomDevice>() -> PhantomDeviceManager<P> {
PhantomDeviceManager::new()
}
pub fn verify_operation_support<P: PhantomDevice, Op: DeviceOperation<P>>() -> bool {
Op::SUPPORTED
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::core::Device;
use std::any::Any;
#[derive(Debug)]
struct MockDevice {
device_type: DeviceType,
}
impl MockDevice {
fn new(device_type: DeviceType) -> Self {
Self { device_type }
}
}
impl Device for MockDevice {
fn device_type(&self) -> DeviceType {
self.device_type
}
fn name(&self) -> &str {
"Mock Device"
}
fn is_available(&self) -> Result<bool, crate::error::TorshError> {
Ok(true)
}
fn capabilities(
&self,
) -> Result<crate::device::DeviceCapabilities, crate::error::TorshError> {
crate::device::DeviceCapabilities::detect(self.device_type)
}
fn synchronize(&self) -> Result<(), crate::error::TorshError> {
Ok(())
}
fn reset(&self) -> Result<(), crate::error::TorshError> {
Ok(())
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn clone_device(&self) -> Result<Box<dyn Device>, crate::error::TorshError> {
Ok(Box::new(MockDevice::new(self.device_type)))
}
}
#[test]
fn test_phantom_device_markers() {
assert_eq!(PhantomCpu::device_type(), DeviceType::Cpu);
assert_eq!(PhantomCuda::<0>::device_type(), DeviceType::Cuda(0));
assert_eq!(PhantomMetal::<1>::device_type(), DeviceType::Metal(1));
assert_eq!(PhantomWgpu::<2>::device_type(), DeviceType::Wgpu(2));
assert_eq!(PhantomCpu::device_name(), "CPU");
assert_eq!(PhantomCuda::<0>::device_name(), "CUDA");
assert_eq!(PhantomMetal::<0>::device_name(), "Metal");
assert_eq!(PhantomWgpu::<0>::device_name(), "WebGPU");
}
#[test]
fn test_phantom_device_properties() {
assert!(!PhantomCpu::requires_gpu());
assert!(PhantomCuda::<0>::requires_gpu());
assert!(PhantomMetal::<0>::requires_gpu());
assert!(PhantomWgpu::<0>::requires_gpu());
assert!(!PhantomCpu::supports_p2p());
assert!(PhantomCuda::<0>::supports_p2p());
assert!(!PhantomMetal::<0>::supports_p2p());
assert!(!PhantomWgpu::<0>::supports_p2p());
}
#[test]
fn test_device_handle() {
let mock_device = Box::new(MockDevice::new(DeviceType::Cpu));
let _handle =
DeviceHandle::<PhantomCpu>::new(mock_device).expect("DeviceHandle::new should succeed");
assert_eq!(
DeviceHandle::<PhantomCpu>::phantom_device_type(),
DeviceType::Cpu
);
assert!(DeviceHandle::<PhantomCpu>::is_cpu());
assert!(!DeviceHandle::<PhantomCpu>::is_gpu());
assert!(!DeviceHandle::<PhantomCpu>::is_cuda());
}
#[test]
fn test_device_handle_type_mismatch() {
let mock_device = Box::new(MockDevice::new(DeviceType::Cuda(0)));
let result = DeviceHandle::<PhantomCpu>::new(mock_device);
assert!(result.is_err());
}
#[test]
fn test_device_compatibility() {
assert!(PhantomCpu::is_compatible::<PhantomCpu>());
assert!(!PhantomCpu::is_compatible::<PhantomCuda<0>>());
assert!(PhantomCuda::<0>::is_compatible::<PhantomCuda<0>>());
assert!(!PhantomCuda::<0>::is_compatible::<PhantomCuda<1>>());
}
#[test]
fn test_phantom_device_manager() {
let mut manager = PhantomDeviceManager::<PhantomCpu>::new();
assert_eq!(manager.device_count(), 0);
let mock_device = Box::new(MockDevice::new(DeviceType::Cpu));
let handle =
DeviceHandle::<PhantomCpu>::new(mock_device).expect("DeviceHandle::new should succeed");
manager.add_device(handle);
assert_eq!(manager.device_count(), 1);
assert!(manager.get_device(0).is_some());
assert!(manager.get_device(1).is_none());
let removed = manager.remove_device(0);
assert!(removed.is_some());
assert_eq!(manager.device_count(), 0);
}
#[test]
fn test_transfer_cost_constants() {
assert_eq!(
TransferCompatible::<PhantomCpu, PhantomCpu>::transfer_cost(),
0
);
assert_eq!(
TransferCompatible::<PhantomCpu, PhantomCuda<0>>::transfer_cost(),
100
);
assert_eq!(
TransferCompatible::<PhantomCpu, PhantomMetal<0>>::transfer_cost(),
80
);
}
#[test]
fn test_utils_functions() {
assert!(utils::check_phantom_compatibility::<PhantomCpu, PhantomCpu>());
assert!(!utils::check_phantom_compatibility::<
PhantomCpu,
PhantomCuda<0>,
>());
let cost = utils::phantom_transfer_cost::<PhantomCpu, PhantomCuda<0>>();
assert_eq!(cost, 100);
let manager = utils::create_phantom_manager::<PhantomCpu>();
assert_eq!(manager.device_count(), 0);
}
}
#[derive(Debug)]
pub struct DeviceGroup<P: PhantomDevice, const N: usize> {
devices: [DeviceHandle<P>; N],
}
impl<P: PhantomDevice, const N: usize> DeviceGroup<P, N> {
pub fn new(devices: [DeviceHandle<P>; N]) -> Self {
Self { devices }
}
pub const fn device_count() -> usize {
N
}
pub fn get(&self, index: usize) -> Option<&DeviceHandle<P>> {
self.devices.get(index)
}
pub fn get_mut(&mut self, index: usize) -> Option<&mut DeviceHandle<P>> {
self.devices.get_mut(index)
}
pub fn iter(&self) -> impl Iterator<Item = &DeviceHandle<P>> {
self.devices.iter()
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut DeviceHandle<P>> {
self.devices.iter_mut()
}
pub fn parallel_execute<F, R>(&self, f: F) -> Vec<R>
where
F: Fn(&DeviceHandle<P>) -> R + Sync,
R: Send,
{
self.devices.iter().map(f).collect()
}
pub fn supports_p2p() -> bool {
P::supports_p2p()
}
pub fn group_type_name() -> String {
format!("DeviceGroup<{}, {}>", P::device_name(), N)
}
}
pub trait PeerToPeerOps<Other: PhantomDevice>: PhantomDevice {
const P2P_SUPPORTED: bool;
fn p2p_bandwidth() -> u32;
fn p2p_latency() -> u32;
}
impl<const I1: usize, const I2: usize> PeerToPeerOps<PhantomCuda<I2>> for PhantomCuda<I1> {
const P2P_SUPPORTED: bool = true;
fn p2p_bandwidth() -> u32 {
if I1 == I2 {
0 } else {
50_000 }
}
fn p2p_latency() -> u32 {
if I1 == I2 {
0
} else {
5 }
}
}
impl<const I1: usize, const I2: usize> PeerToPeerOps<PhantomMetal<I2>> for PhantomMetal<I1> {
const P2P_SUPPORTED: bool = I1 == I2;
fn p2p_bandwidth() -> u32 {
if I1 == I2 {
400_000 } else {
0
}
}
fn p2p_latency() -> u32 {
if I1 == I2 {
1 } else {
0
}
}
}
pub trait DeviceTopology {
const DEVICE_COUNT: usize;
const SUPPORTS_ALLREDUCE: bool;
const SUPPORTS_BROADCAST: bool;
fn topology_name() -> &'static str;
fn allreduce_bandwidth() -> u32;
}
#[derive(Debug)]
pub struct RingTopology<P: PhantomDevice, const N: usize> {
_phantom: PhantomData<P>,
}
impl<P: PhantomDevice, const N: usize> DeviceTopology for RingTopology<P, N> {
const DEVICE_COUNT: usize = N;
const SUPPORTS_ALLREDUCE: bool = true;
const SUPPORTS_BROADCAST: bool = true;
fn topology_name() -> &'static str {
"Ring"
}
fn allreduce_bandwidth() -> u32 {
if P::supports_p2p() {
25_000 } else {
5_000 }
}
}
#[derive(Debug)]
pub struct TreeTopology<P: PhantomDevice, const N: usize> {
_phantom: PhantomData<P>,
}
impl<P: PhantomDevice, const N: usize> DeviceTopology for TreeTopology<P, N> {
const DEVICE_COUNT: usize = N;
const SUPPORTS_ALLREDUCE: bool = true;
const SUPPORTS_BROADCAST: bool = true;
fn topology_name() -> &'static str {
"Tree"
}
fn allreduce_bandwidth() -> u32 {
if P::supports_p2p() {
40_000 } else {
8_000 }
}
}
#[derive(Debug)]
pub struct AllToAllTopology<P: PhantomDevice, const N: usize> {
_phantom: PhantomData<P>,
}
impl<P: PhantomDevice, const N: usize> DeviceTopology for AllToAllTopology<P, N> {
const DEVICE_COUNT: usize = N;
const SUPPORTS_ALLREDUCE: bool = true;
const SUPPORTS_BROADCAST: bool = true;
fn topology_name() -> &'static str {
"AllToAll"
}
fn allreduce_bandwidth() -> u32 {
if P::supports_p2p() {
100_000 } else {
15_000 }
}
}
#[derive(Debug)]
pub struct TypedDeviceAffinity<P: PhantomDevice> {
device_handle: DeviceHandle<P>,
preferred_numa_node: Option<usize>,
cpu_affinity: Option<Vec<usize>>,
}
impl<P: PhantomDevice> TypedDeviceAffinity<P> {
pub fn new(device_handle: DeviceHandle<P>) -> Self {
Self {
device_handle,
preferred_numa_node: None,
cpu_affinity: None,
}
}
pub fn with_numa_node(mut self, node: usize) -> Self {
self.preferred_numa_node = Some(node);
self
}
pub fn with_cpu_affinity(mut self, cpus: Vec<usize>) -> Self {
self.cpu_affinity = Some(cpus);
self
}
pub fn device(&self) -> &DeviceHandle<P> {
&self.device_handle
}
pub fn numa_node(&self) -> Option<usize> {
self.preferred_numa_node
}
pub fn cpu_affinity(&self) -> Option<&[usize]> {
self.cpu_affinity.as_deref()
}
pub const fn is_cpu_device() -> bool {
matches!(P::DEVICE_TYPE, DeviceType::Cpu)
}
pub fn locality_score(&self, target_numa: usize) -> u32 {
match self.preferred_numa_node {
Some(node) if node == target_numa => 100,
Some(_) => 30, None => 50, }
}
}
#[derive(Debug)]
pub struct CrossDeviceOp<PSrc: PhantomDevice, PDst: PhantomDevice> {
_phantom: PhantomData<(PSrc, PDst)>,
}
impl<PSrc: PhantomDevice, PDst: PhantomDevice> CrossDeviceOp<PSrc, PDst> {
pub const SUPPORTED: bool = TransferCompatible::<PSrc, PDst>::SUPPORTED;
pub fn transfer_cost() -> u32 {
TransferCompatible::<PSrc, PDst>::transfer_cost()
}
pub fn transfer_strategy() -> &'static str {
match (PSrc::DEVICE_TYPE, PDst::DEVICE_TYPE) {
(DeviceType::Cpu, DeviceType::Cpu) => "memcpy",
(DeviceType::Cuda(_), DeviceType::Cuda(_)) => "peer-to-peer",
(DeviceType::Cpu, DeviceType::Cuda(_)) => "pinned-transfer",
(DeviceType::Cuda(_), DeviceType::Cpu) => "staged-readback",
(DeviceType::Metal(_), DeviceType::Metal(_)) => "unified-memory",
(DeviceType::Cpu, DeviceType::Metal(_)) => "shared-memory",
_ => "staged-transfer",
}
}
pub const fn supports_zero_copy() -> bool {
matches!(
(PSrc::DEVICE_TYPE, PDst::DEVICE_TYPE),
(DeviceType::Metal(_), DeviceType::Metal(_))
)
}
}
pub mod compile_time {
use super::*;
pub fn assert_same_device<P1: PhantomDevice, P2: PhantomDevice>() {
if !SameDevice::<P1, P2>::is_satisfied() {
panic!("Device types must match");
}
}
pub fn assert_gpu<P: PhantomDevice>() {
if !P::requires_gpu() {
panic!("Operation requires GPU device");
}
}
pub fn assert_cpu<P: PhantomDevice>() {
if P::requires_gpu() {
panic!("Operation requires CPU device");
}
}
pub fn assert_p2p<P1, P2>()
where
P1: PhantomDevice + PeerToPeerOps<P2>,
P2: PhantomDevice,
{
if !P1::P2P_SUPPORTED {
panic!("P2P not supported between these device types");
}
}
}
#[cfg(test)]
mod advanced_tests {
use super::*;
use crate::device::implementations::CpuDevice;
#[test]
fn test_device_group() {
let cpu_device = Box::new(CpuDevice::new());
let handle1 =
DeviceHandle::<PhantomCpu>::new(cpu_device).expect("DeviceHandle::new should succeed");
let cpu_device2 = Box::new(CpuDevice::new());
let handle2 =
DeviceHandle::<PhantomCpu>::new(cpu_device2).expect("DeviceHandle::new should succeed");
let group = DeviceGroup::new([handle1, handle2]);
assert_eq!(DeviceGroup::<PhantomCpu, 2>::device_count(), 2);
assert!(group.get(0).is_some());
assert!(group.get(1).is_some());
assert!(group.get(2).is_none());
}
#[test]
fn test_p2p_cuda() {
assert!(PhantomCuda::<0>::supports_p2p());
assert!(<PhantomCuda<0> as PeerToPeerOps<PhantomCuda<1>>>::P2P_SUPPORTED);
let bandwidth = <PhantomCuda<0> as PeerToPeerOps<PhantomCuda<1>>>::p2p_bandwidth();
assert!(bandwidth > 0);
let latency = <PhantomCuda<0> as PeerToPeerOps<PhantomCuda<1>>>::p2p_latency();
assert!(latency > 0);
}
#[test]
fn test_device_topology() {
assert_eq!(RingTopology::<PhantomCuda<0>, 4>::DEVICE_COUNT, 4);
assert!(RingTopology::<PhantomCuda<0>, 4>::SUPPORTS_ALLREDUCE);
assert!(RingTopology::<PhantomCuda<0>, 4>::SUPPORTS_BROADCAST);
let bandwidth = RingTopology::<PhantomCuda<0>, 4>::allreduce_bandwidth();
assert!(bandwidth > 0);
assert_eq!(TreeTopology::<PhantomCuda<0>, 8>::DEVICE_COUNT, 8);
let tree_bandwidth = TreeTopology::<PhantomCuda<0>, 8>::allreduce_bandwidth();
assert!(tree_bandwidth > 0);
assert_eq!(AllToAllTopology::<PhantomCuda<0>, 4>::DEVICE_COUNT, 4);
let all2all_bandwidth = AllToAllTopology::<PhantomCuda<0>, 4>::allreduce_bandwidth();
assert!(all2all_bandwidth >= tree_bandwidth); }
#[test]
fn test_typed_device_affinity() {
let cpu_device = Box::new(CpuDevice::new());
let handle =
DeviceHandle::<PhantomCpu>::new(cpu_device).expect("DeviceHandle::new should succeed");
let affinity = TypedDeviceAffinity::new(handle)
.with_numa_node(0)
.with_cpu_affinity(vec![0, 1, 2, 3]);
assert_eq!(affinity.numa_node(), Some(0));
assert_eq!(affinity.cpu_affinity(), Some(&[0, 1, 2, 3][..]));
assert_eq!(affinity.locality_score(0), 100); assert_eq!(affinity.locality_score(1), 30); }
#[test]
fn test_cross_device_op() {
assert!(CrossDeviceOp::<PhantomCpu, PhantomCuda<0>>::SUPPORTED);
assert_eq!(
CrossDeviceOp::<PhantomCpu, PhantomCuda<0>>::transfer_cost(),
100
);
assert_eq!(
CrossDeviceOp::<PhantomCpu, PhantomCuda<0>>::transfer_strategy(),
"pinned-transfer"
);
assert!(CrossDeviceOp::<PhantomMetal<0>, PhantomMetal<0>>::supports_zero_copy());
assert_eq!(CrossDeviceOp::<PhantomCpu, PhantomCpu>::transfer_cost(), 0);
}
#[test]
fn test_compile_time_validation() {
assert_eq!(
CrossDeviceOp::<PhantomCpu, PhantomCpu>::transfer_strategy(),
"memcpy"
);
assert!(DeviceGroup::<PhantomCuda<0>, 4>::supports_p2p());
assert!(!DeviceGroup::<PhantomCpu, 4>::supports_p2p());
}
}