use crate::{Matrix, MatrixIdent, MatrixLayout, TypeHash, VectorSize};
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, PartialEq, Eq, TypeHash, Default)]
pub struct TargetProperties {
pub mma: MmaProperties,
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
pub struct MmaProperties {
pub register_size_bits: usize,
pub const_plane_size: u32,
pub register_layout_a: MatrixLayout,
pub register_layout_b: MatrixLayout,
pub register_layout_acc: MatrixLayout,
pub register_duplication_a: usize,
pub register_duplication_b: usize,
pub register_duplication_acc: usize,
#[cfg_attr(feature = "serde", serde(skip))]
pub contiguous_elements: ContiguousElements,
}
#[derive(Clone)]
pub struct ContiguousElements {
inner: alloc::rc::Rc<dyn Fn(MatrixIdent, Matrix) -> VectorSize>,
}
impl ContiguousElements {
pub fn new(func: impl Fn(MatrixIdent, Matrix) -> VectorSize + 'static) -> Self {
Self {
inner: alloc::rc::Rc::new(func),
}
}
pub fn apply(&self, ident: MatrixIdent, matrix: Matrix) -> VectorSize {
(self.inner)(ident, matrix)
}
}
impl Default for ContiguousElements {
fn default() -> Self {
Self {
inner: alloc::rc::Rc::new(|_, _| 2),
}
}
}
impl core::fmt::Debug for ContiguousElements {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("ContiguousElements").finish()
}
}
impl Eq for ContiguousElements {}
impl PartialEq for ContiguousElements {
fn eq(&self, other: &Self) -> bool {
alloc::rc::Rc::ptr_eq(&self.inner, &other.inner)
}
}
impl TypeHash for ContiguousElements {
fn write_hash(hasher: &mut impl core::hash::Hasher) {
hasher.write_i32(0);
}
}
impl Default for MmaProperties {
fn default() -> Self {
Self {
register_size_bits: 32,
const_plane_size: 32,
register_layout_a: MatrixLayout::RowMajor,
register_layout_b: MatrixLayout::ColMajor,
register_layout_acc: MatrixLayout::RowMajor,
register_duplication_a: 1,
register_duplication_b: 1,
register_duplication_acc: 1,
contiguous_elements: Default::default(),
}
}
}