Skip to main content

grapha_core/
pipeline.rs

1use std::path::{Path, PathBuf};
2
3use anyhow::Context;
4
5use crate::classify::{CompositeClassifier, classify_extraction_result, classify_graph};
6use crate::discover;
7use crate::extract::ExtractionResult;
8use crate::graph::Graph;
9use crate::merge;
10use crate::module::ModuleMap;
11use crate::normalize::normalize_graph;
12use crate::plugin::{FileContext, GraphPass, LanguageRegistry, ProjectContext};
13
14pub fn project_context(path: &Path) -> ProjectContext {
15    ProjectContext::new(path)
16}
17
18pub fn discover_files(path: &Path, registry: &LanguageRegistry) -> anyhow::Result<Vec<PathBuf>> {
19    discover::discover_files(path, &registry.supported_extensions())
20}
21
22pub fn relative_path_for_input(input_path: &Path, file: &Path) -> PathBuf {
23    if input_path.is_dir() {
24        file.strip_prefix(input_path).unwrap_or(file).to_path_buf()
25    } else {
26        file.file_name()
27            .map(PathBuf::from)
28            .unwrap_or_else(|| file.to_path_buf())
29    }
30}
31
32pub fn prepare_plugins(
33    registry: &LanguageRegistry,
34    context: &ProjectContext,
35) -> anyhow::Result<()> {
36    registry.prepare_plugins(context)
37}
38
39pub fn discover_modules(
40    registry: &LanguageRegistry,
41    context: &ProjectContext,
42) -> anyhow::Result<ModuleMap> {
43    let mut modules = ModuleMap::new();
44    for plugin in registry.plugins() {
45        modules.merge(
46            plugin
47                .discover_modules(context)
48                .with_context(|| format!("failed to discover modules for '{}'", plugin.id()))?,
49        );
50    }
51    Ok(modules.with_fallback(&context.project_root))
52}
53
54pub fn file_context(context: &ProjectContext, modules: &ModuleMap, file: &Path) -> FileContext {
55    let relative_path = relative_path_for_input(&context.input_path, file);
56    let absolute_path =
57        std::fs::canonicalize(file).unwrap_or_else(|_| context.project_root.join(&relative_path));
58    let module_name = modules.module_for_file(&absolute_path).or_else(|| {
59        relative_path
60            .components()
61            .next()
62            .and_then(|component| component.as_os_str().to_str())
63            .map(|segment| segment.to_string())
64    });
65
66    FileContext {
67        input_path: context.input_path.clone(),
68        project_root: context.project_root.clone(),
69        relative_path,
70        absolute_path,
71        module_name,
72    }
73}
74
75pub fn extract_with_registry(
76    registry: &LanguageRegistry,
77    source: &[u8],
78    context: &FileContext,
79) -> anyhow::Result<ExtractionResult> {
80    let plugin = registry.plugin_for_path(&context.relative_path)?;
81    let result = plugin.extract(source, context)?;
82    Ok(plugin.stamp_module(result, context.module_name.as_deref()))
83}
84
85pub fn stamp_module(result: ExtractionResult, module_name: Option<&str>) -> ExtractionResult {
86    let Some(module_name) = module_name else {
87        return result;
88    };
89
90    let nodes = result
91        .nodes
92        .into_iter()
93        .map(|mut node| {
94            node.module = Some(module_name.to_string());
95            node
96        })
97        .collect();
98
99    ExtractionResult {
100        nodes,
101        edges: result.edges,
102        imports: result.imports,
103    }
104}
105
106pub fn build_graph(
107    results: Vec<ExtractionResult>,
108    classifier: &CompositeClassifier,
109    graph_passes: &[Box<dyn GraphPass>],
110) -> Graph {
111    let preclassified_results = results
112        .into_iter()
113        .map(|result| classify_extraction_result(result, classifier))
114        .collect();
115    let mut graph = classify_graph(&merge::merge(preclassified_results), classifier);
116    for pass in graph_passes {
117        graph = pass.apply(graph);
118    }
119    normalize_graph(graph)
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use crate::classify::{Classification, Classifier};
126    use crate::extract::ExtractionResult;
127    use crate::graph::{
128        Edge, EdgeKind, FlowDirection, Graph, Node, NodeKind, NodeRole, Span, TerminalKind,
129        Visibility,
130    };
131    use crate::plugin::{GraphPass, LanguagePlugin};
132    use std::collections::HashMap;
133
134    struct TestPlugin;
135
136    impl LanguagePlugin for TestPlugin {
137        fn id(&self) -> &'static str {
138            "test"
139        }
140
141        fn extensions(&self) -> &'static [&'static str] {
142            &["rs"]
143        }
144
145        fn extract(
146            &self,
147            _source: &[u8],
148            context: &FileContext,
149        ) -> anyhow::Result<ExtractionResult> {
150            let mut result = ExtractionResult::new();
151            result.nodes.push(Node {
152                id: context.relative_path.to_string_lossy().to_string(),
153                kind: NodeKind::Function,
154                name: "main".to_string(),
155                file: context.relative_path.clone(),
156                span: Span {
157                    start: [0, 0],
158                    end: [1, 0],
159                },
160                visibility: Visibility::Public,
161                metadata: HashMap::new(),
162                role: None,
163                signature: None,
164                doc_comment: None,
165                module: None,
166                snippet: None,
167            });
168            Ok(result)
169        }
170    }
171
172    struct NetworkClassifier;
173
174    impl Classifier for NetworkClassifier {
175        fn classify(
176            &self,
177            _call_target: &str,
178            _context: &crate::classify::ClassifyContext,
179        ) -> Option<Classification> {
180            Some(Classification {
181                terminal_kind: TerminalKind::Network,
182                direction: FlowDirection::Read,
183                operation: "HTTP".to_string(),
184            })
185        }
186    }
187
188    struct EntryPass;
189
190    impl GraphPass for EntryPass {
191        fn apply(&self, mut graph: Graph) -> Graph {
192            if let Some(node) = graph.nodes.first_mut() {
193                node.role = Some(NodeRole::EntryPoint);
194            }
195            graph
196        }
197    }
198
199    #[test]
200    fn extract_with_registry_stamps_module() {
201        let mut registry = LanguageRegistry::new();
202        registry.register(TestPlugin).unwrap();
203        let dir = tempfile::tempdir().unwrap();
204        let src_dir = dir.path().join("src");
205        std::fs::create_dir_all(&src_dir).unwrap();
206        let file = src_dir.join("main.rs");
207        std::fs::write(&file, "fn main() {}").unwrap();
208        let project = ProjectContext {
209            input_path: dir.path().to_path_buf(),
210            project_root: dir.path().to_path_buf(),
211        };
212        let mut modules = ModuleMap::new();
213        modules.modules.insert("core".to_string(), vec![src_dir]);
214        let file_context = file_context(&project, &modules, &file);
215
216        let result = extract_with_registry(&registry, b"fn main() {}", &file_context).unwrap();
217        assert_eq!(result.nodes[0].module.as_deref(), Some("core"));
218    }
219
220    #[test]
221    fn build_graph_runs_classifier_then_graph_pass() {
222        let node = Node {
223            id: "src::main".to_string(),
224            kind: NodeKind::Function,
225            name: "main".to_string(),
226            file: PathBuf::from("main.rs"),
227            span: Span {
228                start: [0, 0],
229                end: [1, 0],
230            },
231            visibility: Visibility::Public,
232            metadata: HashMap::new(),
233            role: None,
234            signature: None,
235            doc_comment: None,
236            module: None,
237            snippet: None,
238        };
239        let result = ExtractionResult {
240            nodes: vec![node],
241            edges: vec![Edge {
242                source: "src::main".to_string(),
243                target: "reqwest::get".to_string(),
244                kind: EdgeKind::Calls,
245                confidence: 1.0,
246                direction: None,
247                operation: None,
248                condition: None,
249                async_boundary: None,
250                provenance: Vec::new(),
251            }],
252            imports: Vec::new(),
253        };
254        let classifier = CompositeClassifier::new(vec![Box::new(NetworkClassifier)]);
255        let graph = build_graph(vec![result], &classifier, &[Box::new(EntryPass)]);
256
257        assert_eq!(graph.edges[0].direction, Some(FlowDirection::Read));
258        assert_eq!(graph.nodes[0].role, Some(NodeRole::EntryPoint));
259    }
260}