1use crate::error::AstError;
8use crate::error::AstResult;
9use crate::types::AstNode;
10use crate::types::ParsedAst;
11use dashmap::DashMap;
12use std::path::Path;
13use tree_sitter::Parser;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
17pub enum Language {
18 Rust,
20 Python,
21 JavaScript,
22 TypeScript,
23 Go,
24 Java,
25 C,
26 Cpp,
27 CSharp,
28 Bash,
30 Ruby,
31 Php,
32 Lua,
33 Haskell,
35 Elixir,
36 Scala,
37 OCaml,
38 Clojure,
39 Zig,
41 Swift,
42 Kotlin,
43 ObjectiveC,
44 R,
46 Julia,
47 Dart,
49 Wgsl,
51 Glsl,
52}
53
54impl Language {
55 pub fn parser(&self) -> tree_sitter::Language {
57 match self {
58 Self::Rust => tree_sitter_rust::LANGUAGE.into(),
60 Self::Python => tree_sitter_python::LANGUAGE.into(),
61 Self::JavaScript => tree_sitter_javascript::LANGUAGE.into(),
62 Self::TypeScript => tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
63 Self::Go => tree_sitter_go::LANGUAGE.into(),
64 Self::Java => tree_sitter_java::LANGUAGE.into(),
65 Self::C => tree_sitter_c::LANGUAGE.into(),
66 Self::Cpp => tree_sitter_cpp::LANGUAGE.into(),
67 Self::CSharp => tree_sitter_c_sharp::LANGUAGE.into(),
68 Self::Bash => tree_sitter_bash::LANGUAGE.into(),
70 Self::Ruby => tree_sitter_ruby::LANGUAGE.into(),
71 Self::Php => tree_sitter_php::LANGUAGE_PHP.into(),
72 Self::Lua => tree_sitter_lua::LANGUAGE.into(),
73 Self::Haskell => tree_sitter_haskell::LANGUAGE.into(),
75 Self::Elixir => tree_sitter_elixir::LANGUAGE.into(),
76 Self::Scala => tree_sitter_scala::LANGUAGE.into(),
77 Self::OCaml => tree_sitter_ocaml::LANGUAGE_OCAML.into(),
78 Self::Clojure => tree_sitter_clojure::LANGUAGE.into(),
79 Self::Zig => tree_sitter_zig::LANGUAGE.into(),
81 Self::Swift => tree_sitter_swift::LANGUAGE.into(),
82 Self::Kotlin => tree_sitter_kotlin_ng::LANGUAGE.into(),
83 Self::ObjectiveC => tree_sitter_objc::LANGUAGE.into(),
84 Self::R => tree_sitter_bash::LANGUAGE.into(), Self::Julia => tree_sitter_bash::LANGUAGE.into(), Self::Dart => tree_sitter_bash::LANGUAGE.into(), Self::Wgsl => tree_sitter_bash::LANGUAGE.into(), Self::Glsl => tree_sitter_bash::LANGUAGE.into(), }
91 }
92
93 pub const fn is_fallback(&self) -> bool {
95 matches!(
96 self,
97 Self::R | Self::Julia | Self::Dart | Self::Wgsl | Self::Glsl
98 )
99 }
100
101 pub const fn actual_parser_name(&self) -> &'static str {
103 if !self.is_fallback() {
104 return self.name();
105 }
106
107 "Bash (fallback)"
109 }
110
111 pub const fn name(&self) -> &'static str {
113 match self {
114 Self::Rust => "Rust",
115 Self::Python => "Python",
116 Self::JavaScript => "JavaScript",
117 Self::TypeScript => "TypeScript",
118 Self::Go => "Go",
119 Self::Java => "Java",
120 Self::C => "C",
121 Self::Cpp => "C++",
122 Self::CSharp => "C#",
123 Self::Bash => "Bash",
124 Self::Ruby => "Ruby",
125 Self::Php => "PHP",
126 Self::Lua => "Lua",
127 Self::Haskell => "Haskell",
128 Self::Elixir => "Elixir",
129 Self::Scala => "Scala",
130 Self::OCaml => "OCaml",
131 Self::Clojure => "Clojure",
132 Self::Zig => "Zig",
133 Self::Swift => "Swift",
134 Self::Kotlin => "Kotlin",
135 Self::ObjectiveC => "Objective-C",
136 Self::R => "R",
137 Self::Julia => "Julia",
138 Self::Dart => "Dart",
139 Self::Wgsl => "WGSL",
140 Self::Glsl => "GLSL",
141 }
142 }
143}
144
145pub struct LanguageRegistry {
147 parsers: DashMap<Language, Parser>,
148}
149
150impl std::fmt::Debug for LanguageRegistry {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 f.debug_struct("LanguageRegistry")
153 .field("parsers_count", &self.parsers.len())
154 .finish()
155 }
156}
157
158impl LanguageRegistry {
159 pub fn new() -> Self {
161 let registry = Self {
162 parsers: DashMap::new(),
163 };
164
165 let common_langs = [
167 Language::Rust,
168 Language::Python,
169 Language::JavaScript,
170 Language::TypeScript,
171 Language::Go,
172 Language::Java,
173 Language::C,
174 Language::Cpp,
175 ];
176
177 for lang in common_langs {
178 let _ = registry.get_or_create_parser(lang);
179 }
180
181 registry
182 }
183
184 pub fn detect_language(&self, path: &Path) -> AstResult<Language> {
187 let extension = path
188 .extension()
189 .and_then(|e| e.to_str())
190 .ok_or_else(|| AstError::LanguageDetectionFailed(path.display().to_string()))?;
191
192 let lang = match extension {
193 "rs" => Language::Rust,
195 "py" | "pyi" => Language::Python,
196 "js" | "mjs" | "cjs" => Language::JavaScript,
197 "ts" | "mts" | "cts" => Language::TypeScript,
198 "tsx" | "jsx" => Language::TypeScript,
199 "go" => Language::Go,
200 "java" => Language::Java,
201 "c" | "h" => Language::C,
202 "cpp" | "cc" | "cxx" | "hpp" | "hxx" | "c++" => Language::Cpp,
203 "cs" => Language::CSharp,
204 "sh" | "bash" | "zsh" | "fish" => Language::Bash,
205 "rb" => Language::Ruby,
206 "php" => Language::Php,
207 "lua" => Language::Lua,
208 "hs" | "lhs" => Language::Haskell,
209 "ex" | "exs" => Language::Elixir,
210 "scala" | "sc" => Language::Scala,
211 "ml" | "mli" => Language::OCaml,
212 "clj" | "cljs" | "cljc" => Language::Clojure,
213 "zig" => Language::Zig,
214 "swift" => Language::Swift,
215 "kt" | "kts" => Language::Kotlin,
216 "m" | "mm" => Language::ObjectiveC,
217 "r" | "R" => Language::R,
218 "jl" => Language::Julia,
219 "dart" => Language::Dart,
220 "wgsl" => Language::Wgsl,
221 "glsl" | "vert" | "frag" => Language::Glsl,
222
223 "toml" | "yaml" | "yml" | "json" | "jsonc" | "xml" | "html" | "htm" | "css"
225 | "scss" | "sass" | "less" | "md" | "markdown" | "tex" | "latex" | "rst" | "sql"
226 | "graphql" | "gql" | "proto" | "dockerfile" | "makefile" | "cmake" | "hcl" | "tf"
227 | "tfvars" | "nix" => {
228 return Err(AstError::UnsupportedLanguage(format!(
229 "{} files should use patch-based editing, not AST operations",
230 extension
231 )));
232 }
233
234 _ => return Err(AstError::UnsupportedLanguage(extension.to_string())),
235 };
236
237 Ok(lang)
238 }
239
240 fn get_or_create_parser(&self, language: Language) -> Parser {
242 let mut parser = Parser::new();
244 let _ = parser.set_language(&language.parser());
246 parser
247 }
248
249 pub fn parse(&self, language: &Language, source: &str) -> AstResult<ParsedAst> {
251 let mut parser = self.get_or_create_parser(*language);
252
253 let tree = parser
254 .parse(source, None)
255 .ok_or_else(|| AstError::ParserError("Failed to parse source code".to_string()))?;
256
257 let root = tree.root_node();
258 let root_node = AstNode {
259 kind: root.kind().to_string(),
260 start_byte: root.start_byte(),
261 end_byte: root.end_byte(),
262 start_position: (root.start_position().row, root.start_position().column),
263 end_position: (root.end_position().row, root.end_position().column),
264 children_count: root.child_count(),
265 };
266
267 Ok(ParsedAst {
268 tree,
269 source: source.to_string(),
270 language: *language,
271 root_node,
272 })
273 }
274
275 pub fn stats(&self) -> LanguageRegistryStats {
277 LanguageRegistryStats {
278 loaded_parsers: self.parsers.len(),
279 total_languages: 27, }
281 }
282}
283
284impl Default for LanguageRegistry {
285 fn default() -> Self {
286 Self::new()
287 }
288}
289
290#[derive(Debug, Clone)]
292pub struct LanguageRegistryStats {
293 pub loaded_parsers: usize,
294 pub total_languages: usize,
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use std::path::PathBuf;
301
302 #[test]
303 fn test_language_detection() {
304 let registry = LanguageRegistry::new();
305
306 assert_eq!(
308 registry.detect_language(&PathBuf::from("test.rs")).unwrap(),
309 Language::Rust
310 );
311 assert_eq!(
312 registry.detect_language(&PathBuf::from("test.py")).unwrap(),
313 Language::Python
314 );
315 assert_eq!(
316 registry.detect_language(&PathBuf::from("test.js")).unwrap(),
317 Language::JavaScript
318 );
319 assert_eq!(
320 registry.detect_language(&PathBuf::from("test.ts")).unwrap(),
321 Language::TypeScript
322 );
323
324 assert!(
326 registry
327 .detect_language(&PathBuf::from("test.toml"))
328 .is_err()
329 );
330 assert!(
331 registry
332 .detect_language(&PathBuf::from("test.yaml"))
333 .is_err()
334 );
335 assert!(registry.detect_language(&PathBuf::from("test.md")).is_err());
336 assert!(
337 registry
338 .detect_language(&PathBuf::from("Dockerfile"))
339 .is_err()
340 );
341 }
342
343 #[test]
344 fn test_parsing() {
345 let registry = LanguageRegistry::new();
346
347 let rust_code = r#"
348fn main() {
349 println!("Hello, world!");
350}
351"#;
352
353 let ast = registry.parse(&Language::Rust, rust_code).unwrap();
354 assert_eq!(ast.language, Language::Rust);
355 assert_eq!(ast.root_node.kind, "source_file");
356 assert!(ast.root_node.children_count > 0);
357 }
358
359 #[test]
360 fn test_fallback_detection() {
361 assert!(!Language::Rust.is_fallback());
362 assert!(!Language::Python.is_fallback());
363 assert!(Language::R.is_fallback());
364 assert!(Language::Julia.is_fallback());
365 assert!(Language::Wgsl.is_fallback());
366 }
367}