cubecl_ir/
runtime_properties.rs1use crate::{Matrix, MatrixIdent, MatrixLayout, TypeHash};
2
3#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
7#[derive(Debug, Clone, PartialEq, Eq, TypeHash, Default)]
8pub struct TargetProperties {
9 pub mma: MmaProperties,
10}
11
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13#[derive(Debug, Clone, PartialEq, Eq, TypeHash)]
14pub struct MmaProperties {
15 pub register_size_bits: u32,
17 pub const_plane_size: u32,
19 pub register_layout_a: MatrixLayout,
21 pub register_layout_b: MatrixLayout,
23 pub register_layout_acc: MatrixLayout,
25
26 pub register_duplication_a: u32,
28 pub register_duplication_b: u32,
30 pub register_duplication_acc: u32,
32 #[cfg_attr(feature = "serde", serde(skip))]
33 pub contiguous_elements: ContiguousElements,
34}
35
36#[derive(Clone)]
37pub struct ContiguousElements {
38 inner: alloc::rc::Rc<dyn Fn(MatrixIdent, Matrix) -> u32>,
39}
40
41impl ContiguousElements {
42 pub fn new(func: impl Fn(MatrixIdent, Matrix) -> u32 + 'static) -> Self {
43 Self {
44 inner: alloc::rc::Rc::new(func),
45 }
46 }
47
48 pub fn apply(&self, ident: MatrixIdent, matrix: Matrix) -> u32 {
49 (self.inner)(ident, matrix)
50 }
51}
52
53impl Default for ContiguousElements {
54 fn default() -> Self {
55 Self {
56 inner: alloc::rc::Rc::new(|_, _| 2),
57 }
58 }
59}
60
61impl core::fmt::Debug for ContiguousElements {
62 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
63 f.debug_struct("ContiguousElements").finish()
64 }
65}
66
67impl Eq for ContiguousElements {}
68impl PartialEq for ContiguousElements {
69 fn eq(&self, other: &Self) -> bool {
70 alloc::rc::Rc::ptr_eq(&self.inner, &other.inner)
71 }
72}
73
74impl TypeHash for ContiguousElements {
75 fn write_hash(hasher: &mut impl core::hash::Hasher) {
76 hasher.write_i32(0);
77 }
78}
79
80impl Default for MmaProperties {
81 fn default() -> Self {
82 Self {
83 register_size_bits: 32,
84 const_plane_size: 32,
85 register_layout_a: MatrixLayout::RowMajor,
86 register_layout_b: MatrixLayout::ColMajor,
87 register_layout_acc: MatrixLayout::RowMajor,
88 register_duplication_a: 1,
89 register_duplication_b: 1,
90 register_duplication_acc: 1,
91 contiguous_elements: Default::default(),
92 }
93 }
94}