1use super::{CodeGen, CodeGenError, mangle_name};
92use crate::ast::{Statement, WordDef};
93use crate::types::{StackType, Type};
94use std::fmt::Write as _;
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
98pub enum RegisterType {
99 I64,
101 Double,
103}
104
105impl RegisterType {
106 pub fn from_type(ty: &Type) -> Option<Self> {
108 match ty {
109 Type::Int | Type::Bool => Some(RegisterType::I64),
110 Type::Float => Some(RegisterType::Double),
111 _ => None,
112 }
113 }
114
115 pub fn llvm_type(&self) -> &'static str {
117 match self {
118 RegisterType::I64 => "i64",
119 RegisterType::Double => "double",
120 }
121 }
122}
123
124#[derive(Debug, Clone)]
126pub struct SpecSignature {
127 pub inputs: Vec<RegisterType>,
129 pub outputs: Vec<RegisterType>,
131}
132
133impl SpecSignature {
134 pub fn suffix(&self) -> String {
138 if self.inputs.len() == 1 && self.outputs.len() == 1 {
139 match (self.inputs[0], self.outputs[0]) {
140 (RegisterType::I64, RegisterType::I64) => "_i64".to_string(),
141 (RegisterType::Double, RegisterType::Double) => "_f64".to_string(),
142 (RegisterType::I64, RegisterType::Double) => "_i64_to_f64".to_string(),
143 (RegisterType::Double, RegisterType::I64) => "_f64_to_i64".to_string(),
144 }
145 } else {
146 let mut suffix = String::new();
148 for ty in &self.inputs {
149 suffix.push('_');
150 suffix.push_str(match ty {
151 RegisterType::I64 => "i",
152 RegisterType::Double => "f",
153 });
154 }
155 suffix.push_str("_to");
156 for ty in &self.outputs {
157 suffix.push('_');
158 suffix.push_str(match ty {
159 RegisterType::I64 => "i",
160 RegisterType::Double => "f",
161 });
162 }
163 suffix
164 }
165 }
166
167 pub fn is_direct_call(&self) -> bool {
169 self.outputs.len() == 1
170 }
171
172 pub fn llvm_return_type(&self) -> String {
177 if self.outputs.len() == 1 {
178 self.outputs[0].llvm_type().to_string()
179 } else {
180 let types: Vec<_> = self.outputs.iter().map(|t| t.llvm_type()).collect();
181 format!("{{ {} }}", types.join(", "))
182 }
183 }
184}
185
186#[derive(Debug, Clone)]
191pub struct RegisterContext {
192 pub values: Vec<(String, RegisterType)>,
194}
195
196impl RegisterContext {
197 pub fn new() -> Self {
199 Self { values: Vec::new() }
200 }
201
202 pub fn from_params(params: &[(String, RegisterType)]) -> Self {
204 Self {
205 values: params.to_vec(),
206 }
207 }
208
209 pub fn push(&mut self, ssa_var: String, ty: RegisterType) {
211 self.values.push((ssa_var, ty));
212 }
213
214 pub fn pop(&mut self) -> Option<(String, RegisterType)> {
216 self.values.pop()
217 }
218
219 #[allow(dead_code)]
221 pub fn peek(&self) -> Option<&(String, RegisterType)> {
222 self.values.last()
223 }
224
225 #[cfg_attr(not(test), allow(dead_code))]
227 pub fn len(&self) -> usize {
228 self.values.len()
229 }
230
231 #[allow(dead_code)]
233 pub fn is_empty(&self) -> bool {
234 self.values.is_empty()
235 }
236
237 pub fn dup(&mut self) {
241 if let Some((ssa, ty)) = self.values.last().cloned() {
242 self.values.push((ssa, ty));
243 }
244 }
245
246 pub fn drop(&mut self) {
248 self.values.pop();
249 }
250
251 pub fn swap(&mut self) {
253 let len = self.values.len();
254 if len >= 2 {
255 self.values.swap(len - 1, len - 2);
256 }
257 }
258
259 pub fn over(&mut self) {
261 let len = self.values.len();
262 if len >= 2 {
263 let a = self.values[len - 2].clone();
264 self.values.push(a);
265 }
266 }
267
268 pub fn rot(&mut self) {
270 let len = self.values.len();
271 if len >= 3 {
272 let a = self.values.remove(len - 3);
273 self.values.push(a);
274 }
275 }
276}
277
278impl Default for RegisterContext {
279 fn default() -> Self {
280 Self::new()
281 }
282}
283
284const SPECIALIZABLE_OPS: &[&str] = &[
294 "i.+",
296 "i.add",
297 "i.-",
298 "i.subtract",
299 "i.*",
300 "i.multiply",
301 "i./",
302 "i.divide",
303 "i.%",
304 "i.mod",
305 "band",
307 "bor",
308 "bxor",
309 "bnot",
310 "shl",
311 "shr",
312 "popcount",
314 "clz",
315 "ctz",
316 "int->float",
318 "float->int",
319 "and",
321 "or",
322 "not",
323 "i.<",
325 "i.lt",
326 "i.>",
327 "i.gt",
328 "i.<=",
329 "i.lte",
330 "i.>=",
331 "i.gte",
332 "i.=",
333 "i.eq",
334 "i.<>",
335 "i.neq",
336 "f.+",
338 "f.add",
339 "f.-",
340 "f.subtract",
341 "f.*",
342 "f.multiply",
343 "f./",
344 "f.divide",
345 "f.<",
347 "f.lt",
348 "f.>",
349 "f.gt",
350 "f.<=",
351 "f.lte",
352 "f.>=",
353 "f.gte",
354 "f.=",
355 "f.eq",
356 "f.<>",
357 "f.neq",
358 "dup",
360 "drop",
361 "swap",
362 "over",
363 "rot",
364 "nip",
365 "tuck",
366 "pick",
367 "roll",
368];
369
370impl CodeGen {
371 pub fn can_specialize(&self, word: &WordDef) -> Option<SpecSignature> {
373 let effect = word.effect.as_ref()?;
375
376 if !effect.is_pure() {
378 return None;
379 }
380
381 let inputs = Self::extract_register_types(&effect.inputs)?;
383 let outputs = Self::extract_register_types(&effect.outputs)?;
384
385 if inputs.is_empty() && outputs.is_empty() {
387 return None;
388 }
389
390 if outputs.is_empty() {
392 return None;
393 }
394
395 if !self.is_body_specializable(&word.body, &word.name) {
397 return None;
398 }
399
400 Some(SpecSignature { inputs, outputs })
401 }
402
403 fn extract_register_types(stack: &StackType) -> Option<Vec<RegisterType>> {
409 let mut types = Vec::new();
410 let mut current = stack;
411
412 loop {
413 match current {
414 StackType::Empty => break,
415 StackType::RowVar(_) => {
416 break;
420 }
421 StackType::Cons { rest, top } => {
422 let reg_ty = RegisterType::from_type(top)?;
423 types.push(reg_ty);
424 current = rest;
425 }
426 }
427 }
428
429 types.reverse();
431 Some(types)
432 }
433
434 fn is_body_specializable(&self, body: &[Statement], word_name: &str) -> bool {
439 let mut prev_was_int_literal = false;
440 for stmt in body {
441 if !self.is_statement_specializable(stmt, word_name, prev_was_int_literal) {
442 return false;
443 }
444 prev_was_int_literal = matches!(stmt, Statement::IntLiteral(_));
446 }
447 true
448 }
449
450 fn is_statement_specializable(
455 &self,
456 stmt: &Statement,
457 word_name: &str,
458 prev_was_int_literal: bool,
459 ) -> bool {
460 match stmt {
461 Statement::IntLiteral(_) => true,
463
464 Statement::FloatLiteral(_) => true,
466
467 Statement::BoolLiteral(_) => true,
469
470 Statement::StringLiteral(_) => false,
472
473 Statement::Symbol(_) => false,
475
476 Statement::Quotation { .. } => false,
478
479 Statement::Match { .. } => false,
481
482 Statement::WordCall { name, .. } => {
484 if name == word_name {
486 return true;
487 }
488
489 if (name == "pick" || name == "roll") && !prev_was_int_literal {
492 return false;
493 }
494
495 if SPECIALIZABLE_OPS.contains(&name.as_str()) {
497 return true;
498 }
499
500 if self.specialized_words.contains_key(name) {
502 return true;
503 }
504
505 false
507 }
508
509 Statement::If {
511 then_branch,
512 else_branch,
513 } => {
514 if !self.is_body_specializable(then_branch, word_name) {
515 return false;
516 }
517 if let Some(else_stmts) = else_branch
518 && !self.is_body_specializable(else_stmts, word_name)
519 {
520 return false;
521 }
522 true
523 }
524 }
525 }
526
527 pub fn codegen_specialized_word(
545 &mut self,
546 word: &WordDef,
547 sig: &SpecSignature,
548 ) -> Result<(), CodeGenError> {
549 let base_name = format!("seq_{}", mangle_name(&word.name));
550 let spec_name = format!("{}{}", base_name, sig.suffix());
551
552 let return_type = if sig.outputs.len() == 1 {
556 sig.outputs[0].llvm_type().to_string()
557 } else {
558 let types: Vec<_> = sig.outputs.iter().map(|t| t.llvm_type()).collect();
560 format!("{{ {} }}", types.join(", "))
561 };
562
563 let params: Vec<String> = sig
565 .inputs
566 .iter()
567 .enumerate()
568 .map(|(i, ty)| format!("{} %arg{}", ty.llvm_type(), i))
569 .collect();
570
571 writeln!(
572 &mut self.output,
573 "define {} @{}({}) {{",
574 return_type,
575 spec_name,
576 params.join(", ")
577 )?;
578 writeln!(&mut self.output, "entry:")?;
579
580 let initial_params: Vec<(String, RegisterType)> = sig
582 .inputs
583 .iter()
584 .enumerate()
585 .map(|(i, ty)| (format!("arg{}", i), *ty))
586 .collect();
587 let mut ctx = RegisterContext::from_params(&initial_params);
588
589 let body_len = word.body.len();
591 let mut prev_int_literal: Option<i64> = None;
592 for (i, stmt) in word.body.iter().enumerate() {
593 let is_last = i == body_len - 1;
594 self.codegen_specialized_statement(
595 &mut ctx,
596 stmt,
597 &word.name,
598 sig,
599 is_last,
600 &mut prev_int_literal,
601 )?;
602 }
603
604 writeln!(&mut self.output, "}}")?;
605 writeln!(&mut self.output)?;
606
607 self.specialized_words
609 .insert(word.name.clone(), sig.clone());
610
611 Ok(())
612 }
613
614 fn codegen_specialized_statement(
616 &mut self,
617 ctx: &mut RegisterContext,
618 stmt: &Statement,
619 word_name: &str,
620 sig: &SpecSignature,
621 is_last: bool,
622 prev_int_literal: &mut Option<i64>,
623 ) -> Result<(), CodeGenError> {
624 let prev_int = *prev_int_literal;
626 *prev_int_literal = None; match stmt {
629 Statement::IntLiteral(n) => {
630 let var = self.fresh_temp();
631 writeln!(&mut self.output, " %{} = add i64 0, {}", var, n)?;
632 ctx.push(var, RegisterType::I64);
633 *prev_int_literal = Some(*n); }
635
636 Statement::FloatLiteral(f) => {
637 let var = self.fresh_temp();
638 let bits = f.to_bits();
643 writeln!(
644 &mut self.output,
645 " %{} = bitcast i64 {} to double",
646 var, bits
647 )?;
648 ctx.push(var, RegisterType::Double);
649 }
650
651 Statement::BoolLiteral(b) => {
652 let var = self.fresh_temp();
653 let val = if *b { 1 } else { 0 };
654 writeln!(&mut self.output, " %{} = add i64 0, {}", var, val)?;
655 ctx.push(var, RegisterType::I64);
656 }
657
658 Statement::WordCall { name, .. } => {
659 self.codegen_specialized_word_call(ctx, name, word_name, sig, is_last, prev_int)?;
660 }
661
662 Statement::If {
663 then_branch,
664 else_branch,
665 } => {
666 self.codegen_specialized_if(
667 ctx,
668 then_branch,
669 else_branch.as_ref(),
670 word_name,
671 sig,
672 is_last,
673 )?;
674 }
675
676 Statement::StringLiteral(_)
678 | Statement::Symbol(_)
679 | Statement::Quotation { .. }
680 | Statement::Match { .. } => {
681 return Err(CodeGenError::Logic(format!(
682 "Non-specializable statement in specialized word: {:?}",
683 stmt
684 )));
685 }
686 }
687
688 let already_returns = match stmt {
691 Statement::If { .. } => true,
692 Statement::WordCall { name, .. } if name == word_name => true,
693 _ => false,
694 };
695 if is_last && !already_returns {
696 self.emit_specialized_return(ctx, sig)?;
697 }
698
699 Ok(())
700 }
701
702 fn codegen_specialized_word_call(
704 &mut self,
705 ctx: &mut RegisterContext,
706 name: &str,
707 word_name: &str,
708 sig: &SpecSignature,
709 is_last: bool,
710 prev_int: Option<i64>,
711 ) -> Result<(), CodeGenError> {
712 match name {
713 "dup" => ctx.dup(),
715 "drop" => ctx.drop(),
716 "swap" => ctx.swap(),
717 "over" => ctx.over(),
718 "rot" => ctx.rot(),
719 "nip" => {
720 ctx.swap();
722 ctx.drop();
723 }
724 "tuck" => {
725 ctx.dup();
727 let b = ctx.pop().unwrap();
728 let b2 = ctx.pop().unwrap();
729 let a = ctx.pop().unwrap();
730 ctx.push(b.0, b.1);
731 ctx.push(a.0, a.1);
732 ctx.push(b2.0, b2.1);
733 }
734 "pick" => {
735 let n = prev_int.ok_or_else(|| {
738 CodeGenError::Logic("pick requires constant N in specialized mode".to_string())
739 })?;
740 if n < 0 {
741 return Err(CodeGenError::Logic(format!(
742 "pick requires non-negative N, got {}",
743 n
744 )));
745 }
746 let n = n as usize;
747 ctx.pop();
749 let len = ctx.values.len();
751 if n >= len {
752 return Err(CodeGenError::Logic(format!(
753 "pick {} but only {} values in context",
754 n, len
755 )));
756 }
757 let (var, ty) = ctx.values[len - 1 - n].clone();
758 ctx.push(var, ty);
759 }
760 "roll" => {
761 let n = prev_int.ok_or_else(|| {
764 CodeGenError::Logic("roll requires constant N in specialized mode".to_string())
765 })?;
766 if n < 0 {
767 return Err(CodeGenError::Logic(format!(
768 "roll requires non-negative N, got {}",
769 n
770 )));
771 }
772 let n = n as usize;
773 ctx.pop();
775 let len = ctx.values.len();
777 if n >= len {
778 return Err(CodeGenError::Logic(format!(
779 "roll {} but only {} values in context",
780 n, len
781 )));
782 }
783 if n > 0 {
784 let val = ctx.values.remove(len - 1 - n);
785 ctx.values.push(val);
786 }
787 }
789
790 "i.+" | "i.add" => {
793 let (b, _) = ctx.pop().unwrap();
794 let (a, _) = ctx.pop().unwrap();
795 let result = self.fresh_temp();
796 writeln!(&mut self.output, " %{} = add i64 %{}, %{}", result, a, b)?;
797 ctx.push(result, RegisterType::I64);
798 }
799 "i.-" | "i.subtract" => {
800 let (b, _) = ctx.pop().unwrap();
801 let (a, _) = ctx.pop().unwrap();
802 let result = self.fresh_temp();
803 writeln!(&mut self.output, " %{} = sub i64 %{}, %{}", result, a, b)?;
804 ctx.push(result, RegisterType::I64);
805 }
806 "i.*" | "i.multiply" => {
807 let (b, _) = ctx.pop().unwrap();
808 let (a, _) = ctx.pop().unwrap();
809 let result = self.fresh_temp();
810 writeln!(&mut self.output, " %{} = mul i64 %{}, %{}", result, a, b)?;
811 ctx.push(result, RegisterType::I64);
812 }
813 "i./" | "i.divide" => {
814 self.emit_specialized_safe_div(ctx, "sdiv")?;
815 }
816 "i.%" | "i.mod" => {
817 self.emit_specialized_safe_div(ctx, "srem")?;
818 }
819
820 "band" => {
822 let (b, _) = ctx.pop().unwrap();
823 let (a, _) = ctx.pop().unwrap();
824 let result = self.fresh_temp();
825 writeln!(&mut self.output, " %{} = and i64 %{}, %{}", result, a, b)?;
826 ctx.push(result, RegisterType::I64);
827 }
828 "bor" => {
829 let (b, _) = ctx.pop().unwrap();
830 let (a, _) = ctx.pop().unwrap();
831 let result = self.fresh_temp();
832 writeln!(&mut self.output, " %{} = or i64 %{}, %{}", result, a, b)?;
833 ctx.push(result, RegisterType::I64);
834 }
835 "bxor" => {
836 let (b, _) = ctx.pop().unwrap();
837 let (a, _) = ctx.pop().unwrap();
838 let result = self.fresh_temp();
839 writeln!(&mut self.output, " %{} = xor i64 %{}, %{}", result, a, b)?;
840 ctx.push(result, RegisterType::I64);
841 }
842 "bnot" => {
843 let (a, _) = ctx.pop().unwrap();
844 let result = self.fresh_temp();
845 writeln!(&mut self.output, " %{} = xor i64 %{}, -1", result, a)?;
847 ctx.push(result, RegisterType::I64);
848 }
849 "shl" => {
850 self.emit_specialized_safe_shift(ctx, true)?;
851 }
852 "shr" => {
853 self.emit_specialized_safe_shift(ctx, false)?;
854 }
855
856 "popcount" => {
858 let (a, _) = ctx.pop().unwrap();
859 let result = self.fresh_temp();
860 writeln!(
861 &mut self.output,
862 " %{} = call i64 @llvm.ctpop.i64(i64 %{})",
863 result, a
864 )?;
865 ctx.push(result, RegisterType::I64);
866 }
867 "clz" => {
868 let (a, _) = ctx.pop().unwrap();
869 let result = self.fresh_temp();
870 writeln!(
872 &mut self.output,
873 " %{} = call i64 @llvm.ctlz.i64(i64 %{}, i1 false)",
874 result, a
875 )?;
876 ctx.push(result, RegisterType::I64);
877 }
878 "ctz" => {
879 let (a, _) = ctx.pop().unwrap();
880 let result = self.fresh_temp();
881 writeln!(
883 &mut self.output,
884 " %{} = call i64 @llvm.cttz.i64(i64 %{}, i1 false)",
885 result, a
886 )?;
887 ctx.push(result, RegisterType::I64);
888 }
889
890 "int->float" => {
892 let (a, _) = ctx.pop().unwrap();
893 let result = self.fresh_temp();
894 writeln!(
895 &mut self.output,
896 " %{} = sitofp i64 %{} to double",
897 result, a
898 )?;
899 ctx.push(result, RegisterType::Double);
900 }
901 "float->int" => {
902 let (a, _) = ctx.pop().unwrap();
903 let result = self.fresh_temp();
904 writeln!(
905 &mut self.output,
906 " %{} = fptosi double %{} to i64",
907 result, a
908 )?;
909 ctx.push(result, RegisterType::I64);
910 }
911
912 "and" => {
914 let (b, _) = ctx.pop().unwrap();
915 let (a, _) = ctx.pop().unwrap();
916 let result = self.fresh_temp();
917 writeln!(&mut self.output, " %{} = and i64 %{}, %{}", result, a, b)?;
918 ctx.push(result, RegisterType::I64);
919 }
920 "or" => {
921 let (b, _) = ctx.pop().unwrap();
922 let (a, _) = ctx.pop().unwrap();
923 let result = self.fresh_temp();
924 writeln!(&mut self.output, " %{} = or i64 %{}, %{}", result, a, b)?;
925 ctx.push(result, RegisterType::I64);
926 }
927 "not" => {
928 let (a, _) = ctx.pop().unwrap();
929 let result = self.fresh_temp();
930 writeln!(&mut self.output, " %{} = xor i64 %{}, 1", result, a)?;
932 ctx.push(result, RegisterType::I64);
933 }
934
935 "i.<" | "i.lt" => self.emit_specialized_icmp(ctx, "slt")?,
937 "i.>" | "i.gt" => self.emit_specialized_icmp(ctx, "sgt")?,
938 "i.<=" | "i.lte" => self.emit_specialized_icmp(ctx, "sle")?,
939 "i.>=" | "i.gte" => self.emit_specialized_icmp(ctx, "sge")?,
940 "i.=" | "i.eq" => self.emit_specialized_icmp(ctx, "eq")?,
941 "i.<>" | "i.neq" => self.emit_specialized_icmp(ctx, "ne")?,
942
943 "f.+" | "f.add" => {
945 let (b, _) = ctx.pop().unwrap();
946 let (a, _) = ctx.pop().unwrap();
947 let result = self.fresh_temp();
948 writeln!(
949 &mut self.output,
950 " %{} = fadd double %{}, %{}",
951 result, a, b
952 )?;
953 ctx.push(result, RegisterType::Double);
954 }
955 "f.-" | "f.subtract" => {
956 let (b, _) = ctx.pop().unwrap();
957 let (a, _) = ctx.pop().unwrap();
958 let result = self.fresh_temp();
959 writeln!(
960 &mut self.output,
961 " %{} = fsub double %{}, %{}",
962 result, a, b
963 )?;
964 ctx.push(result, RegisterType::Double);
965 }
966 "f.*" | "f.multiply" => {
967 let (b, _) = ctx.pop().unwrap();
968 let (a, _) = ctx.pop().unwrap();
969 let result = self.fresh_temp();
970 writeln!(
971 &mut self.output,
972 " %{} = fmul double %{}, %{}",
973 result, a, b
974 )?;
975 ctx.push(result, RegisterType::Double);
976 }
977 "f./" | "f.divide" => {
978 let (b, _) = ctx.pop().unwrap();
979 let (a, _) = ctx.pop().unwrap();
980 let result = self.fresh_temp();
981 writeln!(
982 &mut self.output,
983 " %{} = fdiv double %{}, %{}",
984 result, a, b
985 )?;
986 ctx.push(result, RegisterType::Double);
987 }
988
989 "f.<" | "f.lt" => self.emit_specialized_fcmp(ctx, "olt")?,
991 "f.>" | "f.gt" => self.emit_specialized_fcmp(ctx, "ogt")?,
992 "f.<=" | "f.lte" => self.emit_specialized_fcmp(ctx, "ole")?,
993 "f.>=" | "f.gte" => self.emit_specialized_fcmp(ctx, "oge")?,
994 "f.=" | "f.eq" => self.emit_specialized_fcmp(ctx, "oeq")?,
995 "f.<>" | "f.neq" => self.emit_specialized_fcmp(ctx, "one")?,
996
997 _ if name == word_name => {
999 self.emit_specialized_recursive_call(ctx, word_name, sig, is_last)?;
1000 }
1001
1002 _ if self.specialized_words.contains_key(name) => {
1004 self.emit_specialized_word_dispatch(ctx, name)?;
1005 }
1006
1007 _ => {
1008 return Err(CodeGenError::Logic(format!(
1009 "Unhandled operation in specialized codegen: {}",
1010 name
1011 )));
1012 }
1013 }
1014 Ok(())
1015 }
1016
1017 fn emit_specialized_icmp(
1019 &mut self,
1020 ctx: &mut RegisterContext,
1021 cmp_op: &str,
1022 ) -> Result<(), CodeGenError> {
1023 let (b, _) = ctx.pop().unwrap();
1024 let (a, _) = ctx.pop().unwrap();
1025 let cmp_result = self.fresh_temp();
1026 let result = self.fresh_temp();
1027 writeln!(
1028 &mut self.output,
1029 " %{} = icmp {} i64 %{}, %{}",
1030 cmp_result, cmp_op, a, b
1031 )?;
1032 writeln!(
1033 &mut self.output,
1034 " %{} = zext i1 %{} to i64",
1035 result, cmp_result
1036 )?;
1037 ctx.push(result, RegisterType::I64);
1038 Ok(())
1039 }
1040
1041 fn emit_specialized_fcmp(
1043 &mut self,
1044 ctx: &mut RegisterContext,
1045 cmp_op: &str,
1046 ) -> Result<(), CodeGenError> {
1047 let (b, _) = ctx.pop().unwrap();
1048 let (a, _) = ctx.pop().unwrap();
1049 let cmp_result = self.fresh_temp();
1050 let result = self.fresh_temp();
1051 writeln!(
1052 &mut self.output,
1053 " %{} = fcmp {} double %{}, %{}",
1054 cmp_result, cmp_op, a, b
1055 )?;
1056 writeln!(
1057 &mut self.output,
1058 " %{} = zext i1 %{} to i64",
1059 result, cmp_result
1060 )?;
1061 ctx.push(result, RegisterType::I64);
1062 Ok(())
1063 }
1064
1065 fn emit_specialized_safe_div(
1074 &mut self,
1075 ctx: &mut RegisterContext,
1076 op: &str, ) -> Result<(), CodeGenError> {
1078 let (b, _) = ctx.pop().unwrap(); let (a, _) = ctx.pop().unwrap(); let is_zero = self.fresh_temp();
1083 writeln!(&mut self.output, " %{} = icmp eq i64 %{}, 0", is_zero, b)?;
1084
1085 let (check_overflow, is_overflow) = if op == "sdiv" {
1088 let is_int_min = self.fresh_temp();
1089 let is_neg_one = self.fresh_temp();
1090 let is_overflow = self.fresh_temp();
1091
1092 writeln!(
1094 &mut self.output,
1095 " %{} = icmp eq i64 %{}, -9223372036854775808",
1096 is_int_min, a
1097 )?;
1098 writeln!(
1100 &mut self.output,
1101 " %{} = icmp eq i64 %{}, -1",
1102 is_neg_one, b
1103 )?;
1104 writeln!(
1106 &mut self.output,
1107 " %{} = and i1 %{}, %{}",
1108 is_overflow, is_int_min, is_neg_one
1109 )?;
1110 (true, is_overflow)
1111 } else {
1112 (false, String::new())
1113 };
1114
1115 let ok_label = self.fresh_block("div_ok");
1117 let fail_label = self.fresh_block("div_fail");
1118 let merge_label = self.fresh_block("div_merge");
1119 let overflow_label = if check_overflow {
1120 self.fresh_block("div_overflow")
1121 } else {
1122 String::new()
1123 };
1124
1125 writeln!(
1127 &mut self.output,
1128 " br i1 %{}, label %{}, label %{}",
1129 is_zero,
1130 fail_label,
1131 if check_overflow {
1132 &overflow_label
1133 } else {
1134 &ok_label
1135 }
1136 )?;
1137
1138 if check_overflow {
1140 writeln!(&mut self.output, "{}:", overflow_label)?;
1141 writeln!(
1142 &mut self.output,
1143 " br i1 %{}, label %{}, label %{}",
1144 is_overflow, merge_label, ok_label
1145 )?;
1146 }
1147
1148 writeln!(&mut self.output, "{}:", ok_label)?;
1150 let ok_result = self.fresh_temp();
1151 writeln!(
1152 &mut self.output,
1153 " %{} = {} i64 %{}, %{}",
1154 ok_result, op, a, b
1155 )?;
1156 writeln!(&mut self.output, " br label %{}", merge_label)?;
1157
1158 writeln!(&mut self.output, "{}:", fail_label)?;
1160 writeln!(&mut self.output, " br label %{}", merge_label)?;
1161
1162 writeln!(&mut self.output, "{}:", merge_label)?;
1164 let result_phi = self.fresh_temp();
1165 let success_phi = self.fresh_temp();
1166
1167 if check_overflow {
1168 writeln!(
1171 &mut self.output,
1172 " %{} = phi i64 [ %{}, %{} ], [ 0, %{} ], [ -9223372036854775808, %{} ]",
1173 result_phi, ok_result, ok_label, fail_label, overflow_label
1174 )?;
1175 writeln!(
1176 &mut self.output,
1177 " %{} = phi i64 [ 1, %{} ], [ 0, %{} ], [ 1, %{} ]",
1178 success_phi, ok_label, fail_label, overflow_label
1179 )?;
1180 } else {
1181 writeln!(
1183 &mut self.output,
1184 " %{} = phi i64 [ %{}, %{} ], [ 0, %{} ]",
1185 result_phi, ok_result, ok_label, fail_label
1186 )?;
1187 writeln!(
1188 &mut self.output,
1189 " %{} = phi i64 [ 1, %{} ], [ 0, %{} ]",
1190 success_phi, ok_label, fail_label
1191 )?;
1192 }
1193
1194 ctx.push(result_phi, RegisterType::I64);
1197 ctx.push(success_phi, RegisterType::I64);
1198
1199 Ok(())
1200 }
1201
1202 fn emit_specialized_safe_shift(
1207 &mut self,
1208 ctx: &mut RegisterContext,
1209 is_left: bool, ) -> Result<(), CodeGenError> {
1211 let (b, _) = ctx.pop().unwrap(); let (a, _) = ctx.pop().unwrap(); let is_negative = self.fresh_temp();
1216 writeln!(
1217 &mut self.output,
1218 " %{} = icmp slt i64 %{}, 0",
1219 is_negative, b
1220 )?;
1221
1222 let is_too_large = self.fresh_temp();
1224 writeln!(
1225 &mut self.output,
1226 " %{} = icmp sge i64 %{}, 64",
1227 is_too_large, b
1228 )?;
1229
1230 let is_invalid = self.fresh_temp();
1232 writeln!(
1233 &mut self.output,
1234 " %{} = or i1 %{}, %{}",
1235 is_invalid, is_negative, is_too_large
1236 )?;
1237
1238 let safe_count = self.fresh_temp();
1240 writeln!(
1241 &mut self.output,
1242 " %{} = select i1 %{}, i64 0, i64 %{}",
1243 safe_count, is_invalid, b
1244 )?;
1245
1246 let shift_result = self.fresh_temp();
1248 let op = if is_left { "shl" } else { "lshr" };
1249 writeln!(
1250 &mut self.output,
1251 " %{} = {} i64 %{}, %{}",
1252 shift_result, op, a, safe_count
1253 )?;
1254
1255 let result = self.fresh_temp();
1257 writeln!(
1258 &mut self.output,
1259 " %{} = select i1 %{}, i64 0, i64 %{}",
1260 result, is_invalid, shift_result
1261 )?;
1262
1263 ctx.push(result, RegisterType::I64);
1264 Ok(())
1265 }
1266
1267 fn emit_specialized_recursive_call(
1275 &mut self,
1276 ctx: &mut RegisterContext,
1277 word_name: &str,
1278 sig: &SpecSignature,
1279 is_tail: bool,
1280 ) -> Result<(), CodeGenError> {
1281 let spec_name = format!("seq_{}{}", mangle_name(word_name), sig.suffix());
1282
1283 if ctx.values.len() < sig.inputs.len() {
1285 return Err(CodeGenError::Logic(format!(
1286 "Not enough values in context for recursive call to {}: need {}, have {}",
1287 word_name,
1288 sig.inputs.len(),
1289 ctx.values.len()
1290 )));
1291 }
1292
1293 let mut args = Vec::new();
1295 for _ in 0..sig.inputs.len() {
1296 args.push(ctx.pop().unwrap());
1297 }
1298 args.reverse(); let arg_strs: Vec<String> = args
1302 .iter()
1303 .map(|(var, ty)| format!("{} %{}", ty.llvm_type(), var))
1304 .collect();
1305
1306 let return_type = sig.llvm_return_type();
1307
1308 if is_tail {
1309 let result = self.fresh_temp();
1311 writeln!(
1312 &mut self.output,
1313 " %{} = musttail call {} @{}({})",
1314 result,
1315 return_type,
1316 spec_name,
1317 arg_strs.join(", ")
1318 )?;
1319 writeln!(&mut self.output, " ret {} %{}", return_type, result)?;
1320 } else {
1321 let result = self.fresh_temp();
1323 writeln!(
1324 &mut self.output,
1325 " %{} = call {} @{}({})",
1326 result,
1327 return_type,
1328 spec_name,
1329 arg_strs.join(", ")
1330 )?;
1331
1332 if sig.outputs.len() == 1 {
1333 ctx.push(result, sig.outputs[0]);
1335 } else {
1336 for (i, out_ty) in sig.outputs.iter().enumerate() {
1338 let extracted = self.fresh_temp();
1339 writeln!(
1340 &mut self.output,
1341 " %{} = extractvalue {} %{}, {}",
1342 extracted, return_type, result, i
1343 )?;
1344 ctx.push(extracted, *out_ty);
1345 }
1346 }
1347 }
1348
1349 Ok(())
1350 }
1351
1352 fn emit_specialized_word_dispatch(
1354 &mut self,
1355 ctx: &mut RegisterContext,
1356 name: &str,
1357 ) -> Result<(), CodeGenError> {
1358 let sig = self
1359 .specialized_words
1360 .get(name)
1361 .ok_or_else(|| CodeGenError::Logic(format!("Unknown specialized word: {}", name)))?
1362 .clone();
1363
1364 let spec_name = format!("seq_{}{}", mangle_name(name), sig.suffix());
1365
1366 let mut args = Vec::new();
1368 for _ in 0..sig.inputs.len() {
1369 args.push(ctx.pop().unwrap());
1370 }
1371 args.reverse();
1372
1373 let arg_strs: Vec<String> = args
1375 .iter()
1376 .map(|(var, ty)| format!("{} %{}", ty.llvm_type(), var))
1377 .collect();
1378
1379 let return_type = sig.llvm_return_type();
1380
1381 let result = self.fresh_temp();
1382 writeln!(
1383 &mut self.output,
1384 " %{} = call {} @{}({})",
1385 result,
1386 return_type,
1387 spec_name,
1388 arg_strs.join(", ")
1389 )?;
1390
1391 if sig.outputs.len() == 1 {
1392 ctx.push(result, sig.outputs[0]);
1394 } else {
1395 for (i, out_ty) in sig.outputs.iter().enumerate() {
1397 let extracted = self.fresh_temp();
1398 writeln!(
1399 &mut self.output,
1400 " %{} = extractvalue {} %{}, {}",
1401 extracted, return_type, result, i
1402 )?;
1403 ctx.push(extracted, *out_ty);
1404 }
1405 }
1406
1407 Ok(())
1408 }
1409
1410 fn emit_specialized_return(
1412 &mut self,
1413 ctx: &RegisterContext,
1414 sig: &SpecSignature,
1415 ) -> Result<(), CodeGenError> {
1416 let output_count = sig.outputs.len();
1417
1418 if output_count == 0 {
1419 writeln!(&mut self.output, " ret void")?;
1420 } else if output_count == 1 {
1421 let (var, ty) = ctx
1422 .values
1423 .last()
1424 .ok_or_else(|| CodeGenError::Logic("Empty context at return".to_string()))?;
1425 writeln!(&mut self.output, " ret {} %{}", ty.llvm_type(), var)?;
1426 } else {
1427 if ctx.values.len() < output_count {
1430 return Err(CodeGenError::Logic(format!(
1431 "Not enough values for multi-output return: need {}, have {}",
1432 output_count,
1433 ctx.values.len()
1434 )));
1435 }
1436
1437 let start_idx = ctx.values.len() - output_count;
1439 let return_values: Vec<_> = ctx.values[start_idx..].to_vec();
1440
1441 let struct_type = sig.llvm_return_type();
1443
1444 let mut current_struct = "undef".to_string();
1446 for (i, (var, ty)) in return_values.iter().enumerate() {
1447 let new_struct = self.fresh_temp();
1448 writeln!(
1449 &mut self.output,
1450 " %{} = insertvalue {} {}, {} %{}, {}",
1451 new_struct,
1452 struct_type,
1453 current_struct,
1454 ty.llvm_type(),
1455 var,
1456 i
1457 )?;
1458 current_struct = format!("%{}", new_struct);
1459 }
1460
1461 writeln!(&mut self.output, " ret {} {}", struct_type, current_struct)?;
1462 }
1463 Ok(())
1464 }
1465
1466 fn codegen_specialized_if(
1468 &mut self,
1469 ctx: &mut RegisterContext,
1470 then_branch: &[Statement],
1471 else_branch: Option<&Vec<Statement>>,
1472 word_name: &str,
1473 sig: &SpecSignature,
1474 is_last: bool,
1475 ) -> Result<(), CodeGenError> {
1476 let (cond_var, _) = ctx
1478 .pop()
1479 .ok_or_else(|| CodeGenError::Logic("Empty context at if condition".to_string()))?;
1480
1481 let cmp_result = self.fresh_temp();
1483 writeln!(
1484 &mut self.output,
1485 " %{} = icmp ne i64 %{}, 0",
1486 cmp_result, cond_var
1487 )?;
1488
1489 let then_label = self.fresh_block("if_then");
1491 let else_label = self.fresh_block("if_else");
1492 let merge_label = self.fresh_block("if_merge");
1493
1494 writeln!(
1495 &mut self.output,
1496 " br i1 %{}, label %{}, label %{}",
1497 cmp_result, then_label, else_label
1498 )?;
1499
1500 writeln!(&mut self.output, "{}:", then_label)?;
1502 let mut then_ctx = ctx.clone();
1503 let mut then_prev_int: Option<i64> = None;
1504 for (i, stmt) in then_branch.iter().enumerate() {
1505 let is_stmt_last = i == then_branch.len() - 1 && is_last;
1506 self.codegen_specialized_statement(
1507 &mut then_ctx,
1508 stmt,
1509 word_name,
1510 sig,
1511 is_stmt_last,
1512 &mut then_prev_int,
1513 )?;
1514 }
1515 if is_last && then_branch.is_empty() {
1517 self.emit_specialized_return(&then_ctx, sig)?;
1518 }
1519 let then_result = then_ctx.values.last().cloned();
1520 let then_emitted_return = is_last;
1522 let then_pred = if then_emitted_return {
1523 None
1524 } else {
1525 writeln!(&mut self.output, " br label %{}", merge_label)?;
1526 Some(then_label.clone())
1527 };
1528
1529 writeln!(&mut self.output, "{}:", else_label)?;
1531 let mut else_ctx = ctx.clone();
1532 let mut else_prev_int: Option<i64> = None;
1533 if let Some(else_stmts) = else_branch {
1534 for (i, stmt) in else_stmts.iter().enumerate() {
1535 let is_stmt_last = i == else_stmts.len() - 1 && is_last;
1536 self.codegen_specialized_statement(
1537 &mut else_ctx,
1538 stmt,
1539 word_name,
1540 sig,
1541 is_stmt_last,
1542 &mut else_prev_int,
1543 )?;
1544 }
1545 }
1546 if is_last && (else_branch.is_none() || else_branch.as_ref().is_some_and(|b| b.is_empty()))
1548 {
1549 self.emit_specialized_return(&else_ctx, sig)?;
1550 }
1551 let else_result = else_ctx.values.last().cloned();
1552 let else_emitted_return = is_last;
1554 let else_pred = if else_emitted_return {
1555 None
1556 } else {
1557 writeln!(&mut self.output, " br label %{}", merge_label)?;
1558 Some(else_label.clone())
1559 };
1560
1561 if then_pred.is_some() || else_pred.is_some() {
1563 writeln!(&mut self.output, "{}:", merge_label)?;
1564
1565 if let (
1567 Some(then_p),
1568 Some(else_p),
1569 Some((then_var, then_ty)),
1570 Some((else_var, else_ty)),
1571 ) = (&then_pred, &else_pred, &then_result, &else_result)
1572 {
1573 if then_ty == else_ty {
1574 let phi_result = self.fresh_temp();
1575 writeln!(
1576 &mut self.output,
1577 " %{} = phi {} [ %{}, %{} ], [ %{}, %{} ]",
1578 phi_result,
1579 then_ty.llvm_type(),
1580 then_var,
1581 then_p,
1582 else_var,
1583 else_p
1584 )?;
1585 ctx.values.clear();
1586 ctx.push(phi_result, *then_ty);
1587 }
1588 } else if let (Some(_), Some((then_var, then_ty))) = (&then_pred, &then_result) {
1589 ctx.values.clear();
1591 ctx.push(then_var.clone(), *then_ty);
1592 } else if let (Some(_), Some((else_var, else_ty))) = (&else_pred, &else_result) {
1593 ctx.values.clear();
1595 ctx.push(else_var.clone(), *else_ty);
1596 }
1597
1598 if is_last && (then_pred.is_some() || else_pred.is_some()) {
1600 self.emit_specialized_return(ctx, sig)?;
1601 }
1602 }
1603
1604 Ok(())
1605 }
1606}
1607
1608#[cfg(test)]
1609mod tests {
1610 use super::*;
1611
1612 #[test]
1613 fn test_register_type_from_type() {
1614 assert_eq!(RegisterType::from_type(&Type::Int), Some(RegisterType::I64));
1615 assert_eq!(
1616 RegisterType::from_type(&Type::Bool),
1617 Some(RegisterType::I64)
1618 );
1619 assert_eq!(
1620 RegisterType::from_type(&Type::Float),
1621 Some(RegisterType::Double)
1622 );
1623 assert_eq!(RegisterType::from_type(&Type::String), None);
1624 }
1625
1626 #[test]
1627 fn test_spec_signature_suffix() {
1628 let sig = SpecSignature {
1629 inputs: vec![RegisterType::I64],
1630 outputs: vec![RegisterType::I64],
1631 };
1632 assert_eq!(sig.suffix(), "_i64");
1633
1634 let sig2 = SpecSignature {
1635 inputs: vec![RegisterType::Double],
1636 outputs: vec![RegisterType::Double],
1637 };
1638 assert_eq!(sig2.suffix(), "_f64");
1639 }
1640
1641 #[test]
1642 fn test_register_context_stack_ops() {
1643 let mut ctx = RegisterContext::new();
1644 ctx.push("a".to_string(), RegisterType::I64);
1645 ctx.push("b".to_string(), RegisterType::I64);
1646
1647 assert_eq!(ctx.len(), 2);
1648
1649 ctx.swap();
1651 assert_eq!(ctx.values[0].0, "b");
1652 assert_eq!(ctx.values[1].0, "a");
1653
1654 ctx.dup();
1656 assert_eq!(ctx.len(), 3);
1657 assert_eq!(ctx.values[2].0, "a");
1658
1659 ctx.drop();
1661 assert_eq!(ctx.len(), 2);
1662 }
1663}