1use crate::types::{Effect, StackType, Type};
6
7use super::{Program, Statement, WordDef};
8
9impl Program {
10 pub fn new() -> Self {
11 Program {
12 includes: Vec::new(),
13 unions: Vec::new(),
14 words: Vec::new(),
15 }
16 }
17
18 pub fn find_word(&self, name: &str) -> Option<&WordDef> {
19 self.words.iter().find(|w| w.name == name)
20 }
21
22 pub fn validate_word_calls(&self) -> Result<(), String> {
24 self.validate_word_calls_with_externals(&[])
25 }
26
27 pub fn validate_word_calls_with_externals(
32 &self,
33 external_words: &[&str],
34 ) -> Result<(), String> {
35 let builtins = [
38 "io.write",
40 "io.write-line",
41 "io.read-line",
42 "io.read-line+",
43 "io.read-n",
44 "int->string",
45 "symbol->string",
46 "string->symbol",
47 "args.count",
49 "args.at",
50 "file.slurp",
52 "file.exists?",
53 "file.for-each-line+",
54 "file.spit",
55 "file.append",
56 "file.delete",
57 "file.size",
58 "dir.exists?",
60 "dir.make",
61 "dir.delete",
62 "dir.list",
63 "string.concat",
65 "string.length",
66 "string.byte-length",
67 "string.char-at",
68 "string.substring",
69 "char->string",
70 "string.find",
71 "string.split",
72 "string.contains",
73 "string.starts-with",
74 "string.empty?",
75 "string.trim",
76 "string.chomp",
77 "string.to-upper",
78 "string.to-lower",
79 "string.equal?",
80 "string.join",
81 "string.json-escape",
82 "string->int",
83 "symbol.=",
85 "encoding.base64-encode",
87 "encoding.base64-decode",
88 "encoding.base64url-encode",
89 "encoding.base64url-decode",
90 "encoding.hex-encode",
91 "encoding.hex-decode",
92 "crypto.sha256",
94 "crypto.hmac-sha256",
95 "crypto.constant-time-eq",
96 "crypto.random-bytes",
97 "crypto.random-int",
98 "crypto.uuid4",
99 "crypto.aes-gcm-encrypt",
100 "crypto.aes-gcm-decrypt",
101 "crypto.pbkdf2-sha256",
102 "crypto.ed25519-keypair",
103 "crypto.ed25519-sign",
104 "crypto.ed25519-verify",
105 "http.get",
107 "http.post",
108 "http.put",
109 "http.delete",
110 "list.make",
112 "list.push",
113 "list.get",
114 "list.set",
115 "list.map",
116 "list.filter",
117 "list.fold",
118 "list.each",
119 "list.length",
120 "list.empty?",
121 "list.reverse",
122 "map.make",
124 "map.get",
125 "map.set",
126 "map.has?",
127 "map.remove",
128 "map.keys",
129 "map.values",
130 "map.size",
131 "map.empty?",
132 "map.each",
133 "map.fold",
134 "variant.field-count",
136 "variant.tag",
137 "variant.field-at",
138 "variant.append",
139 "variant.last",
140 "variant.init",
141 "variant.make-0",
142 "variant.make-1",
143 "variant.make-2",
144 "variant.make-3",
145 "variant.make-4",
146 "wrap-0",
148 "wrap-1",
149 "wrap-2",
150 "wrap-3",
151 "wrap-4",
152 "i.add",
154 "i.subtract",
155 "i.multiply",
156 "i.divide",
157 "i.modulo",
158 "i.+",
160 "i.-",
161 "i.*",
162 "i./",
163 "i.%",
164 "i.=",
166 "i.<",
167 "i.>",
168 "i.<=",
169 "i.>=",
170 "i.<>",
171 "i.eq",
173 "i.lt",
174 "i.gt",
175 "i.lte",
176 "i.gte",
177 "i.neq",
178 "dup",
180 "drop",
181 "swap",
182 "over",
183 "rot",
184 "nip",
185 "tuck",
186 "2dup",
187 "3drop",
188 "pick",
189 "roll",
190 ">aux",
192 "aux>",
193 "and",
195 "or",
196 "not",
197 "band",
199 "bor",
200 "bxor",
201 "bnot",
202 "i.neg",
203 "negate",
204 "+",
206 "-",
207 "*",
208 "/",
209 "%",
210 "=",
211 "<",
212 ">",
213 "<=",
214 ">=",
215 "<>",
216 "shl",
217 "shr",
218 "popcount",
219 "clz",
220 "ctz",
221 "int-bits",
222 "chan.make",
224 "chan.send",
225 "chan.receive",
226 "chan.close",
227 "chan.yield",
228 "call",
230 "dip",
232 "keep",
233 "bi",
234 "if",
235 "strand.spawn",
236 "strand.weave",
237 "strand.resume",
238 "strand.weave-cancel",
239 "yield",
240 "cond",
241 "tcp.listen",
243 "tcp.accept",
244 "tcp.read",
245 "tcp.write",
246 "tcp.close",
247 "os.getenv",
249 "os.home-dir",
250 "os.current-dir",
251 "os.path-exists",
252 "os.path-is-file",
253 "os.path-is-dir",
254 "os.path-join",
255 "os.path-parent",
256 "os.path-filename",
257 "os.exit",
258 "os.name",
259 "os.arch",
260 "signal.trap",
262 "signal.received?",
263 "signal.pending?",
264 "signal.default",
265 "signal.ignore",
266 "signal.clear",
267 "signal.SIGINT",
268 "signal.SIGTERM",
269 "signal.SIGHUP",
270 "signal.SIGPIPE",
271 "signal.SIGUSR1",
272 "signal.SIGUSR2",
273 "signal.SIGCHLD",
274 "signal.SIGALRM",
275 "signal.SIGCONT",
276 "terminal.raw-mode",
278 "terminal.read-char",
279 "terminal.read-char?",
280 "terminal.width",
281 "terminal.height",
282 "terminal.flush",
283 "f.add",
285 "f.subtract",
286 "f.multiply",
287 "f.divide",
288 "f.+",
290 "f.-",
291 "f.*",
292 "f./",
293 "f.=",
295 "f.<",
296 "f.>",
297 "f.<=",
298 "f.>=",
299 "f.<>",
300 "f.eq",
302 "f.lt",
303 "f.gt",
304 "f.lte",
305 "f.gte",
306 "f.neq",
307 "int->float",
309 "float->int",
310 "float->string",
311 "string->float",
312 "test.init",
314 "test.set-name",
315 "test.finish",
316 "test.has-failures",
317 "test.assert",
318 "test.assert-not",
319 "test.assert-eq",
320 "test.assert-eq-str",
321 "test.fail",
322 "test.pass-count",
323 "test.fail-count",
324 "time.now",
326 "time.nanos",
327 "time.sleep-ms",
328 "son.dump",
330 "son.dump-pretty",
331 "stack.dump",
333 "regex.match?",
335 "regex.find",
336 "regex.find-all",
337 "regex.replace",
338 "regex.replace-all",
339 "regex.captures",
340 "regex.split",
341 "regex.valid?",
342 "compress.gzip",
344 "compress.gzip-level",
345 "compress.gunzip",
346 "compress.zstd",
347 "compress.zstd-level",
348 "compress.unzstd",
349 ];
350
351 for word in &self.words {
352 self.validate_statements(&word.body, &word.name, &builtins, external_words)?;
353 }
354
355 Ok(())
356 }
357
358 fn validate_statements(
360 &self,
361 statements: &[Statement],
362 word_name: &str,
363 builtins: &[&str],
364 external_words: &[&str],
365 ) -> Result<(), String> {
366 for statement in statements {
367 match statement {
368 Statement::WordCall { name, .. } => {
369 if builtins.contains(&name.as_str()) {
371 continue;
372 }
373 if self.find_word(name).is_some() {
375 continue;
376 }
377 if external_words.contains(&name.as_str()) {
379 continue;
380 }
381 return Err(format!(
383 "Undefined word '{}' called in word '{}'. \
384 Did you forget to define it or misspell a built-in?",
385 name, word_name
386 ));
387 }
388 Statement::If {
389 then_branch,
390 else_branch,
391 span: _,
392 } => {
393 self.validate_statements(then_branch, word_name, builtins, external_words)?;
395 if let Some(eb) = else_branch {
396 self.validate_statements(eb, word_name, builtins, external_words)?;
397 }
398 }
399 Statement::Quotation { body, .. } => {
400 self.validate_statements(body, word_name, builtins, external_words)?;
402 }
403 Statement::Match { arms, span: _ } => {
404 for arm in arms {
406 self.validate_statements(&arm.body, word_name, builtins, external_words)?;
407 }
408 }
409 _ => {} }
411 }
412 Ok(())
413 }
414
415 pub const MAX_VARIANT_FIELDS: usize = 12;
419
420 pub fn generate_constructors(&mut self) -> Result<(), String> {
433 let mut new_words = Vec::new();
434
435 for union_def in &self.unions {
436 for variant in &union_def.variants {
437 let field_count = variant.fields.len();
438
439 if field_count > Self::MAX_VARIANT_FIELDS {
441 return Err(format!(
442 "Variant '{}' in union '{}' has {} fields, but the maximum is {}. \
443 Consider grouping fields into nested union types.",
444 variant.name,
445 union_def.name,
446 field_count,
447 Self::MAX_VARIANT_FIELDS
448 ));
449 }
450
451 let constructor_name = format!("Make-{}", variant.name);
453 let mut input_stack = StackType::RowVar("a".to_string());
454 for field in &variant.fields {
455 let field_type = parse_type_name(&field.type_name);
456 input_stack = input_stack.push(field_type);
457 }
458 let output_stack =
459 StackType::RowVar("a".to_string()).push(Type::Union(union_def.name.clone()));
460 let effect = Effect::new(input_stack, output_stack);
461 let body = vec![
462 Statement::Symbol(variant.name.clone()),
463 Statement::WordCall {
464 name: format!("variant.make-{}", field_count),
465 span: None,
466 },
467 ];
468 new_words.push(WordDef {
469 name: constructor_name,
470 effect: Some(effect),
471 body,
472 source: variant.source.clone(),
473 allowed_lints: vec![],
474 });
475
476 let predicate_name = format!("is-{}?", variant.name);
480 let predicate_input =
481 StackType::RowVar("a".to_string()).push(Type::Union(union_def.name.clone()));
482 let predicate_output = StackType::RowVar("a".to_string()).push(Type::Bool);
483 let predicate_effect = Effect::new(predicate_input, predicate_output);
484 let predicate_body = vec![
485 Statement::WordCall {
486 name: "variant.tag".to_string(),
487 span: None,
488 },
489 Statement::Symbol(variant.name.clone()),
490 Statement::WordCall {
491 name: "symbol.=".to_string(),
492 span: None,
493 },
494 ];
495 new_words.push(WordDef {
496 name: predicate_name,
497 effect: Some(predicate_effect),
498 body: predicate_body,
499 source: variant.source.clone(),
500 allowed_lints: vec![],
501 });
502
503 for (index, field) in variant.fields.iter().enumerate() {
507 let accessor_name = format!("{}-{}", variant.name, field.name);
508 let field_type = parse_type_name(&field.type_name);
509 let accessor_input = StackType::RowVar("a".to_string())
510 .push(Type::Union(union_def.name.clone()));
511 let accessor_output = StackType::RowVar("a".to_string()).push(field_type);
512 let accessor_effect = Effect::new(accessor_input, accessor_output);
513 let accessor_body = vec![
514 Statement::IntLiteral(index as i64),
515 Statement::WordCall {
516 name: "variant.field-at".to_string(),
517 span: None,
518 },
519 ];
520 new_words.push(WordDef {
521 name: accessor_name,
522 effect: Some(accessor_effect),
523 body: accessor_body,
524 source: variant.source.clone(), allowed_lints: vec![],
526 });
527 }
528 }
529 }
530
531 self.words.extend(new_words);
532 Ok(())
533 }
534
535 pub fn fixup_union_types(&mut self) {
544 let union_names: std::collections::HashSet<String> =
546 self.unions.iter().map(|u| u.name.clone()).collect();
547
548 for word in &mut self.words {
550 if let Some(ref mut effect) = word.effect {
551 Self::fixup_stack_type(&mut effect.inputs, &union_names);
552 Self::fixup_stack_type(&mut effect.outputs, &union_names);
553 }
554 }
555 }
556
557 fn fixup_stack_type(stack: &mut StackType, union_names: &std::collections::HashSet<String>) {
559 match stack {
560 StackType::Empty | StackType::RowVar(_) => {}
561 StackType::Cons { rest, top } => {
562 Self::fixup_type(top, union_names);
563 Self::fixup_stack_type(rest, union_names);
564 }
565 }
566 }
567
568 fn fixup_type(ty: &mut Type, union_names: &std::collections::HashSet<String>) {
570 match ty {
571 Type::Var(name) if union_names.contains(name) => {
572 *ty = Type::Union(name.clone());
573 }
574 Type::Quotation(effect) => {
575 Self::fixup_stack_type(&mut effect.inputs, union_names);
576 Self::fixup_stack_type(&mut effect.outputs, union_names);
577 }
578 Type::Closure { effect, captures } => {
579 Self::fixup_stack_type(&mut effect.inputs, union_names);
580 Self::fixup_stack_type(&mut effect.outputs, union_names);
581 for cap in captures {
582 Self::fixup_type(cap, union_names);
583 }
584 }
585 _ => {}
586 }
587 }
588}
589
590fn parse_type_name(name: &str) -> Type {
593 match name {
594 "Int" => Type::Int,
595 "Float" => Type::Float,
596 "Bool" => Type::Bool,
597 "String" => Type::String,
598 "Channel" => Type::Channel,
599 other => Type::Union(other.to_string()),
600 }
601}
602
603impl Default for Program {
604 fn default() -> Self {
605 Self::new()
606 }
607}