Skip to main content

grapha_core/
pipeline.rs

1use std::path::{Path, PathBuf};
2
3use anyhow::Context;
4
5use crate::classify::{CompositeClassifier, classify_extraction_result, classify_graph};
6use crate::discover;
7use crate::extract::ExtractionResult;
8use crate::graph::Graph;
9use crate::merge;
10use crate::module::ModuleMap;
11use crate::normalize::normalize_graph;
12use crate::plugin::{FileContext, GraphPass, LanguageRegistry, ProjectContext};
13use crate::semantic::SemanticDocument;
14
15pub fn project_context(path: &Path) -> ProjectContext {
16    ProjectContext::new(path)
17}
18
19pub fn discover_files(path: &Path, registry: &LanguageRegistry) -> anyhow::Result<Vec<PathBuf>> {
20    discover::discover_files(path, &registry.supported_extensions())
21}
22
23pub fn relative_path_for_input(input_path: &Path, file: &Path) -> PathBuf {
24    if input_path.is_dir() {
25        file.strip_prefix(input_path).unwrap_or(file).to_path_buf()
26    } else {
27        file.file_name()
28            .map(PathBuf::from)
29            .unwrap_or_else(|| file.to_path_buf())
30    }
31}
32
33pub fn prepare_plugins(
34    registry: &LanguageRegistry,
35    context: &ProjectContext,
36) -> anyhow::Result<()> {
37    registry.prepare_plugins(context)
38}
39
40pub fn finish_plugins(registry: &LanguageRegistry, context: &ProjectContext) -> anyhow::Result<()> {
41    registry.finish_plugins(context)
42}
43
44pub fn discover_modules(
45    registry: &LanguageRegistry,
46    context: &ProjectContext,
47) -> anyhow::Result<ModuleMap> {
48    let mut modules = ModuleMap::new();
49    for plugin in registry.plugins() {
50        modules.merge(
51            plugin
52                .discover_modules(context)
53                .with_context(|| format!("failed to discover modules for '{}'", plugin.id()))?,
54        );
55    }
56    Ok(modules.with_fallback(&context.project_root))
57}
58
59pub fn file_context(context: &ProjectContext, modules: &ModuleMap, file: &Path) -> FileContext {
60    let relative_path = relative_path_for_input(&context.input_path, file);
61    let absolute_path =
62        std::fs::canonicalize(file).unwrap_or_else(|_| context.project_root.join(&relative_path));
63    let module_name = modules.module_for_file(&absolute_path).or_else(|| {
64        relative_path
65            .components()
66            .next()
67            .and_then(|component| component.as_os_str().to_str())
68            .map(|segment| segment.to_string())
69    });
70
71    FileContext {
72        input_path: context.input_path.clone(),
73        project_root: context.project_root.clone(),
74        relative_path,
75        absolute_path,
76        module_name,
77        index_store_enabled: context.index_store_enabled,
78    }
79}
80
81pub fn extract_with_registry(
82    registry: &LanguageRegistry,
83    source: &[u8],
84    context: &FileContext,
85) -> anyhow::Result<ExtractionResult> {
86    Ok(lower_semantics(extract_semantics_with_registry(
87        registry, source, context,
88    )?))
89}
90
91pub fn extract_semantics_with_registry(
92    registry: &LanguageRegistry,
93    source: &[u8],
94    context: &FileContext,
95) -> anyhow::Result<SemanticDocument> {
96    let plugin = registry.plugin_for_path(&context.relative_path)?;
97    let document = plugin.extract_semantics(source, context)?;
98    Ok(plugin.stamp_semantic_module(document, context.module_name.as_deref()))
99}
100
101pub fn stamp_module(result: ExtractionResult, module_name: Option<&str>) -> ExtractionResult {
102    let Some(module_name) = module_name else {
103        return result;
104    };
105
106    let nodes = result
107        .nodes
108        .into_iter()
109        .map(|mut node| {
110            node.module = Some(module_name.to_string());
111            node
112        })
113        .collect();
114
115    ExtractionResult {
116        nodes,
117        edges: result.edges,
118        imports: result.imports,
119    }
120}
121
122pub fn stamp_semantic_module(
123    document: SemanticDocument,
124    module_name: Option<&str>,
125) -> SemanticDocument {
126    document.stamp_module(module_name)
127}
128
129pub fn lower_semantics(document: SemanticDocument) -> ExtractionResult {
130    document.into_extraction_result()
131}
132
133pub fn build_graph(
134    results: Vec<ExtractionResult>,
135    classifier: &CompositeClassifier,
136    graph_passes: &[Box<dyn GraphPass>],
137) -> Graph {
138    let preclassified_results = results
139        .into_iter()
140        .map(|result| classify_extraction_result(result, classifier))
141        .collect();
142    let mut graph = classify_graph(&merge::merge(preclassified_results), classifier);
143    for pass in graph_passes {
144        graph = pass.apply(graph);
145    }
146    normalize_graph(graph)
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use crate::classify::{Classification, Classifier};
153    use crate::extract::ExtractionResult;
154    use crate::graph::{
155        Edge, EdgeKind, FlowDirection, Graph, Node, NodeKind, NodeRole, Span, TerminalKind,
156        Visibility,
157    };
158    use crate::plugin::{GraphPass, LanguagePlugin};
159    use std::collections::HashMap;
160
161    struct TestPlugin;
162
163    impl LanguagePlugin for TestPlugin {
164        fn id(&self) -> &'static str {
165            "test"
166        }
167
168        fn extensions(&self) -> &'static [&'static str] {
169            &["rs"]
170        }
171
172        fn extract(
173            &self,
174            _source: &[u8],
175            context: &FileContext,
176        ) -> anyhow::Result<ExtractionResult> {
177            let mut result = ExtractionResult::new();
178            result.nodes.push(Node {
179                id: context.relative_path.to_string_lossy().to_string(),
180                kind: NodeKind::Function,
181                name: "main".to_string(),
182                file: context.relative_path.clone(),
183                span: Span {
184                    start: [0, 0],
185                    end: [1, 0],
186                },
187                visibility: Visibility::Public,
188                metadata: HashMap::new(),
189                role: None,
190                signature: None,
191                doc_comment: None,
192                module: None,
193                snippet: None,
194            });
195            Ok(result)
196        }
197    }
198
199    struct NetworkClassifier;
200
201    impl Classifier for NetworkClassifier {
202        fn classify(
203            &self,
204            _call_target: &str,
205            _context: &crate::classify::ClassifyContext,
206        ) -> Option<Classification> {
207            Some(Classification {
208                terminal_kind: TerminalKind::Network,
209                direction: FlowDirection::Read,
210                operation: "HTTP".to_string(),
211            })
212        }
213    }
214
215    struct EntryPass;
216
217    impl GraphPass for EntryPass {
218        fn apply(&self, mut graph: Graph) -> Graph {
219            if let Some(node) = graph.nodes.first_mut() {
220                node.role = Some(NodeRole::EntryPoint);
221            }
222            graph
223        }
224    }
225
226    #[test]
227    fn extract_with_registry_stamps_module() {
228        let mut registry = LanguageRegistry::new();
229        registry.register(TestPlugin).unwrap();
230        let dir = tempfile::tempdir().unwrap();
231        let src_dir = dir.path().join("src");
232        std::fs::create_dir_all(&src_dir).unwrap();
233        let file = src_dir.join("main.rs");
234        std::fs::write(&file, "fn main() {}").unwrap();
235        let project = ProjectContext {
236            input_path: dir.path().to_path_buf(),
237            project_root: dir.path().to_path_buf(),
238            index_store_enabled: true,
239        };
240        let mut modules = ModuleMap::new();
241        modules.modules.insert("core".to_string(), vec![src_dir]);
242        modules = modules.with_fallback(dir.path());
243        let file_context = file_context(&project, &modules, &file);
244
245        let result = extract_with_registry(&registry, b"fn main() {}", &file_context).unwrap();
246        assert_eq!(result.nodes[0].module.as_deref(), Some("core"));
247    }
248
249    #[test]
250    fn build_graph_runs_classifier_then_graph_pass() {
251        let node = Node {
252            id: "src::main".to_string(),
253            kind: NodeKind::Function,
254            name: "main".to_string(),
255            file: PathBuf::from("main.rs"),
256            span: Span {
257                start: [0, 0],
258                end: [1, 0],
259            },
260            visibility: Visibility::Public,
261            metadata: HashMap::new(),
262            role: None,
263            signature: None,
264            doc_comment: None,
265            module: None,
266            snippet: None,
267        };
268        let result = ExtractionResult {
269            nodes: vec![node],
270            edges: vec![Edge {
271                source: "src::main".to_string(),
272                target: "reqwest::get".to_string(),
273                kind: EdgeKind::Calls,
274                confidence: 1.0,
275                direction: None,
276                operation: None,
277                condition: None,
278                async_boundary: None,
279                provenance: Vec::new(),
280            }],
281            imports: Vec::new(),
282        };
283        let classifier = CompositeClassifier::new(vec![Box::new(NetworkClassifier)]);
284        let graph = build_graph(vec![result], &classifier, &[Box::new(EntryPass)]);
285
286        assert_eq!(graph.edges[0].direction, Some(FlowDirection::Read));
287        assert_eq!(graph.nodes[0].role, Some(NodeRole::EntryPoint));
288    }
289}