use super::lang::Lang;
use crate::mutators::peephole::{
EG,
eggsy::{RandomExtractor, encoder::rebuild::build_expr_inner},
};
use egg::{AstSize, Id, Language, RecExpr};
use rand::{Rng, SeedableRng, prelude::SmallRng};
use std::{cell::RefCell, rc::Rc};
pub fn lazy_expand<'a>(
id: Id,
egraph: Rc<EG>,
depth: u32,
rnd: Rc<RefCell<SmallRng>>,
recexpr: Rc<RefCell<RecExpr<Lang>>>,
) -> Box<dyn Iterator<Item = Id> + 'a> {
if depth == 0 {
let cf = AstSize;
let extractor = RandomExtractor::new(&egraph, cf);
let shorter = extractor
.extract_smallest(id, &recexpr, |a, b, c| build_expr_inner(a, b, c))
.unwrap();
return Box::new(vec![shorter].into_iter());
}
let nodes = egraph[id].nodes.clone();
let count = nodes.len();
let split_at = rnd.borrow_mut().random_range(0..count);
let indices = (split_at..count).into_iter().chain(0..split_at);
let t = indices
.map(move |i| nodes[i].clone())
.map(move |mut l| {
let depth = match l {
Lang::UnfoldI32(_) | Lang::UnfoldI64(_) => 0,
_ => depth - 1,
};
let children = l.children_mut();
let rec = recexpr.clone();
let iter: Box<dyn Iterator<Item = Id>> = match children.len() {
0 => Box::new(std::iter::once(recexpr.borrow_mut().add(l))),
1 => Box::new(
lazy_expand(
children[0],
egraph.clone(),
depth,
rnd.clone(),
recexpr.clone(),
)
.map(move |id| {
let mut l = l.clone();
l.children_mut()[0] = id;
rec.clone().borrow_mut().add(l)
}),
),
2 => {
let rec = recexpr.clone();
let egraph = egraph.clone();
let rnd = rnd.clone();
let recexpr = recexpr.clone();
let left = children[0];
let right = children[1];
Box::new(
lazy_expand(left, egraph.clone(), depth, rnd.clone(), recexpr.clone())
.flat_map(move |e| {
std::iter::repeat(e).zip(lazy_expand(
right,
egraph.clone(),
depth,
rnd.clone(),
recexpr.clone(),
))
})
.map(move |(left, right)| {
let mut l = l.clone();
l.children_mut()[0] = left;
l.children_mut()[1] = right;
rec.borrow_mut().add(l)
}),
)
}
_ => {
for child in children {
*child = lazy_expand(
*child,
egraph.clone(),
depth,
rnd.clone(),
recexpr.clone(),
)
.next()
.unwrap();
}
Box::new(std::iter::once(recexpr.borrow_mut().add(l)))
}
};
iter
})
.flatten();
Box::new(t)
}
pub fn lazy_expand_aux<'a>(
id: Id,
egraph: EG,
depth: u32,
seed: u64,
) -> Box<dyn Iterator<Item = RecExpr<Lang>> + 'a> {
let expr_buffer = RecExpr::default();
let recexpr = Rc::new(RefCell::new(expr_buffer));
let r = SmallRng::seed_from_u64(seed);
let refrnd = Rc::new(RefCell::new(r));
let recexprcp = recexpr.clone();
let recexprcp2 = recexpr;
let eg = Rc::new(egraph);
let it = lazy_expand(id, eg, depth, refrnd, recexprcp).map(move |_id| {
let expr = RecExpr::from(recexprcp2.borrow().as_ref().to_vec());
expr
});
Box::new(it)
}
#[cfg(test)]
mod tests {
use crate::ModuleInfo;
use crate::mutators::peephole::eggsy::{
analysis::PeepholeMutationAnalysis, expr_enumerator::lazy_expand, lang::Lang,
};
use egg::{AstSize, RecExpr, Rewrite, Runner, rewrite};
use rand::{SeedableRng, prelude::SmallRng};
use std::{cell::RefCell, collections::HashMap, rc::Rc};
#[derive(Clone, Copy, Debug)]
struct Mimi {
id: usize,
}
impl Mimi {
pub fn used(&self) {
println!("mimi {}", self.id)
}
}
#[test]
fn test_lazy() {
let o1 = vec![Mimi { id: 1 }, Mimi { id: 2 }].into_iter();
let o2 = vec![Mimi { id: 2 }, Mimi { id: 3 }].into_iter();
let mut l = o1
.flat_map(|e| std::iter::repeat(e).zip(o2.clone()))
.map(|(l, r)| {
l.used();
r.used();
(l, r)
});
println!("1 {:?}", l.next());
println!("2 {:?}", l.next());
println!("3 {:?}", l.next());
println!("4 {:?}", l.next());
}
#[test]
fn test_rec_iterator() {
let rules: &[Rewrite<Lang, PeepholeMutationAnalysis>] = &[
rewrite!("rule"; "?x" => "(i32.add ?x i32.const.0)"),
rewrite!("rule2"; "(i32.add ?y ?x)" => "(i32.add ?x ?y)"),
rewrite!("rule2"; "i32.const.0" => "(i32.sub i32.const.50 i32.const.500)"),
];
let expr = "(i32.add i32.const.100 i32.const.200)";
let empty_wasm = wat::parse_str("(module)").unwrap();
let info = ModuleInfo::new(&empty_wasm).unwrap();
let analysis = PeepholeMutationAnalysis::new(&info, vec![]);
let runner = Runner::<Lang, PeepholeMutationAnalysis, ()>::new(analysis)
.with_iter_limit(1) .with_expr(&expr.parse().unwrap())
.run(rules);
let mut egraph = runner.egraph;
let _cf = AstSize;
let _enumeration_start = std::time::Instant::now();
let root = egraph.add_expr(&expr.parse().unwrap());
let rnd = SmallRng::seed_from_u64(0);
let recexpr = RecExpr::default();
let rnd = Rc::new(RefCell::new(rnd));
let recexpr = Rc::new(RefCell::new(recexpr));
let r = recexpr.clone();
let mut it = lazy_expand(root, Rc::new(egraph), 10, rnd, recexpr);
let mut h: HashMap<String, usize> = HashMap::new();
for _ in 0..100000 {
if let Some(_) = it.next() {
let t = format!("{}", r.borrow());
assert!(!h.contains_key(&t));
h.insert(t, 1);
} else {
break;
}
}
}
}