Skip to main content

codemem_engine/persistence/
cross_repo.rs

1//! Cross-repo persistence: register packages, store unresolved refs,
2//! run forward/backward linking, persist cross-namespace edges, and
3//! detect API endpoints.
4
5use super::CrossRepoPersistResult;
6use crate::index::api_surface;
7use crate::index::linker::{self, CrossRepoEdge, PendingRef, RegisteredPackage};
8use crate::index::manifest::ManifestResult;
9use crate::index::resolver::UnresolvedRef;
10use crate::index::symbol::{Reference, Symbol};
11use codemem_core::{CodememError, Edge, RelationshipType};
12use std::collections::HashMap;
13
14impl super::super::CodememEngine {
15    /// Persist cross-repo linking data after `persist_index_results`.
16    ///
17    /// This method runs 3 phases:
18    /// 1. Register packages + store unresolved refs
19    /// 2. Forward/backward cross-repo linking
20    /// 3. API endpoint + client call detection
21    pub fn persist_cross_repo_data(
22        &self,
23        manifests: &ManifestResult,
24        unresolved: &[UnresolvedRef],
25        symbols: &[Symbol],
26        references: &[Reference],
27        namespace: &str,
28    ) -> Result<CrossRepoPersistResult, CodememError> {
29        let mut result = CrossRepoPersistResult::default();
30
31        // 1. Register packages from manifests into the package registry.
32        let packages = linker::extract_packages(manifests, namespace);
33        for pkg in &packages {
34            if let Err(e) = self.storage.upsert_package_registry(
35                &pkg.package_name,
36                &pkg.namespace,
37                &pkg.version,
38                &pkg.manifest,
39            ) {
40                tracing::warn!("Failed to register package {}: {e}", pkg.package_name);
41            } else {
42                result.packages_registered += 1;
43            }
44        }
45
46        // 2. Store unresolved refs for future backward linking by other namespaces.
47        {
48            let batch: Vec<codemem_core::UnresolvedRefData> = unresolved
49                .iter()
50                .map(|uref| codemem_core::UnresolvedRefData {
51                    source_qualified_name: uref.source_node.clone(),
52                    target_name: uref.target_name.clone(),
53                    namespace: namespace.to_string(),
54                    file_path: uref.file_path.clone(),
55                    line: uref.line,
56                    ref_kind: uref.ref_kind.clone(),
57                    package_hint: uref.package_hint.clone(),
58                })
59                .collect();
60            match self.storage.store_unresolved_refs_batch(&batch) {
61                Ok(count) => result.unresolved_refs_stored = count,
62                Err(e) => tracing::warn!("Failed to store unresolved refs batch: {e}"),
63            }
64        }
65
66        // 3. Load existing registered packages and pending refs from storage.
67        let all_registry: Vec<RegisteredPackage> = self
68            .storage
69            .list_registered_packages()
70            .unwrap_or_default()
71            .into_iter()
72            .map(|(name, ns, manifest)| RegisteredPackage {
73                package_name: name,
74                namespace: ns,
75                version: String::new(),
76                manifest,
77            })
78            .collect();
79
80        let package_names: Vec<String> = packages.iter().map(|p| p.package_name.clone()).collect();
81
82        // Convert resolver UnresolvedRef -> linker PendingRef for this namespace.
83        let this_ns_pending: Vec<PendingRef> = unresolved
84            .iter()
85            .map(|uref| PendingRef {
86                id: format!("uref:{namespace}:{}:{}", uref.source_node, uref.target_name),
87                namespace: namespace.to_string(),
88                source_node: uref.source_node.clone(),
89                target_name: uref.target_name.clone(),
90                package_hint: uref.package_hint.clone(),
91                ref_kind: uref.ref_kind.clone(),
92                file_path: Some(uref.file_path.clone()),
93                line: Some(uref.line),
94            })
95            .collect();
96
97        // Load ALL pending refs from storage (for backward linking).
98        let all_pending: Vec<PendingRef> = self
99            .storage
100            .list_pending_unresolved_refs()
101            .unwrap_or_default()
102            .into_iter()
103            .map(|r| PendingRef {
104                id: r.id,
105                namespace: r.namespace,
106                source_node: r.source_node,
107                target_name: r.target_name,
108                package_hint: r.package_hint,
109                ref_kind: r.ref_kind,
110                file_path: Some(r.file_path),
111                line: Some(r.line),
112            })
113            .collect();
114
115        // 4. Forward link: resolve our unresolved refs against other namespaces.
116        //    Pre-build namespace→SymbolMatch index to avoid O(N*M) scans.
117        let ns_symbol_index: HashMap<String, Vec<linker::SymbolMatch>> = {
118            let graph = self.lock_graph()?;
119            let mut index: HashMap<String, Vec<linker::SymbolMatch>> = HashMap::new();
120            for n in graph.get_all_nodes() {
121                if !n.id.starts_with("sym:") {
122                    continue;
123                }
124                let Some(ref ns) = n.namespace else {
125                    continue;
126                };
127                let vis_str = n
128                    .payload
129                    .get("visibility")
130                    .and_then(|v| v.as_str())
131                    .unwrap_or("private");
132                let visibility = match vis_str {
133                    "public" | "Public" => crate::index::symbol::Visibility::Public,
134                    "crate" | "Crate" => crate::index::symbol::Visibility::Crate,
135                    "protected" | "Protected" => crate::index::symbol::Visibility::Protected,
136                    _ => crate::index::symbol::Visibility::Private,
137                };
138                let kind = n
139                    .payload
140                    .get("symbol_kind")
141                    .and_then(|v| v.as_str())
142                    .unwrap_or("unknown")
143                    .to_string();
144                index
145                    .entry(ns.clone())
146                    .or_default()
147                    .push(linker::SymbolMatch {
148                        qualified_name: n.label.clone(),
149                        visibility,
150                        kind,
151                    });
152            }
153            index
154        };
155
156        let resolve_fn = |target_ns: &str, target_name: &str| -> Vec<linker::SymbolMatch> {
157            let Some(symbols) = ns_symbol_index.get(target_ns) else {
158                return Vec::new();
159            };
160            symbols
161                .iter()
162                .filter(|s| {
163                    let label = &s.qualified_name;
164                    // Exact match
165                    if label == target_name {
166                        return true;
167                    }
168                    // Suffix match with separator check (. or ::)
169                    if label.ends_with(target_name) {
170                        let prefix = &label[..label.len() - target_name.len()];
171                        return prefix.ends_with('.') || prefix.ends_with("::");
172                    }
173                    false
174                })
175                .cloned()
176                .collect()
177        };
178
179        let forward_result =
180            linker::forward_link(namespace, &this_ns_pending, &all_registry, &resolve_fn);
181        for edge in &forward_result.forward_edges {
182            if let Err(e) = self.persist_cross_repo_edge(edge) {
183                tracing::warn!("Failed to persist forward edge: {e}");
184            } else {
185                result.forward_edges_created += 1;
186            }
187        }
188
189        // 5. Backward link: resolve other namespaces' pending refs against our symbols.
190        let backward_result =
191            linker::backward_link(namespace, &package_names, &all_pending, symbols);
192        for edge in &backward_result.backward_edges {
193            if let Err(e) = self.persist_cross_repo_edge(edge) {
194                tracing::warn!("Failed to persist backward edge: {e}");
195            } else {
196                result.backward_edges_created += 1;
197            }
198        }
199
200        // 5b. Clean up resolved refs so they don't accumulate.
201        let all_resolved: Vec<&str> = forward_result
202            .resolved_ref_ids
203            .iter()
204            .chain(backward_result.resolved_ref_ids.iter())
205            .map(|s| s.as_str())
206            .collect();
207        for ref_id in &all_resolved {
208            if let Err(e) = self.storage.delete_unresolved_ref(ref_id) {
209                tracing::warn!("Failed to delete resolved ref {ref_id}: {e}");
210            }
211        }
212
213        // ── Phase 3: API Surface ────────────────────────────────────────────
214
215        // 6a. Detect endpoints from decorators/annotations (existing)
216        let mut all_endpoints = api_surface::detect_endpoints(symbols, namespace);
217
218        // 6b. Detect endpoints from call references (Go, Express.js)
219        let ref_endpoints = api_surface::detect_endpoints_from_references(references, namespace);
220        all_endpoints.extend(ref_endpoints);
221
222        result.endpoints_detected = all_endpoints.len();
223        for ep in &all_endpoints {
224            if let Err(e) = self.storage.store_api_endpoint(
225                ep.method.as_deref().unwrap_or("ANY"),
226                &ep.path,
227                &ep.handler,
228                namespace,
229            ) {
230                tracing::warn!(
231                    "Failed to store endpoint {} {}: {e}",
232                    ep.method.as_deref().unwrap_or("ANY"),
233                    ep.path
234                );
235            }
236        }
237
238        // 7. Detect HTTP client calls
239        let client_calls = api_surface::detect_client_calls(references);
240        result.client_calls_detected = client_calls.len();
241        for call in &client_calls {
242            if let Err(e) = self.storage.store_api_client_call(
243                &call.client_library,
244                call.method.as_deref(),
245                &call.caller,
246                namespace,
247            ) {
248                tracing::warn!(
249                    "Failed to store client call to {}: {e}",
250                    call.client_library
251                );
252            }
253        }
254
255        // 8. Detect event channel interactions (Kafka, RabbitMQ, Redis, SQS, etc.)
256        let event_calls = api_surface::detect_event_calls(references, symbols);
257        result.event_channels_detected = event_calls.len();
258        for ec in &event_calls {
259            if let Err(e) = self.storage.store_event_channel(
260                ec.channel.as_deref().unwrap_or("unknown"),
261                &ec.direction,
262                &ec.protocol,
263                &ec.caller,
264                namespace,
265                "",
266            ) {
267                tracing::warn!("Failed to store event channel for {}: {e}", ec.caller);
268            }
269        }
270
271        // ── Phase 4: Cross-service edge matching ──────────────────────────
272
273        // 9a. Match HTTP client calls to detected endpoints across namespaces
274        let all_stored_with_ns = self.get_all_stored_endpoints_with_ns();
275        let all_ep_list: Vec<api_surface::DetectedEndpoint> = all_stored_with_ns
276            .iter()
277            .map(|(ep, _)| ep.clone())
278            .collect();
279        for call in &client_calls {
280            if let Some(url) = &call.url_pattern {
281                if let Some((matched_ep, confidence)) =
282                    api_surface::match_endpoint(url, call.method.as_deref(), &all_ep_list)
283                {
284                    // Find the namespace for this matched endpoint
285                    let ep_ns = all_stored_with_ns
286                        .iter()
287                        .find(|(ep, _)| ep.id == matched_ep.id)
288                        .map(|(_, ns)| ns.as_str());
289                    // Only create cross-namespace edges
290                    if ep_ns != Some(namespace) {
291                        let edge = Edge {
292                            id: format!("http:{}->{}", call.caller, matched_ep.handler),
293                            src: format!("sym:{}", call.caller),
294                            dst: format!("sym:{}", matched_ep.handler),
295                            relationship: RelationshipType::HttpCalls,
296                            weight: confidence * 0.7,
297                            properties: {
298                                let mut p = HashMap::new();
299                                p.insert(
300                                    "cross_namespace".to_string(),
301                                    serde_json::Value::Bool(true),
302                                );
303                                p.insert(
304                                    "path".to_string(),
305                                    serde_json::Value::String(matched_ep.path.clone()),
306                                );
307                                p
308                            },
309                            created_at: chrono::Utc::now(),
310                            valid_from: Some(chrono::Utc::now()),
311                            valid_to: None,
312                        };
313                        if self.storage.insert_graph_edge(&edge).is_ok() {
314                            if let Ok(mut graph) = self.lock_graph() {
315                                let _ = graph.add_edge(edge);
316                            }
317                            result.http_edges_matched += 1;
318                        }
319                    }
320                }
321            }
322        }
323
324        // 9b. Match event producers to consumers across namespaces
325        let all_event_channels = self.storage.list_all_event_channels().unwrap_or_default();
326        let producers: Vec<api_surface::DetectedEventCall> = all_event_channels
327            .iter()
328            .filter(|ec| ec.1 == "publish")
329            .map(|ec| api_surface::DetectedEventCall {
330                caller: ec.3.clone(),
331                channel: Some(ec.0.clone()),
332                direction: "publish".to_string(),
333                protocol: ec.2.clone(),
334                file_path: String::new(),
335                line: 0,
336            })
337            .collect();
338        let consumers: Vec<api_surface::DetectedEventCall> = all_event_channels
339            .iter()
340            .filter(|ec| ec.1 == "subscribe")
341            .map(|ec| api_surface::DetectedEventCall {
342                caller: ec.3.clone(),
343                channel: Some(ec.0.clone()),
344                direction: "subscribe".to_string(),
345                protocol: ec.2.clone(),
346                file_path: String::new(),
347                line: 0,
348            })
349            .collect();
350
351        let event_matches = api_surface::match_event_channels(&producers, &consumers);
352        let now = chrono::Utc::now();
353        for (producer, consumer, channel, protocol, confidence) in &event_matches {
354            // Only create cross-namespace edges (different callers imply different namespaces in practice)
355            if producer == consumer {
356                continue;
357            }
358            let edge = Edge {
359                id: format!("event:{producer}->{consumer}:{protocol}:{channel}"),
360                src: format!("sym:{producer}"),
361                dst: format!("sym:{consumer}"),
362                relationship: RelationshipType::PublishesTo,
363                weight: confidence * 0.6,
364                properties: {
365                    let mut p = HashMap::new();
366                    p.insert(
367                        "channel".to_string(),
368                        serde_json::Value::String(channel.clone()),
369                    );
370                    p.insert(
371                        "protocol".to_string(),
372                        serde_json::Value::String(protocol.clone()),
373                    );
374                    p
375                },
376                created_at: now,
377                valid_from: Some(now),
378                valid_to: None,
379            };
380            if self.storage.insert_graph_edge(&edge).is_ok() {
381                if let Ok(mut graph) = self.lock_graph() {
382                    let _ = graph.add_edge(edge);
383                }
384                result.event_edges_matched += 1;
385            }
386        }
387
388        Ok(result)
389    }
390
391    /// Get all stored endpoints across all namespaces, paired with their namespace.
392    fn get_all_stored_endpoints_with_ns(&self) -> Vec<(api_surface::DetectedEndpoint, String)> {
393        let namespaces = self.storage.list_namespaces().unwrap_or_default();
394        let mut all = Vec::new();
395        for ns in &namespaces {
396            if let Ok(eps) = self.get_detected_endpoints(ns) {
397                for ep in eps {
398                    all.push((ep, ns.clone()));
399                }
400            }
401        }
402        all
403    }
404
405    /// Persist a cross-repo edge into the graph_edges table and in-memory graph.
406    fn persist_cross_repo_edge(&self, edge: &CrossRepoEdge) -> Result<(), CodememError> {
407        let now = chrono::Utc::now();
408        let relationship = match edge.relationship.as_str() {
409            "Calls" => RelationshipType::Calls,
410            "Imports" => RelationshipType::Imports,
411            "Inherits" => RelationshipType::Inherits,
412            "Implements" => RelationshipType::Implements,
413            "DependsOn" => RelationshipType::DependsOn,
414            _ => RelationshipType::RelatesTo,
415        };
416
417        let graph_edge = Edge {
418            id: edge.id.clone(),
419            src: edge.source.clone(),
420            dst: edge.target.clone(),
421            relationship,
422            weight: edge.confidence.min(1.0) * 0.7,
423            valid_from: Some(now),
424            valid_to: None,
425            properties: {
426                let mut props = HashMap::new();
427                props.insert(
428                    "src_namespace".to_string(),
429                    serde_json::Value::String(edge.source_namespace.clone()),
430                );
431                props.insert(
432                    "dst_namespace".to_string(),
433                    serde_json::Value::String(edge.target_namespace.clone()),
434                );
435                props.insert("cross_namespace".to_string(), serde_json::Value::Bool(true));
436                props.insert("confidence".to_string(), serde_json::json!(edge.confidence));
437                props
438            },
439            created_at: now,
440        };
441
442        self.storage.insert_graph_edge(&graph_edge)?;
443        let mut graph = self.lock_graph()?;
444        let _ = graph.add_edge(graph_edge);
445        Ok(())
446    }
447
448    // ── Query helpers for tool_get_cross_repo ────────────────────────────
449
450    /// Get all cross-namespace edges touching a given namespace.
451    pub fn get_cross_namespace_edges(&self, namespace: &str) -> Result<Vec<Edge>, CodememError> {
452        self.storage
453            .graph_edges_for_namespace_with_cross(namespace, true)
454    }
455
456    /// Count unresolved refs for a namespace.
457    pub fn count_unresolved_refs(&self, namespace: &str) -> Result<usize, CodememError> {
458        self.storage.count_unresolved_refs(namespace)
459    }
460
461    /// List registered packages for a namespace.
462    pub fn get_registered_packages(
463        &self,
464        namespace: &str,
465    ) -> Result<Vec<RegisteredPackage>, CodememError> {
466        let tuples = self
467            .storage
468            .list_registered_packages_for_namespace(namespace)?;
469        Ok(tuples
470            .into_iter()
471            .map(|(name, ns, manifest)| RegisteredPackage {
472                package_name: name,
473                namespace: ns,
474                version: String::new(),
475                manifest,
476            })
477            .collect())
478    }
479
480    /// List detected API endpoints for a namespace.
481    pub fn get_detected_endpoints(
482        &self,
483        namespace: &str,
484    ) -> Result<Vec<api_surface::DetectedEndpoint>, CodememError> {
485        let tuples = self.storage.list_api_endpoints(namespace)?;
486        Ok(tuples
487            .into_iter()
488            .map(
489                |(method, path, handler, _ns)| api_surface::DetectedEndpoint {
490                    id: format!("ep:{namespace}:{method}:{path}"),
491                    method: if method == "ANY" { None } else { Some(method) },
492                    path,
493                    handler,
494                    file_path: String::new(),
495                    line: 0,
496                },
497            )
498            .collect())
499    }
500}