use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct TcbGeometry {
pub m: u32,
pub n: u32,
pub k: u32,
pub alignment: u32,
}
impl TcbGeometry {
#[must_use]
pub fn new(m: u32, n: u32, k: u32) -> Self {
assert!(m > 0 && n > 0 && k > 0, "TCB dimensions must be non-zero");
Self {
m,
n,
k,
alignment: 16, }
}
#[must_use]
pub fn with_alignment(m: u32, n: u32, k: u32, alignment: u32) -> Self {
assert!(m > 0 && n > 0 && k > 0, "TCB dimensions must be non-zero");
assert!(alignment.is_power_of_two(), "Alignment must be power of 2");
Self { m, n, k, alignment }
}
#[must_use]
pub fn arithmetic_intensity(&self) -> f32 {
let flops = 2.0 * self.m as f64 * self.n as f64 * self.k as f64;
let bytes = (self.m as f64 * self.k as f64 + self.k as f64 * self.n as f64) * 4.0;
(flops / bytes) as f32
}
#[must_use]
pub fn total_elements(&self) -> u64 {
self.m as u64 * self.n as u64
}
#[must_use]
pub fn total_flops(&self) -> u64 {
2 * self.m as u64 * self.n as u64 * self.k as u64
}
#[must_use]
pub fn is_q4k_aligned(&self) -> bool {
self.k % 256 == 0
}
#[must_use]
pub fn is_q4_0_aligned(&self) -> bool {
self.k % 32 == 0
}
#[must_use]
pub fn a_tile_bytes(&self) -> usize {
self.m as usize * self.k as usize * 4
}
#[must_use]
pub fn b_tile_bytes(&self) -> usize {
self.k as usize * self.n as usize * 4
}
#[must_use]
pub fn c_tile_bytes(&self) -> usize {
self.m as usize * self.n as usize * 4
}
#[must_use]
pub fn fits_in_cache(&self, cache_bytes: usize) -> bool {
self.a_tile_bytes() + self.b_tile_bytes() <= cache_bytes
}
}
impl Default for TcbGeometry {
fn default() -> Self {
Self { m: 4, n: 4, k: 4, alignment: 16 }
}
}
impl fmt::Display for TcbGeometry {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"TCB({}×{}×{}, align={}, AI={:.2})",
self.m,
self.n,
self.k,
self.alignment,
self.arithmetic_intensity()
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TcbLevel {
Macro,
Midi,
Micro,
}
impl TcbLevel {
#[must_use]
pub fn typical_cache_bytes(&self) -> usize {
match self {
TcbLevel::Macro => 32 * 1024 * 1024, TcbLevel::Midi => 256 * 1024, TcbLevel::Micro => 32 * 1024, }
}
}