Skip to main content

grapha_core/
plugin.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use anyhow::{Context, anyhow, bail};
6
7use crate::classify::Classifier;
8use crate::extract::ExtractionResult;
9use crate::graph::Graph;
10use crate::module::ModuleMap;
11
12#[derive(Debug, Clone)]
13pub struct ProjectContext {
14    pub input_path: PathBuf,
15    pub project_root: PathBuf,
16    pub index_store_enabled: bool,
17}
18
19impl ProjectContext {
20    pub fn new(input_path: &Path) -> Self {
21        Self {
22            input_path: input_path.to_path_buf(),
23            project_root: std::fs::canonicalize(input_path)
24                .unwrap_or_else(|_| input_path.to_path_buf()),
25            index_store_enabled: true,
26        }
27    }
28
29    pub fn is_single_file(&self) -> bool {
30        self.project_root.is_file()
31    }
32}
33
34#[derive(Debug, Clone)]
35pub struct FileContext {
36    pub input_path: PathBuf,
37    pub project_root: PathBuf,
38    pub relative_path: PathBuf,
39    pub absolute_path: PathBuf,
40    pub module_name: Option<String>,
41    pub index_store_enabled: bool,
42}
43
44pub trait GraphPass: Send + Sync {
45    fn apply(&self, graph: Graph) -> Graph;
46}
47
48pub trait LanguagePlugin: Send + Sync {
49    fn id(&self) -> &'static str;
50    fn extensions(&self) -> &'static [&'static str];
51
52    fn prepare_project(&self, _context: &ProjectContext) -> anyhow::Result<()> {
53        Ok(())
54    }
55
56    fn finish_project(&self, _context: &ProjectContext) -> anyhow::Result<()> {
57        Ok(())
58    }
59
60    fn discover_modules(&self, _context: &ProjectContext) -> anyhow::Result<ModuleMap> {
61        Ok(ModuleMap::new())
62    }
63
64    fn extract(&self, source: &[u8], context: &FileContext) -> anyhow::Result<ExtractionResult>;
65
66    fn stamp_module(
67        &self,
68        result: ExtractionResult,
69        module_name: Option<&str>,
70    ) -> ExtractionResult {
71        crate::pipeline::stamp_module(result, module_name)
72    }
73
74    fn classifiers(&self) -> Vec<Box<dyn Classifier>> {
75        Vec::new()
76    }
77
78    fn graph_passes(&self) -> Vec<Box<dyn GraphPass>> {
79        Vec::new()
80    }
81}
82
83pub struct LanguageRegistry {
84    plugins: Vec<Arc<dyn LanguagePlugin>>,
85    plugins_by_extension: HashMap<String, Arc<dyn LanguagePlugin>>,
86}
87
88impl LanguageRegistry {
89    pub fn new() -> Self {
90        Self {
91            plugins: Vec::new(),
92            plugins_by_extension: HashMap::new(),
93        }
94    }
95
96    pub fn register<P>(&mut self, plugin: P) -> anyhow::Result<()>
97    where
98        P: LanguagePlugin + 'static,
99    {
100        let plugin = Arc::new(plugin) as Arc<dyn LanguagePlugin>;
101        for extension in plugin.extensions() {
102            if let Some(existing) = self.plugins_by_extension.get(*extension) {
103                bail!(
104                    "language plugin '{}' conflicts with '{}' for extension '{}'",
105                    plugin.id(),
106                    existing.id(),
107                    extension
108                );
109            }
110        }
111
112        for extension in plugin.extensions() {
113            self.plugins_by_extension
114                .insert((*extension).to_string(), Arc::clone(&plugin));
115        }
116        self.plugins.push(plugin);
117        Ok(())
118    }
119
120    pub fn supported_extensions(&self) -> Vec<String> {
121        let mut extensions: Vec<_> = self.plugins_by_extension.keys().cloned().collect();
122        extensions.sort();
123        extensions
124    }
125
126    pub fn plugin_for_extension(&self, extension: &str) -> Option<Arc<dyn LanguagePlugin>> {
127        self.plugins_by_extension.get(extension).cloned()
128    }
129
130    pub fn plugin_for_path(&self, path: &Path) -> anyhow::Result<Arc<dyn LanguagePlugin>> {
131        let extension = path
132            .extension()
133            .and_then(|ext| ext.to_str())
134            .ok_or_else(|| anyhow!("unsupported language for file: {}", path.display()))?;
135        self.plugin_for_extension(extension)
136            .ok_or_else(|| anyhow!("unsupported language for file: {}", path.display()))
137    }
138
139    pub fn plugins(&self) -> &[Arc<dyn LanguagePlugin>] {
140        &self.plugins
141    }
142
143    pub fn collect_classifiers(&self) -> Vec<Box<dyn Classifier>> {
144        self.plugins
145            .iter()
146            .flat_map(|plugin| plugin.classifiers())
147            .collect()
148    }
149
150    pub fn collect_graph_passes(&self) -> Vec<Box<dyn GraphPass>> {
151        self.plugins
152            .iter()
153            .flat_map(|plugin| plugin.graph_passes())
154            .collect()
155    }
156
157    pub fn prepare_plugins(&self, context: &ProjectContext) -> anyhow::Result<()> {
158        for plugin in &self.plugins {
159            plugin
160                .prepare_project(context)
161                .with_context(|| format!("failed to prepare plugin '{}'", plugin.id()))?;
162        }
163        Ok(())
164    }
165
166    pub fn finish_plugins(&self, context: &ProjectContext) -> anyhow::Result<()> {
167        for plugin in &self.plugins {
168            plugin
169                .finish_project(context)
170                .with_context(|| format!("failed to finish plugin '{}'", plugin.id()))?;
171        }
172        Ok(())
173    }
174}
175
176impl Default for LanguageRegistry {
177    fn default() -> Self {
178        Self::new()
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use crate::extract::ExtractionResult;
186
187    struct TestPlugin {
188        id: &'static str,
189        exts: &'static [&'static str],
190    }
191
192    impl LanguagePlugin for TestPlugin {
193        fn id(&self) -> &'static str {
194            self.id
195        }
196
197        fn extensions(&self) -> &'static [&'static str] {
198            self.exts
199        }
200
201        fn extract(
202            &self,
203            _source: &[u8],
204            _context: &FileContext,
205        ) -> anyhow::Result<ExtractionResult> {
206            Ok(ExtractionResult::new())
207        }
208    }
209
210    #[test]
211    fn rejects_duplicate_extensions() {
212        let mut registry = LanguageRegistry::new();
213        registry
214            .register(TestPlugin {
215                id: "first",
216                exts: &["rs"],
217            })
218            .unwrap();
219
220        let error = registry
221            .register(TestPlugin {
222                id: "second",
223                exts: &["rs"],
224            })
225            .unwrap_err();
226
227        assert!(error.to_string().contains("conflicts"));
228    }
229}