1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use anyhow::{Context, anyhow, bail};
6
7use crate::classify::Classifier;
8use crate::extract::ExtractionResult;
9use crate::graph::Graph;
10use crate::module::ModuleMap;
11
12#[derive(Debug, Clone)]
13pub struct ProjectContext {
14 pub input_path: PathBuf,
15 pub project_root: PathBuf,
16}
17
18impl ProjectContext {
19 pub fn new(input_path: &Path) -> Self {
20 Self {
21 input_path: input_path.to_path_buf(),
22 project_root: std::fs::canonicalize(input_path)
23 .unwrap_or_else(|_| input_path.to_path_buf()),
24 }
25 }
26
27 pub fn is_single_file(&self) -> bool {
28 self.project_root.is_file()
29 }
30}
31
32#[derive(Debug, Clone)]
33pub struct FileContext {
34 pub input_path: PathBuf,
35 pub project_root: PathBuf,
36 pub relative_path: PathBuf,
37 pub absolute_path: PathBuf,
38 pub module_name: Option<String>,
39}
40
41pub trait GraphPass: Send + Sync {
42 fn apply(&self, graph: Graph) -> Graph;
43}
44
45pub trait LanguagePlugin: Send + Sync {
46 fn id(&self) -> &'static str;
47 fn extensions(&self) -> &'static [&'static str];
48
49 fn prepare_project(&self, _context: &ProjectContext) -> anyhow::Result<()> {
50 Ok(())
51 }
52
53 fn discover_modules(&self, _context: &ProjectContext) -> anyhow::Result<ModuleMap> {
54 Ok(ModuleMap::new())
55 }
56
57 fn extract(&self, source: &[u8], context: &FileContext) -> anyhow::Result<ExtractionResult>;
58
59 fn stamp_module(
60 &self,
61 result: ExtractionResult,
62 module_name: Option<&str>,
63 ) -> ExtractionResult {
64 crate::pipeline::stamp_module(result, module_name)
65 }
66
67 fn classifiers(&self) -> Vec<Box<dyn Classifier>> {
68 Vec::new()
69 }
70
71 fn graph_passes(&self) -> Vec<Box<dyn GraphPass>> {
72 Vec::new()
73 }
74}
75
76pub struct LanguageRegistry {
77 plugins: Vec<Arc<dyn LanguagePlugin>>,
78 plugins_by_extension: HashMap<String, Arc<dyn LanguagePlugin>>,
79}
80
81impl LanguageRegistry {
82 pub fn new() -> Self {
83 Self {
84 plugins: Vec::new(),
85 plugins_by_extension: HashMap::new(),
86 }
87 }
88
89 pub fn register<P>(&mut self, plugin: P) -> anyhow::Result<()>
90 where
91 P: LanguagePlugin + 'static,
92 {
93 let plugin = Arc::new(plugin) as Arc<dyn LanguagePlugin>;
94 for extension in plugin.extensions() {
95 if let Some(existing) = self.plugins_by_extension.get(*extension) {
96 bail!(
97 "language plugin '{}' conflicts with '{}' for extension '{}'",
98 plugin.id(),
99 existing.id(),
100 extension
101 );
102 }
103 }
104
105 for extension in plugin.extensions() {
106 self.plugins_by_extension
107 .insert((*extension).to_string(), Arc::clone(&plugin));
108 }
109 self.plugins.push(plugin);
110 Ok(())
111 }
112
113 pub fn supported_extensions(&self) -> Vec<String> {
114 let mut extensions: Vec<_> = self.plugins_by_extension.keys().cloned().collect();
115 extensions.sort();
116 extensions
117 }
118
119 pub fn plugin_for_extension(&self, extension: &str) -> Option<Arc<dyn LanguagePlugin>> {
120 self.plugins_by_extension.get(extension).cloned()
121 }
122
123 pub fn plugin_for_path(&self, path: &Path) -> anyhow::Result<Arc<dyn LanguagePlugin>> {
124 let extension = path
125 .extension()
126 .and_then(|ext| ext.to_str())
127 .ok_or_else(|| anyhow!("unsupported language for file: {}", path.display()))?;
128 self.plugin_for_extension(extension)
129 .ok_or_else(|| anyhow!("unsupported language for file: {}", path.display()))
130 }
131
132 pub fn plugins(&self) -> &[Arc<dyn LanguagePlugin>] {
133 &self.plugins
134 }
135
136 pub fn collect_classifiers(&self) -> Vec<Box<dyn Classifier>> {
137 self.plugins
138 .iter()
139 .flat_map(|plugin| plugin.classifiers())
140 .collect()
141 }
142
143 pub fn collect_graph_passes(&self) -> Vec<Box<dyn GraphPass>> {
144 self.plugins
145 .iter()
146 .flat_map(|plugin| plugin.graph_passes())
147 .collect()
148 }
149
150 pub fn prepare_plugins(&self, context: &ProjectContext) -> anyhow::Result<()> {
151 for plugin in &self.plugins {
152 plugin
153 .prepare_project(context)
154 .with_context(|| format!("failed to prepare plugin '{}'", plugin.id()))?;
155 }
156 Ok(())
157 }
158}
159
160impl Default for LanguageRegistry {
161 fn default() -> Self {
162 Self::new()
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use crate::extract::ExtractionResult;
170
171 struct TestPlugin {
172 id: &'static str,
173 exts: &'static [&'static str],
174 }
175
176 impl LanguagePlugin for TestPlugin {
177 fn id(&self) -> &'static str {
178 self.id
179 }
180
181 fn extensions(&self) -> &'static [&'static str] {
182 self.exts
183 }
184
185 fn extract(
186 &self,
187 _source: &[u8],
188 _context: &FileContext,
189 ) -> anyhow::Result<ExtractionResult> {
190 Ok(ExtractionResult::new())
191 }
192 }
193
194 #[test]
195 fn rejects_duplicate_extensions() {
196 let mut registry = LanguageRegistry::new();
197 registry
198 .register(TestPlugin {
199 id: "first",
200 exts: &["rs"],
201 })
202 .unwrap();
203
204 let error = registry
205 .register(TestPlugin {
206 id: "second",
207 exts: &["rs"],
208 })
209 .unwrap_err();
210
211 assert!(error.to_string().contains("conflicts"));
212 }
213}