Skip to main content

seqc/ast/
program.rs

1//! Program-level AST methods: word-call validation, auto-generated variant
2//! constructors (`Make-Variant`), and type fix-up for union types declared
3//! in stack effects.
4
5use crate::types::{Effect, StackType, Type};
6
7use super::{Program, Statement, WordDef};
8
9impl Program {
10    pub fn new() -> Self {
11        Program {
12            includes: Vec::new(),
13            unions: Vec::new(),
14            words: Vec::new(),
15        }
16    }
17
18    pub fn find_word(&self, name: &str) -> Option<&WordDef> {
19        self.words.iter().find(|w| w.name == name)
20    }
21
22    /// Validate that all word calls reference either a defined word or a built-in
23    pub fn validate_word_calls(&self) -> Result<(), String> {
24        self.validate_word_calls_with_externals(&[])
25    }
26
27    /// Validate that all word calls reference a defined word, built-in, or external word.
28    ///
29    /// The `external_words` parameter should contain names of words available from
30    /// external sources (e.g., included modules) that should be considered valid.
31    pub fn validate_word_calls_with_externals(
32        &self,
33        external_words: &[&str],
34    ) -> Result<(), String> {
35        // List of known runtime built-ins
36        // IMPORTANT: Keep this in sync with codegen.rs WordCall matching
37        let builtins = [
38            // I/O operations
39            "io.write",
40            "io.write-line",
41            "io.read-line",
42            "io.read-line+",
43            "io.read-n",
44            "int->string",
45            "symbol->string",
46            "string->symbol",
47            // Command-line arguments
48            "args.count",
49            "args.at",
50            // File operations
51            "file.slurp",
52            "file.exists?",
53            "file.for-each-line+",
54            "file.spit",
55            "file.append",
56            "file.delete",
57            "file.size",
58            // Directory operations
59            "dir.exists?",
60            "dir.make",
61            "dir.delete",
62            "dir.list",
63            // String operations
64            "string.concat",
65            "string.length",
66            "string.byte-length",
67            "string.char-at",
68            "string.substring",
69            "char->string",
70            "string.find",
71            "string.split",
72            "string.contains",
73            "string.starts-with",
74            "string.empty?",
75            "string.trim",
76            "string.chomp",
77            "string.to-upper",
78            "string.to-lower",
79            "string.equal?",
80            "string.join",
81            "string.json-escape",
82            "string->int",
83            // Symbol operations
84            "symbol.=",
85            // Encoding operations
86            "encoding.base64-encode",
87            "encoding.base64-decode",
88            "encoding.base64url-encode",
89            "encoding.base64url-decode",
90            "encoding.hex-encode",
91            "encoding.hex-decode",
92            // Crypto operations
93            "crypto.sha256",
94            "crypto.hmac-sha256",
95            "crypto.constant-time-eq",
96            "crypto.random-bytes",
97            "crypto.random-int",
98            "crypto.uuid4",
99            "crypto.aes-gcm-encrypt",
100            "crypto.aes-gcm-decrypt",
101            "crypto.pbkdf2-sha256",
102            "crypto.ed25519-keypair",
103            "crypto.ed25519-sign",
104            "crypto.ed25519-verify",
105            // HTTP client operations
106            "http.get",
107            "http.post",
108            "http.put",
109            "http.delete",
110            // List operations
111            "list.make",
112            "list.push",
113            "list.get",
114            "list.set",
115            "list.map",
116            "list.filter",
117            "list.fold",
118            "list.each",
119            "list.length",
120            "list.empty?",
121            "list.reverse",
122            // Map operations
123            "map.make",
124            "map.get",
125            "map.set",
126            "map.has?",
127            "map.remove",
128            "map.keys",
129            "map.values",
130            "map.size",
131            "map.empty?",
132            "map.each",
133            "map.fold",
134            // Variant operations
135            "variant.field-count",
136            "variant.tag",
137            "variant.field-at",
138            "variant.append",
139            "variant.last",
140            "variant.init",
141            "variant.make-0",
142            "variant.make-1",
143            "variant.make-2",
144            "variant.make-3",
145            "variant.make-4",
146            // SON wrap aliases
147            "wrap-0",
148            "wrap-1",
149            "wrap-2",
150            "wrap-3",
151            "wrap-4",
152            // Integer arithmetic operations
153            "i.add",
154            "i.subtract",
155            "i.multiply",
156            "i.divide",
157            "i.modulo",
158            // Terse integer arithmetic
159            "i.+",
160            "i.-",
161            "i.*",
162            "i./",
163            "i.%",
164            // Integer comparison operations (return 0 or 1)
165            "i.=",
166            "i.<",
167            "i.>",
168            "i.<=",
169            "i.>=",
170            "i.<>",
171            // Integer comparison operations (verbose form)
172            "i.eq",
173            "i.lt",
174            "i.gt",
175            "i.lte",
176            "i.gte",
177            "i.neq",
178            // Stack operations (simple - no parameters)
179            "dup",
180            "drop",
181            "swap",
182            "over",
183            "rot",
184            "nip",
185            "tuck",
186            "2dup",
187            "3drop",
188            "pick",
189            "roll",
190            // Aux stack operations
191            ">aux",
192            "aux>",
193            // Boolean operations
194            "and",
195            "or",
196            "not",
197            // Bitwise operations
198            "band",
199            "bor",
200            "bxor",
201            "bnot",
202            "i.neg",
203            "negate",
204            // Arithmetic sugar (resolved to concrete ops by typechecker)
205            "+",
206            "-",
207            "*",
208            "/",
209            "%",
210            "=",
211            "<",
212            ">",
213            "<=",
214            ">=",
215            "<>",
216            "shl",
217            "shr",
218            "popcount",
219            "clz",
220            "ctz",
221            "int-bits",
222            // Channel operations
223            "chan.make",
224            "chan.send",
225            "chan.receive",
226            "chan.close",
227            "chan.yield",
228            // Quotation operations
229            "call",
230            // Dataflow combinators
231            "dip",
232            "keep",
233            "bi",
234            "strand.spawn",
235            "strand.weave",
236            "strand.resume",
237            "strand.weave-cancel",
238            "yield",
239            "cond",
240            // TCP operations
241            "tcp.listen",
242            "tcp.accept",
243            "tcp.read",
244            "tcp.write",
245            "tcp.close",
246            // OS operations
247            "os.getenv",
248            "os.home-dir",
249            "os.current-dir",
250            "os.path-exists",
251            "os.path-is-file",
252            "os.path-is-dir",
253            "os.path-join",
254            "os.path-parent",
255            "os.path-filename",
256            "os.exit",
257            "os.name",
258            "os.arch",
259            // Signal handling
260            "signal.trap",
261            "signal.received?",
262            "signal.pending?",
263            "signal.default",
264            "signal.ignore",
265            "signal.clear",
266            "signal.SIGINT",
267            "signal.SIGTERM",
268            "signal.SIGHUP",
269            "signal.SIGPIPE",
270            "signal.SIGUSR1",
271            "signal.SIGUSR2",
272            "signal.SIGCHLD",
273            "signal.SIGALRM",
274            "signal.SIGCONT",
275            // Terminal operations
276            "terminal.raw-mode",
277            "terminal.read-char",
278            "terminal.read-char?",
279            "terminal.width",
280            "terminal.height",
281            "terminal.flush",
282            // Float arithmetic operations (verbose form)
283            "f.add",
284            "f.subtract",
285            "f.multiply",
286            "f.divide",
287            // Float arithmetic operations (terse form)
288            "f.+",
289            "f.-",
290            "f.*",
291            "f./",
292            // Float comparison operations (symbol form)
293            "f.=",
294            "f.<",
295            "f.>",
296            "f.<=",
297            "f.>=",
298            "f.<>",
299            // Float comparison operations (verbose form)
300            "f.eq",
301            "f.lt",
302            "f.gt",
303            "f.lte",
304            "f.gte",
305            "f.neq",
306            // Type conversions
307            "int->float",
308            "float->int",
309            "float->string",
310            "string->float",
311            // Test framework operations
312            "test.init",
313            "test.finish",
314            "test.has-failures",
315            "test.assert",
316            "test.assert-not",
317            "test.assert-eq",
318            "test.assert-eq-str",
319            "test.fail",
320            "test.pass-count",
321            "test.fail-count",
322            // Time operations
323            "time.now",
324            "time.nanos",
325            "time.sleep-ms",
326            // SON serialization
327            "son.dump",
328            "son.dump-pretty",
329            // Stack introspection (for REPL)
330            "stack.dump",
331            // Regex operations
332            "regex.match?",
333            "regex.find",
334            "regex.find-all",
335            "regex.replace",
336            "regex.replace-all",
337            "regex.captures",
338            "regex.split",
339            "regex.valid?",
340            // Compression operations
341            "compress.gzip",
342            "compress.gzip-level",
343            "compress.gunzip",
344            "compress.zstd",
345            "compress.zstd-level",
346            "compress.unzstd",
347        ];
348
349        for word in &self.words {
350            self.validate_statements(&word.body, &word.name, &builtins, external_words)?;
351        }
352
353        Ok(())
354    }
355
356    /// Helper to validate word calls in a list of statements (recursively)
357    fn validate_statements(
358        &self,
359        statements: &[Statement],
360        word_name: &str,
361        builtins: &[&str],
362        external_words: &[&str],
363    ) -> Result<(), String> {
364        for statement in statements {
365            match statement {
366                Statement::WordCall { name, .. } => {
367                    // Check if it's a built-in
368                    if builtins.contains(&name.as_str()) {
369                        continue;
370                    }
371                    // Check if it's a user-defined word
372                    if self.find_word(name).is_some() {
373                        continue;
374                    }
375                    // Check if it's an external word (from includes)
376                    if external_words.contains(&name.as_str()) {
377                        continue;
378                    }
379                    // Undefined word!
380                    return Err(format!(
381                        "Undefined word '{}' called in word '{}'. \
382                         Did you forget to define it or misspell a built-in?",
383                        name, word_name
384                    ));
385                }
386                Statement::If {
387                    then_branch,
388                    else_branch,
389                    span: _,
390                } => {
391                    // Recursively validate both branches
392                    self.validate_statements(then_branch, word_name, builtins, external_words)?;
393                    if let Some(eb) = else_branch {
394                        self.validate_statements(eb, word_name, builtins, external_words)?;
395                    }
396                }
397                Statement::Quotation { body, .. } => {
398                    // Recursively validate quotation body
399                    self.validate_statements(body, word_name, builtins, external_words)?;
400                }
401                Statement::Match { arms, span: _ } => {
402                    // Recursively validate each match arm's body
403                    for arm in arms {
404                        self.validate_statements(&arm.body, word_name, builtins, external_words)?;
405                    }
406                }
407                _ => {} // Literals don't need validation
408            }
409        }
410        Ok(())
411    }
412
413    /// Generate constructor words for all union definitions
414    ///
415    /// Maximum number of fields a variant can have (limited by runtime support)
416    pub const MAX_VARIANT_FIELDS: usize = 12;
417
418    /// Generate helper words for union types:
419    /// 1. Constructors: `Make-VariantName` - creates variant instances
420    /// 2. Predicates: `is-VariantName?` - tests if value is a specific variant
421    /// 3. Accessors: `VariantName-fieldname` - extracts field values (RFC #345)
422    ///
423    /// Example: For `union Message { Get { chan: Int } }`
424    /// Generates:
425    ///   `: Make-Get ( Int -- Message ) :Get variant.make-1 ;`
426    ///   `: is-Get? ( Message -- Bool ) variant.tag :Get symbol.= ;`
427    ///   `: Get-chan ( Message -- Int ) 0 variant.field-at ;`
428    ///
429    /// Returns an error if any variant exceeds the maximum field count.
430    pub fn generate_constructors(&mut self) -> Result<(), String> {
431        let mut new_words = Vec::new();
432
433        for union_def in &self.unions {
434            for variant in &union_def.variants {
435                let field_count = variant.fields.len();
436
437                // Check field count limit before generating constructor
438                if field_count > Self::MAX_VARIANT_FIELDS {
439                    return Err(format!(
440                        "Variant '{}' in union '{}' has {} fields, but the maximum is {}. \
441                         Consider grouping fields into nested union types.",
442                        variant.name,
443                        union_def.name,
444                        field_count,
445                        Self::MAX_VARIANT_FIELDS
446                    ));
447                }
448
449                // 1. Generate constructor: Make-VariantName
450                let constructor_name = format!("Make-{}", variant.name);
451                let mut input_stack = StackType::RowVar("a".to_string());
452                for field in &variant.fields {
453                    let field_type = parse_type_name(&field.type_name);
454                    input_stack = input_stack.push(field_type);
455                }
456                let output_stack =
457                    StackType::RowVar("a".to_string()).push(Type::Union(union_def.name.clone()));
458                let effect = Effect::new(input_stack, output_stack);
459                let body = vec![
460                    Statement::Symbol(variant.name.clone()),
461                    Statement::WordCall {
462                        name: format!("variant.make-{}", field_count),
463                        span: None,
464                    },
465                ];
466                new_words.push(WordDef {
467                    name: constructor_name,
468                    effect: Some(effect),
469                    body,
470                    source: variant.source.clone(),
471                    allowed_lints: vec![],
472                });
473
474                // 2. Generate predicate: is-VariantName?
475                // Effect: ( UnionType -- Bool )
476                // Body: variant.tag :VariantName symbol.=
477                let predicate_name = format!("is-{}?", variant.name);
478                let predicate_input =
479                    StackType::RowVar("a".to_string()).push(Type::Union(union_def.name.clone()));
480                let predicate_output = StackType::RowVar("a".to_string()).push(Type::Bool);
481                let predicate_effect = Effect::new(predicate_input, predicate_output);
482                let predicate_body = vec![
483                    Statement::WordCall {
484                        name: "variant.tag".to_string(),
485                        span: None,
486                    },
487                    Statement::Symbol(variant.name.clone()),
488                    Statement::WordCall {
489                        name: "symbol.=".to_string(),
490                        span: None,
491                    },
492                ];
493                new_words.push(WordDef {
494                    name: predicate_name,
495                    effect: Some(predicate_effect),
496                    body: predicate_body,
497                    source: variant.source.clone(),
498                    allowed_lints: vec![],
499                });
500
501                // 3. Generate field accessors: VariantName-fieldname
502                // Effect: ( UnionType -- FieldType )
503                // Body: N variant.field-at
504                for (index, field) in variant.fields.iter().enumerate() {
505                    let accessor_name = format!("{}-{}", variant.name, field.name);
506                    let field_type = parse_type_name(&field.type_name);
507                    let accessor_input = StackType::RowVar("a".to_string())
508                        .push(Type::Union(union_def.name.clone()));
509                    let accessor_output = StackType::RowVar("a".to_string()).push(field_type);
510                    let accessor_effect = Effect::new(accessor_input, accessor_output);
511                    let accessor_body = vec![
512                        Statement::IntLiteral(index as i64),
513                        Statement::WordCall {
514                            name: "variant.field-at".to_string(),
515                            span: None,
516                        },
517                    ];
518                    new_words.push(WordDef {
519                        name: accessor_name,
520                        effect: Some(accessor_effect),
521                        body: accessor_body,
522                        source: variant.source.clone(), // Use variant's source for field accessors
523                        allowed_lints: vec![],
524                    });
525                }
526            }
527        }
528
529        self.words.extend(new_words);
530        Ok(())
531    }
532
533    /// RFC #345: Fix up type variables in stack effects that should be union types
534    ///
535    /// When parsing files with includes, type variables like "Message" in
536    /// `( Message -- Int )` may be parsed as `Type::Var("Message")` if the
537    /// union definition is in an included file. After resolving includes,
538    /// we know all union names and can convert these to `Type::Union("Message")`.
539    ///
540    /// This ensures proper nominal type checking for union types across files.
541    pub fn fixup_union_types(&mut self) {
542        // Collect all union names from the program
543        let union_names: std::collections::HashSet<String> =
544            self.unions.iter().map(|u| u.name.clone()).collect();
545
546        // Fix up types in all word effects
547        for word in &mut self.words {
548            if let Some(ref mut effect) = word.effect {
549                Self::fixup_stack_type(&mut effect.inputs, &union_names);
550                Self::fixup_stack_type(&mut effect.outputs, &union_names);
551            }
552        }
553    }
554
555    /// Recursively fix up types in a stack type
556    fn fixup_stack_type(stack: &mut StackType, union_names: &std::collections::HashSet<String>) {
557        match stack {
558            StackType::Empty | StackType::RowVar(_) => {}
559            StackType::Cons { rest, top } => {
560                Self::fixup_type(top, union_names);
561                Self::fixup_stack_type(rest, union_names);
562            }
563        }
564    }
565
566    /// Fix up a single type, converting Type::Var to Type::Union if it matches a union name
567    fn fixup_type(ty: &mut Type, union_names: &std::collections::HashSet<String>) {
568        match ty {
569            Type::Var(name) if union_names.contains(name) => {
570                *ty = Type::Union(name.clone());
571            }
572            Type::Quotation(effect) => {
573                Self::fixup_stack_type(&mut effect.inputs, union_names);
574                Self::fixup_stack_type(&mut effect.outputs, union_names);
575            }
576            Type::Closure { effect, captures } => {
577                Self::fixup_stack_type(&mut effect.inputs, union_names);
578                Self::fixup_stack_type(&mut effect.outputs, union_names);
579                for cap in captures {
580                    Self::fixup_type(cap, union_names);
581                }
582            }
583            _ => {}
584        }
585    }
586}
587
588/// Parse a type name string into a Type
589/// Used by constructor generation to build stack effects
590fn parse_type_name(name: &str) -> Type {
591    match name {
592        "Int" => Type::Int,
593        "Float" => Type::Float,
594        "Bool" => Type::Bool,
595        "String" => Type::String,
596        "Channel" => Type::Channel,
597        other => Type::Union(other.to_string()),
598    }
599}
600
601impl Default for Program {
602    fn default() -> Self {
603        Self::new()
604    }
605}