1use std::path::Path;
4
5use proc_macro2::TokenStream;
6use quote::ToTokens;
7use syn::visit::Visit;
8use syn::visit_mut::VisitMut;
9
10use crate::error::SourceResult;
11use crate::ops;
12use crate::visitor::{IdentifierCollector, ImportCollector};
13
14#[derive(Debug, Clone)]
19pub struct RustAST {
20 file: syn::File,
22 source: Option<String>,
24}
25
26impl RustAST {
27 pub fn parse(source: &str) -> SourceResult<Self> {
29 let file = syn::parse_file(source)?;
30 Ok(Self {
31 file,
32 source: Some(source.to_string()),
33 })
34 }
35
36 pub fn from_file(path: &Path) -> SourceResult<Self> {
38 let source = std::fs::read_to_string(path)?;
39 Self::parse(&source)
40 }
41
42 pub fn file(&self) -> &syn::File {
44 &self.file
45 }
46
47 pub fn file_mut(&mut self) -> &mut syn::File {
49 &mut self.file
50 }
51
52 pub fn source(&self) -> Option<&str> {
54 self.source.as_deref()
55 }
56
57 pub fn to_string_pretty(&self) -> String {
59 prettyplease::unparse(&self.file)
60 }
61
62 pub fn to_token_stream(&self) -> TokenStream {
64 self.file.to_token_stream()
65 }
66
67 pub fn collect_imports(&self) -> Vec<&syn::ItemUse> {
71 let mut collector = ImportCollector::new();
72 collector.visit_file(&self.file);
73 collector.imports
74 }
75
76 pub fn collect_used_identifiers(&self) -> std::collections::HashSet<String> {
78 let mut collector = IdentifierCollector::new();
79 collector.visit_file(&self.file);
80 collector.identifiers
81 }
82
83 pub fn find_unused_imports(&self) -> Vec<UnusedImport> {
85 ops::RemoveUnusedImports::detect(self)
86 }
87
88 pub fn remove_unused_imports(&mut self) -> Vec<UnusedImport> {
92 ops::RemoveUnusedImports::apply(self)
93 }
94
95 pub fn visit_mut<V: VisitMut>(&mut self, visitor: &mut V) {
97 visitor.visit_file_mut(&mut self.file);
98 }
99
100 pub fn visit<'a, V: Visit<'a>>(&'a self, visitor: &mut V) {
102 visitor.visit_file(&self.file);
103 }
104
105 pub fn items(&self) -> &[syn::Item] {
109 &self.file.items
110 }
111
112 pub fn items_mut(&mut self) -> &mut Vec<syn::Item> {
114 &mut self.file.items
115 }
116
117 pub fn filter_items<F>(&self, predicate: F) -> Vec<&syn::Item>
119 where
120 F: Fn(&syn::Item) -> bool,
121 {
122 self.file.items.iter().filter(|i| predicate(i)).collect()
123 }
124
125 pub fn functions(&self) -> Vec<&syn::ItemFn> {
127 self.file
128 .items
129 .iter()
130 .filter_map(|item| {
131 if let syn::Item::Fn(f) = item {
132 Some(f)
133 } else {
134 None
135 }
136 })
137 .collect()
138 }
139
140 pub fn structs(&self) -> Vec<&syn::ItemStruct> {
142 self.file
143 .items
144 .iter()
145 .filter_map(|item| {
146 if let syn::Item::Struct(s) = item {
147 Some(s)
148 } else {
149 None
150 }
151 })
152 .collect()
153 }
154
155 pub fn impls(&self) -> Vec<&syn::ItemImpl> {
157 self.file
158 .items
159 .iter()
160 .filter_map(|item| {
161 if let syn::Item::Impl(i) = item {
162 Some(i)
163 } else {
164 None
165 }
166 })
167 .collect()
168 }
169}
170
171impl std::fmt::Display for RustAST {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 write!(f, "{}", self.to_string_pretty())
174 }
175}
176
177#[derive(Debug, Clone)]
179pub struct UnusedImport {
180 pub path: String,
182 pub name: String,
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[test]
191 fn test_parse_simple() {
192 let ast = RustAST::parse("fn main() {}").unwrap();
193 assert_eq!(ast.functions().len(), 1);
194 }
195
196 #[test]
197 fn test_parse_with_imports() {
198 let ast = RustAST::parse("use std::io;\nuse std::fs;\nfn main() {}").unwrap();
199 assert_eq!(ast.collect_imports().len(), 2);
200 }
201
202 #[test]
203 fn test_to_string() {
204 let ast = RustAST::parse("fn main() {}").unwrap();
205 let output = ast.to_string();
206 assert!(output.contains("fn main"));
207 }
208
209 #[test]
210 fn test_collect_identifiers() {
211 let ast = RustAST::parse(
212 r#"
213 use std::io;
214 fn main() {
215 let x = io::stdin();
216 println!("{}", x);
217 }
218 "#,
219 )
220 .unwrap();
221
222 let idents = ast.collect_used_identifiers();
223 assert!(idents.contains("io"));
224 assert!(idents.contains("x"));
225 assert!(idents.contains("println"));
226 }
227}