Skip to main content

cjc_ast/
node_utils.rs

1//! AST Node Utilities — Pure query methods on existing types
2//!
3//! Adds convenience methods to `Expr`, `Block`, and `Program` via new `impl`
4//! blocks.  These are read-only, side-effect-free query methods.
5//!
6//! ## Design decisions
7//!
8//! - **New `impl` blocks** — safe in Rust; does not modify existing impls
9//! - **No mutation** — all methods return computed values
10//! - **No dependencies** — uses only types from this crate
11
12use crate::{Block, DeclKind, Expr, ExprKind, Program};
13
14// ---------------------------------------------------------------------------
15// Expr utilities
16// ---------------------------------------------------------------------------
17
18impl Expr {
19    /// Return the number of direct child expressions.
20    ///
21    /// Only counts immediate children, not transitive descendants.
22    /// Leaf nodes (literals, identifiers, `NaLit`) return 0.
23    /// For `Call` expressions the count includes both the callee and all arguments.
24    ///
25    /// # Returns
26    ///
27    /// The number of direct child expression sub-trees.
28    pub fn child_count(&self) -> usize {
29        match &self.kind {
30            ExprKind::IntLit(_)
31            | ExprKind::FloatLit(_)
32            | ExprKind::StringLit(_)
33            | ExprKind::ByteStringLit(_)
34            | ExprKind::ByteCharLit(_)
35            | ExprKind::RawStringLit(_)
36            | ExprKind::RawByteStringLit(_)
37            | ExprKind::RegexLit { .. }
38            | ExprKind::BoolLit(_)
39            | ExprKind::NaLit
40            | ExprKind::Ident(_)
41            | ExprKind::Col(_) => 0,
42
43            ExprKind::FStringLit(segs) => segs.iter().filter(|(_, e)| e.is_some()).count(),
44            ExprKind::TensorLit { rows } => rows.iter().map(|r| r.len()).sum(),
45            ExprKind::Unary { .. } | ExprKind::Try(_) => 1,
46            ExprKind::Binary { .. }
47            | ExprKind::Assign { .. }
48            | ExprKind::CompoundAssign { .. }
49            | ExprKind::Pipe { .. }
50            | ExprKind::Index { .. } => 2,
51            ExprKind::Field { .. } => 1,
52            ExprKind::MultiIndex { object: _, indices } => 1 + indices.len(),
53            ExprKind::Call { args, .. } => 1 + args.len(), // callee + args
54            ExprKind::IfExpr { .. } => 1,                  // condition
55            ExprKind::Block(_) => 0,
56            ExprKind::StructLit { fields, .. } => fields.len(),
57            ExprKind::ArrayLit(elems) => elems.len(),
58            ExprKind::TupleLit(elems) => elems.len(),
59            ExprKind::Lambda { .. } => 1,                  // body
60            ExprKind::Match { arms, .. } => 1 + arms.len(), // scrutinee + arms
61            ExprKind::VariantLit { fields, .. } => fields.len(),
62            ExprKind::Cast { .. } => 1, // the inner expression
63        }
64    }
65
66    /// Return `true` if this expression is a literal value.
67    ///
68    /// Covers integer, float, string (including byte-string, raw-string, and
69    /// raw-byte-string variants), boolean, `NA`, and regex literals.
70    /// Collection literals (`ArrayLit`, `TupleLit`, `TensorLit`) are **not**
71    /// considered literals by this method because they contain sub-expressions.
72    pub fn is_literal(&self) -> bool {
73        matches!(
74            &self.kind,
75            ExprKind::IntLit(_)
76                | ExprKind::FloatLit(_)
77                | ExprKind::StringLit(_)
78                | ExprKind::ByteStringLit(_)
79                | ExprKind::ByteCharLit(_)
80                | ExprKind::RawStringLit(_)
81                | ExprKind::RawByteStringLit(_)
82                | ExprKind::BoolLit(_)
83                | ExprKind::NaLit
84                | ExprKind::RegexLit { .. }
85        )
86    }
87
88    /// Return `true` if this expression is a valid assignment target (place expression).
89    ///
90    /// Place expressions are identifiers, field accesses, and index accesses.
91    /// These are the only forms that may appear on the left-hand side of an
92    /// assignment.
93    pub fn is_place(&self) -> bool {
94        matches!(
95            &self.kind,
96            ExprKind::Ident(_) | ExprKind::Field { .. } | ExprKind::Index { .. }
97        )
98    }
99
100    /// Return `true` if this expression is a compound (non-leaf) expression.
101    ///
102    /// Compound expressions contain sub-expressions that require recursive
103    /// evaluation: binary/unary operations, calls, matches, if-expressions,
104    /// blocks, pipes, and lambdas.
105    pub fn is_compound(&self) -> bool {
106        matches!(
107            &self.kind,
108            ExprKind::Binary { .. }
109                | ExprKind::Unary { .. }
110                | ExprKind::Call { .. }
111                | ExprKind::Match { .. }
112                | ExprKind::IfExpr { .. }
113                | ExprKind::Block(_)
114                | ExprKind::Pipe { .. }
115                | ExprKind::Lambda { .. }
116        )
117    }
118}
119
120// ---------------------------------------------------------------------------
121// Block utilities
122// ---------------------------------------------------------------------------
123
124impl Block {
125    /// Return `true` if the block has no statements and no trailing expression.
126    ///
127    /// An empty block `{}` contributes nothing to the program and may be
128    /// flagged as a warning by the validator.
129    pub fn is_empty(&self) -> bool {
130        self.stmts.is_empty() && self.expr.is_none()
131    }
132
133    /// Return the number of statements in this block.
134    ///
135    /// Does not count the trailing expression (if any). Use
136    /// [`has_trailing_expr`](Block::has_trailing_expr) to check for that separately.
137    pub fn stmt_count(&self) -> usize {
138        self.stmts.len()
139    }
140
141    /// Return `true` if the block has a trailing expression.
142    ///
143    /// A trailing expression (the final expression without a semicolon)
144    /// determines the block's value when used in expression position.
145    pub fn has_trailing_expr(&self) -> bool {
146        self.expr.is_some()
147    }
148}
149
150// ---------------------------------------------------------------------------
151// Program utilities
152// ---------------------------------------------------------------------------
153
154impl Program {
155    /// Return the number of function declarations in the program.
156    ///
157    /// Counts top-level `fn` declarations plus methods inside `impl` blocks.
158    /// Does not count lambdas or closures.
159    pub fn function_count(&self) -> usize {
160        let mut count = 0;
161        for decl in &self.declarations {
162            match &decl.kind {
163                DeclKind::Fn(_) => count += 1,
164                DeclKind::Impl(i) => count += i.methods.len(),
165                _ => {}
166            }
167        }
168        count
169    }
170
171    /// Return the number of `struct` declarations in the program.
172    pub fn struct_count(&self) -> usize {
173        self.declarations
174            .iter()
175            .filter(|d| matches!(&d.kind, DeclKind::Struct(_)))
176            .count()
177    }
178
179    /// Return `true` if there is a top-level function named `"main"`.
180    ///
181    /// The CJC runtime uses the presence of a `main` function to determine
182    /// the program entry point.
183    pub fn has_main_function(&self) -> bool {
184        self.declarations.iter().any(|d| {
185            if let DeclKind::Fn(f) = &d.kind {
186                f.name.name == "main"
187            } else {
188                false
189            }
190        })
191    }
192}
193
194// ---------------------------------------------------------------------------
195// Tests
196// ---------------------------------------------------------------------------
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use crate::*;
202
203    fn dummy_expr(kind: ExprKind) -> Expr {
204        Expr {
205            kind,
206            span: Span::dummy(),
207        }
208    }
209
210    #[test]
211    fn test_expr_child_count() {
212        assert_eq!(dummy_expr(ExprKind::IntLit(1)).child_count(), 0);
213        assert_eq!(
214            dummy_expr(ExprKind::Binary {
215                op: BinOp::Add,
216                left: Box::new(dummy_expr(ExprKind::IntLit(1))),
217                right: Box::new(dummy_expr(ExprKind::IntLit(2))),
218            })
219            .child_count(),
220            2
221        );
222    }
223
224    #[test]
225    fn test_expr_is_literal() {
226        assert!(dummy_expr(ExprKind::IntLit(1)).is_literal());
227        assert!(dummy_expr(ExprKind::FloatLit(1.0)).is_literal());
228        assert!(dummy_expr(ExprKind::BoolLit(true)).is_literal());
229        assert!(!dummy_expr(ExprKind::Ident(Ident::dummy("x"))).is_literal());
230    }
231
232    #[test]
233    fn test_expr_is_place() {
234        assert!(dummy_expr(ExprKind::Ident(Ident::dummy("x"))).is_place());
235        assert!(!dummy_expr(ExprKind::IntLit(1)).is_place());
236    }
237
238    #[test]
239    fn test_expr_is_compound() {
240        assert!(dummy_expr(ExprKind::Binary {
241            op: BinOp::Add,
242            left: Box::new(dummy_expr(ExprKind::IntLit(1))),
243            right: Box::new(dummy_expr(ExprKind::IntLit(2))),
244        })
245        .is_compound());
246        assert!(!dummy_expr(ExprKind::IntLit(1)).is_compound());
247    }
248
249    #[test]
250    fn test_block_utils() {
251        let empty = Block {
252            stmts: vec![],
253            expr: None,
254            span: Span::dummy(),
255        };
256        assert!(empty.is_empty());
257        assert_eq!(empty.stmt_count(), 0);
258        assert!(!empty.has_trailing_expr());
259
260        let with_expr = Block {
261            stmts: vec![Stmt {
262                kind: StmtKind::Expr(dummy_expr(ExprKind::IntLit(1))),
263                span: Span::dummy(),
264            }],
265            expr: Some(Box::new(dummy_expr(ExprKind::IntLit(2)))),
266            span: Span::dummy(),
267        };
268        assert!(!with_expr.is_empty());
269        assert_eq!(with_expr.stmt_count(), 1);
270        assert!(with_expr.has_trailing_expr());
271    }
272
273    #[test]
274    fn test_program_utils() {
275        let program = Program {
276            declarations: vec![
277                Decl {
278                    kind: DeclKind::Fn(FnDecl {
279                        name: Ident::dummy("main"),
280                        type_params: vec![],
281                        params: vec![],
282                        return_type: None,
283                        body: Block {
284                            stmts: vec![],
285                            expr: None,
286                            span: Span::dummy(),
287                        },
288                        is_nogc: false,
289                        effect_annotation: None,
290                        decorators: vec![],
291                        vis: Visibility::Private,
292                    }),
293                    span: Span::dummy(),
294                },
295                Decl {
296                    kind: DeclKind::Struct(StructDecl {
297                        name: Ident::dummy("Point"),
298                        type_params: vec![],
299                        fields: vec![],
300                        vis: Visibility::Private,
301                    }),
302                    span: Span::dummy(),
303                },
304            ],
305        };
306        assert_eq!(program.function_count(), 1);
307        assert_eq!(program.struct_count(), 1);
308        assert!(program.has_main_function());
309    }
310}