1use crate::ast::{Program, Statement};
44
45pub fn lower_literal_if_combinators(program: &mut Program) {
49 for word in &mut program.words {
50 rewrite_statements(&mut word.body);
51 }
52}
53
54fn rewrite_statements(statements: &mut Vec<Statement>) {
55 let mut i = 0;
56 while i < statements.len() {
57 if i + 2 < statements.len() && is_inline_triple(&statements[i..i + 3]) {
58 let if_span = match &statements[i + 2] {
59 Statement::WordCall { span, .. } => span.clone(),
60 _ => None,
61 };
62 statements.remove(i + 2);
63 let mut else_quot = statements.remove(i + 1);
64 let mut then_quot = statements.remove(i);
65
66 let mut then_body = match &mut then_quot {
67 Statement::Quotation { body, .. } => std::mem::take(body),
68 _ => panic!("normalize: is_inline_triple guard accepted a non-Quotation"),
69 };
70 let mut else_body = match &mut else_quot {
71 Statement::Quotation { body, .. } => std::mem::take(body),
72 _ => panic!("normalize: is_inline_triple guard accepted a non-Quotation"),
73 };
74 rewrite_statements(&mut then_body);
75 rewrite_statements(&mut else_body);
76
77 statements.insert(
78 i,
79 Statement::If {
80 then_branch: then_body,
81 else_branch: Some(else_body),
82 span: if_span,
83 },
84 );
85 i += 1;
86 continue;
87 }
88
89 match &mut statements[i] {
90 Statement::If {
91 then_branch,
92 else_branch,
93 ..
94 } => {
95 rewrite_statements(then_branch);
96 if let Some(eb) = else_branch.as_mut() {
97 rewrite_statements(eb);
98 }
99 }
100 Statement::Match { arms, .. } => {
101 for arm in arms {
102 rewrite_statements(&mut arm.body);
103 }
104 }
105 Statement::Quotation { body, .. } => {
106 rewrite_statements(body);
107 }
108 _ => {}
109 }
110 i += 1;
111 }
112}
113
114fn is_inline_triple(triple: &[Statement]) -> bool {
115 matches!(
116 (&triple[0], &triple[1], &triple[2]),
117 (
118 Statement::Quotation { .. },
119 Statement::Quotation { .. },
120 Statement::WordCall { name, .. },
121 ) if name == "if"
122 )
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use crate::ast::{Program, WordDef};
129
130 fn quot(id: usize, body: Vec<Statement>) -> Statement {
131 Statement::Quotation {
132 id,
133 body,
134 span: None,
135 }
136 }
137
138 fn word_call(name: &str) -> Statement {
139 Statement::WordCall {
140 name: name.to_string(),
141 span: None,
142 }
143 }
144
145 fn lower_body(body: Vec<Statement>) -> Vec<Statement> {
146 let mut program = Program {
147 includes: vec![],
148 unions: vec![],
149 words: vec![WordDef {
150 name: "test".to_string(),
151 effect: None,
152 body,
153 source: None,
154 allowed_lints: vec![],
155 }],
156 };
157 lower_literal_if_combinators(&mut program);
158 program.words.into_iter().next().unwrap().body
159 }
160
161 #[test]
162 fn rewrites_literal_triple_to_statement_if() {
163 let body = vec![
164 Statement::BoolLiteral(true),
165 quot(0, vec![Statement::IntLiteral(1)]),
166 quot(1, vec![Statement::IntLiteral(2)]),
167 word_call("if"),
168 ];
169 let lowered = lower_body(body);
170 assert_eq!(lowered.len(), 2);
171 assert!(matches!(lowered[0], Statement::BoolLiteral(true)));
172 match &lowered[1] {
173 Statement::If {
174 then_branch,
175 else_branch,
176 ..
177 } => {
178 assert_eq!(then_branch, &vec![Statement::IntLiteral(1)]);
179 assert_eq!(
180 else_branch.as_deref(),
181 Some(&[Statement::IntLiteral(2)][..])
182 );
183 }
184 other => panic!("expected Statement::If, got {:?}", other),
185 }
186 }
187
188 #[test]
189 fn rewrites_nested_triples() {
190 let inner_triple = vec![
193 Statement::BoolLiteral(true),
194 quot(2, vec![Statement::IntLiteral(10)]),
195 quot(3, vec![Statement::IntLiteral(20)]),
196 word_call("if"),
197 ];
198 let body = vec![
199 Statement::BoolLiteral(false),
200 quot(0, inner_triple),
201 quot(1, vec![Statement::IntLiteral(99)]),
202 word_call("if"),
203 ];
204 let lowered = lower_body(body);
205 assert_eq!(lowered.len(), 2);
206 match &lowered[1] {
207 Statement::If { then_branch, .. } => {
208 assert_eq!(then_branch.len(), 2);
209 assert!(matches!(then_branch[0], Statement::BoolLiteral(true)));
210 assert!(matches!(then_branch[1], Statement::If { .. }));
211 }
212 other => panic!("expected outer Statement::If, got {:?}", other),
213 }
214 }
215
216 #[test]
217 fn leaves_dynamic_dispatch_alone() {
218 let body = vec![
220 Statement::BoolLiteral(true),
221 quot(0, vec![Statement::IntLiteral(1)]),
222 word_call("my-word"),
223 word_call("if"),
224 ];
225 let original = body.clone();
226 let lowered = lower_body(body);
227 assert_eq!(lowered, original);
228 }
229
230 #[test]
231 fn leaves_non_if_word_call_alone() {
232 let body = vec![
234 Statement::IntLiteral(3),
235 quot(0, vec![Statement::IntLiteral(1)]),
236 quot(1, vec![Statement::IntLiteral(2)]),
237 word_call("times"),
238 ];
239 let original = body.clone();
240 let lowered = lower_body(body);
241 assert_eq!(lowered, original);
242 }
243
244 #[test]
245 fn recurses_into_quotation_body() {
246 let inner_triple = vec![
249 Statement::BoolLiteral(true),
250 quot(1, vec![Statement::IntLiteral(1)]),
251 quot(2, vec![Statement::IntLiteral(2)]),
252 word_call("if"),
253 ];
254 let body = vec![quot(0, inner_triple), word_call("my-word")];
255 let lowered = lower_body(body);
256 match &lowered[0] {
257 Statement::Quotation { body, .. } => {
258 assert_eq!(body.len(), 2);
259 assert!(matches!(body[1], Statement::If { .. }));
260 }
261 other => panic!("expected outer Quotation, got {:?}", other),
262 }
263 }
264}