use crate::ir::{Expr, Ident, Node, Program};
use crate::optimizer::program_shape_facts::ProgramShapeFacts;
use rustc_hash::{FxHashMap, FxHashSet};
use smallvec::SmallVec;
use std::sync::Arc;
mod type_facts;
#[derive(Default, Clone, Debug)]
pub struct FactSubstrate {
fingerprint: [u8; 32],
pub shape: Option<Arc<ProgramShapeFacts>>,
pub use_facts: Option<Arc<UseFacts>>,
pub use_counts: Option<Arc<FxHashMap<Ident, usize>>>,
pub type_map: Option<Arc<TypeFacts>>,
}
#[derive(Default, Clone, Debug, PartialEq, Eq)]
pub struct TypeFacts {
pub var_types: FxHashMap<Ident, crate::ir::DataType>,
pub expr_types: FxHashMap<u64, crate::ir::DataType>,
}
#[derive(Default, Clone, Debug, PartialEq, Eq)]
pub struct UseFacts {
pub var_counts: Arc<FxHashMap<Ident, usize>>,
pub buffer_reads: FxHashMap<Ident, usize>,
pub buffer_writes: FxHashMap<Ident, usize>,
pub buffer_index_axes: FxHashMap<Ident, [usize; 3]>,
pub var_buffer_deps: FxHashMap<Ident, FxHashSet<Ident>>,
pub buffer_write_deps: FxHashMap<Ident, FxHashSet<Ident>>,
pub indirect_dispatch_buffers: FxHashSet<Ident>,
pub has_opaque: bool,
}
impl UseFacts {
#[must_use]
pub fn dominant_index_axis(&self, buffer: &Ident) -> Option<u8> {
let axes = self.buffer_index_axes.get(buffer)?;
axes.iter()
.enumerate()
.max_by_key(|&(axis, count)| (*count, std::cmp::Reverse(axis)))
.and_then(|(axis, count)| (*count > 0).then_some(axis as u8))
}
#[must_use]
pub fn access_count(&self, buffer: &Ident) -> usize {
self.buffer_reads.get(buffer).copied().unwrap_or(0)
+ self.buffer_writes.get(buffer).copied().unwrap_or(0)
}
}
#[derive(Default)]
struct UseFactBuilder {
var_counts: FxHashMap<Ident, usize>,
buffer_reads: FxHashMap<Ident, usize>,
buffer_writes: FxHashMap<Ident, usize>,
buffer_index_axes: FxHashMap<Ident, [usize; 3]>,
var_buffer_deps: FxHashMap<Ident, FxHashSet<Ident>>,
buffer_write_deps: FxHashMap<Ident, FxHashSet<Ident>>,
indirect_dispatch_buffers: FxHashSet<Ident>,
has_opaque: bool,
}
impl UseFactBuilder {
fn finish(self) -> UseFacts {
UseFacts {
var_counts: Arc::new(self.var_counts),
buffer_reads: self.buffer_reads,
buffer_writes: self.buffer_writes,
buffer_index_axes: self.buffer_index_axes,
var_buffer_deps: self.var_buffer_deps,
buffer_write_deps: self.buffer_write_deps,
indirect_dispatch_buffers: self.indirect_dispatch_buffers,
has_opaque: self.has_opaque,
}
}
}
impl FactSubstrate {
#[must_use]
pub fn derive(program: &Program) -> Self {
let fp = program.fingerprint();
let use_facts = derive_use_facts(program);
Self {
fingerprint: fp,
shape: Some(ProgramShapeFacts::derive_arc(program)),
use_counts: Some(Arc::clone(&use_facts.var_counts)),
use_facts: Some(Arc::new(use_facts)),
type_map: Some(Arc::new(type_facts::derive(program))),
}
}
#[must_use]
pub fn derive_shape_and_use(program: &Program) -> Self {
let fp = program.fingerprint();
let use_facts = derive_use_facts(program);
Self {
fingerprint: fp,
shape: Some(ProgramShapeFacts::derive_arc(program)),
use_counts: Some(Arc::clone(&use_facts.var_counts)),
use_facts: Some(Arc::new(use_facts)),
type_map: None,
}
}
#[must_use]
pub fn derive_use_only(program: &Program) -> Self {
let use_facts = derive_use_facts(program);
Self {
fingerprint: program.fingerprint(),
shape: None,
use_counts: Some(Arc::clone(&use_facts.var_counts)),
use_facts: Some(Arc::new(use_facts)),
type_map: None,
}
}
pub fn invalidate(&mut self) {
self.shape = None;
self.use_facts = None;
self.use_counts = None;
self.type_map = None;
}
#[must_use]
pub fn is_fresh_for(&self, program: &Program) -> bool {
self.fingerprint == program.fingerprint()
&& self.shape.is_some()
&& self.use_facts.is_some()
&& self.use_counts.is_some()
&& self.type_map.is_some()
}
#[must_use]
pub fn has_fresh_use_facts_for(&self, program: &Program) -> bool {
self.fingerprint == program.fingerprint() && self.use_facts.is_some()
}
#[must_use]
pub fn has_fresh_shape_and_use_for(&self, program: &Program) -> bool {
self.fingerprint == program.fingerprint()
&& self.shape.is_some()
&& self.use_facts.is_some()
&& self.use_counts.is_some()
}
#[must_use]
pub fn use_facts(&self) -> Option<&UseFacts> {
self.use_facts.as_deref()
}
#[must_use]
pub fn use_counts(&self) -> Option<&FxHashMap<Ident, usize>> {
self.use_counts.as_deref()
}
#[must_use]
pub fn use_count_of(&self, name: &Ident) -> usize {
self.use_facts()
.and_then(|facts| facts.var_counts.get(name))
.copied()
.or_else(|| self.use_counts().and_then(|m| m.get(name)).copied())
.unwrap_or(0)
}
}
fn derive_use_facts(program: &Program) -> UseFacts {
let mut facts = UseFactBuilder::default();
derive_nodes_uses(program.entry(), &mut facts, &FxHashSet::default());
facts.finish()
}
fn derive_nodes_uses(nodes: &[Node], facts: &mut UseFactBuilder, control_deps: &FxHashSet<Ident>) {
for node in nodes {
match node {
Node::Let { name, value } | Node::Assign { name, value } => {
let mut deps = record_expr_uses_and_buffer_deps(value, facts);
deps.extend(control_deps.iter().cloned());
facts
.var_buffer_deps
.entry(name.clone())
.or_default()
.extend(deps);
}
Node::Store {
buffer,
index,
value,
} => {
*facts.buffer_writes.entry(buffer.clone()).or_insert(0) += 1;
let mut deps = record_expr_uses_and_buffer_deps(index, facts);
count_index_axes(index, buffer, facts);
deps.extend(record_expr_uses_and_buffer_deps(value, facts));
deps.extend(control_deps.iter().cloned());
add_buffer_write_deps(facts, buffer, deps);
}
Node::If {
cond,
then,
otherwise,
} => {
let cond_deps = record_expr_uses_and_buffer_deps(cond, facts);
let branch_control = union_deps(control_deps, &cond_deps);
derive_nodes_uses(then, facts, &branch_control);
derive_nodes_uses(otherwise, facts, &branch_control);
}
Node::Loop { from, to, body, .. } => {
let mut loop_deps = record_expr_uses_and_buffer_deps(from, facts);
loop_deps.extend(record_expr_uses_and_buffer_deps(to, facts));
let loop_control = union_deps(control_deps, &loop_deps);
derive_nodes_uses(body, facts, &loop_control);
}
Node::Block(nodes) => {
derive_nodes_uses(nodes, facts, control_deps);
}
Node::Region { body, .. } => {
derive_nodes_uses(body, facts, control_deps);
}
Node::AsyncLoad {
source,
destination,
offset,
size,
..
} => {
*facts.buffer_reads.entry(source.clone()).or_insert(0) += 1;
*facts.buffer_writes.entry(destination.clone()).or_insert(0) += 1;
let mut deps = record_expr_uses_and_buffer_deps(offset, facts);
deps.extend(record_expr_uses_and_buffer_deps(size, facts));
deps.extend(control_deps.iter().cloned());
deps.insert(source.clone());
add_buffer_write_deps(facts, destination, deps);
}
Node::AsyncStore {
source,
destination,
offset,
size,
..
} => {
*facts.buffer_reads.entry(source.clone()).or_insert(0) += 1;
*facts.buffer_writes.entry(destination.clone()).or_insert(0) += 1;
let mut deps = record_expr_uses_and_buffer_deps(offset, facts);
deps.extend(record_expr_uses_and_buffer_deps(size, facts));
deps.extend(control_deps.iter().cloned());
deps.insert(source.clone());
add_buffer_write_deps(facts, destination, deps);
}
Node::Trap { address, .. } => {
record_expr_uses_and_buffer_deps(address, facts);
}
Node::IndirectDispatch { count_buffer, .. } => {
facts.indirect_dispatch_buffers.insert(count_buffer.clone());
*facts.buffer_reads.entry(count_buffer.clone()).or_insert(0) += 1;
}
Node::Opaque(_) => {
facts.has_opaque = true;
}
Node::Return | Node::Barrier { .. } | Node::AsyncWait { .. } | Node::Resume { .. } => {}
}
}
}
fn record_expr_uses_and_buffer_deps(expr: &Expr, facts: &mut UseFactBuilder) -> FxHashSet<Ident> {
let mut deps = FxHashSet::default();
let mut stack: SmallVec<[&Expr; 16]> = SmallVec::new();
stack.push(expr);
while let Some(expr) = stack.pop() {
match expr {
Expr::Var(name) => {
*facts.var_counts.entry(name.clone()).or_insert(0) += 1;
if let Some(var_deps) = facts.var_buffer_deps.get(name) {
deps.extend(var_deps.iter().cloned());
}
}
Expr::Load { buffer, index } => {
*facts.buffer_reads.entry(buffer.clone()).or_insert(0) += 1;
count_index_axes(index, buffer, facts);
deps.insert(buffer.clone());
}
Expr::BufLen { buffer } => {
*facts.buffer_reads.entry(buffer.clone()).or_insert(0) += 1;
deps.insert(buffer.clone());
}
Expr::Atomic { buffer, index, .. } => {
*facts.buffer_reads.entry(buffer.clone()).or_insert(0) += 1;
*facts.buffer_writes.entry(buffer.clone()).or_insert(0) += 1;
count_index_axes(index, buffer, facts);
deps.insert(buffer.clone());
}
Expr::Opaque(_) => {
facts.has_opaque = true;
}
_ => {}
}
push_expr_children(expr, &mut stack);
}
deps
}
fn union_deps(a: &FxHashSet<Ident>, b: &FxHashSet<Ident>) -> FxHashSet<Ident> {
let mut out = FxHashSet::default();
out.reserve(a.len().saturating_add(b.len()));
out.extend(a.iter().cloned());
out.extend(b.iter().cloned());
out
}
fn add_buffer_write_deps(facts: &mut UseFactBuilder, buffer: &Ident, deps: FxHashSet<Ident>) {
if deps.is_empty() {
return;
}
facts
.buffer_write_deps
.entry(buffer.clone())
.or_default()
.extend(deps);
}
fn count_index_axes(index: &Expr, buffer: &Ident, facts: &mut UseFactBuilder) {
let mut stack: SmallVec<[&Expr; 16]> = SmallVec::new();
stack.push(index);
while let Some(expr) = stack.pop() {
if let Expr::InvocationId { axis } | Expr::LocalId { axis } = expr {
if let Some(slot) = facts
.buffer_index_axes
.entry(buffer.clone())
.or_insert([0; 3])
.get_mut(usize::from(*axis).min(2))
{
*slot += 1;
}
}
push_expr_children(expr, &mut stack);
}
}
fn push_expr_children<'a>(expr: &'a Expr, stack: &mut SmallVec<[&'a Expr; 16]>) {
match expr {
Expr::Load { index, .. } | Expr::UnOp { operand: index, .. } => stack.push(index),
Expr::BinOp { left, right, .. } => {
stack.push(left);
stack.push(right);
}
Expr::Call { args, .. } => stack.extend(args),
Expr::Select {
cond,
true_val,
false_val,
} => {
stack.push(cond);
stack.push(true_val);
stack.push(false_val);
}
Expr::Cast { value, .. } => stack.push(value),
Expr::Fma { a, b, c } => {
stack.push(a);
stack.push(b);
stack.push(c);
}
Expr::Atomic {
index,
expected,
value,
..
} => {
stack.push(index);
if let Some(expected) = expected {
stack.push(expected);
}
stack.push(value);
}
Expr::SubgroupBallot { cond } => stack.push(cond),
Expr::SubgroupShuffle { value, lane } => {
stack.push(value);
stack.push(lane);
}
Expr::SubgroupAdd { value } => stack.push(value),
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::Var(_)
| Expr::BufLen { .. }
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize
| Expr::Opaque(_) => {}
}
}
#[cfg(test)]
mod tests;