1use crate::types::{Effect, StackType, Type};
7use std::path::PathBuf;
8
9#[derive(Debug, Clone, PartialEq)]
11pub struct SourceLocation {
12 pub file: PathBuf,
13 pub start_line: usize,
15 pub end_line: usize,
17}
18
19impl SourceLocation {
20 pub fn new(file: PathBuf, line: usize) -> Self {
22 SourceLocation {
23 file,
24 start_line: line,
25 end_line: line,
26 }
27 }
28
29 pub fn span(file: PathBuf, start_line: usize, end_line: usize) -> Self {
31 debug_assert!(
32 start_line <= end_line,
33 "SourceLocation: start_line ({}) must be <= end_line ({})",
34 start_line,
35 end_line
36 );
37 SourceLocation {
38 file,
39 start_line,
40 end_line,
41 }
42 }
43
44 pub fn line(&self) -> usize {
46 self.start_line
47 }
48}
49
50impl std::fmt::Display for SourceLocation {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 if self.start_line == self.end_line {
53 write!(f, "{}:{}", self.file.display(), self.start_line + 1)
54 } else {
55 write!(
56 f,
57 "{}:{}-{}",
58 self.file.display(),
59 self.start_line + 1,
60 self.end_line + 1
61 )
62 }
63 }
64}
65
66#[derive(Debug, Clone, PartialEq)]
68pub enum Include {
69 Std(String),
71 Relative(String),
73 Ffi(String),
75}
76
77#[derive(Debug, Clone, PartialEq)]
84pub struct UnionField {
85 pub name: String,
86 pub type_name: String, }
88
89#[derive(Debug, Clone, PartialEq)]
92pub struct UnionVariant {
93 pub name: String,
94 pub fields: Vec<UnionField>,
95 pub source: Option<SourceLocation>,
96}
97
98#[derive(Debug, Clone, PartialEq)]
108pub struct UnionDef {
109 pub name: String,
110 pub variants: Vec<UnionVariant>,
111 pub source: Option<SourceLocation>,
112}
113
114#[derive(Debug, Clone, PartialEq)]
118pub enum Pattern {
119 Variant(String),
122
123 VariantWithBindings { name: String, bindings: Vec<String> },
126}
127
128#[derive(Debug, Clone, PartialEq)]
130pub struct MatchArm {
131 pub pattern: Pattern,
132 pub body: Vec<Statement>,
133}
134
135#[derive(Debug, Clone, PartialEq)]
136pub struct Program {
137 pub includes: Vec<Include>,
138 pub unions: Vec<UnionDef>,
139 pub words: Vec<WordDef>,
140}
141
142#[derive(Debug, Clone, PartialEq)]
143pub struct WordDef {
144 pub name: String,
145 pub effect: Option<Effect>,
148 pub body: Vec<Statement>,
149 pub source: Option<SourceLocation>,
151 pub allowed_lints: Vec<String>,
154}
155
156#[derive(Debug, Clone, PartialEq, Default)]
158pub struct Span {
159 pub line: usize,
161 pub column: usize,
163 pub length: usize,
165}
166
167impl Span {
168 pub fn new(line: usize, column: usize, length: usize) -> Self {
169 Span {
170 line,
171 column,
172 length,
173 }
174 }
175}
176
177#[derive(Debug, Clone, PartialEq, Default)]
179pub struct QuotationSpan {
180 pub start_line: usize,
182 pub start_column: usize,
184 pub end_line: usize,
186 pub end_column: usize,
188}
189
190impl QuotationSpan {
191 pub fn new(start_line: usize, start_column: usize, end_line: usize, end_column: usize) -> Self {
192 QuotationSpan {
193 start_line,
194 start_column,
195 end_line,
196 end_column,
197 }
198 }
199
200 pub fn contains(&self, line: usize, column: usize) -> bool {
202 if line < self.start_line || line > self.end_line {
203 return false;
204 }
205 if line == self.start_line && column < self.start_column {
206 return false;
207 }
208 if line == self.end_line && column >= self.end_column {
209 return false;
210 }
211 true
212 }
213}
214
215#[derive(Debug, Clone, PartialEq)]
216pub enum Statement {
217 IntLiteral(i64),
219
220 FloatLiteral(f64),
222
223 BoolLiteral(bool),
225
226 StringLiteral(String),
228
229 Symbol(String),
234
235 WordCall { name: String, span: Option<Span> },
238
239 If {
244 then_branch: Vec<Statement>,
246 else_branch: Option<Vec<Statement>>,
248 },
249
250 Quotation {
260 id: usize,
261 body: Vec<Statement>,
262 span: Option<QuotationSpan>,
263 },
264
265 Match {
279 arms: Vec<MatchArm>,
281 },
282}
283
284impl Program {
285 pub fn new() -> Self {
286 Program {
287 includes: Vec::new(),
288 unions: Vec::new(),
289 words: Vec::new(),
290 }
291 }
292
293 pub fn find_union(&self, name: &str) -> Option<&UnionDef> {
295 self.unions.iter().find(|u| u.name == name)
296 }
297
298 pub fn find_word(&self, name: &str) -> Option<&WordDef> {
299 self.words.iter().find(|w| w.name == name)
300 }
301
302 pub fn validate_word_calls(&self) -> Result<(), String> {
304 self.validate_word_calls_with_externals(&[])
305 }
306
307 pub fn validate_word_calls_with_externals(
312 &self,
313 external_words: &[&str],
314 ) -> Result<(), String> {
315 let builtins = [
318 "io.write",
320 "io.write-line",
321 "io.read-line",
322 "io.read-line+",
323 "io.read-n",
324 "int->string",
325 "symbol->string",
326 "string->symbol",
327 "args.count",
329 "args.at",
330 "file.slurp",
332 "file.exists?",
333 "file.for-each-line+",
334 "string.concat",
336 "string.length",
337 "string.byte-length",
338 "string.char-at",
339 "string.substring",
340 "char->string",
341 "string.find",
342 "string.split",
343 "string.contains",
344 "string.starts-with",
345 "string.empty?",
346 "string.trim",
347 "string.chomp",
348 "string.to-upper",
349 "string.to-lower",
350 "string.equal?",
351 "string.json-escape",
352 "string->int",
353 "symbol.=",
355 "encoding.base64-encode",
357 "encoding.base64-decode",
358 "encoding.base64url-encode",
359 "encoding.base64url-decode",
360 "encoding.hex-encode",
361 "encoding.hex-decode",
362 "crypto.sha256",
364 "crypto.hmac-sha256",
365 "crypto.constant-time-eq",
366 "crypto.random-bytes",
367 "crypto.random-int",
368 "crypto.uuid4",
369 "crypto.aes-gcm-encrypt",
370 "crypto.aes-gcm-decrypt",
371 "crypto.pbkdf2-sha256",
372 "crypto.ed25519-keypair",
373 "crypto.ed25519-sign",
374 "crypto.ed25519-verify",
375 "http.get",
377 "http.post",
378 "http.put",
379 "http.delete",
380 "list.make",
382 "list.push",
383 "list.get",
384 "list.set",
385 "list.map",
386 "list.filter",
387 "list.fold",
388 "list.each",
389 "list.length",
390 "list.empty?",
391 "map.make",
393 "map.get",
394 "map.set",
395 "map.has?",
396 "map.remove",
397 "map.keys",
398 "map.values",
399 "map.size",
400 "map.empty?",
401 "variant.field-count",
403 "variant.tag",
404 "variant.field-at",
405 "variant.append",
406 "variant.last",
407 "variant.init",
408 "variant.make-0",
409 "variant.make-1",
410 "variant.make-2",
411 "variant.make-3",
412 "variant.make-4",
413 "wrap-0",
415 "wrap-1",
416 "wrap-2",
417 "wrap-3",
418 "wrap-4",
419 "i.add",
421 "i.subtract",
422 "i.multiply",
423 "i.divide",
424 "i.modulo",
425 "i.+",
427 "i.-",
428 "i.*",
429 "i./",
430 "i.%",
431 "i.=",
433 "i.<",
434 "i.>",
435 "i.<=",
436 "i.>=",
437 "i.<>",
438 "i.eq",
440 "i.lt",
441 "i.gt",
442 "i.lte",
443 "i.gte",
444 "i.neq",
445 "dup",
447 "drop",
448 "swap",
449 "over",
450 "rot",
451 "nip",
452 "tuck",
453 "2dup",
454 "3drop",
455 "pick",
456 "roll",
457 "and",
459 "or",
460 "not",
461 "band",
463 "bor",
464 "bxor",
465 "bnot",
466 "shl",
467 "shr",
468 "popcount",
469 "clz",
470 "ctz",
471 "int-bits",
472 "chan.make",
474 "chan.send",
475 "chan.receive",
476 "chan.close",
477 "chan.yield",
478 "call",
480 "strand.spawn",
481 "strand.weave",
482 "strand.resume",
483 "strand.weave-cancel",
484 "yield",
485 "cond",
486 "tcp.listen",
488 "tcp.accept",
489 "tcp.read",
490 "tcp.write",
491 "tcp.close",
492 "os.getenv",
494 "os.home-dir",
495 "os.current-dir",
496 "os.path-exists",
497 "os.path-is-file",
498 "os.path-is-dir",
499 "os.path-join",
500 "os.path-parent",
501 "os.path-filename",
502 "os.exit",
503 "os.name",
504 "os.arch",
505 "terminal.raw-mode",
507 "terminal.read-char",
508 "terminal.read-char?",
509 "terminal.width",
510 "terminal.height",
511 "terminal.flush",
512 "f.add",
514 "f.subtract",
515 "f.multiply",
516 "f.divide",
517 "f.+",
519 "f.-",
520 "f.*",
521 "f./",
522 "f.=",
524 "f.<",
525 "f.>",
526 "f.<=",
527 "f.>=",
528 "f.<>",
529 "f.eq",
531 "f.lt",
532 "f.gt",
533 "f.lte",
534 "f.gte",
535 "f.neq",
536 "int->float",
538 "float->int",
539 "float->string",
540 "string->float",
541 "test.init",
543 "test.finish",
544 "test.has-failures",
545 "test.assert",
546 "test.assert-not",
547 "test.assert-eq",
548 "test.assert-eq-str",
549 "test.fail",
550 "test.pass-count",
551 "test.fail-count",
552 "time.now",
554 "time.nanos",
555 "time.sleep-ms",
556 "son.dump",
558 "son.dump-pretty",
559 "stack.dump",
561 "regex.match?",
563 "regex.find",
564 "regex.find-all",
565 "regex.replace",
566 "regex.replace-all",
567 "regex.captures",
568 "regex.split",
569 "regex.valid?",
570 "compress.gzip",
572 "compress.gzip-level",
573 "compress.gunzip",
574 "compress.zstd",
575 "compress.zstd-level",
576 "compress.unzstd",
577 ];
578
579 for word in &self.words {
580 self.validate_statements(&word.body, &word.name, &builtins, external_words)?;
581 }
582
583 Ok(())
584 }
585
586 fn validate_statements(
588 &self,
589 statements: &[Statement],
590 word_name: &str,
591 builtins: &[&str],
592 external_words: &[&str],
593 ) -> Result<(), String> {
594 for statement in statements {
595 match statement {
596 Statement::WordCall { name, .. } => {
597 if builtins.contains(&name.as_str()) {
599 continue;
600 }
601 if self.find_word(name).is_some() {
603 continue;
604 }
605 if external_words.contains(&name.as_str()) {
607 continue;
608 }
609 return Err(format!(
611 "Undefined word '{}' called in word '{}'. \
612 Did you forget to define it or misspell a built-in?",
613 name, word_name
614 ));
615 }
616 Statement::If {
617 then_branch,
618 else_branch,
619 } => {
620 self.validate_statements(then_branch, word_name, builtins, external_words)?;
622 if let Some(eb) = else_branch {
623 self.validate_statements(eb, word_name, builtins, external_words)?;
624 }
625 }
626 Statement::Quotation { body, .. } => {
627 self.validate_statements(body, word_name, builtins, external_words)?;
629 }
630 Statement::Match { arms } => {
631 for arm in arms {
633 self.validate_statements(&arm.body, word_name, builtins, external_words)?;
634 }
635 }
636 _ => {} }
638 }
639 Ok(())
640 }
641
642 pub const MAX_VARIANT_FIELDS: usize = 4;
646
647 pub fn generate_constructors(&mut self) -> Result<(), String> {
657 let mut new_words = Vec::new();
658
659 for union_def in &self.unions {
660 for variant in &union_def.variants {
661 let constructor_name = format!("Make-{}", variant.name);
662 let field_count = variant.fields.len();
663
664 if field_count > Self::MAX_VARIANT_FIELDS {
666 return Err(format!(
667 "Variant '{}' in union '{}' has {} fields, but the maximum is {}. \
668 Consider grouping fields into nested union types.",
669 variant.name,
670 union_def.name,
671 field_count,
672 Self::MAX_VARIANT_FIELDS
673 ));
674 }
675
676 let mut input_stack = StackType::RowVar("a".to_string());
679 for field in &variant.fields {
680 let field_type = parse_type_name(&field.type_name);
681 input_stack = input_stack.push(field_type);
682 }
683
684 let output_stack =
686 StackType::RowVar("a".to_string()).push(Type::Union(union_def.name.clone()));
687
688 let effect = Effect::new(input_stack, output_stack);
689
690 let body = vec![
694 Statement::Symbol(variant.name.clone()),
695 Statement::WordCall {
696 name: format!("variant.make-{}", field_count),
697 span: None, },
699 ];
700
701 new_words.push(WordDef {
702 name: constructor_name,
703 effect: Some(effect),
704 body,
705 source: variant.source.clone(),
706 allowed_lints: vec![],
707 });
708 }
709 }
710
711 self.words.extend(new_words);
712 Ok(())
713 }
714}
715
716fn parse_type_name(name: &str) -> Type {
719 match name {
720 "Int" => Type::Int,
721 "Float" => Type::Float,
722 "Bool" => Type::Bool,
723 "String" => Type::String,
724 "Channel" => Type::Channel,
725 other => Type::Union(other.to_string()),
726 }
727}
728
729impl Default for Program {
730 fn default() -> Self {
731 Self::new()
732 }
733}
734
735#[cfg(test)]
736mod tests {
737 use super::*;
738
739 #[test]
740 fn test_validate_builtin_words() {
741 let program = Program {
742 includes: vec![],
743 unions: vec![],
744 words: vec![WordDef {
745 name: "main".to_string(),
746 effect: None,
747 body: vec![
748 Statement::IntLiteral(2),
749 Statement::IntLiteral(3),
750 Statement::WordCall {
751 name: "i.add".to_string(),
752 span: None,
753 },
754 Statement::WordCall {
755 name: "io.write-line".to_string(),
756 span: None,
757 },
758 ],
759 source: None,
760 allowed_lints: vec![],
761 }],
762 };
763
764 assert!(program.validate_word_calls().is_ok());
766 }
767
768 #[test]
769 fn test_validate_user_defined_words() {
770 let program = Program {
771 includes: vec![],
772 unions: vec![],
773 words: vec![
774 WordDef {
775 name: "helper".to_string(),
776 effect: None,
777 body: vec![Statement::IntLiteral(42)],
778 source: None,
779 allowed_lints: vec![],
780 },
781 WordDef {
782 name: "main".to_string(),
783 effect: None,
784 body: vec![Statement::WordCall {
785 name: "helper".to_string(),
786 span: None,
787 }],
788 source: None,
789 allowed_lints: vec![],
790 },
791 ],
792 };
793
794 assert!(program.validate_word_calls().is_ok());
796 }
797
798 #[test]
799 fn test_validate_undefined_word() {
800 let program = Program {
801 includes: vec![],
802 unions: vec![],
803 words: vec![WordDef {
804 name: "main".to_string(),
805 effect: None,
806 body: vec![Statement::WordCall {
807 name: "undefined_word".to_string(),
808 span: None,
809 }],
810 source: None,
811 allowed_lints: vec![],
812 }],
813 };
814
815 let result = program.validate_word_calls();
817 assert!(result.is_err());
818 let error = result.unwrap_err();
819 assert!(error.contains("undefined_word"));
820 assert!(error.contains("main"));
821 }
822
823 #[test]
824 fn test_validate_misspelled_builtin() {
825 let program = Program {
826 includes: vec![],
827 unions: vec![],
828 words: vec![WordDef {
829 name: "main".to_string(),
830 effect: None,
831 body: vec![Statement::WordCall {
832 name: "wrte_line".to_string(),
833 span: None,
834 }], source: None,
836 allowed_lints: vec![],
837 }],
838 };
839
840 let result = program.validate_word_calls();
842 assert!(result.is_err());
843 let error = result.unwrap_err();
844 assert!(error.contains("wrte_line"));
845 assert!(error.contains("misspell"));
846 }
847
848 #[test]
849 fn test_generate_constructors() {
850 let mut program = Program {
851 includes: vec![],
852 unions: vec![UnionDef {
853 name: "Message".to_string(),
854 variants: vec![
855 UnionVariant {
856 name: "Get".to_string(),
857 fields: vec![UnionField {
858 name: "response-chan".to_string(),
859 type_name: "Int".to_string(),
860 }],
861 source: None,
862 },
863 UnionVariant {
864 name: "Put".to_string(),
865 fields: vec![
866 UnionField {
867 name: "value".to_string(),
868 type_name: "String".to_string(),
869 },
870 UnionField {
871 name: "response-chan".to_string(),
872 type_name: "Int".to_string(),
873 },
874 ],
875 source: None,
876 },
877 ],
878 source: None,
879 }],
880 words: vec![],
881 };
882
883 program.generate_constructors().unwrap();
885
886 assert_eq!(program.words.len(), 2);
888
889 let make_get = program
891 .find_word("Make-Get")
892 .expect("Make-Get should exist");
893 assert_eq!(make_get.name, "Make-Get");
894 assert!(make_get.effect.is_some());
895 let effect = make_get.effect.as_ref().unwrap();
896 assert_eq!(
899 format!("{:?}", effect.outputs),
900 "Cons { rest: RowVar(\"a\"), top: Union(\"Message\") }"
901 );
902
903 let make_put = program
905 .find_word("Make-Put")
906 .expect("Make-Put should exist");
907 assert_eq!(make_put.name, "Make-Put");
908 assert!(make_put.effect.is_some());
909
910 assert_eq!(make_get.body.len(), 2);
913 match &make_get.body[0] {
914 Statement::Symbol(s) if s == "Get" => {}
915 other => panic!("Expected Symbol(\"Get\") for variant tag, got {:?}", other),
916 }
917 match &make_get.body[1] {
918 Statement::WordCall { name, span: None } if name == "variant.make-1" => {}
919 _ => panic!("Expected WordCall(variant.make-1)"),
920 }
921
922 assert_eq!(make_put.body.len(), 2);
924 match &make_put.body[0] {
925 Statement::Symbol(s) if s == "Put" => {}
926 other => panic!("Expected Symbol(\"Put\") for variant tag, got {:?}", other),
927 }
928 match &make_put.body[1] {
929 Statement::WordCall { name, span: None } if name == "variant.make-2" => {}
930 _ => panic!("Expected WordCall(variant.make-2)"),
931 }
932 }
933}