use std::cell::RefCell;
use std::collections::VecDeque;
use std::sync::Arc;
use crate::ir_inner::model::expr::Ident;
use crate::ir_inner::model::program::Program;
use rustc_hash::FxHashMap;
use vyre_spec::data_type::DataType;
use super::shape_facts;
pub const ANALYSIS_KEY: &str = "program_shape_facts";
const SHAPE_FACT_CACHE_CAP: usize = 64;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BufferShapeFacts {
pub name: Ident,
pub binding: u32,
pub dtype: DataType,
pub declared_count: u32,
pub element_size_bytes: Option<usize>,
pub min_count: u32,
pub max_count: Option<u32>,
pub is_fixed_count: bool,
pub min_bytes: u64,
pub max_bytes: Option<u64>,
pub byte_alignment: u32,
}
impl BufferShapeFacts {
#[must_use]
#[inline]
pub fn is_non_empty(&self) -> bool {
self.min_count > 0
}
#[must_use]
#[inline]
pub fn min_bytes_at_least(&self, bytes: u64) -> bool {
self.min_bytes >= bytes
}
#[must_use]
pub fn vectorizable_at(&self, lane_count: u32) -> bool {
if lane_count == 0 {
return false;
}
if self.is_fixed_count {
return self.max_count.is_some_and(|count| count % lane_count == 0);
}
self.byte_alignment
.checked_div(self.element_size_bytes.unwrap_or(0).max(1) as u32)
.map(|elt_align| elt_align % lane_count == 0)
.unwrap_or(false)
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct ProgramShapeFacts {
by_name: FxHashMap<Ident, BufferShapeFacts>,
}
impl ProgramShapeFacts {
#[must_use]
pub fn derive(program: &Program) -> Self {
Self::derive_arc(program).as_ref().clone()
}
#[must_use]
pub fn derive_arc(program: &Program) -> Arc<Self> {
let fingerprint = crate::optimizer::fingerprint_program(program);
if let Some(facts) = SHAPE_FACT_CACHE.with(|cache| cache.borrow().get(fingerprint)) {
return facts;
}
let facts = Arc::new(Self::derive_uncached(program));
SHAPE_FACT_CACHE.with(|cache| cache.borrow_mut().insert(fingerprint, Arc::clone(&facts)));
facts
}
fn derive_uncached(program: &Program) -> Self {
let mut by_name = FxHashMap::default();
by_name.reserve(program.buffers().len());
for decl in program.buffers() {
let name = Ident::from(decl.name.as_ref());
let dtype = decl.element.clone();
let element_size_bytes = dtype.size_bytes();
let declared_count = decl.count;
let predicate = decl.shape_predicate();
let predicate_min = predicate.map_or(0, shape_facts::min_count);
let predicate_max = predicate.and_then(shape_facts::max_count);
let static_min = if declared_count > 0 {
declared_count
} else {
0
};
let static_max = if declared_count > 0 {
Some(declared_count)
} else {
None
};
let min_count = predicate_min.max(static_min);
let max_count = match (predicate_max, static_max) {
(Some(a), Some(b)) => Some(a.min(b)),
(Some(a), None) | (None, Some(a)) => Some(a),
(None, None) => None,
};
let is_fixed_count = max_count.is_some() && Some(min_count) == max_count;
let element_byte_alignment = element_size_bytes
.and_then(|n| u32::try_from(n).ok())
.unwrap_or(1)
.max(1);
let count_alignment = predicate.map(count_alignment_from_predicate).unwrap_or(1);
let byte_alignment = element_byte_alignment.saturating_mul(count_alignment);
let element_size_u64 = element_size_bytes.map(|n| n as u64);
let min_bytes = match element_size_u64 {
Some(esz) => esz.saturating_mul(u64::from(min_count)),
None => 0,
};
let max_bytes = match (element_size_u64, max_count) {
(Some(esz), Some(count)) => Some(esz.saturating_mul(u64::from(count))),
_ => None,
};
by_name.insert(
name.clone(),
BufferShapeFacts {
name,
binding: decl.binding,
dtype,
declared_count,
element_size_bytes,
min_count,
max_count,
is_fixed_count,
min_bytes,
max_bytes,
byte_alignment,
},
);
}
Self { by_name }
}
#[must_use]
#[inline]
pub fn get(&self, name: &Ident) -> Option<&BufferShapeFacts> {
self.by_name.get(name)
}
pub fn iter(&self) -> impl Iterator<Item = (&Ident, &BufferShapeFacts)> {
self.by_name.iter()
}
#[must_use]
#[inline]
pub fn len(&self) -> usize {
self.by_name.len()
}
#[must_use]
#[inline]
pub fn is_empty(&self) -> bool {
self.by_name.is_empty()
}
pub fn derive_into_cache(program: &Program, cache: &mut crate::optimizer::AnalysisCache) {
cache.insert(ANALYSIS_KEY, Self::derive_arc(program).as_ref().clone());
}
#[must_use]
pub fn from_cache(cache: &crate::optimizer::AnalysisCache) -> Option<&Self> {
cache.get::<Self>(ANALYSIS_KEY)
}
}
#[derive(Default)]
struct ShapeFactCache {
by_fingerprint: FxHashMap<u64, Arc<ProgramShapeFacts>>,
order: VecDeque<u64>,
}
impl ShapeFactCache {
fn get(&self, fingerprint: u64) -> Option<Arc<ProgramShapeFacts>> {
self.by_fingerprint.get(&fingerprint).cloned()
}
fn insert(&mut self, fingerprint: u64, facts: Arc<ProgramShapeFacts>) {
if self.by_fingerprint.insert(fingerprint, facts).is_none() {
self.order.push_back(fingerprint);
}
while self.order.len() > SHAPE_FACT_CACHE_CAP {
if let Some(evicted) = self.order.pop_front() {
self.by_fingerprint.remove(&evicted);
}
}
}
}
thread_local! {
static SHAPE_FACT_CACHE: RefCell<ShapeFactCache> = RefCell::new(ShapeFactCache::default());
}
fn count_alignment_from_predicate(
predicate: &crate::ir_inner::model::program::ShapePredicate,
) -> u32 {
use crate::ir_inner::model::program::ShapePredicate;
match predicate {
ShapePredicate::MultipleOf(n) if *n > 0 => *n,
ShapePredicate::ModEquals { modulus, remainder } if *modulus > 0 && *remainder == 0 => {
*modulus
}
ShapePredicate::Exactly(n) if *n > 0 => *n,
ShapePredicate::And(a, b) => {
count_alignment_from_predicate(a).max(count_alignment_from_predicate(b))
}
ShapePredicate::Or(a, b) => {
let left = count_alignment_from_predicate(a);
let right = count_alignment_from_predicate(b);
left.min(right)
}
_ => 1,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir_inner::model::program::{BufferDecl, ShapePredicate};
use vyre::ir::{DataType as VyreDataType, Node, Program as VyreProgram};
fn one_buffer_program() -> VyreProgram {
VyreProgram::wrapped(
vec![
BufferDecl::read("input", 0, VyreDataType::U32).with_count(64),
BufferDecl::output("out", 1, VyreDataType::U32)
.with_count(64)
.with_output_byte_range(0..256),
],
[64, 1, 1],
vec![Node::return_()],
)
}
#[test]
fn derive_records_every_buffer() {
let program = one_buffer_program();
let facts = ProgramShapeFacts::derive(&program);
assert_eq!(facts.len(), 2);
assert!(facts.get(&Ident::from("input")).is_some());
assert!(facts.get(&Ident::from("out")).is_some());
}
#[test]
fn declared_count_pins_min_and_max() {
let program = one_buffer_program();
let facts = ProgramShapeFacts::derive(&program);
let input = facts
.get(&Ident::from("input"))
.expect("Fix: input fact must exist");
assert_eq!(input.min_count, 64);
assert_eq!(input.max_count, Some(64));
assert!(input.is_fixed_count);
assert_eq!(input.element_size_bytes, Some(4));
assert_eq!(input.min_bytes, 256);
assert_eq!(input.max_bytes, Some(256));
assert!(input.is_non_empty());
}
#[test]
fn shape_predicate_at_least_widens_lower_bound_only() {
let program = VyreProgram::wrapped(
vec![BufferDecl::read("input", 0, VyreDataType::U32)
.with_shape_predicate(ShapePredicate::AtLeast(32))],
[64, 1, 1],
vec![Node::return_()],
);
let facts = ProgramShapeFacts::derive(&program);
let input = facts.get(&Ident::from("input")).unwrap();
assert_eq!(input.min_count, 32);
assert_eq!(input.max_count, None);
assert!(!input.is_fixed_count);
assert!(input.is_non_empty());
}
#[test]
fn shape_predicate_multiple_of_proves_byte_alignment() {
let program = VyreProgram::wrapped(
vec![BufferDecl::read("input", 0, VyreDataType::U32)
.with_shape_predicate(ShapePredicate::MultipleOf(16))],
[64, 1, 1],
vec![Node::return_()],
);
let facts = ProgramShapeFacts::derive(&program);
let input = facts.get(&Ident::from("input")).unwrap();
assert_eq!(input.byte_alignment, 64);
assert!(input.vectorizable_at(4));
assert!(input.vectorizable_at(8));
assert!(input.vectorizable_at(16));
}
#[test]
fn vectorizable_at_uses_exact_runtime_predicate_count() {
let program = VyreProgram::wrapped(
vec![BufferDecl::read("input", 0, VyreDataType::U32)
.with_shape_predicate(ShapePredicate::Exactly(96))],
[64, 1, 1],
vec![Node::return_()],
);
let facts = ProgramShapeFacts::derive(&program);
let input = facts.get(&Ident::from("input")).unwrap();
assert!(input.is_fixed_count);
assert_eq!(input.declared_count, 0);
assert!(input.vectorizable_at(32));
assert!(!input.vectorizable_at(64));
}
#[test]
fn variable_size_dtype_leaves_bytes_unbounded() {
let program = VyreProgram::wrapped(
vec![BufferDecl::read("input", 0, VyreDataType::Tensor)],
[64, 1, 1],
vec![Node::return_()],
);
let facts = ProgramShapeFacts::derive(&program);
let input = facts.get(&Ident::from("input")).unwrap();
assert_eq!(input.element_size_bytes, None);
assert_eq!(input.min_bytes, 0);
assert_eq!(input.max_bytes, None);
assert!(!input.is_fixed_count);
}
#[test]
fn fixed_count_program_proves_byte_capacity_exactly() {
let program = one_buffer_program();
let facts = ProgramShapeFacts::derive(&program);
let out = facts.get(&Ident::from("out")).unwrap();
assert_eq!(out.min_bytes, 256);
assert_eq!(out.max_bytes, Some(256));
assert!(out.min_bytes_at_least(128));
assert!(!out.min_bytes_at_least(512));
}
#[test]
fn cache_round_trip_returns_same_facts() {
use crate::optimizer::AnalysisCache;
let program = one_buffer_program();
let mut cache = AnalysisCache::new();
ProgramShapeFacts::derive_into_cache(&program, &mut cache);
let cached = ProgramShapeFacts::from_cache(&cache).expect("Fix: facts must round-trip");
assert_eq!(cached.len(), 2);
let direct = ProgramShapeFacts::derive(&program);
assert_eq!(cached, &direct);
}
#[test]
fn shape_fact_cache_eviction_is_fifo_without_shifting() {
let mut cache = ShapeFactCache::default();
for fingerprint in 0..(SHAPE_FACT_CACHE_CAP as u64 + 2) {
cache.insert(fingerprint, Arc::new(ProgramShapeFacts::default()));
}
assert_eq!(cache.order.len(), SHAPE_FACT_CACHE_CAP);
assert!(!cache.by_fingerprint.contains_key(&0));
assert!(!cache.by_fingerprint.contains_key(&1));
assert!(cache
.by_fingerprint
.contains_key(&(SHAPE_FACT_CACHE_CAP as u64 + 1)));
assert_eq!(cache.order.front().copied(), Some(2));
}
}