1use crate::lcnf::*;
6use std::collections::{HashMap, HashSet};
7
8use super::functions::*;
9use std::collections::VecDeque;
10
11#[derive(Debug, Clone)]
13pub struct JoinPointConfig {
14 pub max_join_size: usize,
16 pub inline_small_joins: bool,
18 pub detect_tail_calls: bool,
20 pub enable_contification: bool,
22 pub float_join_points: bool,
24 pub eliminate_dead_joins: bool,
26 pub max_iterations: usize,
28}
29#[derive(Debug, Clone, Default)]
31pub struct OJoinConfig {
32 pub(super) entries: std::collections::HashMap<String, String>,
33}
34impl OJoinConfig {
35 pub fn new() -> Self {
36 OJoinConfig::default()
37 }
38 pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) {
39 self.entries.insert(key.into(), value.into());
40 }
41 pub fn get(&self, key: &str) -> Option<&str> {
42 self.entries.get(key).map(|s| s.as_str())
43 }
44 pub fn get_bool(&self, key: &str) -> bool {
45 matches!(self.get(key), Some("true") | Some("1") | Some("yes"))
46 }
47 pub fn get_int(&self, key: &str) -> Option<i64> {
48 self.get(key)?.parse().ok()
49 }
50 pub fn len(&self) -> usize {
51 self.entries.len()
52 }
53 pub fn is_empty(&self) -> bool {
54 self.entries.is_empty()
55 }
56}
57#[derive(Debug, Clone, Default)]
59pub struct OJoinFeatures {
60 pub(super) flags: std::collections::HashSet<String>,
61}
62impl OJoinFeatures {
63 pub fn new() -> Self {
64 OJoinFeatures::default()
65 }
66 pub fn enable(&mut self, flag: impl Into<String>) {
67 self.flags.insert(flag.into());
68 }
69 pub fn disable(&mut self, flag: &str) {
70 self.flags.remove(flag);
71 }
72 pub fn is_enabled(&self, flag: &str) -> bool {
73 self.flags.contains(flag)
74 }
75 pub fn len(&self) -> usize {
76 self.flags.len()
77 }
78 pub fn is_empty(&self) -> bool {
79 self.flags.is_empty()
80 }
81 pub fn union(&self, other: &OJoinFeatures) -> OJoinFeatures {
82 OJoinFeatures {
83 flags: self.flags.union(&other.flags).cloned().collect(),
84 }
85 }
86 pub fn intersection(&self, other: &OJoinFeatures) -> OJoinFeatures {
87 OJoinFeatures {
88 flags: self.flags.intersection(&other.flags).cloned().collect(),
89 }
90 }
91}
92#[allow(dead_code)]
93pub struct OJConstantFoldingHelper;
94impl OJConstantFoldingHelper {
95 #[allow(dead_code)]
96 pub fn fold_add_i64(a: i64, b: i64) -> Option<i64> {
97 a.checked_add(b)
98 }
99 #[allow(dead_code)]
100 pub fn fold_sub_i64(a: i64, b: i64) -> Option<i64> {
101 a.checked_sub(b)
102 }
103 #[allow(dead_code)]
104 pub fn fold_mul_i64(a: i64, b: i64) -> Option<i64> {
105 a.checked_mul(b)
106 }
107 #[allow(dead_code)]
108 pub fn fold_div_i64(a: i64, b: i64) -> Option<i64> {
109 if b == 0 {
110 None
111 } else {
112 a.checked_div(b)
113 }
114 }
115 #[allow(dead_code)]
116 pub fn fold_add_f64(a: f64, b: f64) -> f64 {
117 a + b
118 }
119 #[allow(dead_code)]
120 pub fn fold_mul_f64(a: f64, b: f64) -> f64 {
121 a * b
122 }
123 #[allow(dead_code)]
124 pub fn fold_neg_i64(a: i64) -> Option<i64> {
125 a.checked_neg()
126 }
127 #[allow(dead_code)]
128 pub fn fold_not_bool(a: bool) -> bool {
129 !a
130 }
131 #[allow(dead_code)]
132 pub fn fold_and_bool(a: bool, b: bool) -> bool {
133 a && b
134 }
135 #[allow(dead_code)]
136 pub fn fold_or_bool(a: bool, b: bool) -> bool {
137 a || b
138 }
139 #[allow(dead_code)]
140 pub fn fold_shl_i64(a: i64, b: u32) -> Option<i64> {
141 a.checked_shl(b)
142 }
143 #[allow(dead_code)]
144 pub fn fold_shr_i64(a: i64, b: u32) -> Option<i64> {
145 a.checked_shr(b)
146 }
147 #[allow(dead_code)]
148 pub fn fold_rem_i64(a: i64, b: i64) -> Option<i64> {
149 if b == 0 {
150 None
151 } else {
152 Some(a % b)
153 }
154 }
155 #[allow(dead_code)]
156 pub fn fold_bitand_i64(a: i64, b: i64) -> i64 {
157 a & b
158 }
159 #[allow(dead_code)]
160 pub fn fold_bitor_i64(a: i64, b: i64) -> i64 {
161 a | b
162 }
163 #[allow(dead_code)]
164 pub fn fold_bitxor_i64(a: i64, b: i64) -> i64 {
165 a ^ b
166 }
167 #[allow(dead_code)]
168 pub fn fold_bitnot_i64(a: i64) -> i64 {
169 !a
170 }
171}
172#[derive(Debug, Clone)]
174pub struct CallSiteInfo {
175 pub(super) caller: String,
177 pub(super) is_tail: bool,
179 pub(super) arg_count: usize,
181 pub(super) callee_var: Option<LcnfVarId>,
183}
184#[allow(dead_code)]
185#[derive(Debug, Clone)]
186pub struct OJDepGraph {
187 pub(super) nodes: Vec<u32>,
188 pub(super) edges: Vec<(u32, u32)>,
189}
190impl OJDepGraph {
191 #[allow(dead_code)]
192 pub fn new() -> Self {
193 OJDepGraph {
194 nodes: Vec::new(),
195 edges: Vec::new(),
196 }
197 }
198 #[allow(dead_code)]
199 pub fn add_node(&mut self, id: u32) {
200 if !self.nodes.contains(&id) {
201 self.nodes.push(id);
202 }
203 }
204 #[allow(dead_code)]
205 pub fn add_dep(&mut self, dep: u32, dependent: u32) {
206 self.add_node(dep);
207 self.add_node(dependent);
208 self.edges.push((dep, dependent));
209 }
210 #[allow(dead_code)]
211 pub fn dependents_of(&self, node: u32) -> Vec<u32> {
212 self.edges
213 .iter()
214 .filter(|(d, _)| *d == node)
215 .map(|(_, dep)| *dep)
216 .collect()
217 }
218 #[allow(dead_code)]
219 pub fn dependencies_of(&self, node: u32) -> Vec<u32> {
220 self.edges
221 .iter()
222 .filter(|(_, dep)| *dep == node)
223 .map(|(d, _)| *d)
224 .collect()
225 }
226 #[allow(dead_code)]
227 pub fn topological_sort(&self) -> Vec<u32> {
228 let mut in_degree: std::collections::HashMap<u32, u32> = std::collections::HashMap::new();
229 for &n in &self.nodes {
230 in_degree.insert(n, 0);
231 }
232 for (_, dep) in &self.edges {
233 *in_degree.entry(*dep).or_insert(0) += 1;
234 }
235 let mut queue: std::collections::VecDeque<u32> = self
236 .nodes
237 .iter()
238 .filter(|&&n| in_degree[&n] == 0)
239 .copied()
240 .collect();
241 let mut result = Vec::new();
242 while let Some(node) = queue.pop_front() {
243 result.push(node);
244 for dep in self.dependents_of(node) {
245 let cnt = in_degree.entry(dep).or_insert(0);
246 *cnt = cnt.saturating_sub(1);
247 if *cnt == 0 {
248 queue.push_back(dep);
249 }
250 }
251 }
252 result
253 }
254 #[allow(dead_code)]
255 pub fn has_cycle(&self) -> bool {
256 self.topological_sort().len() < self.nodes.len()
257 }
258}
259#[derive(Debug, Clone, Default)]
261pub struct JoinPointStats {
262 pub joins_created: usize,
264 pub joins_inlined: usize,
266 pub joins_eliminated: usize,
268 pub tail_calls_detected: usize,
270 pub functions_contified: usize,
272 pub joins_floated: usize,
274 pub iterations: usize,
276}
277impl JoinPointStats {
278 pub(super) fn total_changes(&self) -> usize {
279 self.joins_created
280 + self.joins_inlined
281 + self.joins_eliminated
282 + self.tail_calls_detected
283 + self.functions_contified
284 + self.joins_floated
285 }
286}
287#[derive(Debug, Default)]
289pub struct OJoinSourceBuffer {
290 pub(super) buf: String,
291 pub(super) indent_level: usize,
292 pub(super) indent_str: String,
293}
294impl OJoinSourceBuffer {
295 pub fn new() -> Self {
296 OJoinSourceBuffer {
297 buf: String::new(),
298 indent_level: 0,
299 indent_str: " ".to_string(),
300 }
301 }
302 pub fn with_indent(mut self, indent: impl Into<String>) -> Self {
303 self.indent_str = indent.into();
304 self
305 }
306 pub fn push_line(&mut self, line: &str) {
307 for _ in 0..self.indent_level {
308 self.buf.push_str(&self.indent_str);
309 }
310 self.buf.push_str(line);
311 self.buf.push('\n');
312 }
313 pub fn push_raw(&mut self, s: &str) {
314 self.buf.push_str(s);
315 }
316 pub fn indent(&mut self) {
317 self.indent_level += 1;
318 }
319 pub fn dedent(&mut self) {
320 self.indent_level = self.indent_level.saturating_sub(1);
321 }
322 pub fn as_str(&self) -> &str {
323 &self.buf
324 }
325 pub fn len(&self) -> usize {
326 self.buf.len()
327 }
328 pub fn is_empty(&self) -> bool {
329 self.buf.is_empty()
330 }
331 pub fn line_count(&self) -> usize {
332 self.buf.lines().count()
333 }
334 pub fn into_string(self) -> String {
335 self.buf
336 }
337 pub fn reset(&mut self) {
338 self.buf.clear();
339 self.indent_level = 0;
340 }
341}
342#[derive(Debug, Default)]
344pub struct OJoinNameScope {
345 pub(super) declared: std::collections::HashSet<String>,
346 pub(super) depth: usize,
347 pub(super) parent: Option<Box<OJoinNameScope>>,
348}
349impl OJoinNameScope {
350 pub fn new() -> Self {
351 OJoinNameScope::default()
352 }
353 pub fn declare(&mut self, name: impl Into<String>) -> bool {
354 self.declared.insert(name.into())
355 }
356 pub fn is_declared(&self, name: &str) -> bool {
357 self.declared.contains(name)
358 }
359 pub fn push_scope(self) -> Self {
360 OJoinNameScope {
361 declared: std::collections::HashSet::new(),
362 depth: self.depth + 1,
363 parent: Some(Box::new(self)),
364 }
365 }
366 pub fn pop_scope(self) -> Self {
367 *self.parent.unwrap_or_default()
368 }
369 pub fn depth(&self) -> usize {
370 self.depth
371 }
372 pub fn len(&self) -> usize {
373 self.declared.len()
374 }
375}
376#[allow(dead_code)]
377#[derive(Debug, Clone)]
378pub struct OJAnalysisCache {
379 pub(super) entries: std::collections::HashMap<String, OJCacheEntry>,
380 pub(super) max_size: usize,
381 pub(super) hits: u64,
382 pub(super) misses: u64,
383}
384impl OJAnalysisCache {
385 #[allow(dead_code)]
386 pub fn new(max_size: usize) -> Self {
387 OJAnalysisCache {
388 entries: std::collections::HashMap::new(),
389 max_size,
390 hits: 0,
391 misses: 0,
392 }
393 }
394 #[allow(dead_code)]
395 pub fn get(&mut self, key: &str) -> Option<&OJCacheEntry> {
396 if self.entries.contains_key(key) {
397 self.hits += 1;
398 self.entries.get(key)
399 } else {
400 self.misses += 1;
401 None
402 }
403 }
404 #[allow(dead_code)]
405 pub fn insert(&mut self, key: String, data: Vec<u8>) {
406 if self.entries.len() >= self.max_size {
407 if let Some(oldest) = self.entries.keys().next().cloned() {
408 self.entries.remove(&oldest);
409 }
410 }
411 self.entries.insert(
412 key.clone(),
413 OJCacheEntry {
414 key,
415 data,
416 timestamp: 0,
417 valid: true,
418 },
419 );
420 }
421 #[allow(dead_code)]
422 pub fn invalidate(&mut self, key: &str) {
423 if let Some(entry) = self.entries.get_mut(key) {
424 entry.valid = false;
425 }
426 }
427 #[allow(dead_code)]
428 pub fn clear(&mut self) {
429 self.entries.clear();
430 }
431 #[allow(dead_code)]
432 pub fn hit_rate(&self) -> f64 {
433 let total = self.hits + self.misses;
434 if total == 0 {
435 return 0.0;
436 }
437 self.hits as f64 / total as f64
438 }
439 #[allow(dead_code)]
440 pub fn size(&self) -> usize {
441 self.entries.len()
442 }
443}
444#[derive(Debug, Default)]
446pub struct OJoinIdGen {
447 pub(super) next: u32,
448}
449impl OJoinIdGen {
450 pub fn new() -> Self {
451 OJoinIdGen::default()
452 }
453 pub fn next_id(&mut self) -> u32 {
454 let id = self.next;
455 self.next += 1;
456 id
457 }
458 pub fn peek_next(&self) -> u32 {
459 self.next
460 }
461 pub fn reset(&mut self) {
462 self.next = 0;
463 }
464 pub fn skip(&mut self, n: u32) {
465 self.next += n;
466 }
467}
468#[derive(Debug, Clone, PartialEq, Eq, Hash)]
470pub struct OJoinIncrKey {
471 pub content_hash: u64,
472 pub config_hash: u64,
473}
474impl OJoinIncrKey {
475 pub fn new(content: u64, config: u64) -> Self {
476 OJoinIncrKey {
477 content_hash: content,
478 config_hash: config,
479 }
480 }
481 pub fn combined_hash(&self) -> u64 {
482 self.content_hash.wrapping_mul(0x9e3779b97f4a7c15) ^ self.config_hash
483 }
484 pub fn matches(&self, other: &OJoinIncrKey) -> bool {
485 self.content_hash == other.content_hash && self.config_hash == other.config_hash
486 }
487}
488#[derive(Debug, Clone)]
490pub struct OJoinDiagMsg {
491 pub severity: OJoinDiagSeverity,
492 pub pass: String,
493 pub message: String,
494}
495impl OJoinDiagMsg {
496 pub fn error(pass: impl Into<String>, msg: impl Into<String>) -> Self {
497 OJoinDiagMsg {
498 severity: OJoinDiagSeverity::Error,
499 pass: pass.into(),
500 message: msg.into(),
501 }
502 }
503 pub fn warning(pass: impl Into<String>, msg: impl Into<String>) -> Self {
504 OJoinDiagMsg {
505 severity: OJoinDiagSeverity::Warning,
506 pass: pass.into(),
507 message: msg.into(),
508 }
509 }
510 pub fn note(pass: impl Into<String>, msg: impl Into<String>) -> Self {
511 OJoinDiagMsg {
512 severity: OJoinDiagSeverity::Note,
513 pass: pass.into(),
514 message: msg.into(),
515 }
516 }
517}
518#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
520pub struct OJoinVersion {
521 pub major: u32,
522 pub minor: u32,
523 pub patch: u32,
524 pub pre: Option<String>,
525}
526impl OJoinVersion {
527 pub fn new(major: u32, minor: u32, patch: u32) -> Self {
528 OJoinVersion {
529 major,
530 minor,
531 patch,
532 pre: None,
533 }
534 }
535 pub fn with_pre(mut self, pre: impl Into<String>) -> Self {
536 self.pre = Some(pre.into());
537 self
538 }
539 pub fn is_stable(&self) -> bool {
540 self.pre.is_none()
541 }
542 pub fn is_compatible_with(&self, other: &OJoinVersion) -> bool {
543 self.major == other.major && self.minor >= other.minor
544 }
545}
546#[derive(Debug, Default)]
548pub struct OJoinDiagCollector {
549 pub(super) msgs: Vec<OJoinDiagMsg>,
550}
551impl OJoinDiagCollector {
552 pub fn new() -> Self {
553 OJoinDiagCollector::default()
554 }
555 pub fn emit(&mut self, d: OJoinDiagMsg) {
556 self.msgs.push(d);
557 }
558 pub fn has_errors(&self) -> bool {
559 self.msgs
560 .iter()
561 .any(|d| d.severity == OJoinDiagSeverity::Error)
562 }
563 pub fn errors(&self) -> Vec<&OJoinDiagMsg> {
564 self.msgs
565 .iter()
566 .filter(|d| d.severity == OJoinDiagSeverity::Error)
567 .collect()
568 }
569 pub fn warnings(&self) -> Vec<&OJoinDiagMsg> {
570 self.msgs
571 .iter()
572 .filter(|d| d.severity == OJoinDiagSeverity::Warning)
573 .collect()
574 }
575 pub fn len(&self) -> usize {
576 self.msgs.len()
577 }
578 pub fn is_empty(&self) -> bool {
579 self.msgs.is_empty()
580 }
581 pub fn clear(&mut self) {
582 self.msgs.clear();
583 }
584}
585#[derive(Debug, Default)]
587pub struct OJoinProfiler {
588 pub(super) timings: Vec<OJoinPassTiming>,
589}
590impl OJoinProfiler {
591 pub fn new() -> Self {
592 OJoinProfiler::default()
593 }
594 pub fn record(&mut self, t: OJoinPassTiming) {
595 self.timings.push(t);
596 }
597 pub fn total_elapsed_us(&self) -> u64 {
598 self.timings.iter().map(|t| t.elapsed_us).sum()
599 }
600 pub fn slowest_pass(&self) -> Option<&OJoinPassTiming> {
601 self.timings.iter().max_by_key(|t| t.elapsed_us)
602 }
603 pub fn num_passes(&self) -> usize {
604 self.timings.len()
605 }
606 pub fn profitable_passes(&self) -> Vec<&OJoinPassTiming> {
607 self.timings.iter().filter(|t| t.is_profitable()).collect()
608 }
609}
610#[allow(dead_code)]
611pub struct OJPassRegistry {
612 pub(super) configs: Vec<OJPassConfig>,
613 pub(super) stats: std::collections::HashMap<String, OJPassStats>,
614}
615impl OJPassRegistry {
616 #[allow(dead_code)]
617 pub fn new() -> Self {
618 OJPassRegistry {
619 configs: Vec::new(),
620 stats: std::collections::HashMap::new(),
621 }
622 }
623 #[allow(dead_code)]
624 pub fn register(&mut self, config: OJPassConfig) {
625 self.stats
626 .insert(config.pass_name.clone(), OJPassStats::new());
627 self.configs.push(config);
628 }
629 #[allow(dead_code)]
630 pub fn enabled_passes(&self) -> Vec<&OJPassConfig> {
631 self.configs.iter().filter(|c| c.enabled).collect()
632 }
633 #[allow(dead_code)]
634 pub fn get_stats(&self, name: &str) -> Option<&OJPassStats> {
635 self.stats.get(name)
636 }
637 #[allow(dead_code)]
638 pub fn total_passes(&self) -> usize {
639 self.configs.len()
640 }
641 #[allow(dead_code)]
642 pub fn enabled_count(&self) -> usize {
643 self.enabled_passes().len()
644 }
645 #[allow(dead_code)]
646 pub fn update_stats(&mut self, name: &str, changes: u64, time_ms: u64, iter: u32) {
647 if let Some(stats) = self.stats.get_mut(name) {
648 stats.record_run(changes, time_ms, iter);
649 }
650 }
651}
652pub struct JoinPointOptimizer {
654 pub(super) config: JoinPointConfig,
655 pub(super) stats: JoinPointStats,
656 pub(super) next_id: u64,
657}
658impl JoinPointOptimizer {
659 pub fn new(config: JoinPointConfig) -> Self {
661 JoinPointOptimizer {
662 config,
663 stats: JoinPointStats::default(),
664 next_id: 1000,
665 }
666 }
667 pub fn stats(&self) -> &JoinPointStats {
669 &self.stats
670 }
671 pub(super) fn fresh_id(&mut self) -> LcnfVarId {
673 let id = self.next_id;
674 self.next_id += 1;
675 LcnfVarId(id)
676 }
677 pub(super) fn optimize_decl(&mut self, decl: &mut LcnfFunDecl) {
679 for _ in 0..self.config.max_iterations {
680 let changes_before = self.stats.total_changes();
681 if self.config.detect_tail_calls {
682 self.detect_tail_calls_in_expr(&mut decl.body, &decl.name);
683 }
684 if self.config.inline_small_joins {
685 self.inline_small_joins(&mut decl.body);
686 }
687 if self.config.eliminate_dead_joins {
688 self.eliminate_dead_joins(&mut decl.body);
689 }
690 if self.config.enable_contification {
691 self.contify_functions(&mut decl.body);
692 }
693 if self.config.float_join_points {
694 self.float_joins(&mut decl.body);
695 }
696 self.stats.iterations += 1;
697 if self.stats.total_changes() == changes_before {
698 break;
699 }
700 }
701 }
702 pub(super) fn detect_tail_calls_in_expr(&mut self, expr: &mut LcnfExpr, _current_fn: &str) {
704 let should_convert = if let LcnfExpr::Let {
705 id,
706 value: LcnfLetValue::App(func, args),
707 body,
708 ..
709 } = &*expr
710 {
711 if let LcnfExpr::Return(LcnfArg::Var(ret_var)) = body.as_ref() {
712 if *ret_var == *id {
713 Some((func.clone(), args.clone()))
714 } else {
715 None
716 }
717 } else {
718 None
719 }
720 } else {
721 None
722 };
723 if let Some((func, args)) = should_convert {
724 *expr = LcnfExpr::TailCall(func, args);
725 self.stats.tail_calls_detected += 1;
726 return;
727 }
728 match expr {
729 LcnfExpr::Let { body, .. } => {
730 self.detect_tail_calls_in_expr(body, _current_fn);
731 }
732 LcnfExpr::Case { alts, default, .. } => {
733 for alt in alts.iter_mut() {
734 self.detect_tail_calls_in_expr(&mut alt.body, _current_fn);
735 }
736 if let Some(def) = default {
737 self.detect_tail_calls_in_expr(def, _current_fn);
738 }
739 }
740 LcnfExpr::Return(_) | LcnfExpr::Unreachable | LcnfExpr::TailCall(_, _) => {}
741 }
742 }
743 pub(super) fn inline_small_joins(&mut self, expr: &mut LcnfExpr) {
745 let small_joins = self.find_small_joins(expr);
746 if !small_joins.is_empty() {
747 self.apply_join_inlining(expr, &small_joins);
748 }
749 }
750 pub(super) fn find_small_joins(&self, expr: &LcnfExpr) -> HashMap<LcnfVarId, LcnfLetValue> {
752 let mut joins = HashMap::new();
753 match expr {
754 LcnfExpr::Let {
755 id, value, body, ..
756 } => {
757 let size = self.value_size(value);
758 if size <= self.config.max_join_size {
759 joins.insert(*id, value.clone());
760 }
761 joins.extend(self.find_small_joins(body));
762 }
763 LcnfExpr::Case { alts, default, .. } => {
764 for alt in alts {
765 joins.extend(self.find_small_joins(&alt.body));
766 }
767 if let Some(def) = default {
768 joins.extend(self.find_small_joins(def));
769 }
770 }
771 _ => {}
772 }
773 joins
774 }
775 pub(super) fn value_size(&self, value: &LcnfLetValue) -> usize {
777 match value {
778 LcnfLetValue::Lit(_)
779 | LcnfLetValue::Erased
780 | LcnfLetValue::FVar(_)
781 | LcnfLetValue::Reset(_)
782 | LcnfLetValue::Reuse(_, _, _, _) => 1,
783 LcnfLetValue::Proj(_, _, _) => 1,
784 LcnfLetValue::App(_, args) => 1 + args.len(),
785 LcnfLetValue::Ctor(_, _, args) => 1 + args.len(),
786 }
787 }
788 pub(super) fn apply_join_inlining(
790 &mut self,
791 expr: &mut LcnfExpr,
792 joins: &HashMap<LcnfVarId, LcnfLetValue>,
793 ) {
794 match expr {
795 LcnfExpr::Let {
796 id, value, body, ..
797 } => {
798 if let LcnfLetValue::FVar(ref fvar) = value {
799 if let Some(replacement) = joins.get(fvar) {
800 if *id != *fvar {
801 *value = replacement.clone();
802 self.stats.joins_inlined += 1;
803 }
804 }
805 }
806 self.apply_join_inlining(body, joins);
807 }
808 LcnfExpr::Case { alts, default, .. } => {
809 for alt in alts.iter_mut() {
810 self.apply_join_inlining(&mut alt.body, joins);
811 }
812 if let Some(def) = default {
813 self.apply_join_inlining(def, joins);
814 }
815 }
816 _ => {}
817 }
818 }
819 pub(super) fn eliminate_dead_joins(&mut self, expr: &mut LcnfExpr) {
821 let used = collect_used_vars(expr);
822 self.remove_dead_lets(expr, &used);
823 }
824 pub(super) fn remove_dead_lets(&mut self, expr: &mut LcnfExpr, used: &HashSet<LcnfVarId>) {
826 loop {
827 let mut changed = false;
828 if let LcnfExpr::Let {
829 id, value, body, ..
830 } = expr
831 {
832 if !used.contains(id) && is_pure_value(value) {
833 *expr = *body.clone();
834 self.stats.joins_eliminated += 1;
835 changed = true;
836 }
837 }
838 if !changed {
839 break;
840 }
841 }
842 match expr {
843 LcnfExpr::Let { body, .. } => {
844 self.remove_dead_lets(body, used);
845 }
846 LcnfExpr::Case { alts, default, .. } => {
847 for alt in alts.iter_mut() {
848 self.remove_dead_lets(&mut alt.body, used);
849 }
850 if let Some(def) = default {
851 self.remove_dead_lets(def, used);
852 }
853 }
854 _ => {}
855 }
856 }
857 pub(super) fn contify_functions(&mut self, expr: &mut LcnfExpr) {
859 let tail_uses = analyze_tail_uses(expr, true);
860 let candidates: Vec<LcnfVarId> = tail_uses
861 .iter()
862 .filter(|(_, use_kind)| **use_kind == TailUse::TailOnly)
863 .map(|(var, _)| *var)
864 .collect();
865 if !candidates.is_empty() {
866 self.mark_contified(expr, &candidates);
867 }
868 }
869 pub(super) fn mark_contified(&mut self, expr: &mut LcnfExpr, candidates: &[LcnfVarId]) {
871 match expr {
872 LcnfExpr::Let { id, body, .. } => {
873 if candidates.contains(id) {
874 self.stats.functions_contified += 1;
875 }
876 self.mark_contified(body, candidates);
877 }
878 LcnfExpr::Case { alts, default, .. } => {
879 for alt in alts.iter_mut() {
880 self.mark_contified(&mut alt.body, candidates);
881 }
882 if let Some(def) = default {
883 self.mark_contified(def, candidates);
884 }
885 }
886 _ => {}
887 }
888 }
889 pub(super) fn float_joins(&mut self, expr: &mut LcnfExpr) {
891 let moved = self.try_float_into_case(expr);
892 if moved {
893 self.stats.joins_floated += 1;
894 }
895 match expr {
896 LcnfExpr::Let { body, .. } => {
897 self.float_joins(body);
898 }
899 LcnfExpr::Case { alts, default, .. } => {
900 for alt in alts.iter_mut() {
901 self.float_joins(&mut alt.body);
902 }
903 if let Some(def) = default {
904 self.float_joins(def);
905 }
906 }
907 _ => {}
908 }
909 }
910 pub(super) fn try_float_into_case(&mut self, expr: &mut LcnfExpr) -> bool {
919 let can_float = if let LcnfExpr::Let { id, body, .. } = &*expr {
920 if let LcnfExpr::Case { alts, default, .. } = body.as_ref() {
921 let use_count = alts.iter().filter(|a| expr_uses_var(&a.body, *id)).count()
922 + default
923 .as_ref()
924 .map(|d| usize::from(expr_uses_var(d, *id)))
925 .unwrap_or(0);
926 use_count == 1
927 } else {
928 false
929 }
930 } else {
931 false
932 };
933 if !can_float {
934 return false;
935 }
936 let old = std::mem::replace(expr, LcnfExpr::Unreachable);
937 if let LcnfExpr::Let {
938 id,
939 name,
940 ty,
941 value,
942 body,
943 } = old
944 {
945 if let LcnfExpr::Case {
946 scrutinee,
947 scrutinee_ty,
948 mut alts,
949 mut default,
950 } = *body
951 {
952 if let Some(idx) = alts.iter().position(|a| expr_uses_var(&a.body, id)) {
953 let old_body = std::mem::replace(&mut alts[idx].body, LcnfExpr::Unreachable);
954 alts[idx].body = LcnfExpr::Let {
955 id,
956 name,
957 ty,
958 value,
959 body: Box::new(old_body),
960 };
961 } else if let Some(def) = default.take() {
962 default = Some(Box::new(LcnfExpr::Let {
963 id,
964 name,
965 ty,
966 value,
967 body: def,
968 }));
969 }
970 *expr = LcnfExpr::Case {
971 scrutinee,
972 scrutinee_ty,
973 alts,
974 default,
975 };
976 return true;
977 }
978 }
979 false
980 }
981}
982#[allow(dead_code)]
983#[derive(Debug, Clone)]
984pub struct OJDominatorTree {
985 pub idom: Vec<Option<u32>>,
986 pub dom_children: Vec<Vec<u32>>,
987 pub dom_depth: Vec<u32>,
988}
989impl OJDominatorTree {
990 #[allow(dead_code)]
991 pub fn new(size: usize) -> Self {
992 OJDominatorTree {
993 idom: vec![None; size],
994 dom_children: vec![Vec::new(); size],
995 dom_depth: vec![0; size],
996 }
997 }
998 #[allow(dead_code)]
999 pub fn set_idom(&mut self, node: usize, idom: u32) {
1000 self.idom[node] = Some(idom);
1001 }
1002 #[allow(dead_code)]
1003 pub fn dominates(&self, a: usize, b: usize) -> bool {
1004 if a == b {
1005 return true;
1006 }
1007 let mut cur = b;
1008 loop {
1009 match self.idom[cur] {
1010 Some(parent) if parent as usize == a => return true,
1011 Some(parent) if parent as usize == cur => return false,
1012 Some(parent) => cur = parent as usize,
1013 None => return false,
1014 }
1015 }
1016 }
1017 #[allow(dead_code)]
1018 pub fn depth(&self, node: usize) -> u32 {
1019 self.dom_depth.get(node).copied().unwrap_or(0)
1020 }
1021}
1022#[allow(dead_code)]
1023#[derive(Debug, Clone)]
1024pub struct OJWorklist {
1025 pub(super) items: std::collections::VecDeque<u32>,
1026 pub(super) in_worklist: std::collections::HashSet<u32>,
1027}
1028impl OJWorklist {
1029 #[allow(dead_code)]
1030 pub fn new() -> Self {
1031 OJWorklist {
1032 items: std::collections::VecDeque::new(),
1033 in_worklist: std::collections::HashSet::new(),
1034 }
1035 }
1036 #[allow(dead_code)]
1037 pub fn push(&mut self, item: u32) -> bool {
1038 if self.in_worklist.insert(item) {
1039 self.items.push_back(item);
1040 true
1041 } else {
1042 false
1043 }
1044 }
1045 #[allow(dead_code)]
1046 pub fn pop(&mut self) -> Option<u32> {
1047 let item = self.items.pop_front()?;
1048 self.in_worklist.remove(&item);
1049 Some(item)
1050 }
1051 #[allow(dead_code)]
1052 pub fn is_empty(&self) -> bool {
1053 self.items.is_empty()
1054 }
1055 #[allow(dead_code)]
1056 pub fn len(&self) -> usize {
1057 self.items.len()
1058 }
1059 #[allow(dead_code)]
1060 pub fn contains(&self, item: u32) -> bool {
1061 self.in_worklist.contains(&item)
1062 }
1063}
1064#[derive(Debug, Clone)]
1066pub struct OJoinPassTiming {
1067 pub pass_name: String,
1068 pub elapsed_us: u64,
1069 pub items_processed: usize,
1070 pub bytes_before: usize,
1071 pub bytes_after: usize,
1072}
1073impl OJoinPassTiming {
1074 pub fn new(
1075 pass_name: impl Into<String>,
1076 elapsed_us: u64,
1077 items: usize,
1078 before: usize,
1079 after: usize,
1080 ) -> Self {
1081 OJoinPassTiming {
1082 pass_name: pass_name.into(),
1083 elapsed_us,
1084 items_processed: items,
1085 bytes_before: before,
1086 bytes_after: after,
1087 }
1088 }
1089 pub fn throughput_mps(&self) -> f64 {
1090 if self.elapsed_us == 0 {
1091 0.0
1092 } else {
1093 self.items_processed as f64 / (self.elapsed_us as f64 / 1_000_000.0)
1094 }
1095 }
1096 pub fn size_ratio(&self) -> f64 {
1097 if self.bytes_before == 0 {
1098 1.0
1099 } else {
1100 self.bytes_after as f64 / self.bytes_before as f64
1101 }
1102 }
1103 pub fn is_profitable(&self) -> bool {
1104 self.size_ratio() <= 1.05
1105 }
1106}
1107#[derive(Debug, Clone, Default)]
1109pub struct OJoinEmitStats {
1110 pub bytes_emitted: usize,
1111 pub items_emitted: usize,
1112 pub errors: usize,
1113 pub warnings: usize,
1114 pub elapsed_ms: u64,
1115}
1116impl OJoinEmitStats {
1117 pub fn new() -> Self {
1118 OJoinEmitStats::default()
1119 }
1120 pub fn throughput_bps(&self) -> f64 {
1121 if self.elapsed_ms == 0 {
1122 0.0
1123 } else {
1124 self.bytes_emitted as f64 / (self.elapsed_ms as f64 / 1000.0)
1125 }
1126 }
1127 pub fn is_clean(&self) -> bool {
1128 self.errors == 0
1129 }
1130}
1131#[derive(Debug, Clone, PartialEq, Eq)]
1133pub enum TailUse {
1134 Unused,
1136 TailOnly,
1138 NonTail,
1140 Mixed,
1142}
1143impl TailUse {
1144 pub(super) fn merge(&self, other: &TailUse) -> TailUse {
1145 match (self, other) {
1146 (TailUse::Unused, x) | (x, TailUse::Unused) => x.clone(),
1147 (TailUse::TailOnly, TailUse::TailOnly) => TailUse::TailOnly,
1148 (TailUse::NonTail, TailUse::NonTail) => TailUse::NonTail,
1149 _ => TailUse::Mixed,
1150 }
1151 }
1152}
1153#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
1155pub enum OJoinDiagSeverity {
1156 Note,
1157 Warning,
1158 Error,
1159}
1160#[allow(dead_code)]
1161#[derive(Debug, Clone)]
1162pub struct OJLivenessInfo {
1163 pub live_in: Vec<std::collections::HashSet<u32>>,
1164 pub live_out: Vec<std::collections::HashSet<u32>>,
1165 pub defs: Vec<std::collections::HashSet<u32>>,
1166 pub uses: Vec<std::collections::HashSet<u32>>,
1167}
1168impl OJLivenessInfo {
1169 #[allow(dead_code)]
1170 pub fn new(block_count: usize) -> Self {
1171 OJLivenessInfo {
1172 live_in: vec![std::collections::HashSet::new(); block_count],
1173 live_out: vec![std::collections::HashSet::new(); block_count],
1174 defs: vec![std::collections::HashSet::new(); block_count],
1175 uses: vec![std::collections::HashSet::new(); block_count],
1176 }
1177 }
1178 #[allow(dead_code)]
1179 pub fn add_def(&mut self, block: usize, var: u32) {
1180 if block < self.defs.len() {
1181 self.defs[block].insert(var);
1182 }
1183 }
1184 #[allow(dead_code)]
1185 pub fn add_use(&mut self, block: usize, var: u32) {
1186 if block < self.uses.len() {
1187 self.uses[block].insert(var);
1188 }
1189 }
1190 #[allow(dead_code)]
1191 pub fn is_live_in(&self, block: usize, var: u32) -> bool {
1192 self.live_in
1193 .get(block)
1194 .map(|s| s.contains(&var))
1195 .unwrap_or(false)
1196 }
1197 #[allow(dead_code)]
1198 pub fn is_live_out(&self, block: usize, var: u32) -> bool {
1199 self.live_out
1200 .get(block)
1201 .map(|s| s.contains(&var))
1202 .unwrap_or(false)
1203 }
1204}
1205#[allow(dead_code)]
1206#[derive(Debug, Clone)]
1207pub struct OJCacheEntry {
1208 pub key: String,
1209 pub data: Vec<u8>,
1210 pub timestamp: u64,
1211 pub valid: bool,
1212}
1213#[derive(Debug)]
1215pub struct OJoinEventLog {
1216 pub(super) entries: std::collections::VecDeque<String>,
1217 pub(super) capacity: usize,
1218}
1219impl OJoinEventLog {
1220 pub fn new(capacity: usize) -> Self {
1221 OJoinEventLog {
1222 entries: std::collections::VecDeque::with_capacity(capacity),
1223 capacity,
1224 }
1225 }
1226 pub fn push(&mut self, event: impl Into<String>) {
1227 if self.entries.len() >= self.capacity {
1228 self.entries.pop_front();
1229 }
1230 self.entries.push_back(event.into());
1231 }
1232 pub fn iter(&self) -> impl Iterator<Item = &String> {
1233 self.entries.iter()
1234 }
1235 pub fn len(&self) -> usize {
1236 self.entries.len()
1237 }
1238 pub fn is_empty(&self) -> bool {
1239 self.entries.is_empty()
1240 }
1241 pub fn capacity(&self) -> usize {
1242 self.capacity
1243 }
1244 pub fn clear(&mut self) {
1245 self.entries.clear();
1246 }
1247}
1248#[allow(dead_code)]
1249#[derive(Debug, Clone, Default)]
1250pub struct OJPassStats {
1251 pub total_runs: u32,
1252 pub successful_runs: u32,
1253 pub total_changes: u64,
1254 pub time_ms: u64,
1255 pub iterations_used: u32,
1256}
1257impl OJPassStats {
1258 #[allow(dead_code)]
1259 pub fn new() -> Self {
1260 Self::default()
1261 }
1262 #[allow(dead_code)]
1263 pub fn record_run(&mut self, changes: u64, time_ms: u64, iterations: u32) {
1264 self.total_runs += 1;
1265 self.successful_runs += 1;
1266 self.total_changes += changes;
1267 self.time_ms += time_ms;
1268 self.iterations_used = iterations;
1269 }
1270 #[allow(dead_code)]
1271 pub fn average_changes_per_run(&self) -> f64 {
1272 if self.total_runs == 0 {
1273 return 0.0;
1274 }
1275 self.total_changes as f64 / self.total_runs as f64
1276 }
1277 #[allow(dead_code)]
1278 pub fn success_rate(&self) -> f64 {
1279 if self.total_runs == 0 {
1280 return 0.0;
1281 }
1282 self.successful_runs as f64 / self.total_runs as f64
1283 }
1284 #[allow(dead_code)]
1285 pub fn format_summary(&self) -> String {
1286 format!(
1287 "Runs: {}/{}, Changes: {}, Time: {}ms",
1288 self.successful_runs, self.total_runs, self.total_changes, self.time_ms
1289 )
1290 }
1291}
1292#[allow(dead_code)]
1293#[derive(Debug, Clone)]
1294pub struct OJPassConfig {
1295 pub phase: OJPassPhase,
1296 pub enabled: bool,
1297 pub max_iterations: u32,
1298 pub debug_output: bool,
1299 pub pass_name: String,
1300}
1301impl OJPassConfig {
1302 #[allow(dead_code)]
1303 pub fn new(name: impl Into<String>, phase: OJPassPhase) -> Self {
1304 OJPassConfig {
1305 phase,
1306 enabled: true,
1307 max_iterations: 10,
1308 debug_output: false,
1309 pass_name: name.into(),
1310 }
1311 }
1312 #[allow(dead_code)]
1313 pub fn disabled(mut self) -> Self {
1314 self.enabled = false;
1315 self
1316 }
1317 #[allow(dead_code)]
1318 pub fn with_debug(mut self) -> Self {
1319 self.debug_output = true;
1320 self
1321 }
1322 #[allow(dead_code)]
1323 pub fn max_iter(mut self, n: u32) -> Self {
1324 self.max_iterations = n;
1325 self
1326 }
1327}
1328#[allow(dead_code)]
1329#[derive(Debug, Clone, PartialEq)]
1330pub enum OJPassPhase {
1331 Analysis,
1332 Transformation,
1333 Verification,
1334 Cleanup,
1335}
1336impl OJPassPhase {
1337 #[allow(dead_code)]
1338 pub fn name(&self) -> &str {
1339 match self {
1340 OJPassPhase::Analysis => "analysis",
1341 OJPassPhase::Transformation => "transformation",
1342 OJPassPhase::Verification => "verification",
1343 OJPassPhase::Cleanup => "cleanup",
1344 }
1345 }
1346 #[allow(dead_code)]
1347 pub fn is_modifying(&self) -> bool {
1348 matches!(self, OJPassPhase::Transformation | OJPassPhase::Cleanup)
1349 }
1350}