Skip to main content

grapha_core/
plugin.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use anyhow::{Context, anyhow, bail};
6
7use crate::classify::Classifier;
8use crate::extract::ExtractionResult;
9use crate::graph::Graph;
10use crate::module::ModuleMap;
11
12#[derive(Debug, Clone)]
13pub struct ProjectContext {
14    pub input_path: PathBuf,
15    pub project_root: PathBuf,
16}
17
18impl ProjectContext {
19    pub fn new(input_path: &Path) -> Self {
20        Self {
21            input_path: input_path.to_path_buf(),
22            project_root: std::fs::canonicalize(input_path)
23                .unwrap_or_else(|_| input_path.to_path_buf()),
24        }
25    }
26
27    pub fn is_single_file(&self) -> bool {
28        self.project_root.is_file()
29    }
30}
31
32#[derive(Debug, Clone)]
33pub struct FileContext {
34    pub input_path: PathBuf,
35    pub project_root: PathBuf,
36    pub relative_path: PathBuf,
37    pub absolute_path: PathBuf,
38    pub module_name: Option<String>,
39}
40
41pub trait GraphPass: Send + Sync {
42    fn apply(&self, graph: Graph) -> Graph;
43}
44
45pub trait LanguagePlugin: Send + Sync {
46    fn id(&self) -> &'static str;
47    fn extensions(&self) -> &'static [&'static str];
48
49    fn prepare_project(&self, _context: &ProjectContext) -> anyhow::Result<()> {
50        Ok(())
51    }
52
53    fn discover_modules(&self, _context: &ProjectContext) -> anyhow::Result<ModuleMap> {
54        Ok(ModuleMap::new())
55    }
56
57    fn extract(&self, source: &[u8], context: &FileContext) -> anyhow::Result<ExtractionResult>;
58
59    fn stamp_module(
60        &self,
61        result: ExtractionResult,
62        module_name: Option<&str>,
63    ) -> ExtractionResult {
64        crate::pipeline::stamp_module(result, module_name)
65    }
66
67    fn classifiers(&self) -> Vec<Box<dyn Classifier>> {
68        Vec::new()
69    }
70
71    fn graph_passes(&self) -> Vec<Box<dyn GraphPass>> {
72        Vec::new()
73    }
74}
75
76pub struct LanguageRegistry {
77    plugins: Vec<Arc<dyn LanguagePlugin>>,
78    plugins_by_extension: HashMap<String, Arc<dyn LanguagePlugin>>,
79}
80
81impl LanguageRegistry {
82    pub fn new() -> Self {
83        Self {
84            plugins: Vec::new(),
85            plugins_by_extension: HashMap::new(),
86        }
87    }
88
89    pub fn register<P>(&mut self, plugin: P) -> anyhow::Result<()>
90    where
91        P: LanguagePlugin + 'static,
92    {
93        let plugin = Arc::new(plugin) as Arc<dyn LanguagePlugin>;
94        for extension in plugin.extensions() {
95            if let Some(existing) = self.plugins_by_extension.get(*extension) {
96                bail!(
97                    "language plugin '{}' conflicts with '{}' for extension '{}'",
98                    plugin.id(),
99                    existing.id(),
100                    extension
101                );
102            }
103        }
104
105        for extension in plugin.extensions() {
106            self.plugins_by_extension
107                .insert((*extension).to_string(), Arc::clone(&plugin));
108        }
109        self.plugins.push(plugin);
110        Ok(())
111    }
112
113    pub fn supported_extensions(&self) -> Vec<String> {
114        let mut extensions: Vec<_> = self.plugins_by_extension.keys().cloned().collect();
115        extensions.sort();
116        extensions
117    }
118
119    pub fn plugin_for_extension(&self, extension: &str) -> Option<Arc<dyn LanguagePlugin>> {
120        self.plugins_by_extension.get(extension).cloned()
121    }
122
123    pub fn plugin_for_path(&self, path: &Path) -> anyhow::Result<Arc<dyn LanguagePlugin>> {
124        let extension = path
125            .extension()
126            .and_then(|ext| ext.to_str())
127            .ok_or_else(|| anyhow!("unsupported language for file: {}", path.display()))?;
128        self.plugin_for_extension(extension)
129            .ok_or_else(|| anyhow!("unsupported language for file: {}", path.display()))
130    }
131
132    pub fn plugins(&self) -> &[Arc<dyn LanguagePlugin>] {
133        &self.plugins
134    }
135
136    pub fn collect_classifiers(&self) -> Vec<Box<dyn Classifier>> {
137        self.plugins
138            .iter()
139            .flat_map(|plugin| plugin.classifiers())
140            .collect()
141    }
142
143    pub fn collect_graph_passes(&self) -> Vec<Box<dyn GraphPass>> {
144        self.plugins
145            .iter()
146            .flat_map(|plugin| plugin.graph_passes())
147            .collect()
148    }
149
150    pub fn prepare_plugins(&self, context: &ProjectContext) -> anyhow::Result<()> {
151        for plugin in &self.plugins {
152            plugin
153                .prepare_project(context)
154                .with_context(|| format!("failed to prepare plugin '{}'", plugin.id()))?;
155        }
156        Ok(())
157    }
158}
159
160impl Default for LanguageRegistry {
161    fn default() -> Self {
162        Self::new()
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use crate::extract::ExtractionResult;
170
171    struct TestPlugin {
172        id: &'static str,
173        exts: &'static [&'static str],
174    }
175
176    impl LanguagePlugin for TestPlugin {
177        fn id(&self) -> &'static str {
178            self.id
179        }
180
181        fn extensions(&self) -> &'static [&'static str] {
182            self.exts
183        }
184
185        fn extract(
186            &self,
187            _source: &[u8],
188            _context: &FileContext,
189        ) -> anyhow::Result<ExtractionResult> {
190            Ok(ExtractionResult::new())
191        }
192    }
193
194    #[test]
195    fn rejects_duplicate_extensions() {
196        let mut registry = LanguageRegistry::new();
197        registry
198            .register(TestPlugin {
199                id: "first",
200                exts: &["rs"],
201            })
202            .unwrap();
203
204        let error = registry
205            .register(TestPlugin {
206                id: "second",
207                exts: &["rs"],
208            })
209            .unwrap_err();
210
211        assert!(error.to_string().contains("conflicts"));
212    }
213}