Skip to main content

grapha_core/
plugin.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use anyhow::{Context, anyhow, bail};
6
7use crate::classify::Classifier;
8use crate::extract::ExtractionResult;
9use crate::graph::Graph;
10use crate::module::ModuleMap;
11use crate::semantic::SemanticDocument;
12
13#[derive(Debug, Clone)]
14pub struct ProjectContext {
15    pub input_path: PathBuf,
16    pub project_root: PathBuf,
17    pub index_store_enabled: bool,
18}
19
20impl ProjectContext {
21    pub fn new(input_path: &Path) -> Self {
22        Self {
23            input_path: input_path.to_path_buf(),
24            project_root: std::fs::canonicalize(input_path)
25                .unwrap_or_else(|_| input_path.to_path_buf()),
26            index_store_enabled: true,
27        }
28    }
29
30    pub fn is_single_file(&self) -> bool {
31        self.project_root.is_file()
32    }
33}
34
35#[derive(Debug, Clone)]
36pub struct FileContext {
37    pub input_path: PathBuf,
38    pub project_root: PathBuf,
39    pub relative_path: PathBuf,
40    pub absolute_path: PathBuf,
41    pub module_name: Option<String>,
42    pub index_store_enabled: bool,
43}
44
45pub trait GraphPass: Send + Sync {
46    fn apply(&self, graph: Graph) -> Graph;
47}
48
49pub trait LanguagePlugin: Send + Sync {
50    fn id(&self) -> &'static str;
51    fn extensions(&self) -> &'static [&'static str];
52
53    fn prepare_project(&self, _context: &ProjectContext) -> anyhow::Result<()> {
54        Ok(())
55    }
56
57    fn finish_project(&self, _context: &ProjectContext) -> anyhow::Result<()> {
58        Ok(())
59    }
60
61    fn discover_modules(&self, _context: &ProjectContext) -> anyhow::Result<ModuleMap> {
62        Ok(ModuleMap::new())
63    }
64
65    fn extract(&self, source: &[u8], context: &FileContext) -> anyhow::Result<ExtractionResult>;
66
67    fn extract_semantics(
68        &self,
69        source: &[u8],
70        context: &FileContext,
71    ) -> anyhow::Result<SemanticDocument> {
72        Ok(SemanticDocument::from_extraction_result(
73            self.extract(source, context)?,
74        ))
75    }
76
77    fn stamp_module(
78        &self,
79        result: ExtractionResult,
80        module_name: Option<&str>,
81    ) -> ExtractionResult {
82        crate::pipeline::stamp_module(result, module_name)
83    }
84
85    fn stamp_semantic_module(
86        &self,
87        document: SemanticDocument,
88        module_name: Option<&str>,
89    ) -> SemanticDocument {
90        crate::pipeline::stamp_semantic_module(document, module_name)
91    }
92
93    fn classifiers(&self) -> Vec<Box<dyn Classifier>> {
94        Vec::new()
95    }
96
97    fn graph_passes(&self) -> Vec<Box<dyn GraphPass>> {
98        Vec::new()
99    }
100}
101
102pub struct LanguageRegistry {
103    plugins: Vec<Arc<dyn LanguagePlugin>>,
104    plugins_by_extension: HashMap<String, Arc<dyn LanguagePlugin>>,
105}
106
107impl LanguageRegistry {
108    pub fn new() -> Self {
109        Self {
110            plugins: Vec::new(),
111            plugins_by_extension: HashMap::new(),
112        }
113    }
114
115    pub fn register<P>(&mut self, plugin: P) -> anyhow::Result<()>
116    where
117        P: LanguagePlugin + 'static,
118    {
119        let plugin = Arc::new(plugin) as Arc<dyn LanguagePlugin>;
120        for extension in plugin.extensions() {
121            if let Some(existing) = self.plugins_by_extension.get(*extension) {
122                bail!(
123                    "language plugin '{}' conflicts with '{}' for extension '{}'",
124                    plugin.id(),
125                    existing.id(),
126                    extension
127                );
128            }
129        }
130
131        for extension in plugin.extensions() {
132            self.plugins_by_extension
133                .insert((*extension).to_string(), Arc::clone(&plugin));
134        }
135        self.plugins.push(plugin);
136        Ok(())
137    }
138
139    pub fn supported_extensions(&self) -> Vec<String> {
140        let mut extensions: Vec<_> = self.plugins_by_extension.keys().cloned().collect();
141        extensions.sort();
142        extensions
143    }
144
145    pub fn plugin_for_extension(&self, extension: &str) -> Option<Arc<dyn LanguagePlugin>> {
146        self.plugins_by_extension.get(extension).cloned()
147    }
148
149    pub fn plugin_for_path(&self, path: &Path) -> anyhow::Result<Arc<dyn LanguagePlugin>> {
150        let extension = path
151            .extension()
152            .and_then(|ext| ext.to_str())
153            .ok_or_else(|| anyhow!("unsupported language for file: {}", path.display()))?;
154        self.plugin_for_extension(extension)
155            .ok_or_else(|| anyhow!("unsupported language for file: {}", path.display()))
156    }
157
158    pub fn plugins(&self) -> &[Arc<dyn LanguagePlugin>] {
159        &self.plugins
160    }
161
162    pub fn collect_classifiers(&self) -> Vec<Box<dyn Classifier>> {
163        self.plugins
164            .iter()
165            .flat_map(|plugin| plugin.classifiers())
166            .collect()
167    }
168
169    pub fn collect_graph_passes(&self) -> Vec<Box<dyn GraphPass>> {
170        self.plugins
171            .iter()
172            .flat_map(|plugin| plugin.graph_passes())
173            .collect()
174    }
175
176    pub fn prepare_plugins(&self, context: &ProjectContext) -> anyhow::Result<()> {
177        for plugin in &self.plugins {
178            plugin
179                .prepare_project(context)
180                .with_context(|| format!("failed to prepare plugin '{}'", plugin.id()))?;
181        }
182        Ok(())
183    }
184
185    pub fn finish_plugins(&self, context: &ProjectContext) -> anyhow::Result<()> {
186        for plugin in &self.plugins {
187            plugin
188                .finish_project(context)
189                .with_context(|| format!("failed to finish plugin '{}'", plugin.id()))?;
190        }
191        Ok(())
192    }
193}
194
195impl Default for LanguageRegistry {
196    fn default() -> Self {
197        Self::new()
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use crate::extract::ExtractionResult;
205
206    struct TestPlugin {
207        id: &'static str,
208        exts: &'static [&'static str],
209    }
210
211    impl LanguagePlugin for TestPlugin {
212        fn id(&self) -> &'static str {
213            self.id
214        }
215
216        fn extensions(&self) -> &'static [&'static str] {
217            self.exts
218        }
219
220        fn extract(
221            &self,
222            _source: &[u8],
223            _context: &FileContext,
224        ) -> anyhow::Result<ExtractionResult> {
225            Ok(ExtractionResult::new())
226        }
227    }
228
229    #[test]
230    fn rejects_duplicate_extensions() {
231        let mut registry = LanguageRegistry::new();
232        registry
233            .register(TestPlugin {
234                id: "first",
235                exts: &["rs"],
236            })
237            .unwrap();
238
239        let error = registry
240            .register(TestPlugin {
241                id: "second",
242                exts: &["rs"],
243            })
244            .unwrap_err();
245
246        assert!(error.to_string().contains("conflicts"));
247    }
248}