use crate::ir::{Ident, Program};
use crate::optimizer::program_shape_facts::ProgramShapeFacts;
use rustc_hash::{FxHashMap, FxHashSet};
use std::sync::Arc;
mod type_facts;
mod use_facts;
use self::use_facts::derive_use_facts;
#[derive(Default, Clone, Debug)]
pub struct FactSubstrate {
fingerprint: [u8; 32],
pub shape: Option<Arc<ProgramShapeFacts>>,
pub use_facts: Option<Arc<UseFacts>>,
pub use_counts: Option<Arc<FxHashMap<Ident, usize>>>,
pub type_map: Option<Arc<TypeFacts>>,
}
#[derive(Default, Clone, Debug, PartialEq, Eq)]
pub struct TypeFacts {
pub var_types: FxHashMap<Ident, crate::ir::DataType>,
pub expr_types: FxHashMap<u64, crate::ir::DataType>,
}
#[derive(Default, Clone, Debug, PartialEq, Eq)]
pub struct UseFacts {
pub var_counts: Arc<FxHashMap<Ident, usize>>,
pub buffer_reads: FxHashMap<Ident, usize>,
pub buffer_writes: FxHashMap<Ident, usize>,
pub buffer_index_axes: FxHashMap<Ident, [usize; 3]>,
pub var_buffer_deps: FxHashMap<Ident, FxHashSet<Ident>>,
pub buffer_write_deps: FxHashMap<Ident, FxHashSet<Ident>>,
pub indirect_dispatch_buffers: FxHashSet<Ident>,
pub has_opaque: bool,
}
impl UseFacts {
#[must_use]
pub fn dominant_index_axis(&self, buffer: &Ident) -> Option<u8> {
let axes = self.buffer_index_axes.get(buffer)?;
axes.iter()
.enumerate()
.max_by_key(|&(axis, count)| (*count, std::cmp::Reverse(axis)))
.and_then(|(axis, count)| (*count > 0).then(|| u8::try_from(axis).ok()))
.flatten()
}
#[must_use]
pub fn access_count(&self, buffer: &Ident) -> usize {
self.buffer_reads.get(buffer).copied().unwrap_or(0)
+ self.buffer_writes.get(buffer).copied().unwrap_or(0)
}
}
thread_local! {
static FACT_SUBSTRATE_CACHE_FULL: std::cell::RefCell<Option<([u8; 32], FactSubstrate)>> =
const { std::cell::RefCell::new(None) };
static FACT_SUBSTRATE_CACHE_SHAPE_USE: std::cell::RefCell<Option<([u8; 32], FactSubstrate)>> =
const { std::cell::RefCell::new(None) };
static FACT_SUBSTRATE_CACHE_USE_ONLY: std::cell::RefCell<Option<([u8; 32], FactSubstrate)>> =
const { std::cell::RefCell::new(None) };
}
impl FactSubstrate {
#[must_use]
pub fn derive(program: &Program) -> Self {
let fp = program.fingerprint();
let use_facts = derive_use_facts(program);
Self {
fingerprint: fp,
shape: Some(Arc::new(ProgramShapeFacts::derive(program))),
use_counts: Some(Arc::clone(&use_facts.var_counts)),
use_facts: Some(Arc::new(use_facts)),
type_map: Some(Arc::new(type_facts::derive(program))),
}
}
#[must_use]
pub fn derive_cached(program: &Program) -> Self {
let fp = program.fingerprint();
FACT_SUBSTRATE_CACHE_FULL.with(|cell| {
if let Some((cached_fp, ref cached)) = *cell.borrow() {
if cached_fp == fp {
return cached.clone();
}
}
let fresh = Self::derive(program);
*cell.borrow_mut() = Some((fp, fresh.clone()));
fresh
})
}
#[must_use]
pub fn derive_shape_and_use(program: &Program) -> Self {
let fp = program.fingerprint();
let use_facts = derive_use_facts(program);
Self {
fingerprint: fp,
shape: Some(Arc::new(ProgramShapeFacts::derive(program))),
use_counts: Some(Arc::clone(&use_facts.var_counts)),
use_facts: Some(Arc::new(use_facts)),
type_map: None,
}
}
#[must_use]
pub fn derive_shape_and_use_cached(program: &Program) -> Self {
let fp = program.fingerprint();
FACT_SUBSTRATE_CACHE_SHAPE_USE.with(|cell| {
if let Some((cached_fp, ref cached)) = *cell.borrow() {
if cached_fp == fp {
return cached.clone();
}
}
let fresh = Self::derive_shape_and_use(program);
*cell.borrow_mut() = Some((fp, fresh.clone()));
fresh
})
}
#[must_use]
pub fn derive_use_only(program: &Program) -> Self {
let use_facts = derive_use_facts(program);
Self {
fingerprint: program.fingerprint(),
shape: None,
use_counts: Some(Arc::clone(&use_facts.var_counts)),
use_facts: Some(Arc::new(use_facts)),
type_map: None,
}
}
#[must_use]
pub fn derive_use_only_cached(program: &Program) -> Self {
let fp = program.fingerprint();
FACT_SUBSTRATE_CACHE_USE_ONLY.with(|cell| {
if let Some((cached_fp, ref cached)) = *cell.borrow() {
if cached_fp == fp {
return cached.clone();
}
}
let fresh = Self::derive_use_only(program);
*cell.borrow_mut() = Some((fp, fresh.clone()));
fresh
})
}
pub fn invalidate(&mut self) {
self.invalidate_shape();
self.invalidate_use_facts();
self.invalidate_type_map();
}
pub fn invalidate_shape(&mut self) {
self.shape = None;
}
pub fn invalidate_use_facts(&mut self) {
self.use_facts = None;
self.use_counts = None;
}
pub fn invalidate_type_map(&mut self) {
self.type_map = None;
}
#[must_use]
pub fn is_fresh_for(&self, program: &Program) -> bool {
self.fingerprint == program.fingerprint()
&& self.shape.is_some()
&& self.use_facts.is_some()
&& self.use_counts.is_some()
&& self.type_map.is_some()
}
#[must_use]
pub fn has_fresh_use_facts_for(&self, program: &Program) -> bool {
self.fingerprint == program.fingerprint() && self.use_facts.is_some()
}
#[must_use]
pub fn has_fresh_shape_and_use_for(&self, program: &Program) -> bool {
self.fingerprint == program.fingerprint()
&& self.shape.is_some()
&& self.use_facts.is_some()
&& self.use_counts.is_some()
}
#[must_use]
pub fn use_facts(&self) -> Option<&UseFacts> {
self.use_facts.as_deref()
}
#[must_use]
pub fn use_counts(&self) -> Option<&FxHashMap<Ident, usize>> {
self.use_counts.as_deref()
}
#[must_use]
pub fn use_count_of(&self, name: &Ident) -> usize {
self.use_facts()
.and_then(|facts| facts.var_counts.get(name))
.copied()
.or_else(|| self.use_counts().and_then(|m| m.get(name)).copied())
.unwrap_or(0)
}
}
#[cfg(test)]
mod tests;