mod sealed {
pub trait Sealed {}
}
mod sealed_indexed {
pub trait Sealed {}
}
mod sealed_byte {
pub trait Sealed {}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(i32)]
pub enum DatatypeTag {
F32 = 0,
F64 = 1,
I32 = 2,
I64 = 3,
U8 = 4,
U32 = 5,
U64 = 6,
FloatInt = 7,
DoubleInt = 8,
LongInt = 9,
Int2 = 10,
ShortInt = 11,
LongDoubleInt = 12,
Byte = 13,
}
pub trait MpiDatatype: sealed::Sealed + Copy + Send + 'static {
const TAG: DatatypeTag;
}
macro_rules! impl_mpi_datatype {
($ty:ty, $tag:expr) => {
impl sealed::Sealed for $ty {}
impl MpiDatatype for $ty {
const TAG: DatatypeTag = $tag;
}
};
}
impl_mpi_datatype!(f32, DatatypeTag::F32);
impl_mpi_datatype!(f64, DatatypeTag::F64);
impl_mpi_datatype!(i32, DatatypeTag::I32);
impl_mpi_datatype!(i64, DatatypeTag::I64);
impl_mpi_datatype!(u8, DatatypeTag::U8);
impl_mpi_datatype!(u32, DatatypeTag::U32);
impl_mpi_datatype!(u64, DatatypeTag::U64);
pub trait MpiIndexedDatatype: sealed_indexed::Sealed + Copy + Send + 'static {
const TAG: DatatypeTag;
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(C)]
pub struct FloatInt {
pub value: f32,
pub index: i32,
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(C)]
pub struct DoubleInt {
pub value: f64,
pub index: i32,
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(C)]
pub struct LongInt {
pub value: i64,
pub index: i32,
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(C)]
pub struct Int2 {
pub value: i32,
pub index: i32,
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(C)]
pub struct ShortInt {
pub value: i16,
pub index: i32,
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(C, align(16))]
pub struct LongDoubleInt {
pub value: [u8; 16],
pub index: i32,
}
macro_rules! impl_mpi_indexed_datatype {
($ty:ty, $tag:expr) => {
impl sealed_indexed::Sealed for $ty {}
impl MpiIndexedDatatype for $ty {
const TAG: DatatypeTag = $tag;
}
};
}
impl_mpi_indexed_datatype!(FloatInt, DatatypeTag::FloatInt);
impl_mpi_indexed_datatype!(DoubleInt, DatatypeTag::DoubleInt);
impl_mpi_indexed_datatype!(LongInt, DatatypeTag::LongInt);
impl_mpi_indexed_datatype!(Int2, DatatypeTag::Int2);
impl_mpi_indexed_datatype!(ShortInt, DatatypeTag::ShortInt);
impl_mpi_indexed_datatype!(LongDoubleInt, DatatypeTag::LongDoubleInt);
pub trait BytePermutable: sealed_byte::Sealed + Copy + Send + 'static {}
macro_rules! impl_byte_permutable {
($ty:ty) => {
impl sealed_byte::Sealed for $ty {}
impl BytePermutable for $ty {}
};
}
impl_byte_permutable!(u8);
impl_byte_permutable!(u16);
impl_byte_permutable!(u32);
impl_byte_permutable!(u64);
impl_byte_permutable!(i8);
impl_byte_permutable!(i16);
impl_byte_permutable!(i32);
impl_byte_permutable!(i64);
impl<T: BytePermutable, const N: usize> sealed_byte::Sealed for [T; N] {}
impl<T: BytePermutable, const N: usize> BytePermutable for [T; N] {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tag_values_match_c_defines() {
assert_eq!(DatatypeTag::F32 as i32, 0);
assert_eq!(DatatypeTag::F64 as i32, 1);
assert_eq!(DatatypeTag::I32 as i32, 2);
assert_eq!(DatatypeTag::I64 as i32, 3);
assert_eq!(DatatypeTag::U8 as i32, 4);
assert_eq!(DatatypeTag::U32 as i32, 5);
assert_eq!(DatatypeTag::U64 as i32, 6);
}
#[test]
fn datatype_tags_match_c_defines() {
assert_eq!(f32::TAG as i32, 0); assert_eq!(f64::TAG as i32, 1); assert_eq!(i32::TAG as i32, 2); assert_eq!(i64::TAG as i32, 3); assert_eq!(u8::TAG as i32, 4); assert_eq!(u32::TAG as i32, 5); assert_eq!(u64::TAG as i32, 6); }
#[test]
fn datatype_tag_values_are_sequential() {
let tags = [
DatatypeTag::F32,
DatatypeTag::F64,
DatatypeTag::I32,
DatatypeTag::I64,
DatatypeTag::U8,
DatatypeTag::U32,
DatatypeTag::U64,
];
for (i, tag) in tags.iter().enumerate() {
assert_eq!(*tag as i32, i as i32);
}
assert_eq!(DatatypeTag::Byte as i32, 13);
}
#[test]
fn trait_is_implemented() {
fn assert_mpi_datatype<T: MpiDatatype>() {}
assert_mpi_datatype::<f32>();
assert_mpi_datatype::<f64>();
assert_mpi_datatype::<i32>();
assert_mpi_datatype::<i64>();
assert_mpi_datatype::<u8>();
assert_mpi_datatype::<u32>();
assert_mpi_datatype::<u64>();
}
#[test]
fn datatype_tag_debug_format() {
assert_eq!(format!("{:?}", DatatypeTag::F32), "F32");
assert_eq!(format!("{:?}", DatatypeTag::F64), "F64");
assert_eq!(format!("{:?}", DatatypeTag::I32), "I32");
assert_eq!(format!("{:?}", DatatypeTag::I64), "I64");
assert_eq!(format!("{:?}", DatatypeTag::U8), "U8");
assert_eq!(format!("{:?}", DatatypeTag::U32), "U32");
assert_eq!(format!("{:?}", DatatypeTag::U64), "U64");
}
#[test]
fn datatype_tag_clone_hash() {
use std::collections::HashSet;
let tag = DatatypeTag::F64;
let cloned = tag;
assert_eq!(cloned, DatatypeTag::F64);
let mut set = HashSet::new();
set.insert(DatatypeTag::F32);
set.insert(DatatypeTag::F64);
set.insert(DatatypeTag::F32); assert_eq!(set.len(), 2);
}
#[test]
fn indexed_datatype_tags_match_c_defines() {
assert_eq!(DatatypeTag::FloatInt as i32, 7); assert_eq!(DatatypeTag::DoubleInt as i32, 8); assert_eq!(DatatypeTag::LongInt as i32, 9); assert_eq!(DatatypeTag::Int2 as i32, 10); assert_eq!(DatatypeTag::ShortInt as i32, 11); assert_eq!(DatatypeTag::LongDoubleInt as i32, 12);
assert_eq!(FloatInt::TAG as i32, 7);
assert_eq!(DoubleInt::TAG as i32, 8);
assert_eq!(LongInt::TAG as i32, 9);
assert_eq!(Int2::TAG as i32, 10);
assert_eq!(ShortInt::TAG as i32, 11);
assert_eq!(LongDoubleInt::TAG as i32, 12);
assert_eq!(DatatypeTag::Byte as i32, 13); }
#[test]
fn byte_datatype_tag_is_13() {
assert_eq!(DatatypeTag::Byte as i32, 13);
}
#[test]
fn byte_permutable_implemented_for_integer_primitives() {
fn assert_byte_permutable<T: BytePermutable>() {}
assert_byte_permutable::<u8>();
assert_byte_permutable::<u16>();
assert_byte_permutable::<u32>();
assert_byte_permutable::<u64>();
assert_byte_permutable::<i8>();
assert_byte_permutable::<i16>();
assert_byte_permutable::<i32>();
assert_byte_permutable::<i64>();
assert_byte_permutable::<[u64; 4]>();
}
#[test]
fn indexed_datatype_struct_layouts() {
use std::mem::{align_of, size_of};
assert_eq!(size_of::<FloatInt>(), 8, "FloatInt size");
assert_eq!(align_of::<FloatInt>(), 4, "FloatInt align");
assert!(
size_of::<DoubleInt>() >= 12,
"DoubleInt must hold at least f64 + i32"
);
assert!(
align_of::<DoubleInt>() >= 8,
"DoubleInt alignment must be at least f64 alignment"
);
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
assert_eq!(
size_of::<DoubleInt>(),
16,
"DoubleInt size on x86_64/aarch64"
);
assert_eq!(
align_of::<DoubleInt>(),
8,
"DoubleInt align on x86_64/aarch64"
);
}
assert!(
size_of::<LongInt>() >= 12,
"LongInt must hold at least i64 + i32"
);
assert!(
align_of::<LongInt>() >= 8,
"LongInt alignment must be at least i64 alignment"
);
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
assert_eq!(size_of::<LongInt>(), 16, "LongInt size on x86_64/aarch64");
assert_eq!(align_of::<LongInt>(), 8, "LongInt align on x86_64/aarch64");
}
assert_eq!(size_of::<Int2>(), 8, "Int2 size");
assert_eq!(align_of::<Int2>(), 4, "Int2 align");
assert!(
size_of::<ShortInt>() >= 6,
"ShortInt must hold at least i16 + i32"
);
assert!(
align_of::<ShortInt>() >= 4,
"ShortInt alignment must be at least i32 alignment"
);
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
assert_eq!(size_of::<ShortInt>(), 8, "ShortInt size on x86_64/aarch64");
assert_eq!(
align_of::<ShortInt>(),
4,
"ShortInt align on x86_64/aarch64"
);
}
assert!(
size_of::<LongDoubleInt>() >= 20,
"LongDoubleInt must hold at least [u8;16] + i32"
);
assert!(
align_of::<LongDoubleInt>() >= 1,
"LongDoubleInt must have at least 1-byte alignment"
);
#[cfg(target_arch = "x86_64")]
{
assert_eq!(
size_of::<LongDoubleInt>(),
32,
"LongDoubleInt size on x86_64"
);
assert_eq!(
align_of::<LongDoubleInt>(),
16,
"LongDoubleInt align on x86_64"
);
}
}
#[test]
fn indexed_datatype_trait_is_implemented() {
fn assert_indexed<T: MpiIndexedDatatype>() {}
assert_indexed::<FloatInt>();
assert_indexed::<DoubleInt>();
assert_indexed::<LongInt>();
assert_indexed::<Int2>();
assert_indexed::<ShortInt>();
assert_indexed::<LongDoubleInt>();
}
#[test]
fn indexed_and_primitive_traits_are_disjoint() {
fn assert_primitive<T: MpiDatatype>() {}
assert_primitive::<f64>();
assert_primitive::<i32>();
fn assert_indexed<T: MpiIndexedDatatype>() {}
assert_indexed::<DoubleInt>();
assert_indexed::<Int2>();
}
}