Skip to main content

seqc/ast/
program.rs

1//! Program-level AST methods: word-call validation, auto-generated variant
2//! constructors (`Make-Variant`), and type fix-up for union types declared
3//! in stack effects.
4
5use crate::types::{Effect, StackType, Type};
6
7use super::{Program, Statement, WordDef};
8
9impl Program {
10    pub fn new() -> Self {
11        Program {
12            includes: Vec::new(),
13            unions: Vec::new(),
14            words: Vec::new(),
15        }
16    }
17
18    pub fn find_word(&self, name: &str) -> Option<&WordDef> {
19        self.words.iter().find(|w| w.name == name)
20    }
21
22    /// Validate that all word calls reference either a defined word or a built-in
23    pub fn validate_word_calls(&self) -> Result<(), String> {
24        self.validate_word_calls_with_externals(&[])
25    }
26
27    /// Validate that all word calls reference a defined word, built-in, or external word.
28    ///
29    /// The `external_words` parameter should contain names of words available from
30    /// external sources (e.g., included modules) that should be considered valid.
31    pub fn validate_word_calls_with_externals(
32        &self,
33        external_words: &[&str],
34    ) -> Result<(), String> {
35        // List of known runtime built-ins
36        // IMPORTANT: Keep this in sync with codegen.rs WordCall matching
37        let builtins = [
38            // I/O operations
39            "io.write",
40            "io.write-line",
41            "io.read-line",
42            "io.read-line+",
43            "io.read-n",
44            "int->string",
45            "symbol->string",
46            "string->symbol",
47            // Command-line arguments
48            "args.count",
49            "args.at",
50            // File operations
51            "file.slurp",
52            "file.exists?",
53            "file.for-each-line+",
54            "file.spit",
55            "file.append",
56            "file.delete",
57            "file.size",
58            // Directory operations
59            "dir.exists?",
60            "dir.make",
61            "dir.delete",
62            "dir.list",
63            // String operations
64            "string.concat",
65            "string.length",
66            "string.byte-length",
67            "string.char-at",
68            "string.substring",
69            "char->string",
70            "string.find",
71            "string.split",
72            "string.contains",
73            "string.starts-with",
74            "string.empty?",
75            "string.trim",
76            "string.chomp",
77            "string.to-upper",
78            "string.to-lower",
79            "string.equal?",
80            "string.join",
81            "string.json-escape",
82            "string->int",
83            // Symbol operations
84            "symbol.=",
85            // Encoding operations
86            "encoding.base64-encode",
87            "encoding.base64-decode",
88            "encoding.base64url-encode",
89            "encoding.base64url-decode",
90            "encoding.hex-encode",
91            "encoding.hex-decode",
92            // Crypto operations
93            "crypto.sha256",
94            "crypto.hmac-sha256",
95            "crypto.constant-time-eq",
96            "crypto.random-bytes",
97            "crypto.random-int",
98            "crypto.uuid4",
99            "crypto.aes-gcm-encrypt",
100            "crypto.aes-gcm-decrypt",
101            "crypto.pbkdf2-sha256",
102            "crypto.ed25519-keypair",
103            "crypto.ed25519-sign",
104            "crypto.ed25519-verify",
105            // HTTP client operations
106            "http.get",
107            "http.post",
108            "http.put",
109            "http.delete",
110            // List operations
111            "list.make",
112            "list.push",
113            "list.get",
114            "list.set",
115            "list.map",
116            "list.filter",
117            "list.fold",
118            "list.each",
119            "list.length",
120            "list.empty?",
121            "list.reverse",
122            "list.first",
123            "list.last",
124            // Map operations
125            "map.make",
126            "map.get",
127            "map.set",
128            "map.has?",
129            "map.remove",
130            "map.keys",
131            "map.values",
132            "map.size",
133            "map.empty?",
134            "map.each",
135            "map.fold",
136            // Variant operations
137            "variant.field-count",
138            "variant.tag",
139            "variant.field-at",
140            "variant.append",
141            "variant.first",
142            "variant.last",
143            "variant.init",
144            "variant.make-0",
145            "variant.make-1",
146            "variant.make-2",
147            "variant.make-3",
148            "variant.make-4",
149            // SON wrap aliases
150            "wrap-0",
151            "wrap-1",
152            "wrap-2",
153            "wrap-3",
154            "wrap-4",
155            // Integer arithmetic operations
156            "i.add",
157            "i.subtract",
158            "i.multiply",
159            "i.divide",
160            "i.modulo",
161            // Terse integer arithmetic
162            "i.+",
163            "i.-",
164            "i.*",
165            "i./",
166            "i.%",
167            // Integer comparison operations (return 0 or 1)
168            "i.=",
169            "i.<",
170            "i.>",
171            "i.<=",
172            "i.>=",
173            "i.<>",
174            // Integer comparison operations (verbose form)
175            "i.eq",
176            "i.lt",
177            "i.gt",
178            "i.lte",
179            "i.gte",
180            "i.neq",
181            // Stack operations (simple - no parameters)
182            "dup",
183            "drop",
184            "swap",
185            "over",
186            "rot",
187            "nip",
188            "tuck",
189            "2dup",
190            "3drop",
191            "pick",
192            "roll",
193            // Aux stack operations
194            ">aux",
195            "aux>",
196            // Boolean operations
197            "and",
198            "or",
199            "not",
200            // Bitwise operations
201            "band",
202            "bor",
203            "bxor",
204            "bnot",
205            "i.neg",
206            "negate",
207            // Arithmetic sugar (resolved to concrete ops by typechecker)
208            "+",
209            "-",
210            "*",
211            "/",
212            "%",
213            "=",
214            "<",
215            ">",
216            "<=",
217            ">=",
218            "<>",
219            "shl",
220            "shr",
221            "popcount",
222            "clz",
223            "ctz",
224            "int-bits",
225            // Channel operations
226            "chan.make",
227            "chan.send",
228            "chan.receive",
229            "chan.close",
230            "chan.yield",
231            // Quotation operations
232            "call",
233            // Dataflow combinators
234            "dip",
235            "keep",
236            "bi",
237            "if",
238            "strand.spawn",
239            "strand.weave",
240            "strand.resume",
241            "strand.weave-cancel",
242            "yield",
243            "cond",
244            // TCP operations
245            "tcp.listen",
246            "tcp.accept",
247            "tcp.read",
248            "tcp.write",
249            "tcp.close",
250            // UDP operations
251            "udp.bind",
252            "udp.send-to",
253            "udp.receive-from",
254            "udp.close",
255            // OS operations
256            "os.getenv",
257            "os.home-dir",
258            "os.current-dir",
259            "os.path-exists",
260            "os.path-is-file",
261            "os.path-is-dir",
262            "os.path-join",
263            "os.path-parent",
264            "os.path-filename",
265            "os.exit",
266            "os.name",
267            "os.arch",
268            // Signal handling
269            "signal.trap",
270            "signal.received?",
271            "signal.pending?",
272            "signal.default",
273            "signal.ignore",
274            "signal.clear",
275            "signal.SIGINT",
276            "signal.SIGTERM",
277            "signal.SIGHUP",
278            "signal.SIGPIPE",
279            "signal.SIGUSR1",
280            "signal.SIGUSR2",
281            "signal.SIGCHLD",
282            "signal.SIGALRM",
283            "signal.SIGCONT",
284            // Terminal operations
285            "terminal.raw-mode",
286            "terminal.read-char",
287            "terminal.read-char?",
288            "terminal.width",
289            "terminal.height",
290            "terminal.flush",
291            // Float arithmetic operations (verbose form)
292            "f.add",
293            "f.subtract",
294            "f.multiply",
295            "f.divide",
296            // Float arithmetic operations (terse form)
297            "f.+",
298            "f.-",
299            "f.*",
300            "f./",
301            // Float comparison operations (symbol form)
302            "f.=",
303            "f.<",
304            "f.>",
305            "f.<=",
306            "f.>=",
307            "f.<>",
308            // Float comparison operations (verbose form)
309            "f.eq",
310            "f.lt",
311            "f.gt",
312            "f.lte",
313            "f.gte",
314            "f.neq",
315            // Type conversions
316            "int->float",
317            "float->int",
318            "float->string",
319            "string->float",
320            // Byte construction (binary protocol encoders)
321            "int.to-bytes-i32-be",
322            "float.to-bytes-f32-be",
323            // Test framework operations
324            "test.init",
325            "test.set-name",
326            "test.finish",
327            "test.has-failures",
328            "test.assert",
329            "test.assert-not",
330            "test.assert-eq",
331            "test.assert-eq-str",
332            "test.fail",
333            "test.pass-count",
334            "test.fail-count",
335            // Time operations
336            "time.now",
337            "time.nanos",
338            "time.sleep-ms",
339            // SON serialization
340            "son.dump",
341            "son.dump-pretty",
342            // Stack introspection (for REPL)
343            "stack.dump",
344            // Regex operations
345            "regex.match?",
346            "regex.find",
347            "regex.find-all",
348            "regex.replace",
349            "regex.replace-all",
350            "regex.captures",
351            "regex.split",
352            "regex.valid?",
353            // Compression operations
354            "compress.gzip",
355            "compress.gzip-level",
356            "compress.gunzip",
357            "compress.zstd",
358            "compress.zstd-level",
359            "compress.unzstd",
360        ];
361
362        for word in &self.words {
363            self.validate_statements(&word.body, &word.name, &builtins, external_words)?;
364        }
365
366        Ok(())
367    }
368
369    /// Helper to validate word calls in a list of statements (recursively)
370    fn validate_statements(
371        &self,
372        statements: &[Statement],
373        word_name: &str,
374        builtins: &[&str],
375        external_words: &[&str],
376    ) -> Result<(), String> {
377        for statement in statements {
378            match statement {
379                Statement::WordCall { name, .. } => {
380                    // Check if it's a built-in
381                    if builtins.contains(&name.as_str()) {
382                        continue;
383                    }
384                    // Check if it's a user-defined word
385                    if self.find_word(name).is_some() {
386                        continue;
387                    }
388                    // Check if it's an external word (from includes)
389                    if external_words.contains(&name.as_str()) {
390                        continue;
391                    }
392                    // Undefined word!
393                    return Err(format!(
394                        "Undefined word '{}' called in word '{}'. \
395                         Did you forget to define it or misspell a built-in?",
396                        name, word_name
397                    ));
398                }
399                Statement::If {
400                    then_branch,
401                    else_branch,
402                    span: _,
403                } => {
404                    // Recursively validate both branches
405                    self.validate_statements(then_branch, word_name, builtins, external_words)?;
406                    if let Some(eb) = else_branch {
407                        self.validate_statements(eb, word_name, builtins, external_words)?;
408                    }
409                }
410                Statement::Quotation { body, .. } => {
411                    // Recursively validate quotation body
412                    self.validate_statements(body, word_name, builtins, external_words)?;
413                }
414                Statement::Match { arms, span: _ } => {
415                    // Recursively validate each match arm's body
416                    for arm in arms {
417                        self.validate_statements(&arm.body, word_name, builtins, external_words)?;
418                    }
419                }
420                _ => {} // Literals don't need validation
421            }
422        }
423        Ok(())
424    }
425
426    /// Generate constructor words for all union definitions
427    ///
428    /// Maximum number of fields a variant can have (limited by runtime support)
429    pub const MAX_VARIANT_FIELDS: usize = 12;
430
431    /// Generate helper words for union types:
432    /// 1. Constructors: `Make-VariantName` - creates variant instances
433    /// 2. Predicates: `is-VariantName?` - tests if value is a specific variant
434    /// 3. Accessors: `VariantName-fieldname` - extracts field values (RFC #345)
435    ///
436    /// Example: For `union Message { Get { chan: Int } }`
437    /// Generates:
438    ///   `: Make-Get ( Int -- Message ) :Get variant.make-1 ;`
439    ///   `: is-Get? ( Message -- Bool ) variant.tag :Get symbol.= ;`
440    ///   `: Get-chan ( Message -- Int ) 0 variant.field-at ;`
441    ///
442    /// Returns an error if any variant exceeds the maximum field count.
443    pub fn generate_constructors(&mut self) -> Result<(), String> {
444        let mut new_words = Vec::new();
445
446        for union_def in &self.unions {
447            for variant in &union_def.variants {
448                let field_count = variant.fields.len();
449
450                // Check field count limit before generating constructor
451                if field_count > Self::MAX_VARIANT_FIELDS {
452                    return Err(format!(
453                        "Variant '{}' in union '{}' has {} fields, but the maximum is {}. \
454                         Consider grouping fields into nested union types.",
455                        variant.name,
456                        union_def.name,
457                        field_count,
458                        Self::MAX_VARIANT_FIELDS
459                    ));
460                }
461
462                // 1. Generate constructor: Make-VariantName
463                let constructor_name = format!("Make-{}", variant.name);
464                let mut input_stack = StackType::RowVar("a".to_string());
465                for field in &variant.fields {
466                    let field_type = parse_type_name(&field.type_name);
467                    input_stack = input_stack.push(field_type);
468                }
469                let output_stack =
470                    StackType::RowVar("a".to_string()).push(Type::Union(union_def.name.clone()));
471                let effect = Effect::new(input_stack, output_stack);
472                let body = vec![
473                    Statement::Symbol(variant.name.clone()),
474                    Statement::WordCall {
475                        name: format!("variant.make-{}", field_count),
476                        span: None,
477                    },
478                ];
479                new_words.push(WordDef {
480                    name: constructor_name,
481                    effect: Some(effect),
482                    body,
483                    source: variant.source.clone(),
484                    allowed_lints: vec![],
485                });
486
487                // 2. Generate predicate: is-VariantName?
488                // Effect: ( UnionType -- Bool )
489                // Body: variant.tag :VariantName symbol.=
490                let predicate_name = format!("is-{}?", variant.name);
491                let predicate_input =
492                    StackType::RowVar("a".to_string()).push(Type::Union(union_def.name.clone()));
493                let predicate_output = StackType::RowVar("a".to_string()).push(Type::Bool);
494                let predicate_effect = Effect::new(predicate_input, predicate_output);
495                let predicate_body = vec![
496                    Statement::WordCall {
497                        name: "variant.tag".to_string(),
498                        span: None,
499                    },
500                    Statement::Symbol(variant.name.clone()),
501                    Statement::WordCall {
502                        name: "symbol.=".to_string(),
503                        span: None,
504                    },
505                ];
506                new_words.push(WordDef {
507                    name: predicate_name,
508                    effect: Some(predicate_effect),
509                    body: predicate_body,
510                    source: variant.source.clone(),
511                    allowed_lints: vec![],
512                });
513
514                // 3. Generate field accessors: VariantName-fieldname
515                // Effect: ( UnionType -- FieldType )
516                // Body: N variant.field-at
517                for (index, field) in variant.fields.iter().enumerate() {
518                    let accessor_name = format!("{}-{}", variant.name, field.name);
519                    let field_type = parse_type_name(&field.type_name);
520                    let accessor_input = StackType::RowVar("a".to_string())
521                        .push(Type::Union(union_def.name.clone()));
522                    let accessor_output = StackType::RowVar("a".to_string()).push(field_type);
523                    let accessor_effect = Effect::new(accessor_input, accessor_output);
524                    let accessor_body = vec![
525                        Statement::IntLiteral(index as i64),
526                        Statement::WordCall {
527                            name: "variant.field-at".to_string(),
528                            span: None,
529                        },
530                    ];
531                    new_words.push(WordDef {
532                        name: accessor_name,
533                        effect: Some(accessor_effect),
534                        body: accessor_body,
535                        source: variant.source.clone(), // Use variant's source for field accessors
536                        allowed_lints: vec![],
537                    });
538                }
539            }
540        }
541
542        self.words.extend(new_words);
543        Ok(())
544    }
545
546    /// RFC #345: Fix up type variables in stack effects that should be union types
547    ///
548    /// When parsing files with includes, type variables like "Message" in
549    /// `( Message -- Int )` may be parsed as `Type::Var("Message")` if the
550    /// union definition is in an included file. After resolving includes,
551    /// we know all union names and can convert these to `Type::Union("Message")`.
552    ///
553    /// This ensures proper nominal type checking for union types across files.
554    pub fn fixup_union_types(&mut self) {
555        // Collect all union names from the program
556        let union_names: std::collections::HashSet<String> =
557            self.unions.iter().map(|u| u.name.clone()).collect();
558
559        // Fix up types in all word effects
560        for word in &mut self.words {
561            if let Some(ref mut effect) = word.effect {
562                Self::fixup_stack_type(&mut effect.inputs, &union_names);
563                Self::fixup_stack_type(&mut effect.outputs, &union_names);
564            }
565        }
566    }
567
568    /// Recursively fix up types in a stack type
569    fn fixup_stack_type(stack: &mut StackType, union_names: &std::collections::HashSet<String>) {
570        match stack {
571            StackType::Empty | StackType::RowVar(_) => {}
572            StackType::Cons { rest, top } => {
573                Self::fixup_type(top, union_names);
574                Self::fixup_stack_type(rest, union_names);
575            }
576        }
577    }
578
579    /// Fix up a single type, converting Type::Var to Type::Union if it matches a union name
580    fn fixup_type(ty: &mut Type, union_names: &std::collections::HashSet<String>) {
581        match ty {
582            Type::Var(name) if union_names.contains(name) => {
583                *ty = Type::Union(name.clone());
584            }
585            Type::Quotation(effect) => {
586                Self::fixup_stack_type(&mut effect.inputs, union_names);
587                Self::fixup_stack_type(&mut effect.outputs, union_names);
588            }
589            Type::Closure { effect, captures } => {
590                Self::fixup_stack_type(&mut effect.inputs, union_names);
591                Self::fixup_stack_type(&mut effect.outputs, union_names);
592                for cap in captures {
593                    Self::fixup_type(cap, union_names);
594                }
595            }
596            _ => {}
597        }
598    }
599}
600
601/// Parse a type name string into a Type
602/// Used by constructor generation to build stack effects
603fn parse_type_name(name: &str) -> Type {
604    match name {
605        "Int" => Type::Int,
606        "Float" => Type::Float,
607        "Bool" => Type::Bool,
608        "String" => Type::String,
609        "Channel" => Type::Channel,
610        other => Type::Union(other.to_string()),
611    }
612}
613
614impl Default for Program {
615    fn default() -> Self {
616        Self::new()
617    }
618}