aprender/code/
ast.rs

1//! Abstract Syntax Tree representation for code analysis
2//!
3//! Provides lightweight AST node types for code2vec path extraction.
4//! This is not a full parser - it's designed to work with pre-parsed AST data.
5
6use std::fmt;
7
8/// Types of AST nodes for code analysis
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum AstNodeType {
11    /// Function or method definition
12    Function,
13    /// Function/method parameter
14    Parameter,
15    /// Return statement or expression
16    Return,
17    /// Variable declaration
18    Variable,
19    /// Assignment expression
20    Assignment,
21    /// Binary operation (e.g., +, -, *, /)
22    BinaryOp,
23    /// Unary operation (e.g., !, -)
24    UnaryOp,
25    /// If/else conditional
26    Conditional,
27    /// Loop construct (for, while, loop)
28    Loop,
29    /// Function call expression
30    Call,
31    /// Literal value (number, string, bool)
32    Literal,
33    /// Array/vector access
34    Index,
35    /// Field access (e.g., obj.field)
36    FieldAccess,
37    /// Block of statements
38    Block,
39    /// Type annotation
40    TypeAnnotation,
41    /// Generic type parameter
42    Generic,
43    /// Match/switch expression
44    Match,
45    /// Match arm
46    MatchArm,
47    /// Struct definition
48    Struct,
49    /// Enum definition
50    Enum,
51    /// Trait/interface definition
52    Trait,
53    /// Implementation block
54    Impl,
55    /// Module declaration
56    Module,
57    /// Import/use statement
58    Import,
59}
60
61impl fmt::Display for AstNodeType {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        let s = match self {
64            Self::Function => "Func",
65            Self::Parameter => "Param",
66            Self::Return => "Ret",
67            Self::Variable => "Var",
68            Self::Assignment => "Assign",
69            Self::BinaryOp => "BinOp",
70            Self::UnaryOp => "UnOp",
71            Self::Conditional => "Cond",
72            Self::Loop => "Loop",
73            Self::Call => "Call",
74            Self::Literal => "Lit",
75            Self::Index => "Idx",
76            Self::FieldAccess => "Field",
77            Self::Block => "Block",
78            Self::TypeAnnotation => "Type",
79            Self::Generic => "Gen",
80            Self::Match => "Match",
81            Self::MatchArm => "Arm",
82            Self::Struct => "Struct",
83            Self::Enum => "Enum",
84            Self::Trait => "Trait",
85            Self::Impl => "Impl",
86            Self::Module => "Mod",
87            Self::Import => "Import",
88        };
89        write!(f, "{s}")
90    }
91}
92
93/// Types of tokens (terminal nodes in the AST)
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
95pub enum TokenType {
96    /// Identifier (variable name, function name, etc.)
97    Identifier,
98    /// Numeric literal
99    Number,
100    /// String literal
101    String,
102    /// Boolean literal
103    Boolean,
104    /// Keyword (if, else, fn, let, etc.)
105    Keyword,
106    /// Operator (+, -, *, /, etc.)
107    Operator,
108    /// Punctuation (parentheses, braces, etc.)
109    Punctuation,
110    /// Type name
111    TypeName,
112    /// Comment
113    Comment,
114}
115
116impl fmt::Display for TokenType {
117    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
118        let s = match self {
119            Self::Identifier => "Id",
120            Self::Number => "Num",
121            Self::String => "Str",
122            Self::Boolean => "Bool",
123            Self::Keyword => "Kw",
124            Self::Operator => "Op",
125            Self::Punctuation => "Punct",
126            Self::TypeName => "Type",
127            Self::Comment => "Comment",
128        };
129        write!(f, "{s}")
130    }
131}
132
133/// A token (terminal node) in the AST
134#[derive(Debug, Clone, PartialEq, Eq, Hash)]
135pub struct Token {
136    /// Type of token
137    token_type: TokenType,
138    /// Token value/content
139    value: String,
140}
141
142impl Token {
143    /// Create a new token
144    #[must_use]
145    pub fn new(token_type: TokenType, value: impl Into<String>) -> Self {
146        Self {
147            token_type,
148            value: value.into(),
149        }
150    }
151
152    /// Get the token type
153    #[must_use]
154    pub fn token_type(&self) -> TokenType {
155        self.token_type
156    }
157
158    /// Get the token value
159    #[must_use]
160    pub fn value(&self) -> &str {
161        &self.value
162    }
163}
164
165impl fmt::Display for Token {
166    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167        write!(f, "{}:{}", self.token_type, self.value)
168    }
169}
170
171/// A node in the Abstract Syntax Tree
172#[derive(Debug, Clone)]
173pub struct AstNode {
174    /// Type of AST node
175    node_type: AstNodeType,
176    /// Node value (e.g., function name, variable name)
177    value: String,
178    /// Child nodes
179    children: Vec<AstNode>,
180    /// Optional token for terminal nodes
181    token: Option<Token>,
182}
183
184impl AstNode {
185    /// Create a new AST node
186    #[must_use]
187    pub fn new(node_type: AstNodeType, value: impl Into<String>) -> Self {
188        Self {
189            node_type,
190            value: value.into(),
191            children: Vec::new(),
192            token: None,
193        }
194    }
195
196    /// Create a terminal node with a token
197    #[must_use]
198    pub fn terminal(token: Token) -> Self {
199        Self {
200            node_type: AstNodeType::Literal,
201            value: token.value().to_string(),
202            children: Vec::new(),
203            token: Some(token),
204        }
205    }
206
207    /// Get the node type
208    #[must_use]
209    pub fn node_type(&self) -> AstNodeType {
210        self.node_type
211    }
212
213    /// Get the node value
214    #[must_use]
215    pub fn value(&self) -> &str {
216        &self.value
217    }
218
219    /// Get the children of this node
220    #[must_use]
221    pub fn children(&self) -> &[AstNode] {
222        &self.children
223    }
224
225    /// Get mutable access to children
226    pub fn children_mut(&mut self) -> &mut Vec<AstNode> {
227        &mut self.children
228    }
229
230    /// Add a child node
231    pub fn add_child(&mut self, child: AstNode) {
232        self.children.push(child);
233    }
234
235    /// Check if this is a terminal node (leaf)
236    #[must_use]
237    pub fn is_terminal(&self) -> bool {
238        self.children.is_empty()
239    }
240
241    /// Get the token if this is a terminal node
242    #[must_use]
243    pub fn token(&self) -> Option<&Token> {
244        self.token.as_ref()
245    }
246
247    /// Count all nodes in the subtree (including self)
248    #[must_use]
249    pub fn node_count(&self) -> usize {
250        1 + self.children.iter().map(AstNode::node_count).sum::<usize>()
251    }
252
253    /// Get the depth of the tree
254    #[must_use]
255    pub fn depth(&self) -> usize {
256        if self.children.is_empty() {
257            1
258        } else {
259            1 + self.children.iter().map(AstNode::depth).max().unwrap_or(0)
260        }
261    }
262
263    /// Collect all terminal nodes (leaves)
264    #[must_use]
265    pub fn terminals(&self) -> Vec<&AstNode> {
266        if self.is_terminal() {
267            vec![self]
268        } else {
269            self.children.iter().flat_map(AstNode::terminals).collect()
270        }
271    }
272}
273
274impl fmt::Display for AstNode {
275    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276        write!(f, "{}:{}", self.node_type, self.value)
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn test_ast_node_type_display() {
286        assert_eq!(AstNodeType::Function.to_string(), "Func");
287        assert_eq!(AstNodeType::Parameter.to_string(), "Param");
288        assert_eq!(AstNodeType::Return.to_string(), "Ret");
289    }
290
291    #[test]
292    fn test_token_type_display() {
293        assert_eq!(TokenType::Identifier.to_string(), "Id");
294        assert_eq!(TokenType::Number.to_string(), "Num");
295        assert_eq!(TokenType::String.to_string(), "Str");
296    }
297
298    #[test]
299    fn test_token_creation() {
300        let token = Token::new(TokenType::Identifier, "my_var");
301        assert_eq!(token.token_type(), TokenType::Identifier);
302        assert_eq!(token.value(), "my_var");
303        assert_eq!(token.to_string(), "Id:my_var");
304    }
305
306    #[test]
307    fn test_ast_node_creation() {
308        let node = AstNode::new(AstNodeType::Function, "calculate");
309        assert_eq!(node.node_type(), AstNodeType::Function);
310        assert_eq!(node.value(), "calculate");
311        assert!(node.children().is_empty());
312        assert!(node.is_terminal());
313    }
314
315    #[test]
316    fn test_ast_node_with_children() {
317        let mut func = AstNode::new(AstNodeType::Function, "add");
318        func.add_child(AstNode::new(AstNodeType::Parameter, "x"));
319        func.add_child(AstNode::new(AstNodeType::Parameter, "y"));
320        func.add_child(AstNode::new(AstNodeType::Return, "result"));
321
322        assert_eq!(func.children().len(), 3);
323        assert!(!func.is_terminal());
324        assert_eq!(func.node_count(), 4);
325        assert_eq!(func.depth(), 2);
326    }
327
328    #[test]
329    fn test_terminal_node() {
330        let token = Token::new(TokenType::Number, "42");
331        let node = AstNode::terminal(token);
332
333        assert!(node.is_terminal());
334        assert!(node.token().is_some());
335        assert_eq!(node.token().map(Token::value), Some("42"));
336    }
337
338    #[test]
339    fn test_collect_terminals() {
340        let mut func = AstNode::new(AstNodeType::Function, "test");
341        func.add_child(AstNode::new(AstNodeType::Parameter, "a"));
342        func.add_child(AstNode::new(AstNodeType::Parameter, "b"));
343
344        let terminals = func.terminals();
345        assert_eq!(terminals.len(), 2);
346    }
347
348    #[test]
349    fn test_deep_tree() {
350        let mut root = AstNode::new(AstNodeType::Function, "deep");
351        let mut level1 = AstNode::new(AstNodeType::Block, "body");
352        let mut level2 = AstNode::new(AstNodeType::Conditional, "if");
353        level2.add_child(AstNode::new(AstNodeType::Return, "early"));
354        level1.add_child(level2);
355        root.add_child(level1);
356
357        assert_eq!(root.depth(), 4);
358        assert_eq!(root.node_count(), 4);
359    }
360}