1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use anyhow::{Context, anyhow, bail};
6
7use crate::classify::Classifier;
8use crate::extract::ExtractionResult;
9use crate::graph::Graph;
10use crate::module::ModuleMap;
11
12#[derive(Debug, Clone)]
13pub struct ProjectContext {
14 pub input_path: PathBuf,
15 pub project_root: PathBuf,
16 pub index_store_enabled: bool,
17}
18
19impl ProjectContext {
20 pub fn new(input_path: &Path) -> Self {
21 Self {
22 input_path: input_path.to_path_buf(),
23 project_root: std::fs::canonicalize(input_path)
24 .unwrap_or_else(|_| input_path.to_path_buf()),
25 index_store_enabled: true,
26 }
27 }
28
29 pub fn is_single_file(&self) -> bool {
30 self.project_root.is_file()
31 }
32}
33
34#[derive(Debug, Clone)]
35pub struct FileContext {
36 pub input_path: PathBuf,
37 pub project_root: PathBuf,
38 pub relative_path: PathBuf,
39 pub absolute_path: PathBuf,
40 pub module_name: Option<String>,
41 pub index_store_enabled: bool,
42}
43
44pub trait GraphPass: Send + Sync {
45 fn apply(&self, graph: Graph) -> Graph;
46}
47
48pub trait LanguagePlugin: Send + Sync {
49 fn id(&self) -> &'static str;
50 fn extensions(&self) -> &'static [&'static str];
51
52 fn prepare_project(&self, _context: &ProjectContext) -> anyhow::Result<()> {
53 Ok(())
54 }
55
56 fn finish_project(&self, _context: &ProjectContext) -> anyhow::Result<()> {
57 Ok(())
58 }
59
60 fn discover_modules(&self, _context: &ProjectContext) -> anyhow::Result<ModuleMap> {
61 Ok(ModuleMap::new())
62 }
63
64 fn extract(&self, source: &[u8], context: &FileContext) -> anyhow::Result<ExtractionResult>;
65
66 fn stamp_module(
67 &self,
68 result: ExtractionResult,
69 module_name: Option<&str>,
70 ) -> ExtractionResult {
71 crate::pipeline::stamp_module(result, module_name)
72 }
73
74 fn classifiers(&self) -> Vec<Box<dyn Classifier>> {
75 Vec::new()
76 }
77
78 fn graph_passes(&self) -> Vec<Box<dyn GraphPass>> {
79 Vec::new()
80 }
81}
82
83pub struct LanguageRegistry {
84 plugins: Vec<Arc<dyn LanguagePlugin>>,
85 plugins_by_extension: HashMap<String, Arc<dyn LanguagePlugin>>,
86}
87
88impl LanguageRegistry {
89 pub fn new() -> Self {
90 Self {
91 plugins: Vec::new(),
92 plugins_by_extension: HashMap::new(),
93 }
94 }
95
96 pub fn register<P>(&mut self, plugin: P) -> anyhow::Result<()>
97 where
98 P: LanguagePlugin + 'static,
99 {
100 let plugin = Arc::new(plugin) as Arc<dyn LanguagePlugin>;
101 for extension in plugin.extensions() {
102 if let Some(existing) = self.plugins_by_extension.get(*extension) {
103 bail!(
104 "language plugin '{}' conflicts with '{}' for extension '{}'",
105 plugin.id(),
106 existing.id(),
107 extension
108 );
109 }
110 }
111
112 for extension in plugin.extensions() {
113 self.plugins_by_extension
114 .insert((*extension).to_string(), Arc::clone(&plugin));
115 }
116 self.plugins.push(plugin);
117 Ok(())
118 }
119
120 pub fn supported_extensions(&self) -> Vec<String> {
121 let mut extensions: Vec<_> = self.plugins_by_extension.keys().cloned().collect();
122 extensions.sort();
123 extensions
124 }
125
126 pub fn plugin_for_extension(&self, extension: &str) -> Option<Arc<dyn LanguagePlugin>> {
127 self.plugins_by_extension.get(extension).cloned()
128 }
129
130 pub fn plugin_for_path(&self, path: &Path) -> anyhow::Result<Arc<dyn LanguagePlugin>> {
131 let extension = path
132 .extension()
133 .and_then(|ext| ext.to_str())
134 .ok_or_else(|| anyhow!("unsupported language for file: {}", path.display()))?;
135 self.plugin_for_extension(extension)
136 .ok_or_else(|| anyhow!("unsupported language for file: {}", path.display()))
137 }
138
139 pub fn plugins(&self) -> &[Arc<dyn LanguagePlugin>] {
140 &self.plugins
141 }
142
143 pub fn collect_classifiers(&self) -> Vec<Box<dyn Classifier>> {
144 self.plugins
145 .iter()
146 .flat_map(|plugin| plugin.classifiers())
147 .collect()
148 }
149
150 pub fn collect_graph_passes(&self) -> Vec<Box<dyn GraphPass>> {
151 self.plugins
152 .iter()
153 .flat_map(|plugin| plugin.graph_passes())
154 .collect()
155 }
156
157 pub fn prepare_plugins(&self, context: &ProjectContext) -> anyhow::Result<()> {
158 for plugin in &self.plugins {
159 plugin
160 .prepare_project(context)
161 .with_context(|| format!("failed to prepare plugin '{}'", plugin.id()))?;
162 }
163 Ok(())
164 }
165
166 pub fn finish_plugins(&self, context: &ProjectContext) -> anyhow::Result<()> {
167 for plugin in &self.plugins {
168 plugin
169 .finish_project(context)
170 .with_context(|| format!("failed to finish plugin '{}'", plugin.id()))?;
171 }
172 Ok(())
173 }
174}
175
176impl Default for LanguageRegistry {
177 fn default() -> Self {
178 Self::new()
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use crate::extract::ExtractionResult;
186
187 struct TestPlugin {
188 id: &'static str,
189 exts: &'static [&'static str],
190 }
191
192 impl LanguagePlugin for TestPlugin {
193 fn id(&self) -> &'static str {
194 self.id
195 }
196
197 fn extensions(&self) -> &'static [&'static str] {
198 self.exts
199 }
200
201 fn extract(
202 &self,
203 _source: &[u8],
204 _context: &FileContext,
205 ) -> anyhow::Result<ExtractionResult> {
206 Ok(ExtractionResult::new())
207 }
208 }
209
210 #[test]
211 fn rejects_duplicate_extensions() {
212 let mut registry = LanguageRegistry::new();
213 registry
214 .register(TestPlugin {
215 id: "first",
216 exts: &["rs"],
217 })
218 .unwrap();
219
220 let error = registry
221 .register(TestPlugin {
222 id: "second",
223 exts: &["rs"],
224 })
225 .unwrap_err();
226
227 assert!(error.to_string().contains("conflicts"));
228 }
229}