1#![allow(clippy::result_large_err)]
13
14use cljrs_types::error::CljxError::SerializationError;
15use cljrs_types::error::CljxResult;
16use cljrs_types::span::Span;
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::fmt;
20use std::sync::Arc;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
26pub struct VarId(pub u32);
27
28impl fmt::Display for VarId {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 write!(f, "v{}", self.0)
31 }
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
36pub struct BlockId(pub u32);
37
38impl fmt::Display for BlockId {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 write!(f, "bb{}", self.0)
41 }
42}
43
44#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
52pub enum KnownFn {
53 Vector,
55 HashMap,
56 HashSet,
57 List,
58
59 Assoc,
61 Dissoc,
62 Conj,
63 Disj,
64 Get,
65 Nth,
66 Count,
67 Contains,
68
69 Transient,
71 AssocBang,
72 ConjBang,
73 PersistentBang,
74
75 First,
77 Rest,
78 Next,
79 Cons,
80 Seq,
81 LazySeq,
82
83 Add,
85 Sub,
86 Mul,
87 Div,
88 Rem,
89
90 Eq,
92 Lt,
93 Gt,
94 Lte,
95 Gte,
96
97 IsNil,
99 IsSeq,
100 IsVector,
101 IsMap,
102
103 Str,
105
106 Deref,
108 Identical,
109
110 Println,
112 Pr,
113
114 AtomDeref,
116 AtomReset,
117 AtomSwap,
118
119 Apply,
121
122 Reduce2,
124 Reduce3,
125 Map,
126 Filter,
127 Mapv,
128 Filterv,
129 Some,
130 Every,
131 Into,
132 Into3,
133
134 GroupBy,
136 Partition2,
137 Partition3,
138 Partition4,
139 Frequencies,
140 Keep,
141 Remove,
142 MapIndexed,
143 Zipmap,
144 Juxt,
145 Comp,
146 Partial,
147 Complement,
148
149 Concat,
151 Range1,
152 Range2,
153 Range3,
154 Take,
155 Drop,
156 Reverse,
157 Sort,
158 SortBy,
159
160 Keys,
162 Vals,
163 Merge,
164 Update,
165 GetIn,
166 AssocIn,
167
168 IsNumber,
170 IsString,
171 IsKeyword,
172 IsSymbol,
173 IsBool,
174 IsInt,
175
176 Prn,
178 Print,
179
180 Atom,
182
183 TryCatchFinally,
185
186 SetBangVar,
188 WithBindings,
189
190 WithOutStr,
192}
193
194#[derive(Debug, Clone, Copy, PartialEq, Eq)]
201pub enum Effect {
202 Pure,
204 Alloc,
206 HeapRead,
208 HeapWrite,
210 IO,
212 UnknownCall,
214}
215
216#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
221pub enum Const {
222 Nil,
223 Bool(bool),
224 Long(i64),
225 Double(f64),
226 Str(Arc<str>),
227 Keyword(Arc<str>),
228 Symbol(Arc<str>),
229 Char(char),
230}
231
232#[derive(Debug, Clone, Serialize, Deserialize)]
240pub enum Inst {
241 Const(VarId, Const),
243
244 LoadLocal(VarId, Arc<str>),
246
247 LoadGlobal(VarId, Arc<str>, Arc<str>), LoadVar(VarId, Arc<str>, Arc<str>), AllocVector(VarId, Vec<VarId>),
255
256 AllocMap(VarId, Vec<(VarId, VarId)>),
258
259 AllocSet(VarId, Vec<VarId>),
261
262 AllocList(VarId, Vec<VarId>),
264
265 AllocCons(VarId, VarId, VarId), AllocClosure(VarId, ClosureTemplate, Vec<VarId>),
270
271 CallKnown(VarId, KnownFn, Vec<VarId>),
273
274 Call(VarId, VarId, Vec<VarId>), CallDirect(VarId, Arc<str>, Vec<VarId>), Deref(VarId, VarId),
284
285 DefVar(VarId, Arc<str>, Arc<str>, VarId), SetBang(VarId, VarId), Throw(VarId),
293
294 Phi(VarId, Vec<(BlockId, VarId)>),
296
297 Recur(Vec<VarId>),
299
300 SourceLoc(Span),
302
303 RegionStart(VarId),
308
309 RegionAlloc(VarId, VarId, RegionAllocKind, Vec<VarId>),
315
316 RegionEnd(VarId),
319}
320
321#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
323pub enum RegionAllocKind {
324 Vector,
326 Map,
328 Set,
330 List,
332 Cons,
334}
335
336impl fmt::Display for RegionAllocKind {
337 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
338 match self {
339 Self::Vector => write!(f, "vector"),
340 Self::Map => write!(f, "map"),
341 Self::Set => write!(f, "set"),
342 Self::List => write!(f, "list"),
343 Self::Cons => write!(f, "cons"),
344 }
345 }
346}
347
348#[derive(Debug, Clone, Serialize, Deserialize)]
350pub struct ClosureTemplate {
351 pub name: Option<Arc<str>>,
353 pub arity_fn_names: Vec<Arc<str>>,
355 pub param_counts: Vec<usize>,
357 pub is_variadic: Vec<bool>,
361 pub capture_names: Vec<Arc<str>>,
363}
364
365#[derive(Debug, Clone, Serialize, Deserialize)]
369pub enum Terminator {
370 Jump(BlockId),
372
373 Branch {
375 cond: VarId,
376 then_block: BlockId,
377 else_block: BlockId,
378 },
379
380 Return(VarId),
382
383 RecurJump { target: BlockId, args: Vec<VarId> },
385
386 Unreachable,
388}
389
390#[derive(Debug, Clone, Serialize, Deserialize)]
394pub struct Block {
395 pub id: BlockId,
396 pub phis: Vec<Inst>,
398 pub insts: Vec<Inst>,
400 pub terminator: Terminator,
402}
403
404#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct IrFunction {
407 pub name: Option<Arc<str>>,
409 pub params: Vec<(Arc<str>, VarId)>,
411 pub blocks: Vec<Block>,
413 pub next_var: u32,
415 pub next_block: u32,
417 pub span: Option<Span>,
419 pub subfunctions: Vec<IrFunction>,
421}
422
423impl IrFunction {
424 pub fn new(name: Option<Arc<str>>, span: Option<Span>) -> Self {
426 Self {
427 name,
428 params: Vec::new(),
429 blocks: Vec::new(),
430 next_var: 0,
431 next_block: 0,
432 span,
433 subfunctions: Vec::new(),
434 }
435 }
436
437 pub fn fresh_var(&mut self) -> VarId {
439 let id = VarId(self.next_var);
440 self.next_var += 1;
441 id
442 }
443
444 pub fn fresh_block(&mut self) -> BlockId {
446 let id = BlockId(self.next_block);
447 self.next_block += 1;
448 id
449 }
450
451 pub fn block_index(&self) -> Option<Vec<usize>> {
457 let is_identity = self
459 .blocks
460 .iter()
461 .enumerate()
462 .all(|(i, b)| b.id.0 as usize == i);
463 if is_identity {
464 return None; }
466 let max_id = self.blocks.iter().map(|b| b.id.0).max().unwrap_or(0);
468 let mut table = vec![0usize; max_id as usize + 1];
469 for (i, b) in self.blocks.iter().enumerate() {
470 table[b.id.0 as usize] = i;
471 }
472 Some(table)
473 }
474
475 pub fn serialize(&self) -> CljxResult<Vec<u8>> {
476 postcard::to_allocvec(self).map_err(|e| SerializationError {
477 message: e.to_string(),
478 })
479 }
480
481 pub fn deserialize(bytes: &[u8]) -> CljxResult<Self> {
482 postcard::from_bytes(bytes).map_err(|e| SerializationError {
483 message: e.to_string(),
484 })
485 }
486}
487
488#[derive(Debug, Serialize, Deserialize)]
496pub struct IrBundle {
497 pub functions: HashMap<String, IrFunction>,
500}
501
502impl IrBundle {
503 pub fn new() -> Self {
504 Self {
505 functions: HashMap::new(),
506 }
507 }
508
509 pub fn insert(&mut self, key: String, func: IrFunction) {
511 self.functions.insert(key, func);
512 }
513
514 pub fn get(&self, key: &str) -> Option<&IrFunction> {
516 self.functions.get(key)
517 }
518
519 pub fn len(&self) -> usize {
521 self.functions.len()
522 }
523
524 pub fn is_empty(&self) -> bool {
526 self.functions.is_empty()
527 }
528}
529
530impl Default for IrBundle {
531 fn default() -> Self {
532 Self::new()
533 }
534}
535
536pub fn serialize_bundle(bundle: &IrBundle) -> CljxResult<Vec<u8>> {
538 postcard::to_allocvec(bundle).map_err(|e| SerializationError {
539 message: e.to_string(),
540 })
541}
542
543pub fn deserialize_bundle(bytes: &[u8]) -> CljxResult<IrBundle> {
545 postcard::from_bytes(bytes).map_err(|e| SerializationError {
546 message: e.to_string(),
547 })
548}
549
550impl Inst {
553 pub fn effect(&self) -> Effect {
555 match self {
556 Inst::Const(..) | Inst::LoadLocal(..) | Inst::Phi(..) | Inst::SourceLoc(..) => {
557 Effect::Pure
558 }
559 Inst::LoadGlobal(..) | Inst::LoadVar(..) => Effect::HeapRead,
560 Inst::AllocVector(..)
561 | Inst::AllocMap(..)
562 | Inst::AllocSet(..)
563 | Inst::AllocList(..)
564 | Inst::AllocCons(..)
565 | Inst::AllocClosure(..) => Effect::Alloc,
566 Inst::CallKnown(_, known, _) => known.effect(),
567 Inst::Call(..) | Inst::CallDirect(..) => Effect::UnknownCall,
568 Inst::Deref(..) => Effect::HeapRead,
569 Inst::DefVar(..) => Effect::HeapWrite,
570 Inst::SetBang(..) => Effect::HeapWrite,
571 Inst::Throw(..) => Effect::UnknownCall, Inst::Recur(..) => Effect::Pure,
573 Inst::RegionStart(..) | Inst::RegionEnd(..) => Effect::Alloc,
574 Inst::RegionAlloc(..) => Effect::Alloc,
575 }
576 }
577
578 pub fn dst(&self) -> Option<VarId> {
580 match self {
581 Inst::Const(v, _)
582 | Inst::LoadLocal(v, _)
583 | Inst::LoadGlobal(v, _, _)
584 | Inst::LoadVar(v, _, _)
585 | Inst::AllocVector(v, _)
586 | Inst::AllocMap(v, _)
587 | Inst::AllocSet(v, _)
588 | Inst::AllocList(v, _)
589 | Inst::AllocCons(v, _, _)
590 | Inst::AllocClosure(v, _, _)
591 | Inst::CallKnown(v, _, _)
592 | Inst::Call(v, _, _)
593 | Inst::CallDirect(v, _, _)
594 | Inst::Deref(v, _)
595 | Inst::DefVar(v, _, _, _)
596 | Inst::Phi(v, _)
597 | Inst::RegionStart(v)
598 | Inst::RegionAlloc(v, _, _, _) => Some(*v),
599 Inst::SetBang(..)
600 | Inst::Throw(..)
601 | Inst::Recur(..)
602 | Inst::SourceLoc(..)
603 | Inst::RegionEnd(..) => None,
604 }
605 }
606
607 pub fn uses(&self) -> Vec<VarId> {
609 match self {
610 Inst::Const(..)
611 | Inst::LoadLocal(..)
612 | Inst::LoadGlobal(..)
613 | Inst::LoadVar(..)
614 | Inst::SourceLoc(..) => vec![],
615 Inst::AllocVector(_, elems) | Inst::AllocSet(_, elems) | Inst::AllocList(_, elems) => {
616 elems.clone()
617 }
618 Inst::AllocMap(_, pairs) => pairs.iter().flat_map(|(k, v)| [*k, *v]).collect(),
619 Inst::AllocCons(_, h, t) => vec![*h, *t],
620 Inst::AllocClosure(_, _, captures) => captures.clone(),
621 Inst::CallKnown(_, _, args) => args.clone(),
622 Inst::Call(_, callee, args) => {
623 let mut v = vec![*callee];
624 v.extend(args);
625 v
626 }
627 Inst::CallDirect(_, _, args) => args.clone(),
628 Inst::Deref(_, src) => vec![*src],
629 Inst::DefVar(_, _, _, val) => vec![*val],
630 Inst::SetBang(var, val) => vec![*var, *val],
631 Inst::Throw(val) => vec![*val],
632 Inst::Phi(_, entries) => entries.iter().map(|(_, v)| *v).collect(),
633 Inst::Recur(args) => args.clone(),
634 Inst::RegionStart(..) => vec![],
635 Inst::RegionAlloc(_, region, _, operands) => {
636 let mut v = vec![*region];
637 v.extend(operands);
638 v
639 }
640 Inst::RegionEnd(region) => vec![*region],
641 }
642 }
643}
644
645impl KnownFn {
646 pub fn effect(&self) -> Effect {
648 match self {
649 KnownFn::Get
651 | KnownFn::Nth
652 | KnownFn::Count
653 | KnownFn::Contains
654 | KnownFn::First
655 | KnownFn::Add
656 | KnownFn::Sub
657 | KnownFn::Mul
658 | KnownFn::Div
659 | KnownFn::Rem
660 | KnownFn::Eq
661 | KnownFn::Lt
662 | KnownFn::Gt
663 | KnownFn::Lte
664 | KnownFn::Gte
665 | KnownFn::IsNil
666 | KnownFn::IsSeq
667 | KnownFn::IsVector
668 | KnownFn::IsMap
669 | KnownFn::Identical => Effect::Pure,
670
671 KnownFn::Vector
673 | KnownFn::HashMap
674 | KnownFn::HashSet
675 | KnownFn::List
676 | KnownFn::Assoc
677 | KnownFn::Dissoc
678 | KnownFn::Conj
679 | KnownFn::Disj
680 | KnownFn::Cons
681 | KnownFn::Rest
682 | KnownFn::Next
683 | KnownFn::Seq
684 | KnownFn::LazySeq
685 | KnownFn::Str
686 | KnownFn::Transient
687 | KnownFn::PersistentBang => Effect::Alloc,
688
689 KnownFn::AssocBang | KnownFn::ConjBang => Effect::HeapWrite,
691
692 KnownFn::Deref | KnownFn::AtomDeref => Effect::HeapRead,
694
695 KnownFn::AtomReset | KnownFn::AtomSwap => Effect::HeapWrite,
697
698 KnownFn::Println | KnownFn::Pr => Effect::IO,
700
701 KnownFn::Apply => Effect::UnknownCall,
703
704 KnownFn::Concat
706 | KnownFn::Range1
707 | KnownFn::Range2
708 | KnownFn::Range3
709 | KnownFn::Take
710 | KnownFn::Drop
711 | KnownFn::Reverse => Effect::Alloc,
712
713 KnownFn::Sort | KnownFn::SortBy => Effect::UnknownCall,
715
716 KnownFn::Keys | KnownFn::Vals => Effect::Alloc,
718 KnownFn::Merge | KnownFn::Update | KnownFn::GetIn | KnownFn::AssocIn => Effect::Alloc,
719
720 KnownFn::IsNumber
722 | KnownFn::IsString
723 | KnownFn::IsKeyword
724 | KnownFn::IsSymbol
725 | KnownFn::IsBool
726 | KnownFn::IsInt => Effect::Pure,
727
728 KnownFn::Prn | KnownFn::Print => Effect::IO,
730
731 KnownFn::Atom => Effect::Alloc,
733
734 KnownFn::GroupBy
736 | KnownFn::Partition2
737 | KnownFn::Partition3
738 | KnownFn::Partition4
739 | KnownFn::Keep
740 | KnownFn::Remove
741 | KnownFn::MapIndexed => Effect::UnknownCall,
742
743 KnownFn::Juxt | KnownFn::Comp | KnownFn::Partial | KnownFn::Complement => {
745 Effect::UnknownCall
746 }
747
748 KnownFn::Frequencies | KnownFn::Zipmap => Effect::Alloc,
750
751 KnownFn::Reduce2
753 | KnownFn::Reduce3
754 | KnownFn::Map
755 | KnownFn::Filter
756 | KnownFn::Mapv
757 | KnownFn::Filterv
758 | KnownFn::Some
759 | KnownFn::Every
760 | KnownFn::Into
761 | KnownFn::Into3 => Effect::UnknownCall,
762
763 KnownFn::TryCatchFinally => Effect::UnknownCall,
765
766 KnownFn::SetBangVar => Effect::HeapWrite,
768 KnownFn::WithBindings | KnownFn::WithOutStr => Effect::UnknownCall,
769 }
770 }
771}
772
773impl fmt::Display for IrFunction {
776 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
777 writeln!(
778 f,
779 "fn {}({}):",
780 self.name.as_deref().unwrap_or("<anon>"),
781 self.params
782 .iter()
783 .map(|(name, id)| format!("{name}: {id}"))
784 .collect::<Vec<_>>()
785 .join(", ")
786 )?;
787 for block in &self.blocks {
788 writeln!(f, " {block}")?;
789 }
790 Ok(())
791 }
792}
793
794impl fmt::Display for Block {
795 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
796 writeln!(f, "{}:", self.id)?;
797 for phi in &self.phis {
798 writeln!(f, " {phi}")?;
799 }
800 for inst in &self.insts {
801 writeln!(f, " {inst}")?;
802 }
803 write!(f, " {}", self.terminator)
804 }
805}
806
807impl fmt::Display for Inst {
808 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
809 match self {
810 Inst::Const(dst, c) => write!(f, "{dst} = const {c:?}"),
811 Inst::LoadLocal(dst, name) => write!(f, "{dst} = load_local {name:?}"),
812 Inst::LoadGlobal(dst, ns, name) => write!(f, "{dst} = load_global {ns}/{name}"),
813 Inst::LoadVar(dst, ns, name) => write!(f, "{dst} = load_var {ns}/{name}"),
814 Inst::AllocVector(dst, elems) => write!(f, "{dst} = alloc_vec {elems:?}"),
815 Inst::AllocMap(dst, pairs) => write!(f, "{dst} = alloc_map {pairs:?}"),
816 Inst::AllocSet(dst, elems) => write!(f, "{dst} = alloc_set {elems:?}"),
817 Inst::AllocList(dst, elems) => write!(f, "{dst} = alloc_list {elems:?}"),
818 Inst::AllocCons(dst, h, t) => write!(f, "{dst} = cons {h} {t}"),
819 Inst::AllocClosure(dst, tmpl, captures) => {
820 write!(f, "{dst} = closure {:?} captures={captures:?}", tmpl.name)
821 }
822 Inst::CallKnown(dst, func, args) => write!(f, "{dst} = call_known {func:?} {args:?}"),
823 Inst::Call(dst, callee, args) => write!(f, "{dst} = call {callee} {args:?}"),
824 Inst::CallDirect(dst, name, args) => write!(f, "{dst} = call_direct {name} {args:?}"),
825 Inst::Deref(dst, src) => write!(f, "{dst} = deref {src}"),
826 Inst::DefVar(dst, ns, name, val) => write!(f, "{dst} = def {ns}/{name} {val}"),
827 Inst::SetBang(var, val) => write!(f, "set! {var} {val}"),
828 Inst::Throw(val) => write!(f, "throw {val}"),
829 Inst::Phi(dst, entries) => write!(f, "{dst} = phi {entries:?}"),
830 Inst::Recur(args) => write!(f, "recur {args:?}"),
831 Inst::SourceLoc(span) => write!(f, "# {}:{}:{}", span.file, span.line, span.col),
832 Inst::RegionStart(dst) => write!(f, "{dst} = region_start"),
833 Inst::RegionAlloc(dst, region, kind, operands) => {
834 write!(f, "{dst} = region_alloc {region} {kind} {operands:?}")
835 }
836 Inst::RegionEnd(region) => write!(f, "region_end {region}"),
837 }
838 }
839}
840
841impl fmt::Display for Terminator {
842 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
843 match self {
844 Terminator::Jump(target) => write!(f, "jump {target}"),
845 Terminator::Branch {
846 cond,
847 then_block,
848 else_block,
849 } => write!(f, "branch {cond} then={then_block} else={else_block}"),
850 Terminator::Return(val) => write!(f, "return {val}"),
851 Terminator::RecurJump { target, args } => {
852 write!(f, "recur_jump {target} {args:?}")
853 }
854 Terminator::Unreachable => write!(f, "unreachable"),
855 }
856 }
857}
858
859pub const COMPILER_IR_SOURCE: &str = include_str!("clojure/compiler/ir.cljrs");
863
864pub const COMPILER_KNOWN_SOURCE: &str = include_str!("clojure/compiler/known.cljrs");
866
867pub const COMPILER_ANF_SOURCE: &str = include_str!("clojure/compiler/anf.cljrs");
869
870pub const COMPILER_ESCAPE_SOURCE: &str = include_str!("clojure/compiler/escape.cljrs");
872
873pub const COMPILER_OPTIMIZE_SOURCE: &str = include_str!("clojure/compiler/optimize.cljrs");
875
876#[cfg(test)]
879mod tests {
880 use super::*;
881
882 fn make_test_fn(name: &str, const_val: i64) -> IrFunction {
884 let mut f = IrFunction::new(Some(Arc::from(name)), None);
885 let dst = f.fresh_var();
886 let block_id = f.fresh_block();
887 f.blocks.push(Block {
888 id: block_id,
889 phis: vec![],
890 insts: vec![Inst::Const(dst, Const::Long(const_val))],
891 terminator: Terminator::Return(dst),
892 });
893 f
894 }
895
896 #[test]
897 fn test_ir_function_serialize_roundtrip() {
898 let f = make_test_fn("identity", 42);
899 let bytes = f.serialize().unwrap();
900 let f2 = IrFunction::deserialize(&bytes).unwrap();
901 assert_eq!(f2.name.as_deref(), Some("identity"));
902 assert_eq!(f2.blocks.len(), 1);
903 assert_eq!(f2.next_var, 1);
904 match &f2.blocks[0].insts[0] {
905 Inst::Const(_, Const::Long(v)) => assert_eq!(*v, 42),
906 other => panic!("expected Const(Long(42)), got {other:?}"),
907 }
908 }
909
910 #[test]
911 fn test_ir_function_with_closure_template() {
912 let mut f = IrFunction::new(Some(Arc::from("outer")), None);
913 let dst = f.fresh_var();
914 let capture = f.fresh_var();
915 let block_id = f.fresh_block();
916 f.blocks.push(Block {
917 id: block_id,
918 phis: vec![],
919 insts: vec![
920 Inst::Const(capture, Const::Str(Arc::from("hello"))),
921 Inst::AllocClosure(
922 dst,
923 ClosureTemplate {
924 name: Some(Arc::from("inner")),
925 arity_fn_names: vec![Arc::from("inner__0")],
926 param_counts: vec![1],
927 is_variadic: vec![false],
928 capture_names: vec![Arc::from("x")],
929 },
930 vec![capture],
931 ),
932 ],
933 terminator: Terminator::Return(dst),
934 });
935
936 let bytes = f.serialize().unwrap();
937 let f2 = IrFunction::deserialize(&bytes).unwrap();
938 match &f2.blocks[0].insts[1] {
939 Inst::AllocClosure(_, tmpl, captures) => {
940 assert_eq!(tmpl.name.as_deref(), Some("inner"));
941 assert_eq!(tmpl.param_counts, vec![1]);
942 assert_eq!(tmpl.is_variadic, vec![false]);
943 assert_eq!(captures.len(), 1);
944 }
945 other => panic!("expected AllocClosure, got {other:?}"),
946 }
947 }
948
949 #[test]
950 fn test_empty_bundle_roundtrip() {
951 let bundle = IrBundle::new();
952 assert!(bundle.is_empty());
953 let bytes = serialize_bundle(&bundle).unwrap();
954 let bundle2 = deserialize_bundle(&bytes).unwrap();
955 assert!(bundle2.is_empty());
956 assert_eq!(bundle2.len(), 0);
957 }
958
959 #[test]
960 fn test_bundle_single_function() {
961 let mut bundle = IrBundle::new();
962 bundle.insert("clojure.core/inc:1".to_string(), make_test_fn("inc", 1));
963 assert_eq!(bundle.len(), 1);
964
965 let bytes = serialize_bundle(&bundle).unwrap();
966 let bundle2 = deserialize_bundle(&bytes).unwrap();
967 assert_eq!(bundle2.len(), 1);
968
969 let f = bundle2.get("clojure.core/inc:1").unwrap();
970 assert_eq!(f.name.as_deref(), Some("inc"));
971 }
972
973 #[test]
974 fn test_bundle_multiple_functions() {
975 let mut bundle = IrBundle::new();
976 bundle.insert("clojure.core/inc:1".to_string(), make_test_fn("inc", 1));
977 bundle.insert("clojure.core/dec:1".to_string(), make_test_fn("dec", -1));
978 bundle.insert(
979 "clojure.core/identity:1".to_string(),
980 make_test_fn("identity", 0),
981 );
982 assert_eq!(bundle.len(), 3);
983
984 let bytes = serialize_bundle(&bundle).unwrap();
985 let bundle2 = deserialize_bundle(&bytes).unwrap();
986 assert_eq!(bundle2.len(), 3);
987
988 assert_eq!(
989 bundle2.get("clojure.core/inc:1").unwrap().name.as_deref(),
990 Some("inc")
991 );
992 assert_eq!(
993 bundle2.get("clojure.core/dec:1").unwrap().name.as_deref(),
994 Some("dec")
995 );
996 assert_eq!(
997 bundle2
998 .get("clojure.core/identity:1")
999 .unwrap()
1000 .name
1001 .as_deref(),
1002 Some("identity")
1003 );
1004 assert!(bundle2.get("nonexistent").is_none());
1005 }
1006
1007 #[test]
1008 fn test_bundle_with_complex_ir() {
1009 let mut f = IrFunction::new(Some(Arc::from("complex")), None);
1010 let p0 = f.fresh_var();
1011 let p1 = f.fresh_var();
1012 f.params = vec![(Arc::from("x"), p0), (Arc::from("y"), p1)];
1013
1014 let entry = f.fresh_block();
1016 let then_bb = f.fresh_block();
1017 let else_bb = f.fresh_block();
1018 let join_bb = f.fresh_block();
1019
1020 let cond_dst = f.fresh_var();
1021 f.blocks.push(Block {
1022 id: entry,
1023 phis: vec![],
1024 insts: vec![Inst::CallKnown(cond_dst, KnownFn::IsNil, vec![p0])],
1025 terminator: Terminator::Branch {
1026 cond: cond_dst,
1027 then_block: then_bb,
1028 else_block: else_bb,
1029 },
1030 });
1031
1032 f.blocks.push(Block {
1034 id: then_bb,
1035 phis: vec![],
1036 insts: vec![],
1037 terminator: Terminator::Jump(join_bb),
1038 });
1039
1040 f.blocks.push(Block {
1042 id: else_bb,
1043 phis: vec![],
1044 insts: vec![],
1045 terminator: Terminator::Jump(join_bb),
1046 });
1047
1048 let phi_dst = f.fresh_var();
1050 f.blocks.push(Block {
1051 id: join_bb,
1052 phis: vec![Inst::Phi(phi_dst, vec![(then_bb, p1), (else_bb, p0)])],
1053 insts: vec![],
1054 terminator: Terminator::Return(phi_dst),
1055 });
1056
1057 let mut bundle = IrBundle::new();
1058 bundle.insert("test/complex:2".to_string(), f);
1059
1060 let bytes = serialize_bundle(&bundle).unwrap();
1061 let bundle2 = deserialize_bundle(&bytes).unwrap();
1062
1063 let f2 = bundle2.get("test/complex:2").unwrap();
1064 assert_eq!(f2.params.len(), 2);
1065 assert_eq!(f2.blocks.len(), 4);
1066
1067 match &f2.blocks[0].terminator {
1069 Terminator::Branch {
1070 cond,
1071 then_block,
1072 else_block,
1073 } => {
1074 assert_eq!(*cond, cond_dst);
1075 assert_eq!(*then_block, then_bb);
1076 assert_eq!(*else_block, else_bb);
1077 }
1078 other => panic!("expected Branch, got {other:?}"),
1079 }
1080
1081 assert_eq!(f2.blocks[3].phis.len(), 1);
1083 match &f2.blocks[3].phis[0] {
1084 Inst::Phi(dst, entries) => {
1085 assert_eq!(*dst, phi_dst);
1086 assert_eq!(entries.len(), 2);
1087 }
1088 other => panic!("expected Phi, got {other:?}"),
1089 }
1090 }
1091
1092 #[test]
1093 fn test_bundle_with_subfunctions() {
1094 let mut outer = make_test_fn("outer", 100);
1095 let inner = make_test_fn("inner", 200);
1096 outer.subfunctions.push(inner);
1097
1098 let mut bundle = IrBundle::new();
1099 bundle.insert("test/outer:0".to_string(), outer);
1100
1101 let bytes = serialize_bundle(&bundle).unwrap();
1102 let bundle2 = deserialize_bundle(&bytes).unwrap();
1103
1104 let f = bundle2.get("test/outer:0").unwrap();
1105 assert_eq!(f.subfunctions.len(), 1);
1106 assert_eq!(f.subfunctions[0].name.as_deref(), Some("inner"));
1107 }
1108
1109 #[test]
1110 fn test_deserialize_invalid_bytes() {
1111 let result = IrFunction::deserialize(&[0xFF, 0xFE, 0xFD]);
1112 assert!(result.is_err());
1113
1114 let result = deserialize_bundle(&[0xFF, 0xFE, 0xFD]);
1115 assert!(result.is_err());
1116 }
1117}