Skip to main content

grapha_core/
pipeline.rs

1use std::path::{Path, PathBuf};
2
3use anyhow::Context;
4
5use crate::classify::{CompositeClassifier, classify_extraction_result, classify_graph};
6use crate::discover;
7use crate::extract::ExtractionResult;
8use crate::graph::Graph;
9use crate::merge;
10use crate::module::ModuleMap;
11use crate::normalize::normalize_graph;
12use crate::plugin::{FileContext, GraphPass, LanguageRegistry, ProjectContext};
13
14pub fn project_context(path: &Path) -> ProjectContext {
15    ProjectContext::new(path)
16}
17
18pub fn discover_files(path: &Path, registry: &LanguageRegistry) -> anyhow::Result<Vec<PathBuf>> {
19    discover::discover_files(path, &registry.supported_extensions())
20}
21
22pub fn relative_path_for_input(input_path: &Path, file: &Path) -> PathBuf {
23    if input_path.is_dir() {
24        file.strip_prefix(input_path).unwrap_or(file).to_path_buf()
25    } else {
26        file.file_name()
27            .map(PathBuf::from)
28            .unwrap_or_else(|| file.to_path_buf())
29    }
30}
31
32pub fn prepare_plugins(
33    registry: &LanguageRegistry,
34    context: &ProjectContext,
35) -> anyhow::Result<()> {
36    registry.prepare_plugins(context)
37}
38
39pub fn finish_plugins(registry: &LanguageRegistry, context: &ProjectContext) -> anyhow::Result<()> {
40    registry.finish_plugins(context)
41}
42
43pub fn discover_modules(
44    registry: &LanguageRegistry,
45    context: &ProjectContext,
46) -> anyhow::Result<ModuleMap> {
47    let mut modules = ModuleMap::new();
48    for plugin in registry.plugins() {
49        modules.merge(
50            plugin
51                .discover_modules(context)
52                .with_context(|| format!("failed to discover modules for '{}'", plugin.id()))?,
53        );
54    }
55    Ok(modules.with_fallback(&context.project_root))
56}
57
58pub fn file_context(context: &ProjectContext, modules: &ModuleMap, file: &Path) -> FileContext {
59    let relative_path = relative_path_for_input(&context.input_path, file);
60    let absolute_path =
61        std::fs::canonicalize(file).unwrap_or_else(|_| context.project_root.join(&relative_path));
62    let module_name = modules.module_for_file(&absolute_path).or_else(|| {
63        relative_path
64            .components()
65            .next()
66            .and_then(|component| component.as_os_str().to_str())
67            .map(|segment| segment.to_string())
68    });
69
70    FileContext {
71        input_path: context.input_path.clone(),
72        project_root: context.project_root.clone(),
73        relative_path,
74        absolute_path,
75        module_name,
76        index_store_enabled: context.index_store_enabled,
77    }
78}
79
80pub fn extract_with_registry(
81    registry: &LanguageRegistry,
82    source: &[u8],
83    context: &FileContext,
84) -> anyhow::Result<ExtractionResult> {
85    let plugin = registry.plugin_for_path(&context.relative_path)?;
86    let result = plugin.extract(source, context)?;
87    Ok(plugin.stamp_module(result, context.module_name.as_deref()))
88}
89
90pub fn stamp_module(result: ExtractionResult, module_name: Option<&str>) -> ExtractionResult {
91    let Some(module_name) = module_name else {
92        return result;
93    };
94
95    let nodes = result
96        .nodes
97        .into_iter()
98        .map(|mut node| {
99            node.module = Some(module_name.to_string());
100            node
101        })
102        .collect();
103
104    ExtractionResult {
105        nodes,
106        edges: result.edges,
107        imports: result.imports,
108    }
109}
110
111pub fn build_graph(
112    results: Vec<ExtractionResult>,
113    classifier: &CompositeClassifier,
114    graph_passes: &[Box<dyn GraphPass>],
115) -> Graph {
116    let preclassified_results = results
117        .into_iter()
118        .map(|result| classify_extraction_result(result, classifier))
119        .collect();
120    let mut graph = classify_graph(&merge::merge(preclassified_results), classifier);
121    for pass in graph_passes {
122        graph = pass.apply(graph);
123    }
124    normalize_graph(graph)
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use crate::classify::{Classification, Classifier};
131    use crate::extract::ExtractionResult;
132    use crate::graph::{
133        Edge, EdgeKind, FlowDirection, Graph, Node, NodeKind, NodeRole, Span, TerminalKind,
134        Visibility,
135    };
136    use crate::plugin::{GraphPass, LanguagePlugin};
137    use std::collections::HashMap;
138
139    struct TestPlugin;
140
141    impl LanguagePlugin for TestPlugin {
142        fn id(&self) -> &'static str {
143            "test"
144        }
145
146        fn extensions(&self) -> &'static [&'static str] {
147            &["rs"]
148        }
149
150        fn extract(
151            &self,
152            _source: &[u8],
153            context: &FileContext,
154        ) -> anyhow::Result<ExtractionResult> {
155            let mut result = ExtractionResult::new();
156            result.nodes.push(Node {
157                id: context.relative_path.to_string_lossy().to_string(),
158                kind: NodeKind::Function,
159                name: "main".to_string(),
160                file: context.relative_path.clone(),
161                span: Span {
162                    start: [0, 0],
163                    end: [1, 0],
164                },
165                visibility: Visibility::Public,
166                metadata: HashMap::new(),
167                role: None,
168                signature: None,
169                doc_comment: None,
170                module: None,
171                snippet: None,
172            });
173            Ok(result)
174        }
175    }
176
177    struct NetworkClassifier;
178
179    impl Classifier for NetworkClassifier {
180        fn classify(
181            &self,
182            _call_target: &str,
183            _context: &crate::classify::ClassifyContext,
184        ) -> Option<Classification> {
185            Some(Classification {
186                terminal_kind: TerminalKind::Network,
187                direction: FlowDirection::Read,
188                operation: "HTTP".to_string(),
189            })
190        }
191    }
192
193    struct EntryPass;
194
195    impl GraphPass for EntryPass {
196        fn apply(&self, mut graph: Graph) -> Graph {
197            if let Some(node) = graph.nodes.first_mut() {
198                node.role = Some(NodeRole::EntryPoint);
199            }
200            graph
201        }
202    }
203
204    #[test]
205    fn extract_with_registry_stamps_module() {
206        let mut registry = LanguageRegistry::new();
207        registry.register(TestPlugin).unwrap();
208        let dir = tempfile::tempdir().unwrap();
209        let src_dir = dir.path().join("src");
210        std::fs::create_dir_all(&src_dir).unwrap();
211        let file = src_dir.join("main.rs");
212        std::fs::write(&file, "fn main() {}").unwrap();
213        let project = ProjectContext {
214            input_path: dir.path().to_path_buf(),
215            project_root: dir.path().to_path_buf(),
216            index_store_enabled: true,
217        };
218        let mut modules = ModuleMap::new();
219        modules.modules.insert("core".to_string(), vec![src_dir]);
220        modules = modules.with_fallback(dir.path());
221        let file_context = file_context(&project, &modules, &file);
222
223        let result = extract_with_registry(&registry, b"fn main() {}", &file_context).unwrap();
224        assert_eq!(result.nodes[0].module.as_deref(), Some("core"));
225    }
226
227    #[test]
228    fn build_graph_runs_classifier_then_graph_pass() {
229        let node = Node {
230            id: "src::main".to_string(),
231            kind: NodeKind::Function,
232            name: "main".to_string(),
233            file: PathBuf::from("main.rs"),
234            span: Span {
235                start: [0, 0],
236                end: [1, 0],
237            },
238            visibility: Visibility::Public,
239            metadata: HashMap::new(),
240            role: None,
241            signature: None,
242            doc_comment: None,
243            module: None,
244            snippet: None,
245        };
246        let result = ExtractionResult {
247            nodes: vec![node],
248            edges: vec![Edge {
249                source: "src::main".to_string(),
250                target: "reqwest::get".to_string(),
251                kind: EdgeKind::Calls,
252                confidence: 1.0,
253                direction: None,
254                operation: None,
255                condition: None,
256                async_boundary: None,
257                provenance: Vec::new(),
258            }],
259            imports: Vec::new(),
260        };
261        let classifier = CompositeClassifier::new(vec![Box::new(NetworkClassifier)]);
262        let graph = build_graph(vec![result], &classifier, &[Box::new(EntryPass)]);
263
264        assert_eq!(graph.edges[0].direction, Some(FlowDirection::Read));
265        assert_eq!(graph.nodes[0].role, Some(NodeRole::EntryPoint));
266    }
267}