use crate::ir::Op;
use anyhow::Result;
use std::sync::Arc;
use std::sync::RwLock;
mod func_to_llvm;
mod mlir_to_llvmir;
pub use func_to_llvm::ConvertFuncToLLVM;
pub use mlir_to_llvmir::ConvertMLIRToLLVMIR;
pub struct ChangedOp(pub Arc<RwLock<dyn Op>>);
impl ChangedOp {
pub fn new(op: Arc<RwLock<dyn Op>>) -> Self {
ChangedOp(op)
}
}
impl PartialEq for ChangedOp {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}
#[derive(PartialEq)]
pub enum RewriteResult {
Changed(ChangedOp),
Unchanged,
}
impl RewriteResult {
pub fn is_changed(&self) -> bool {
matches!(self, RewriteResult::Changed(_))
}
}
pub trait Rewrite {
fn is_match(&self, op: Arc<RwLock<dyn Op>>) -> Result<bool>;
fn rewrite(&self, op: Arc<RwLock<dyn Op>>) -> Result<RewriteResult>;
}
fn apply_rewrites_helper(
root: Arc<RwLock<dyn Op>>,
rewrites: &[&dyn Rewrite],
) -> Result<RewriteResult> {
for rewrite in rewrites {
let ops = root.try_read().unwrap().ops();
for nested_op in ops.iter() {
let result = apply_rewrites_helper(nested_op.clone(), rewrites)?;
if result.is_changed() {
let root_passthrough = ChangedOp::new(root.clone());
let root_passthrough = RewriteResult::Changed(root_passthrough);
return Ok(root_passthrough);
}
}
if rewrite.is_match(root.clone())? {
let root_rewrite = rewrite.rewrite(root.clone())?;
if root_rewrite.is_changed() {
return Ok(root_rewrite);
}
}
}
Ok(RewriteResult::Unchanged)
}
pub fn apply_rewrites(
root: Arc<RwLock<dyn Op>>,
rewrites: &[&dyn Rewrite],
) -> Result<RewriteResult> {
let max_iterations = 16;
let mut root = root;
let mut has_changed = false;
for _ in 0..max_iterations {
let result = apply_rewrites_helper(root.clone(), rewrites)?;
match result {
RewriteResult::Changed(changed) => {
has_changed = true;
root = changed.0;
}
RewriteResult::Unchanged => {
if has_changed {
let op = ChangedOp::new(root);
return Ok(RewriteResult::Changed(op));
} else {
return Ok(result);
}
}
}
}
anyhow::bail!("too many rewrite iterations");
}
pub trait Pass {
fn name() -> &'static str;
fn convert(op: Arc<RwLock<dyn Op>>) -> Result<RewriteResult>;
}