rusty_ast/
parser.rs

1use std::fs;
2use std::io;
3use std::path::Path;
4
5use syn::{File, visit::Visit};
6
7use crate::visitor::AstVisitor;
8
9/// parse rust source code to ast
10///
11/// # Arguments
12/// * `source`: &str - rust source code
13///
14/// # Returns
15/// * `Result<syn::File, syn::Error>` - ast
16///
17/// # Errors
18/// * `syn::Error` - parse error
19pub fn parse_rust_source(source: &str) -> Result<syn::File, syn::Error> {
20    syn::parse_file(source)
21}
22
23/// Parse Rust source code from a file into an AST
24///
25/// # Arguments
26/// * `path`: impl AsRef<Path> - path to the rust source file
27///
28/// # Returns
29/// * `io::Result<syn::File>` - ast
30///
31/// # Errors
32/// * `io::Error` - file read error
33/// * `syn::Error` - parse error (wrapped in io::Error)
34pub fn parse_rust_file<P: AsRef<Path>>(path: P) -> io::Result<syn::File> {
35    let source = fs::read_to_string(path)?;
36    let syntax =
37        syn::parse_file(&source).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
38
39    Ok(syntax)
40}
41
42/// print ast
43///
44/// # Arguments
45/// * `file`: &File - ast
46///
47/// # Returns
48/// * `()`
49pub fn print_ast(file: &File) {
50    println!("AST for Rust code:");
51    let mut visitor = AstVisitor::new();
52    visitor.visit_file(file);
53}
54
55#[cfg(test)]
56mod tests {
57    use super::*;
58    use std::io::Write;
59    use tempfile::NamedTempFile;
60
61    #[test]
62    fn test_parse_rust_file() {
63        let mut file = NamedTempFile::new().unwrap();
64        let test_code = r#"
65            fn test_function() {
66                println!("Hello, world!");
67            }
68        "#;
69
70        file.write_all(test_code.as_bytes()).unwrap();
71        file.flush().unwrap();
72
73        let ast = parse_rust_file(file.path()).unwrap();
74
75        assert_eq!(ast.items.len(), 1);
76        if let syn::Item::Fn(func) = &ast.items[0] {
77            assert_eq!(func.sig.ident.to_string(), "test_function");
78        } else {
79            panic!("Parsed item is not a function");
80        }
81    }
82
83    #[test]
84    fn test_parse_function() {
85        let source = r#"
86            fn add(a: i32, b: i32) -> i32 {
87                a + b
88            }
89        "#;
90
91        let file = parse_rust_source(source).unwrap();
92
93        // should be 1 item
94        assert_eq!(file.items.len(), 1);
95
96        // item should be function
97        if let syn::Item::Fn(func) = &file.items[0] {
98            assert_eq!(func.sig.ident.to_string(), "add");
99            assert_eq!(func.sig.inputs.len(), 2); // should be 2 parameters
100
101            // return type should be i32
102            if let syn::ReturnType::Type(_, return_type) = &func.sig.output {
103                if let syn::Type::Path(type_path) = &**return_type {
104                    let path_segment = &type_path.path.segments[0];
105                    assert_eq!(path_segment.ident.to_string(), "i32");
106                } else {
107                    panic!("Return type is not a path");
108                }
109            } else {
110                panic!("Function has no return type");
111            }
112
113            // should be 1 statement
114            assert_eq!(func.block.stmts.len(), 1);
115        } else {
116            panic!("Item is not a function");
117        }
118    }
119
120    #[test]
121    fn test_parse_struct() {
122        let source = r#"
123            struct Point {
124                x: f64,
125                y: f64,
126            }
127        "#;
128
129        let file = parse_rust_source(source).unwrap();
130
131        // should be 1 item
132        assert_eq!(file.items.len(), 1);
133
134        // item should be struct
135        if let syn::Item::Struct(struct_item) = &file.items[0] {
136            assert_eq!(struct_item.ident.to_string(), "Point");
137
138            // should be 2 fields
139            assert_eq!(struct_item.fields.iter().count(), 2);
140
141            // should be 2 fields
142            let fields: Vec<_> = struct_item.fields.iter().collect();
143
144            // x field
145            assert_eq!(fields[0].ident.as_ref().unwrap().to_string(), "x");
146            if let syn::Type::Path(type_path) = &fields[0].ty {
147                let path_segment = &type_path.path.segments[0];
148                assert_eq!(path_segment.ident.to_string(), "f64");
149            } else {
150                panic!("Field x is not a path type");
151            }
152
153            // y field
154            assert_eq!(fields[1].ident.as_ref().unwrap().to_string(), "y");
155            if let syn::Type::Path(type_path) = &fields[1].ty {
156                let path_segment = &type_path.path.segments[0];
157                assert_eq!(path_segment.ident.to_string(), "f64");
158            } else {
159                panic!("Field y is not a path type");
160            }
161        } else {
162            panic!("Item is not a struct");
163        }
164    }
165
166    #[test]
167    fn test_parse_enum() {
168        let source = r#"
169            enum Direction {
170                North,
171                East,
172                South,
173                West,
174            }
175        "#;
176
177        let file = parse_rust_source(source).unwrap();
178
179        // should be 1 item
180        assert_eq!(file.items.len(), 1);
181
182        // item should be enum
183        if let syn::Item::Enum(enum_item) = &file.items[0] {
184            assert_eq!(enum_item.ident.to_string(), "Direction");
185
186            // should be 4 variants
187            assert_eq!(enum_item.variants.len(), 4);
188
189            // should be 4 variants
190            let variant_names: Vec<String> = enum_item
191                .variants
192                .iter()
193                .map(|v| v.ident.to_string())
194                .collect();
195
196            assert_eq!(variant_names, vec!["North", "East", "South", "West"]);
197        } else {
198            panic!("Item is not an enum");
199        }
200    }
201
202    #[test]
203    fn test_parse_complex_expression() {
204        let source = r#"
205            fn complex_expr() {
206                let result = (10 + 20) * 30 / (5 - 2);
207                if result > 100 {
208                    println!("Large result: {}", result);
209                } else {
210                    println!("Small result: {}", result);
211                }
212            }
213        "#;
214
215        let file = parse_rust_source(source).unwrap();
216
217        // should be 1 item
218        assert_eq!(file.items.len(), 1);
219
220        // item should be function
221        if let syn::Item::Fn(func) = &file.items[0] {
222            assert_eq!(func.sig.ident.to_string(), "complex_expr");
223
224            // should be 2 statements
225            assert_eq!(func.block.stmts.len(), 2);
226
227            // first statement should be variable declaration
228            if let syn::Stmt::Local(local) = &func.block.stmts[0] {
229                assert!(local.init.is_some());
230
231                // variable name should be result
232                if let syn::Pat::Ident(pat_ident) = &local.pat {
233                    assert_eq!(pat_ident.ident.to_string(), "result");
234                } else {
235                    panic!("Variable declaration pattern is not an identifier");
236                }
237            } else {
238                panic!("First statement is not a variable declaration");
239            }
240
241            // second statement should be if expression
242            if let syn::Stmt::Expr(expr, _) = &func.block.stmts[1] {
243                if let syn::Expr::If(_) = expr {
244                    // OK
245                } else {
246                    panic!("Second statement is not an if expression");
247                }
248            } else {
249                panic!("Second statement is not an expression");
250            }
251        } else {
252            panic!("Item is not a function");
253        }
254    }
255
256    #[test]
257    fn test_parse_invalid_code() {
258        let source = r#"
259            fn invalid_function( {
260                let x = 10;
261            }
262        "#;
263
264        let result = parse_rust_source(source);
265        assert!(result.is_err(), "Expected parse error for invalid code");
266    }
267
268    #[test]
269    fn test_parse_multiple_items() {
270        let source = r#"
271            fn function1() -> i32 { 42 }
272            
273            struct MyStruct {
274                field: i32,
275            }
276            
277            fn function2(s: MyStruct) -> i32 {
278                s.field
279            }
280        "#;
281
282        let file = parse_rust_source(source).unwrap();
283
284        // should be 3 items
285        assert_eq!(file.items.len(), 3);
286
287        // first item should be function
288        if let syn::Item::Fn(func) = &file.items[0] {
289            assert_eq!(func.sig.ident.to_string(), "function1");
290        } else {
291            panic!("First item is not a function");
292        }
293
294        // second item should be struct
295        if let syn::Item::Struct(struct_item) = &file.items[1] {
296            assert_eq!(struct_item.ident.to_string(), "MyStruct");
297        } else {
298            panic!("Second item is not a struct");
299        }
300
301        // third item should be function
302        if let syn::Item::Fn(func) = &file.items[2] {
303            assert_eq!(func.sig.ident.to_string(), "function2");
304        } else {
305            panic!("Third item is not a function");
306        }
307    }
308}