use crate::ir::spaces;
use crate::ir::Op;
use crate::shared::Shared;
use crate::shared::SharedExt;
use anyhow::Result;
use rayon::prelude::*;
use std::sync::Arc;
use tracing::debug;
mod cf_to_llvm;
mod experimental_to_mlir;
mod func_to_llvm;
mod mlir_to_llvmir;
mod mlir_to_wat;
mod scf_to_cf;
pub use cf_to_llvm::ConvertCFToLLVM;
pub use experimental_to_mlir::ConvertExperimentalToMLIR;
pub use func_to_llvm::ConvertFuncToLLVM;
pub use mlir_to_llvmir::ConvertMLIRToLLVMIR;
pub use mlir_to_wat::ConvertMLIRToWat;
pub use scf_to_cf::ConvertSCFToCF;
pub struct ChangedOp {
pub op: Shared<dyn Op>,
}
impl ChangedOp {
pub fn new(op: Shared<dyn Op>) -> Self {
ChangedOp { op }
}
}
impl PartialEq for ChangedOp {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.op, &other.op)
}
}
#[derive(PartialEq)]
pub enum RewriteResult {
Changed(ChangedOp),
Unchanged,
}
impl RewriteResult {
pub fn is_changed(&self) -> Option<&ChangedOp> {
match self {
RewriteResult::Changed(op) => Some(op),
RewriteResult::Unchanged => None,
}
}
}
pub trait Rewrite: Send + Sync {
fn name(&self) -> &'static str;
fn parallelizable(&self) -> bool;
fn rewrite(&self, op: Shared<dyn Op>) -> Result<RewriteResult>;
}
fn apply_rewrite_helper(
root: Shared<dyn Op>,
rewrite: &dyn Rewrite,
parallel: bool,
nested_op: &Shared<dyn Op>,
indent: i32,
) -> Result<RewriteResult> {
let indent = indent + 1;
let result = apply_rewrite(nested_op.clone(), rewrite, parallel, indent);
match result {
Ok(result) => {
if result.is_changed().is_some() {
let root_passthrough = ChangedOp::new(root.clone());
return Ok(RewriteResult::Changed(root_passthrough));
}
}
Err(e) => {
return Err(e);
}
}
Ok(RewriteResult::Unchanged)
}
fn apply_rewrite(
root: Shared<dyn Op>,
rewrite: &dyn Rewrite,
parallel: bool,
indent: i32,
) -> Result<RewriteResult> {
debug!(
"{}Matching {} with {}",
spaces(indent),
root.clone().rd().name(),
rewrite.name()
);
let root_rewrite = rewrite.rewrite(root.clone())?;
if root_rewrite.is_changed().is_some() {
debug!("{}----> Changed", spaces(indent));
return Ok(root_rewrite);
}
fn finder(result: &Result<RewriteResult>) -> bool {
match result {
Ok(RewriteResult::Changed(_)) => true,
Ok(RewriteResult::Unchanged) => false,
Err(_) => true,
}
}
let ops = root.rd().ops();
let nested_parallel = false;
let first_changed = if parallel {
ops.par_iter()
.map(|nested_op| {
apply_rewrite_helper(root.clone(), rewrite, nested_parallel, nested_op, indent)
})
.find_first(finder)
} else {
ops.iter()
.map(|nested_op| {
apply_rewrite_helper(root.clone(), rewrite, nested_parallel, nested_op, indent)
})
.find(finder)
};
match first_changed {
Some(result) => match result {
Ok(RewriteResult::Changed(op)) => Ok(RewriteResult::Changed(op)),
Ok(RewriteResult::Unchanged) => Ok(RewriteResult::Unchanged),
Err(e) => Err(e),
},
None => Ok(RewriteResult::Unchanged),
}
}
fn apply_rewrites_helper(
root: Shared<dyn Op>,
rewrites: &[&dyn Rewrite],
indent: i32,
) -> Result<RewriteResult> {
for rewrite in rewrites {
let parallel = rewrite.parallelizable();
let result = apply_rewrite(root.clone(), *rewrite, parallel, indent)?;
if result.is_changed().is_some() {
return Ok(result);
}
}
Ok(RewriteResult::Unchanged)
}
pub fn apply_rewrites(root: Shared<dyn Op>, rewrites: &[&dyn Rewrite]) -> Result<RewriteResult> {
let max_iterations = 10240;
let mut root = root;
let mut has_changed = false;
for _ in 0..max_iterations {
let result = apply_rewrites_helper(root.clone(), rewrites, 0)?;
match result {
RewriteResult::Changed(changed) => {
has_changed = true;
root = changed.op;
}
RewriteResult::Unchanged => {
if has_changed {
let op = ChangedOp::new(root);
return Ok(RewriteResult::Changed(op));
} else {
return Ok(result);
}
}
}
}
tracing::warn!("Too many rewrite iterations");
Ok(RewriteResult::Changed(ChangedOp::new(root)))
}
pub trait Pass {
const NAME: &'static str;
fn convert(op: Shared<dyn Op>) -> Result<RewriteResult>;
}
pub fn simple_op_rewrite<A: Op + 'static, B: Op + 'static>(
op: Shared<dyn Op>,
) -> Result<RewriteResult> {
let op = op.rd();
let op = match op.as_any().downcast_ref::<A>() {
Some(op) => op,
None => return Ok(RewriteResult::Unchanged),
};
let operation = op.operation().clone();
let new_op = B::from_operation_arc(operation);
let new_op = Shared::new(new_op.into());
op.replace(new_op.clone());
Ok(RewriteResult::Changed(ChangedOp::new(new_op)))
}