#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub struct DivHint {
pub divisor: i32,
pub max: i32,
}
impl Default for DivHint {
fn default() -> Self {
Self {
divisor: 1,
max: 16,
}
}
}
impl DivHint {
pub fn from_value(val: i32) -> Self {
let raw: i32 = max_pow2_divisor_unclamped(val);
Self {
divisor: raw.min(16),
max: 16,
}
}
pub fn from_ptr(ptr: u64) -> Self {
let raw: i32 = max_pow2_divisor_unclamped(ptr as i32);
Self {
divisor: raw.min(16),
max: 16,
}
}
pub fn with_max(self, max: i32) -> Self {
Self {
divisor: self.divisor.min(max),
max,
}
}
}
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct SpecializationBits {
pub shape_div: Vec<DivHint>,
pub stride_div: Vec<DivHint>,
pub stride_one: Vec<bool>,
pub base_ptr_div: DivHint,
pub elements_disjoint: bool,
}
fn max_pow2_divisor_unclamped(val: i32) -> i32 {
if val == 0 {
return 16;
}
val & val.wrapping_neg()
}
pub fn max_pow2_divisor(val: i32) -> i32 {
max_pow2_divisor_unclamped(val).min(16)
}
pub const INTEGER_SCALAR_TYPES: &[&str] = &["i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64"];
pub const FLOAT_SCALAR_TYPES: &[&str] = &["f16", "bf16", "f32", "f64"];
pub const OTHER_SCALAR_TYPES: &[&str] = &["bool"];
pub fn is_integer_scalar(ty: &str) -> bool {
INTEGER_SCALAR_TYPES.contains(&ty)
}
pub fn is_scalar(ty: &str) -> bool {
INTEGER_SCALAR_TYPES.contains(&ty)
|| FLOAT_SCALAR_TYPES.contains(&ty)
|| OTHER_SCALAR_TYPES.contains(&ty)
}
pub fn compute_spec(
base_ptr: u64,
shape: &[i32],
strides: &[i32],
_dtype_bytes: i32,
) -> SpecializationBits {
let ndim = shape.len();
let mut spec = SpecializationBits {
shape_div: Vec::with_capacity(ndim),
stride_div: Vec::with_capacity(ndim),
stride_one: Vec::with_capacity(ndim),
base_ptr_div: DivHint::from_ptr(base_ptr),
elements_disjoint: true,
};
for i in 0..ndim {
spec.shape_div.push(DivHint::from_value(shape[i]));
spec.stride_div.push(DivHint::from_value(strides[i]));
spec.stride_one.push(strides[i] == 1);
}
let mut sorted: Vec<(i32, i32)> = strides
.iter()
.zip(shape.iter())
.map(|(&s, &d)| (s, d))
.collect();
sorted.sort();
spec.elements_disjoint = sorted.first().map_or(true, |(s, _)| *s > 0);
for w in sorted.windows(2) {
if w[1].0 <= 0 || w[1].0 < w[0].0 * w[0].1 {
spec.elements_disjoint = false;
break;
}
}
spec
}
#[cfg(test)]
#[path = "specialization_tests.rs"]
mod tests;