1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637
use std::fmt; use indexmap::IndexMap; use instant::{Duration, Instant}; use log::*; use crate::{EGraph, Id, Language, Metadata, RecExpr, Rewrite, SearchMatches}; /// Data generated by running a [`Runner`] one iteration. /// /// If the `serde-1` feature is enabled, this implements /// [`serde::Serialize`][ser], which is useful if you want to output /// this as a JSON or some other format. /// /// [`Runner`]: trait.Runner.html /// [ser]: https://docs.rs/serde/latest/serde/trait.Serialize.html #[derive(Debug, Clone)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize))] #[non_exhaustive] pub struct Iteration { /// The number of enodes in the egraph at the start of this /// iteration. pub egraph_nodes: usize, /// The number of eclasses in the egraph at the start of this /// iteration. pub egraph_classes: usize, /// A map from rule name to number of times it was _newly_ applied /// in this iteration. pub applied: IndexMap<String, usize>, /// Seconds spent searching in this iteration. pub search_time: f64, /// Seconds spent applying rules in this iteration. pub apply_time: f64, /// Seconds spent [`rebuild`](struct.EGraph.html#method.rebuild)ing /// the egraph in this iteration. pub rebuild_time: f64, // TODO optionally put best cost back in there // pub best_cost: Cost, } /// Data generated by running a [`Runner`] to completion. /// /// If the `serde-1` feature is enabled, this implements /// [`serde::Serialize`][ser], which is useful if you want to output /// this as a JSON or some other format. /// /// [`Runner`]: trait.Runner.html /// [ser]: https://docs.rs/serde/latest/serde/trait.Serialize.html #[derive(Debug, Clone)] #[cfg_attr( feature = "serde-1", derive(serde::Serialize), serde(bound(serialize = " L: Language + std::fmt::Display, E: serde::Serialize ")) )] #[non_exhaustive] pub struct RunReport<L, E> { /// The initial expression added to the egraph. pub initial_expr: RecExpr<L>, /// The eclass id of the initial expression added to the egraph. pub initial_expr_eclass: Id, // pub initial_cost: Cost, /// The data generated by each [`Iteration`](struct.Iteration.html). pub iterations: Vec<Iteration>, // pub final_expr: RecExpr<L>, // pub final_cost: Cost, /// The total time spent running rules pub rules_time: f64, // pub extract_time: f64, /// The reason the [`Runner`](trait.Runner.html) stop iterating. pub stop_reason: E, // metrics // pub ast_size: usize, // pub ast_depth: usize, } /** Faciliates running rewrites over an [`EGraph`]. One use for [`EGraph`]s is as the basis of a rewriting system. Since an egraph never "forgets" state when applying a [`Rewrite`], you can apply many rewrites many times quite efficiently. After the egraph is "full" (the rewrites can no longer find new equalities) or some other condition, the egraph compactly represents many, many equivalent expressions. At this point, the egraph is ready for extraction (see [`Extractor`]) which can pick the represented expression that's best according to some cost function. This technique is called [equality saturation](https://www.cs.cornell.edu/~ross/publications/eqsat/) in general. However, there can be many challenges in implementing this "outer loop" of applying rewrites, mostly revolving around which rules to run and when to stop. Implementing the [`Runner`] trait allows you to customize this outer loop in many ways. Many of [`Runner`]s method have default implementation, and these call the various hooks ([`pre_step`], [`during_step`], [`post_step`]) during their operation. [`SimpleRunner`] is `egg`'s provided [`Runner`] that has reasonable defaults and implements many useful things like saturation checking, an egraph size limits, and rule back off. Consider using [`SimpleRunner`] before implementing your own [`Runner`]. [`EGraph`]: struct.EGraph.html [`Extractor`]: struct.Extractor.html [`SimpleRunner`]: struct.SimpleRunner.html [`Runner`]: trait.Runner.html [`pre_step`]: trait.Runner.html#method.pre_step [`during_step`]: trait.Runner.html#method.during_step [`post_step`]: trait.Runner.html#method.post_step */ pub trait Runner<L, M> where L: Language, M: Metadata<L>, { /// The type of an error that should stop the runner. /// /// This will be recorded in /// [`RunReport`](struct.RunReport.html#structfield.stop_reason). type Error: fmt::Debug; // TODO make it so Runners can add fields to Iteration data /// The pre-iteration hook. If this returns an error, then the /// search will stop. Useful for checking stop conditions or /// updating `Runner` state. /// /// Default implementation simply returns `Ok(())`. fn pre_step(&mut self, _egraph: &mut EGraph<L, M>) -> Result<(), Self::Error> { Ok(()) } /// The post-iteration hook. If this returns an error, then the /// search will stop. Useful for checking stop conditions or /// updating `Runner` state. /// /// Default implementation simply returns `Ok(())`. fn post_step( &mut self, _iteration: &Iteration, _egraph: &mut EGraph<L, M>, ) -> Result<(), Self::Error> { Ok(()) } /// The intra-iteration hook. If this returns an error, then the /// search will stop. Useful for checking stop conditions. /// This is called after search each rule and after applying each rule. /// /// Default implementation simply returns `Ok(())`. fn during_step(&mut self, _egraph: &EGraph<L, M>) -> Result<(), Self::Error> { Ok(()) } /// A hook allowing you to customize rewrite search behavior. /// Useful to implement rule management. /// /// Default implementation just calls /// [`Rewrite::search`](struct.Rewrite.html#method.search). fn search_rewrite( &mut self, egraph: &mut EGraph<L, M>, rewrite: &Rewrite<L, M>, ) -> Vec<SearchMatches> { rewrite.search(egraph) } /// A hook allowing you to customize rewrite application behavior. /// Useful to implement rule management. /// /// Default implementation just calls /// [`Rewrite::apply`](struct.Rewrite.html#method.apply) /// and returns number of new applications. fn apply_rewrite( &mut self, egraph: &mut EGraph<L, M>, rewrite: &Rewrite<L, M>, matches: Vec<SearchMatches>, ) -> usize { rewrite.apply(egraph, &matches).len() } /// Run the rewrites once on the egraph. /// /// It first searches all the rules using the [`search_rewrite`] wrapper. /// Then it applies all the rules using the [`apply_rewrite`] wrapper. /// /// ## Expectations /// /// After searching or applying a rule, this should call /// [`during_step`], returning immediately if it returns an error. /// This should _not_ call [`pre_step`] or [`post_step`], those /// should be called by the [`run`] method. /// /// Default implementation just calls /// [`Rewrite::apply`](struct.Rewrite.html#method.apply) /// and returns number of new applications. /// /// ## Default implementation /// /// The default implementation is probably good enough. /// It conforms to all the above expectations, and it performs the /// necessary bookkeeping to return an [`Iteration`]. /// It additionally performs useful logging at the debug and info /// levels. /// If you're using [`env_logger`](https://docs.rs/env_logger/) /// (which the tests of `egg` do), /// see its documentation on how to see the logs. /// /// [`search_rewrite`]: trait.Runner.html#method.search_rewrite /// [`apply_rewrite`]: trait.Runner.html#method.apply_rewrite /// [`pre_step`]: trait.Runner.html#method.pre_step /// [`during_step`]: trait.Runner.html#method.during_step /// [`post_step`]: trait.Runner.html#method.post_step /// [`run`]: trait.Runner.html#method.run /// [`Iteration`]: struct.Iteration.html fn step( &mut self, egraph: &mut EGraph<L, M>, rules: &[Rewrite<L, M>], ) -> Result<Iteration, Self::Error> { let egraph_nodes = egraph.total_size(); let egraph_classes = egraph.number_of_classes(); trace!("EGraph {:?}", egraph.dump()); let search_time = Instant::now(); let mut matches = Vec::new(); for rule in rules.iter() { let ms = self.search_rewrite(egraph, rule); matches.push(ms); self.during_step(egraph)? } let search_time = search_time.elapsed().as_secs_f64(); info!("Search time: {}", search_time); let apply_time = Instant::now(); let mut applied = IndexMap::new(); for (rw, ms) in rules.iter().zip(matches) { let total_matches: usize = ms.iter().map(|m| m.substs.len()).sum(); if total_matches == 0 { continue; } debug!("Applying {} {} times", rw.name(), total_matches); let actually_matched = self.apply_rewrite(egraph, rw, ms); if actually_matched > 0 { if let Some(count) = applied.get_mut(rw.name()) { *count += 1; } else { applied.insert(rw.name().to_owned(), 1); } debug!("Applied {} {} times", rw.name(), actually_matched); } self.during_step(egraph)? } let apply_time = apply_time.elapsed().as_secs_f64(); info!("Apply time: {}", apply_time); let rebuild_time = Instant::now(); egraph.rebuild(); let rebuild_time = rebuild_time.elapsed().as_secs_f64(); info!("Rebuild time: {}", rebuild_time); info!( "Size: n={}, e={}", egraph.total_size(), egraph.number_of_classes() ); trace!("Running post_step..."); Ok(Iteration { applied, egraph_nodes, egraph_classes, search_time, apply_time, rebuild_time, // best_cost, }) } /// Run the rewrites on the egraph until an error. /// /// This should call [`pre_step`], [`step`], and [`post_step`] in /// a loop, in that order, until one of them returns an error. /// It returns the completed [`Iteration`]s and the error that /// caused it to stop. /// /// The default implementation does these things. /// /// [`pre_step`]: trait.Runner.html#method.pre_step /// [`step`]: trait.Runner.html#method.step /// [`post_step`]: trait.Runner.html#method.post_step /// [`Iteration`]: struct.Iteration.html fn run( &mut self, egraph: &mut EGraph<L, M>, rules: &[Rewrite<L, M>], ) -> (Vec<Iteration>, Self::Error) { let mut iterations = vec![]; let mut fn_loop = || -> Result<(), Self::Error> { loop { trace!("Running pre_step..."); self.pre_step(egraph)?; trace!("Running step..."); iterations.push(self.step(egraph, rules)?); trace!("Running post_step..."); self.post_step(iterations.last().unwrap(), egraph)?; } }; let stop_reason = fn_loop().unwrap_err(); info!("Stopping {:?}", stop_reason); (iterations, stop_reason) } /// Given an initial expression, make and egraph and [`run`] the /// rules on it. /// /// The default implementation does exactly this, also performing /// the bookkeeping needed to return a [`RunReport`]. /// /// [`run`]: trait.Runner.html#method.run /// [`RunReport`]: struct.RunReport.html fn run_expr( &mut self, initial_expr: RecExpr<L>, rules: &[Rewrite<L, M>], ) -> (EGraph<L, M>, RunReport<L, Self::Error>) { // let initial_cost = calculate_cost(&initial_expr); // info!("Without empty: {}", initial_expr.pretty(80)); let (mut egraph, initial_expr_eclass) = EGraph::from_expr(&initial_expr); let rules_time = Instant::now(); let (iterations, stop_reason) = self.run(&mut egraph, rules); let rules_time = rules_time.elapsed().as_secs_f64(); // let extract_time = Instant::now(); // let best = Extractor::new(&egraph).find_best(root); // let extract_time = extract_time.elapsed().as_secs_f64(); // info!("Extract time: {}", extract_time); // info!("Initial cost: {}", initial_cost); // info!("Final cost: {}", best.cost); // info!("Final: {}", best.expr.pretty(80)); let report = RunReport { iterations, rules_time, // extract_time, stop_reason, // ast_size: best.expr.ast_size(), // ast_depth: best.expr.ast_depth(), initial_expr, initial_expr_eclass: egraph.find(initial_expr_eclass), // initial_cost, // final_cost: best.cost, // final_expr: best.expr, }; (egraph, report) } } /** A reasonable default [`Runner`]. [`SimpleRunner`] is a [`Runner`], so it runs rewrites over an [`EGraph`]. This implementation offers several conveniences to prevent rewriting from behaving badly and eating your computer: - Saturation checking [`SimpleRunner`] checks to see if any of the rules added anything new to the [`EGraph`]. If none did, then it stops, returning [`SimpleRunnerError::Saturated`](enum.SimpleRunnerError.html#variant.Saturated). - Iteration limits You can set a upper limit of iterations to do in case the search doesn't stop for some other reason. If this limit is hit, it stops with [`SimpleRunnerError::IterationLimit`](enum.SimpleRunnerError.html#variant.IterationLimit). - [`EGraph`] size limit You can set a upper limit on the number of enodes in the egraph. If this limit is hit, it stops with [`SimpleRunnerError::NodeLimit`](enum.SimpleRunnerError.html#variant.NodeLimit). - Time limit You can set a time limit on the runner. If this limit is hit, it stops with [`SimpleRunnerError::TimeLimit`](enum.SimpleRunnerError.html#variant.TimeLimit). - Rule backoff Some rules enable themselves, blowing up the [`EGraph`] and preventing other rewrites from running as many times. To prevent this, [`SimpleRunner`] implements exponentional rule backoff. For each rewrite, there exists a configurable initial match limit. If a rewrite search yield more than this limit, then we ban this rule for number of iterations, double its limit, and double the time it will be banned next time. This seems effective at preventing explosive rules like associativity from taking an unfair amount of resources. [`SimpleRunner`]: struct.SimpleRunner.html [`Runner`]: trait.Runner.html [`EGraph`]: struct.EGraph.html # Example ``` use egg::{*, rewrite as rw}; define_language! { enum SimpleLanguage { Num(i32), Add = "+", Mul = "*", Symbol(String), } } let rules: &[Rewrite<SimpleLanguage, ()>] = &[ rw!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"), rw!("commute-mul"; "(* ?a ?b)" => "(* ?b ?a)"), rw!("add-0"; "(+ ?a 0)" => "?a"), rw!("mul-0"; "(* ?a 0)" => "0"), rw!("mul-1"; "(* ?a 1)" => "?a"), ]; let start = "(+ 0 (* 1 foo))".parse().unwrap(); // SimpleRunner is customizable in the builder pattern style. let (egraph, report) = SimpleRunner::default() .with_iter_limit(10) .with_node_limit(10_000) .run_expr(start, &rules); println!( "Stopped after {} iterations, reason: {:?}", report.iterations.len(), report.stop_reason ); ``` */ pub struct SimpleRunner { iter_limit: usize, node_limit: usize, time_limit: Duration, start_time: Instant, i: usize, stats: IndexMap<String, RuleStats>, initial_match_limit: usize, ban_length: usize, } struct RuleStats { times_applied: usize, banned_until: usize, times_banned: usize, } impl Default for SimpleRunner { fn default() -> Self { Self { iter_limit: 30, node_limit: 10_000, i: 0, start_time: Instant::now(), time_limit: Duration::from_secs(60), stats: Default::default(), initial_match_limit: 1_000, ban_length: 5, } } } impl SimpleRunner { /// Sets the iteration limit. Default: 30 pub fn with_iter_limit(self, iter_limit: usize) -> Self { Self { iter_limit, ..self } } /// Sets the egraph size limit (in enodes). Default: 10,000 pub fn with_node_limit(self, node_limit: usize) -> Self { Self { node_limit, ..self } } /// Sets the runner time limit. Default: 60 seconds pub fn with_time_limit(self, time_limit: Duration) -> Self { Self { time_limit, ..self } } /// Sets the initial match limit before a rule is banned. Default: 1,000 /// /// Setting this to a really big number will effectively disable /// rule backoff. pub fn with_initial_match_limit(self, initial_match_limit: usize) -> Self { Self { initial_match_limit, ..self } } } /// Error returned by [`SimpleRunner`] when it stops. /// /// [`SimpleRunner`]: struct.SimpleRunner.html #[derive(Debug, Clone)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize))] pub enum SimpleRunnerError { /// The egraph saturated, i.e., there was an iteration where we /// didn't learn anything new from applying the rules. Saturated, /// The iteration limit was hit. The data is the iteration limit. IterationLimit(usize), /// The enode limit was hit. The data is the enode limit. NodeLimit(usize), /// The time limit was hit. The data is the time limit in seconds. TimeLimit(f64), } impl<L, M> Runner<L, M> for SimpleRunner where L: Language, M: Metadata<L>, { type Error = SimpleRunnerError; fn pre_step(&mut self, egraph: &mut EGraph<L, M>) -> Result<(), Self::Error> { info!( "\n\nIteration {}, n={}, e={}", self.i, egraph.total_size(), egraph.number_of_classes() ); if self.i >= self.iter_limit { Err(SimpleRunnerError::IterationLimit(self.i)) } else { Ok(()) } } fn during_step(&mut self, egraph: &EGraph<L, M>) -> Result<(), Self::Error> { let size = egraph.total_size(); let elapsed = self.start_time.elapsed(); if size > self.node_limit { Err(SimpleRunnerError::NodeLimit(size)) } else if elapsed > self.time_limit { Err(SimpleRunnerError::TimeLimit(elapsed.as_secs_f64())) } else { Ok(()) } } fn post_step( &mut self, iteration: &Iteration, _egraph: &mut EGraph<L, M>, ) -> Result<(), Self::Error> { let is_banned = |s: &RuleStats| s.banned_until > self.i; let any_bans = self.stats.values().any(is_banned); self.i += 1; if !any_bans && iteration.applied.is_empty() { Err(SimpleRunnerError::Saturated) } else { Ok(()) } } fn search_rewrite( &mut self, egraph: &mut EGraph<L, M>, rewrite: &Rewrite<L, M>, ) -> Vec<SearchMatches> { if let Some(limit) = self.stats.get_mut(rewrite.name()) { if self.i < limit.banned_until { debug!( "Skipping {} ({}-{}), banned until {}...", rewrite.name(), limit.times_applied, limit.times_banned, limit.banned_until, ); return vec![]; } let matches = rewrite.search(egraph); let total_len: usize = matches.iter().map(|m| m.substs.len()).sum(); let threshold = self.initial_match_limit << limit.times_banned; if total_len > threshold { let ban_length = self.ban_length << limit.times_banned; limit.times_banned += 1; limit.banned_until = self.i + ban_length; info!( "Banning {} ({}-{}) for {} iters: {} < {}", rewrite.name(), limit.times_applied, limit.times_banned, ban_length, threshold, total_len, ); vec![] } else { limit.times_applied += 1; matches } } else { self.stats.insert( rewrite.name().into(), RuleStats { times_applied: 0, banned_until: 0, times_banned: 0, }, ); rewrite.search(egraph) } } }