1use std::fmt;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum AstNodeType {
11 Function,
13 Parameter,
15 Return,
17 Variable,
19 Assignment,
21 BinaryOp,
23 UnaryOp,
25 Conditional,
27 Loop,
29 Call,
31 Literal,
33 Index,
35 FieldAccess,
37 Block,
39 TypeAnnotation,
41 Generic,
43 Match,
45 MatchArm,
47 Struct,
49 Enum,
51 Trait,
53 Impl,
55 Module,
57 Import,
59}
60
61impl fmt::Display for AstNodeType {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 let s = match self {
64 Self::Function => "Func",
65 Self::Parameter => "Param",
66 Self::Return => "Ret",
67 Self::Variable => "Var",
68 Self::Assignment => "Assign",
69 Self::BinaryOp => "BinOp",
70 Self::UnaryOp => "UnOp",
71 Self::Conditional => "Cond",
72 Self::Loop => "Loop",
73 Self::Call => "Call",
74 Self::Literal => "Lit",
75 Self::Index => "Idx",
76 Self::FieldAccess => "Field",
77 Self::Block => "Block",
78 Self::TypeAnnotation => "Type",
79 Self::Generic => "Gen",
80 Self::Match => "Match",
81 Self::MatchArm => "Arm",
82 Self::Struct => "Struct",
83 Self::Enum => "Enum",
84 Self::Trait => "Trait",
85 Self::Impl => "Impl",
86 Self::Module => "Mod",
87 Self::Import => "Import",
88 };
89 write!(f, "{s}")
90 }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
95pub enum TokenType {
96 Identifier,
98 Number,
100 String,
102 Boolean,
104 Keyword,
106 Operator,
108 Punctuation,
110 TypeName,
112 Comment,
114}
115
116impl fmt::Display for TokenType {
117 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
118 let s = match self {
119 Self::Identifier => "Id",
120 Self::Number => "Num",
121 Self::String => "Str",
122 Self::Boolean => "Bool",
123 Self::Keyword => "Kw",
124 Self::Operator => "Op",
125 Self::Punctuation => "Punct",
126 Self::TypeName => "Type",
127 Self::Comment => "Comment",
128 };
129 write!(f, "{s}")
130 }
131}
132
133#[derive(Debug, Clone, PartialEq, Eq, Hash)]
135pub struct Token {
136 token_type: TokenType,
138 value: String,
140}
141
142impl Token {
143 #[must_use]
145 pub fn new(token_type: TokenType, value: impl Into<String>) -> Self {
146 Self {
147 token_type,
148 value: value.into(),
149 }
150 }
151
152 #[must_use]
154 pub fn token_type(&self) -> TokenType {
155 self.token_type
156 }
157
158 #[must_use]
160 pub fn value(&self) -> &str {
161 &self.value
162 }
163}
164
165impl fmt::Display for Token {
166 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167 write!(f, "{}:{}", self.token_type, self.value)
168 }
169}
170
171#[derive(Debug, Clone)]
173pub struct AstNode {
174 node_type: AstNodeType,
176 value: String,
178 children: Vec<AstNode>,
180 token: Option<Token>,
182}
183
184impl AstNode {
185 #[must_use]
187 pub fn new(node_type: AstNodeType, value: impl Into<String>) -> Self {
188 Self {
189 node_type,
190 value: value.into(),
191 children: Vec::new(),
192 token: None,
193 }
194 }
195
196 #[must_use]
198 pub fn terminal(token: Token) -> Self {
199 Self {
200 node_type: AstNodeType::Literal,
201 value: token.value().to_string(),
202 children: Vec::new(),
203 token: Some(token),
204 }
205 }
206
207 #[must_use]
209 pub fn node_type(&self) -> AstNodeType {
210 self.node_type
211 }
212
213 #[must_use]
215 pub fn value(&self) -> &str {
216 &self.value
217 }
218
219 #[must_use]
221 pub fn children(&self) -> &[AstNode] {
222 &self.children
223 }
224
225 pub fn children_mut(&mut self) -> &mut Vec<AstNode> {
227 &mut self.children
228 }
229
230 pub fn add_child(&mut self, child: AstNode) {
232 self.children.push(child);
233 }
234
235 #[must_use]
237 pub fn is_terminal(&self) -> bool {
238 self.children.is_empty()
239 }
240
241 #[must_use]
243 pub fn token(&self) -> Option<&Token> {
244 self.token.as_ref()
245 }
246
247 #[must_use]
249 pub fn node_count(&self) -> usize {
250 1 + self.children.iter().map(AstNode::node_count).sum::<usize>()
251 }
252
253 #[must_use]
255 pub fn depth(&self) -> usize {
256 if self.children.is_empty() {
257 1
258 } else {
259 1 + self.children.iter().map(AstNode::depth).max().unwrap_or(0)
260 }
261 }
262
263 #[must_use]
265 pub fn terminals(&self) -> Vec<&AstNode> {
266 if self.is_terminal() {
267 vec![self]
268 } else {
269 self.children.iter().flat_map(AstNode::terminals).collect()
270 }
271 }
272}
273
274impl fmt::Display for AstNode {
275 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276 write!(f, "{}:{}", self.node_type, self.value)
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_ast_node_type_display() {
286 assert_eq!(AstNodeType::Function.to_string(), "Func");
287 assert_eq!(AstNodeType::Parameter.to_string(), "Param");
288 assert_eq!(AstNodeType::Return.to_string(), "Ret");
289 }
290
291 #[test]
292 fn test_token_type_display() {
293 assert_eq!(TokenType::Identifier.to_string(), "Id");
294 assert_eq!(TokenType::Number.to_string(), "Num");
295 assert_eq!(TokenType::String.to_string(), "Str");
296 }
297
298 #[test]
299 fn test_token_creation() {
300 let token = Token::new(TokenType::Identifier, "my_var");
301 assert_eq!(token.token_type(), TokenType::Identifier);
302 assert_eq!(token.value(), "my_var");
303 assert_eq!(token.to_string(), "Id:my_var");
304 }
305
306 #[test]
307 fn test_ast_node_creation() {
308 let node = AstNode::new(AstNodeType::Function, "calculate");
309 assert_eq!(node.node_type(), AstNodeType::Function);
310 assert_eq!(node.value(), "calculate");
311 assert!(node.children().is_empty());
312 assert!(node.is_terminal());
313 }
314
315 #[test]
316 fn test_ast_node_with_children() {
317 let mut func = AstNode::new(AstNodeType::Function, "add");
318 func.add_child(AstNode::new(AstNodeType::Parameter, "x"));
319 func.add_child(AstNode::new(AstNodeType::Parameter, "y"));
320 func.add_child(AstNode::new(AstNodeType::Return, "result"));
321
322 assert_eq!(func.children().len(), 3);
323 assert!(!func.is_terminal());
324 assert_eq!(func.node_count(), 4);
325 assert_eq!(func.depth(), 2);
326 }
327
328 #[test]
329 fn test_terminal_node() {
330 let token = Token::new(TokenType::Number, "42");
331 let node = AstNode::terminal(token);
332
333 assert!(node.is_terminal());
334 assert!(node.token().is_some());
335 assert_eq!(node.token().map(Token::value), Some("42"));
336 }
337
338 #[test]
339 fn test_collect_terminals() {
340 let mut func = AstNode::new(AstNodeType::Function, "test");
341 func.add_child(AstNode::new(AstNodeType::Parameter, "a"));
342 func.add_child(AstNode::new(AstNodeType::Parameter, "b"));
343
344 let terminals = func.terminals();
345 assert_eq!(terminals.len(), 2);
346 }
347
348 #[test]
349 fn test_deep_tree() {
350 let mut root = AstNode::new(AstNodeType::Function, "deep");
351 let mut level1 = AstNode::new(AstNodeType::Block, "body");
352 let mut level2 = AstNode::new(AstNodeType::Conditional, "if");
353 level2.add_child(AstNode::new(AstNodeType::Return, "early"));
354 level1.add_child(level2);
355 root.add_child(level1);
356
357 assert_eq!(root.depth(), 4);
358 assert_eq!(root.node_count(), 4);
359 }
360}