use midenc_session::Session;
use super::{AnalysisKey, AnalysisManager, PassInfo};
use crate::diagnostics::Report;
pub type RewriteResult = Result<(), Report>;
pub type RewriteFn<T> = dyn FnMut(&mut T, &mut AnalysisManager, &Session) -> RewriteResult;
pub trait RewritePassInfo: PassInfo + RewritePass {}
impl<P> RewritePassInfo for P where P: PassInfo + RewritePass {}
pub trait RewritePass {
type Entity: AnalysisKey;
fn should_apply(&self, _entity: &Self::Entity, _session: &Session) -> bool {
true
}
fn apply(
&mut self,
entity: &mut Self::Entity,
analyses: &mut AnalysisManager,
session: &Session,
) -> RewriteResult;
fn chain<R>(self, next: R) -> RewriteSet<Self::Entity>
where
Self: Sized + 'static,
R: RewritePass<Entity = Self::Entity> + 'static,
{
RewriteSet::pair(self, next)
}
}
impl<P, T> RewritePass for Box<P>
where
T: AnalysisKey,
P: RewritePass<Entity = T>,
{
type Entity = T;
fn should_apply(&self, entity: &Self::Entity, session: &Session) -> bool {
(**self).should_apply(entity, session)
}
fn apply(
&mut self,
entity: &mut Self::Entity,
analyses: &mut AnalysisManager,
session: &Session,
) -> RewriteResult {
(**self).apply(entity, analyses, session)
}
fn chain<R>(self, next: R) -> RewriteSet<Self::Entity>
where
Self: Sized + 'static,
R: RewritePass<Entity = Self::Entity> + 'static,
{
let mut rewrites = RewriteSet::from(self);
rewrites.push(next);
rewrites
}
}
impl<T> RewritePass for Box<dyn RewritePass<Entity = T>>
where
T: AnalysisKey,
{
type Entity = T;
#[inline]
fn apply(
&mut self,
entity: &mut Self::Entity,
analyses: &mut AnalysisManager,
session: &Session,
) -> RewriteResult {
(**self).apply(entity, analyses, session)
}
}
impl<T> RewritePass for Box<dyn FnMut(&mut T, &mut AnalysisManager, &Session) -> RewriteResult>
where
T: AnalysisKey,
{
type Entity = T;
#[inline]
fn apply(
&mut self,
entity: &mut Self::Entity,
analyses: &mut AnalysisManager,
session: &Session,
) -> RewriteResult {
self(entity, analyses, session)
}
}
impl<T> RewritePass for dyn FnMut(&mut T, &mut AnalysisManager, &Session) -> RewriteResult
where
T: AnalysisKey,
{
type Entity = T;
#[inline]
fn apply(
&mut self,
entity: &mut Self::Entity,
analyses: &mut AnalysisManager,
session: &Session,
) -> RewriteResult {
self(entity, analyses, session)
}
}
pub struct ModuleRewritePassAdapter<R>(R);
impl<R> Default for ModuleRewritePassAdapter<R>
where
R: RewritePass<Entity = crate::Function> + Default,
{
fn default() -> Self {
Self(R::default())
}
}
impl<R> ModuleRewritePassAdapter<R>
where
R: RewritePass<Entity = crate::Function>,
{
pub const fn new(pass: R) -> Self {
Self(pass)
}
}
impl<R: PassInfo> PassInfo for ModuleRewritePassAdapter<R> {
const DESCRIPTION: &'static str = <R as PassInfo>::DESCRIPTION;
const FLAG: &'static str = <R as PassInfo>::FLAG;
const SUMMARY: &'static str = <R as PassInfo>::SUMMARY;
}
impl<R> RewritePass for ModuleRewritePassAdapter<R>
where
R: RewritePass<Entity = crate::Function>,
{
type Entity = crate::Module;
fn apply(
&mut self,
module: &mut Self::Entity,
analyses: &mut AnalysisManager,
session: &Session,
) -> RewriteResult {
let mut cursor = module.cursor_mut();
let mut dirty = false;
while let Some(mut function) = cursor.remove() {
if self.0.should_apply(&function, session) {
dirty = true;
self.0.apply(&mut function, analyses, session)?;
analyses.invalidate::<crate::Function>(&function.id);
}
cursor.insert_before(function);
}
if !dirty {
analyses.mark_all_preserved::<crate::Module>(&module.name);
}
Ok(())
}
}
pub struct RewriteSet<T> {
rewrites: Vec<Box<dyn RewritePass<Entity = T>>>,
}
impl<T> Default for RewriteSet<T> {
fn default() -> Self {
Self { rewrites: vec![] }
}
}
impl<T> RewriteSet<T>
where
T: AnalysisKey,
{
pub fn pair<A, B>(a: A, b: B) -> Self
where
A: RewritePass<Entity = T> + 'static,
B: RewritePass<Entity = T> + 'static,
{
Self {
rewrites: vec![Box::new(a), Box::new(b)],
}
}
pub fn push<R>(&mut self, rewrite: R)
where
R: RewritePass<Entity = T> + 'static,
{
self.rewrites.push(Box::new(rewrite));
}
pub fn append(&mut self, other: &mut Self) {
self.rewrites.append(&mut other.rewrites);
}
pub fn extend(&mut self, iter: impl IntoIterator<Item = Box<dyn RewritePass<Entity = T>>>) {
self.rewrites.extend(iter);
}
}
impl<T> IntoIterator for RewriteSet<T>
where
T: AnalysisKey,
{
type IntoIter = alloc::vec::IntoIter<Self::Item>;
type Item = Box<dyn RewritePass<Entity = T>>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.rewrites.into_iter()
}
}
impl<T> From<Box<dyn RewritePass<Entity = T>>> for RewriteSet<T>
where
T: AnalysisKey,
{
fn from(rewrite: Box<dyn RewritePass<Entity = T>>) -> Self {
Self {
rewrites: vec![rewrite],
}
}
}
impl<T, R: RewritePass<Entity = T> + 'static> From<Box<R>> for RewriteSet<T>
where
T: AnalysisKey,
{
fn from(rewrite: Box<R>) -> Self {
Self {
rewrites: vec![rewrite],
}
}
}
impl<T> RewritePass for RewriteSet<T>
where
T: AnalysisKey,
{
type Entity = T;
fn apply(
&mut self,
entity: &mut Self::Entity,
analyses: &mut AnalysisManager,
session: &Session,
) -> RewriteResult {
for pass in self.rewrites.iter_mut() {
if !pass.should_apply(entity, session) {
continue;
}
pass.apply(entity, analyses, session)?;
analyses.invalidate::<T>(&entity.key());
}
Ok(())
}
fn chain<R>(mut self, next: R) -> RewriteSet<Self::Entity>
where
Self: Sized + 'static,
R: RewritePass<Entity = Self::Entity> + 'static,
{
self.push(next);
self
}
}
#[doc(hidden)]
pub struct RewritePassRegistration<T> {
pub name: &'static str,
pub summary: &'static str,
pub description: &'static str,
ctor: fn() -> Box<dyn RewritePass<Entity = T>>,
}
impl<T> RewritePassRegistration<T> {
pub const fn new<P>() -> Self
where
P: RewritePass<Entity = T> + PassInfo + Default + 'static,
{
Self {
name: <P as PassInfo>::FLAG,
summary: <P as PassInfo>::SUMMARY,
description: <P as PassInfo>::DESCRIPTION,
ctor: dyn_rewrite_pass_ctor::<P>,
}
}
#[inline]
pub const fn name(&self) -> &'static str {
self.name
}
#[inline]
pub const fn summary(&self) -> &'static str {
self.summary
}
#[inline]
pub const fn description(&self) -> &'static str {
self.description
}
#[inline]
pub fn get(&self) -> Box<dyn RewritePass<Entity = T>> {
(self.ctor)()
}
}
fn dyn_rewrite_pass_ctor<P>() -> Box<dyn RewritePass<Entity = <P as RewritePass>::Entity>>
where
P: RewritePass + Default + 'static,
{
Box::<P>::default()
}
inventory::collect!(RewritePassRegistration<crate::Module>);
inventory::collect!(RewritePassRegistration<crate::Function>);