use std::fmt::Display;
use crate::shared::AtomicKind;
use super::{Dialect, Elem};
#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
pub struct Item<D: Dialect> {
pub(crate) elem: Elem<D>,
pub(crate) vectorization: usize,
pub(crate) native: bool,
}
impl<D: Dialect> Display for Item<D> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
D::compile_item(f, self)
}
}
impl<D: Dialect> Item<D> {
pub fn elem(&self) -> &Elem<D> {
&self.elem
}
pub const fn size(&self) -> usize {
self.elem.size() * self.vectorization
}
pub fn new(elem: Elem<D>, vectorization: usize, native: bool) -> Self {
Self {
elem,
vectorization,
native,
}
}
pub fn scalar(elem: Elem<D>, native: bool) -> Self {
Self {
elem,
vectorization: 1,
native,
}
}
pub fn can_be_optimized(&self) -> bool {
D::item_can_be_optimized()
}
pub fn is_optimized(&self) -> bool {
matches!(
self.elem,
Elem::F16x2
| Elem::BF16x2
| Elem::Atomic(AtomicKind::F16x2)
| Elem::Atomic(AtomicKind::BF16x2)
| Elem::FP4x2(_)
| Elem::FP6x2(_)
| Elem::FP8x2(_)
)
}
pub fn is_atomic(&self) -> bool {
matches!(self.elem, Elem::Atomic(_))
}
pub fn optimized(&self) -> Item<D> {
if !self.can_be_optimized() || !self.vectorization.is_multiple_of(2) {
return *self;
}
match self.elem {
Elem::F16 => Item {
elem: Elem::F16x2,
vectorization: self.vectorization / 2,
native: self.native,
},
Elem::Atomic(AtomicKind::F16) => Item {
elem: Elem::Atomic(AtomicKind::F16x2),
vectorization: self.vectorization / 2,
native: self.native,
},
Elem::BF16 => Item {
elem: Elem::BF16x2,
vectorization: self.vectorization / 2,
native: self.native,
},
Elem::Atomic(AtomicKind::BF16) => Item {
elem: Elem::Atomic(AtomicKind::BF16x2),
vectorization: self.vectorization / 2,
native: self.native,
},
Elem::FP4(kind) => Item {
elem: Elem::FP4x2(kind),
vectorization: self.vectorization / 2,
native: self.native,
},
Elem::FP6(kind) => Item {
elem: Elem::FP6x2(kind),
vectorization: self.vectorization / 2,
native: self.native,
},
Elem::FP8(kind) => Item {
elem: Elem::FP8x2(kind),
vectorization: self.vectorization / 2,
native: self.native,
},
_ => *self,
}
}
pub fn packing_factor(&self) -> usize {
self.elem.packing_factor()
}
pub fn de_optimized(&self) -> Self {
match self.elem {
Elem::FP4x2(kind) => Item::new(Elem::FP4(kind), self.vectorization * 2, self.native),
Elem::FP6x2(kind) => Item::new(Elem::FP6(kind), self.vectorization * 2, self.native),
Elem::FP8x2(kind) => Item::new(Elem::FP8(kind), self.vectorization * 2, self.native),
Elem::F16x2 => Item::new(Elem::F16, self.vectorization * 2, self.native),
Elem::BF16x2 => Item::new(Elem::BF16, self.vectorization * 2, self.native),
_ => *self,
}
}
}