use core::cell::{Ref, RefCell, RefMut};
use alloc::{boxed::Box, vec::Vec};
use downcast_rs::{Downcast, impl_downcast};
use rustc_hash::{FxHashMap, FxHashSet};
use crate::{
context::{Context, Ptr},
irbuild::IRStatus,
op::{Op, OpInterfaceMarker, op_impls},
operation::Operation,
result::Result,
};
#[derive(Default)]
pub struct PassResult {
pub ir_changed: IRStatus,
preserved_analyses: FxHashSet<core::any::TypeId>,
}
impl PassResult {
pub fn set_preserved<A: Analysis + 'static>(&mut self) {
self.preserved_analyses.insert(core::any::TypeId::of::<A>());
}
}
pub trait Pass {
fn name(&self) -> &str;
fn run(
&self,
op: Ptr<Operation>,
ctx: &mut Context,
analyses: &mut AnalysisManager,
) -> Result<PassResult>;
}
pub trait PassGroup: Pass {
fn add_pass(&mut self, pass: impl Pass + 'static);
}
#[derive(Default)]
pub struct PassManager {
passes: Vec<Box<dyn Pass>>,
}
impl Pass for PassManager {
fn name(&self) -> &str {
"pass_manager"
}
fn run(
&self,
op: Ptr<Operation>,
ctx: &mut Context,
analyses: &mut AnalysisManager,
) -> Result<PassResult> {
use crate::linked_list::ContainsLinkedList;
let mut pass_res = PassResult::default();
let regions = op.deref(ctx).regions().collect::<Vec<_>>();
for region in regions {
let blocks = region.deref(ctx).iter(ctx).collect::<Vec<_>>();
for block in blocks {
let ops = block.deref(ctx).iter(ctx).collect::<Vec<_>>();
for nested_op in ops {
for pass in &self.passes {
let res = pass.run(nested_op, ctx, analyses)?;
pass_res.ir_changed |= res.ir_changed;
analyses.retain_preserved(&res);
}
}
}
}
let preserved_analyses = analyses.list_analyses();
pass_res.preserved_analyses = preserved_analyses;
Ok(pass_res)
}
}
impl PassManager {
pub fn add_pass(&mut self, pass: impl Pass + 'static) {
self.passes.push(Box::new(pass));
}
}
impl PassGroup for PassManager {
fn add_pass(&mut self, pass: impl Pass + 'static) {
self.add_pass(pass);
}
}
pub trait Guard {
fn is_allowed(&self, op: Ptr<Operation>, ctx: &Context) -> bool;
}
pub struct OpGuard<T: Op> {
_marker: core::marker::PhantomData<T>,
}
impl<T: Op> Default for OpGuard<T> {
fn default() -> Self {
Self {
_marker: core::marker::PhantomData,
}
}
}
impl<T: Op> Guard for OpGuard<T> {
fn is_allowed(&self, op: Ptr<Operation>, ctx: &Context) -> bool {
Operation::is_op::<T>(op, ctx)
}
}
pub struct OpInterfaceGuard<T: ?Sized + OpInterfaceMarker + 'static> {
_marker: core::marker::PhantomData<T>,
}
impl<T: ?Sized + OpInterfaceMarker + 'static> Default for OpInterfaceGuard<T> {
fn default() -> Self {
Self {
_marker: core::marker::PhantomData,
}
}
}
impl<T: ?Sized + OpInterfaceMarker + 'static> Guard for OpInterfaceGuard<T> {
fn is_allowed(&self, op: Ptr<Operation>, ctx: &Context) -> bool {
let op = Operation::get_op_dyn(op, ctx);
op_impls::<T>(&*op)
}
}
#[derive(Default)]
pub struct GuardedPass<P: Pass, G: Guard> {
pass: P,
guard: G,
}
impl<P: Pass, G: Guard> Pass for GuardedPass<P, G> {
fn name(&self) -> &str {
self.pass.name()
}
fn run(
&self,
op: Ptr<Operation>,
ctx: &mut Context,
analyses: &mut AnalysisManager,
) -> Result<PassResult> {
if self.guard.is_allowed(op, ctx) {
self.pass.run(op, ctx, analyses)
} else {
Ok(PassResult::default())
}
}
}
impl<P: Pass + PassGroup, G: Guard> PassGroup for GuardedPass<P, G> {
fn add_pass(&mut self, pass: impl Pass + 'static) {
self.pass.add_pass(pass);
}
}
pub type OpPass<P, T> = GuardedPass<P, OpGuard<T>>;
pub type OpInterfacePass<P, T> = GuardedPass<P, OpInterfaceGuard<T>>;
pub type OpPassManager<T> = OpPass<PassManager, T>;
pub type OpInterfacePassManager<T> = OpInterfacePass<PassManager, T>;
pub trait Analysis: Downcast {
fn name(&self) -> &str;
fn compute(op: Ptr<Operation>, ctx: &Context, analyses: &mut AnalysisManager) -> Result<Self>
where
Self: Sized;
}
impl_downcast!(Analysis);
type AnalysisManagerKey = (core::any::TypeId, Ptr<Operation>);
#[derive(Default)]
pub struct AnalysisManager {
analyses: FxHashMap<AnalysisManagerKey, Box<RefCell<dyn Analysis>>>,
}
impl AnalysisManager {
pub fn compute_analysis<A: Analysis + 'static>(
&mut self,
op: Ptr<Operation>,
ctx: &Context,
) -> Result<()> {
let key = (core::any::TypeId::of::<A>(), op);
if !self.analyses.contains_key(&key) {
let analysis = A::compute(op, ctx, self)?;
self.analyses.insert(key, Box::new(RefCell::new(analysis)));
}
Ok(())
}
pub fn get_analysis_mut<'a, A: Analysis + 'static>(
&'a mut self,
op: Ptr<Operation>,
ctx: &Context,
) -> Result<RefMut<'a, A>> {
self.compute_analysis::<A>(op, ctx)?;
let key = (core::any::TypeId::of::<A>(), op);
let analysis = self.analyses.get(&key).unwrap();
Ok(RefMut::map(analysis.borrow_mut(), |a| {
a.downcast_mut::<A>().unwrap()
}))
}
pub fn get_analysis<'a, A: Analysis + 'static>(
&'a mut self,
op: Ptr<Operation>,
ctx: &Context,
) -> Result<Ref<'a, A>> {
self.compute_analysis::<A>(op, ctx)?;
let key = (core::any::TypeId::of::<A>(), op);
let analysis = self.analyses.get(&key).unwrap();
Ok(Ref::map(analysis.borrow(), |a| {
a.downcast_ref::<A>().unwrap()
}))
}
pub fn try_get_analysis<'a, A: Analysis + 'static>(
&'a self,
op: Ptr<Operation>,
) -> Option<Ref<'a, A>> {
let key = (core::any::TypeId::of::<A>(), op);
self.analyses
.get(&key)
.map(|analysis| Ref::map(analysis.borrow(), |a| a.downcast_ref::<A>().unwrap()))
}
pub fn try_get_analysis_mut<'a, A: Analysis + 'static>(
&'a self,
op: Ptr<Operation>,
) -> Option<RefMut<'a, A>> {
let key = (core::any::TypeId::of::<A>(), op);
self.analyses
.get(&key)
.map(|analysis| RefMut::map(analysis.borrow_mut(), |a| a.downcast_mut::<A>().unwrap()))
}
pub fn retain_preserved(&mut self, pass_res: &PassResult) {
if pass_res.ir_changed == IRStatus::Unchanged {
return;
}
self.analyses
.retain(|(type_id, _), _| pass_res.preserved_analyses.contains(type_id));
}
fn list_analyses(&self) -> FxHashSet<core::any::TypeId> {
self.analyses.keys().map(|(type_id, _)| *type_id).collect()
}
}