1use crate::types::{Effect, StackType, Type};
6
7use super::{Program, Statement, WordDef};
8
9impl Program {
10 pub fn new() -> Self {
11 Program {
12 includes: Vec::new(),
13 unions: Vec::new(),
14 words: Vec::new(),
15 }
16 }
17
18 pub fn find_word(&self, name: &str) -> Option<&WordDef> {
19 self.words.iter().find(|w| w.name == name)
20 }
21
22 pub fn validate_word_calls(&self) -> Result<(), String> {
24 self.validate_word_calls_with_externals(&[])
25 }
26
27 pub fn validate_word_calls_with_externals(
32 &self,
33 external_words: &[&str],
34 ) -> Result<(), String> {
35 let builtins = [
38 "io.write",
40 "io.write-line",
41 "io.read-line",
42 "io.read-line+",
43 "io.read-n",
44 "int->string",
45 "symbol->string",
46 "string->symbol",
47 "args.count",
49 "args.at",
50 "file.slurp",
52 "file.exists?",
53 "file.for-each-line+",
54 "file.spit",
55 "file.append",
56 "file.delete",
57 "file.size",
58 "dir.exists?",
60 "dir.make",
61 "dir.delete",
62 "dir.list",
63 "string.concat",
65 "string.length",
66 "string.byte-length",
67 "string.char-at",
68 "string.substring",
69 "char->string",
70 "string.find",
71 "string.split",
72 "string.contains",
73 "string.starts-with",
74 "string.empty?",
75 "string.trim",
76 "string.chomp",
77 "string.to-upper",
78 "string.to-lower",
79 "string.equal?",
80 "string.join",
81 "string.json-escape",
82 "string->int",
83 "symbol.=",
85 "encoding.base64-encode",
87 "encoding.base64-decode",
88 "encoding.base64url-encode",
89 "encoding.base64url-decode",
90 "encoding.hex-encode",
91 "encoding.hex-decode",
92 "crypto.sha256",
94 "crypto.hmac-sha256",
95 "crypto.constant-time-eq",
96 "crypto.random-bytes",
97 "crypto.random-int",
98 "crypto.uuid4",
99 "crypto.aes-gcm-encrypt",
100 "crypto.aes-gcm-decrypt",
101 "crypto.pbkdf2-sha256",
102 "crypto.ed25519-keypair",
103 "crypto.ed25519-sign",
104 "crypto.ed25519-verify",
105 "http.get",
107 "http.post",
108 "http.put",
109 "http.delete",
110 "list.make",
112 "list.push",
113 "list.get",
114 "list.set",
115 "list.map",
116 "list.filter",
117 "list.fold",
118 "list.each",
119 "list.length",
120 "list.empty?",
121 "list.reverse",
122 "map.make",
124 "map.get",
125 "map.set",
126 "map.has?",
127 "map.remove",
128 "map.keys",
129 "map.values",
130 "map.size",
131 "map.empty?",
132 "map.each",
133 "map.fold",
134 "variant.field-count",
136 "variant.tag",
137 "variant.field-at",
138 "variant.append",
139 "variant.last",
140 "variant.init",
141 "variant.make-0",
142 "variant.make-1",
143 "variant.make-2",
144 "variant.make-3",
145 "variant.make-4",
146 "wrap-0",
148 "wrap-1",
149 "wrap-2",
150 "wrap-3",
151 "wrap-4",
152 "i.add",
154 "i.subtract",
155 "i.multiply",
156 "i.divide",
157 "i.modulo",
158 "i.+",
160 "i.-",
161 "i.*",
162 "i./",
163 "i.%",
164 "i.=",
166 "i.<",
167 "i.>",
168 "i.<=",
169 "i.>=",
170 "i.<>",
171 "i.eq",
173 "i.lt",
174 "i.gt",
175 "i.lte",
176 "i.gte",
177 "i.neq",
178 "dup",
180 "drop",
181 "swap",
182 "over",
183 "rot",
184 "nip",
185 "tuck",
186 "2dup",
187 "3drop",
188 "pick",
189 "roll",
190 ">aux",
192 "aux>",
193 "and",
195 "or",
196 "not",
197 "band",
199 "bor",
200 "bxor",
201 "bnot",
202 "i.neg",
203 "negate",
204 "+",
206 "-",
207 "*",
208 "/",
209 "%",
210 "=",
211 "<",
212 ">",
213 "<=",
214 ">=",
215 "<>",
216 "shl",
217 "shr",
218 "popcount",
219 "clz",
220 "ctz",
221 "int-bits",
222 "chan.make",
224 "chan.send",
225 "chan.receive",
226 "chan.close",
227 "chan.yield",
228 "call",
230 "dip",
232 "keep",
233 "bi",
234 "strand.spawn",
235 "strand.weave",
236 "strand.resume",
237 "strand.weave-cancel",
238 "yield",
239 "cond",
240 "tcp.listen",
242 "tcp.accept",
243 "tcp.read",
244 "tcp.write",
245 "tcp.close",
246 "os.getenv",
248 "os.home-dir",
249 "os.current-dir",
250 "os.path-exists",
251 "os.path-is-file",
252 "os.path-is-dir",
253 "os.path-join",
254 "os.path-parent",
255 "os.path-filename",
256 "os.exit",
257 "os.name",
258 "os.arch",
259 "signal.trap",
261 "signal.received?",
262 "signal.pending?",
263 "signal.default",
264 "signal.ignore",
265 "signal.clear",
266 "signal.SIGINT",
267 "signal.SIGTERM",
268 "signal.SIGHUP",
269 "signal.SIGPIPE",
270 "signal.SIGUSR1",
271 "signal.SIGUSR2",
272 "signal.SIGCHLD",
273 "signal.SIGALRM",
274 "signal.SIGCONT",
275 "terminal.raw-mode",
277 "terminal.read-char",
278 "terminal.read-char?",
279 "terminal.width",
280 "terminal.height",
281 "terminal.flush",
282 "f.add",
284 "f.subtract",
285 "f.multiply",
286 "f.divide",
287 "f.+",
289 "f.-",
290 "f.*",
291 "f./",
292 "f.=",
294 "f.<",
295 "f.>",
296 "f.<=",
297 "f.>=",
298 "f.<>",
299 "f.eq",
301 "f.lt",
302 "f.gt",
303 "f.lte",
304 "f.gte",
305 "f.neq",
306 "int->float",
308 "float->int",
309 "float->string",
310 "string->float",
311 "test.init",
313 "test.set-name",
314 "test.finish",
315 "test.has-failures",
316 "test.assert",
317 "test.assert-not",
318 "test.assert-eq",
319 "test.assert-eq-str",
320 "test.fail",
321 "test.pass-count",
322 "test.fail-count",
323 "time.now",
325 "time.nanos",
326 "time.sleep-ms",
327 "son.dump",
329 "son.dump-pretty",
330 "stack.dump",
332 "regex.match?",
334 "regex.find",
335 "regex.find-all",
336 "regex.replace",
337 "regex.replace-all",
338 "regex.captures",
339 "regex.split",
340 "regex.valid?",
341 "compress.gzip",
343 "compress.gzip-level",
344 "compress.gunzip",
345 "compress.zstd",
346 "compress.zstd-level",
347 "compress.unzstd",
348 ];
349
350 for word in &self.words {
351 self.validate_statements(&word.body, &word.name, &builtins, external_words)?;
352 }
353
354 Ok(())
355 }
356
357 fn validate_statements(
359 &self,
360 statements: &[Statement],
361 word_name: &str,
362 builtins: &[&str],
363 external_words: &[&str],
364 ) -> Result<(), String> {
365 for statement in statements {
366 match statement {
367 Statement::WordCall { name, .. } => {
368 if builtins.contains(&name.as_str()) {
370 continue;
371 }
372 if self.find_word(name).is_some() {
374 continue;
375 }
376 if external_words.contains(&name.as_str()) {
378 continue;
379 }
380 return Err(format!(
382 "Undefined word '{}' called in word '{}'. \
383 Did you forget to define it or misspell a built-in?",
384 name, word_name
385 ));
386 }
387 Statement::If {
388 then_branch,
389 else_branch,
390 span: _,
391 } => {
392 self.validate_statements(then_branch, word_name, builtins, external_words)?;
394 if let Some(eb) = else_branch {
395 self.validate_statements(eb, word_name, builtins, external_words)?;
396 }
397 }
398 Statement::Quotation { body, .. } => {
399 self.validate_statements(body, word_name, builtins, external_words)?;
401 }
402 Statement::Match { arms, span: _ } => {
403 for arm in arms {
405 self.validate_statements(&arm.body, word_name, builtins, external_words)?;
406 }
407 }
408 _ => {} }
410 }
411 Ok(())
412 }
413
414 pub const MAX_VARIANT_FIELDS: usize = 12;
418
419 pub fn generate_constructors(&mut self) -> Result<(), String> {
432 let mut new_words = Vec::new();
433
434 for union_def in &self.unions {
435 for variant in &union_def.variants {
436 let field_count = variant.fields.len();
437
438 if field_count > Self::MAX_VARIANT_FIELDS {
440 return Err(format!(
441 "Variant '{}' in union '{}' has {} fields, but the maximum is {}. \
442 Consider grouping fields into nested union types.",
443 variant.name,
444 union_def.name,
445 field_count,
446 Self::MAX_VARIANT_FIELDS
447 ));
448 }
449
450 let constructor_name = format!("Make-{}", variant.name);
452 let mut input_stack = StackType::RowVar("a".to_string());
453 for field in &variant.fields {
454 let field_type = parse_type_name(&field.type_name);
455 input_stack = input_stack.push(field_type);
456 }
457 let output_stack =
458 StackType::RowVar("a".to_string()).push(Type::Union(union_def.name.clone()));
459 let effect = Effect::new(input_stack, output_stack);
460 let body = vec![
461 Statement::Symbol(variant.name.clone()),
462 Statement::WordCall {
463 name: format!("variant.make-{}", field_count),
464 span: None,
465 },
466 ];
467 new_words.push(WordDef {
468 name: constructor_name,
469 effect: Some(effect),
470 body,
471 source: variant.source.clone(),
472 allowed_lints: vec![],
473 });
474
475 let predicate_name = format!("is-{}?", variant.name);
479 let predicate_input =
480 StackType::RowVar("a".to_string()).push(Type::Union(union_def.name.clone()));
481 let predicate_output = StackType::RowVar("a".to_string()).push(Type::Bool);
482 let predicate_effect = Effect::new(predicate_input, predicate_output);
483 let predicate_body = vec![
484 Statement::WordCall {
485 name: "variant.tag".to_string(),
486 span: None,
487 },
488 Statement::Symbol(variant.name.clone()),
489 Statement::WordCall {
490 name: "symbol.=".to_string(),
491 span: None,
492 },
493 ];
494 new_words.push(WordDef {
495 name: predicate_name,
496 effect: Some(predicate_effect),
497 body: predicate_body,
498 source: variant.source.clone(),
499 allowed_lints: vec![],
500 });
501
502 for (index, field) in variant.fields.iter().enumerate() {
506 let accessor_name = format!("{}-{}", variant.name, field.name);
507 let field_type = parse_type_name(&field.type_name);
508 let accessor_input = StackType::RowVar("a".to_string())
509 .push(Type::Union(union_def.name.clone()));
510 let accessor_output = StackType::RowVar("a".to_string()).push(field_type);
511 let accessor_effect = Effect::new(accessor_input, accessor_output);
512 let accessor_body = vec![
513 Statement::IntLiteral(index as i64),
514 Statement::WordCall {
515 name: "variant.field-at".to_string(),
516 span: None,
517 },
518 ];
519 new_words.push(WordDef {
520 name: accessor_name,
521 effect: Some(accessor_effect),
522 body: accessor_body,
523 source: variant.source.clone(), allowed_lints: vec![],
525 });
526 }
527 }
528 }
529
530 self.words.extend(new_words);
531 Ok(())
532 }
533
534 pub fn fixup_union_types(&mut self) {
543 let union_names: std::collections::HashSet<String> =
545 self.unions.iter().map(|u| u.name.clone()).collect();
546
547 for word in &mut self.words {
549 if let Some(ref mut effect) = word.effect {
550 Self::fixup_stack_type(&mut effect.inputs, &union_names);
551 Self::fixup_stack_type(&mut effect.outputs, &union_names);
552 }
553 }
554 }
555
556 fn fixup_stack_type(stack: &mut StackType, union_names: &std::collections::HashSet<String>) {
558 match stack {
559 StackType::Empty | StackType::RowVar(_) => {}
560 StackType::Cons { rest, top } => {
561 Self::fixup_type(top, union_names);
562 Self::fixup_stack_type(rest, union_names);
563 }
564 }
565 }
566
567 fn fixup_type(ty: &mut Type, union_names: &std::collections::HashSet<String>) {
569 match ty {
570 Type::Var(name) if union_names.contains(name) => {
571 *ty = Type::Union(name.clone());
572 }
573 Type::Quotation(effect) => {
574 Self::fixup_stack_type(&mut effect.inputs, union_names);
575 Self::fixup_stack_type(&mut effect.outputs, union_names);
576 }
577 Type::Closure { effect, captures } => {
578 Self::fixup_stack_type(&mut effect.inputs, union_names);
579 Self::fixup_stack_type(&mut effect.outputs, union_names);
580 for cap in captures {
581 Self::fixup_type(cap, union_names);
582 }
583 }
584 _ => {}
585 }
586 }
587}
588
589fn parse_type_name(name: &str) -> Type {
592 match name {
593 "Int" => Type::Int,
594 "Float" => Type::Float,
595 "Bool" => Type::Bool,
596 "String" => Type::String,
597 "Channel" => Type::Channel,
598 other => Type::Union(other.to_string()),
599 }
600}
601
602impl Default for Program {
603 fn default() -> Self {
604 Self::new()
605 }
606}