Skip to main content

codemem_engine/persistence/
cross_repo.rs

1//! Cross-repo persistence: register packages, store unresolved refs,
2//! run forward/backward linking, persist cross-namespace edges, and
3//! detect API endpoints.
4
5use super::CrossRepoPersistResult;
6use crate::index::api_surface;
7use crate::index::linker::{self, CrossRepoEdge, PendingRef, RegisteredPackage};
8use crate::index::manifest::ManifestResult;
9use crate::index::resolver::UnresolvedRef;
10use crate::index::symbol::{Reference, Symbol};
11use codemem_core::{CodememError, Edge, GraphBackend, RelationshipType};
12use std::collections::HashMap;
13
14impl super::super::CodememEngine {
15    /// Persist cross-repo linking data after `persist_index_results`.
16    ///
17    /// This method runs 3 phases:
18    /// 1. Register packages + store unresolved refs
19    /// 2. Forward/backward cross-repo linking
20    /// 3. API endpoint + client call detection
21    pub fn persist_cross_repo_data(
22        &self,
23        manifests: &ManifestResult,
24        unresolved: &[UnresolvedRef],
25        symbols: &[Symbol],
26        references: &[Reference],
27        namespace: &str,
28    ) -> Result<CrossRepoPersistResult, CodememError> {
29        let mut result = CrossRepoPersistResult::default();
30
31        // 1. Register packages from manifests into the package registry.
32        let packages = linker::extract_packages(manifests, namespace);
33        for pkg in &packages {
34            if let Err(e) = self.storage.upsert_package_registry(
35                &pkg.package_name,
36                &pkg.namespace,
37                &pkg.version,
38                &pkg.manifest,
39            ) {
40                tracing::warn!("Failed to register package {}: {e}", pkg.package_name);
41            } else {
42                result.packages_registered += 1;
43            }
44        }
45
46        // 2. Store unresolved refs for future backward linking by other namespaces.
47        {
48            let batch: Vec<codemem_core::UnresolvedRefData> = unresolved
49                .iter()
50                .map(|uref| codemem_core::UnresolvedRefData {
51                    source_qualified_name: uref.source_node.clone(),
52                    target_name: uref.target_name.clone(),
53                    namespace: namespace.to_string(),
54                    file_path: uref.file_path.clone(),
55                    line: uref.line,
56                    ref_kind: uref.ref_kind.clone(),
57                    package_hint: uref.package_hint.clone(),
58                })
59                .collect();
60            match self.storage.store_unresolved_refs_batch(&batch) {
61                Ok(count) => result.unresolved_refs_stored = count,
62                Err(e) => tracing::warn!("Failed to store unresolved refs batch: {e}"),
63            }
64        }
65
66        // 3. Load existing registered packages and pending refs from storage.
67        let all_registry: Vec<RegisteredPackage> = self
68            .storage
69            .list_registered_packages()
70            .unwrap_or_default()
71            .into_iter()
72            .map(|(name, ns, manifest)| RegisteredPackage {
73                package_name: name,
74                namespace: ns,
75                version: String::new(),
76                manifest,
77            })
78            .collect();
79
80        let package_names: Vec<String> = packages.iter().map(|p| p.package_name.clone()).collect();
81
82        // Convert resolver UnresolvedRef -> linker PendingRef for this namespace.
83        let this_ns_pending: Vec<PendingRef> = unresolved
84            .iter()
85            .map(|uref| PendingRef {
86                id: format!("uref:{namespace}:{}:{}", uref.source_node, uref.target_name),
87                namespace: namespace.to_string(),
88                source_node: uref.source_node.clone(),
89                target_name: uref.target_name.clone(),
90                package_hint: uref.package_hint.clone(),
91                ref_kind: uref.ref_kind.clone(),
92                file_path: Some(uref.file_path.clone()),
93                line: Some(uref.line),
94            })
95            .collect();
96
97        // Load ALL pending refs from storage (for backward linking).
98        let all_pending: Vec<PendingRef> = self
99            .storage
100            .list_pending_unresolved_refs()
101            .unwrap_or_default()
102            .into_iter()
103            .map(|r| PendingRef {
104                id: r.id,
105                namespace: r.namespace,
106                source_node: r.source_node,
107                target_name: r.target_name,
108                package_hint: r.package_hint,
109                ref_kind: r.ref_kind,
110                file_path: Some(r.file_path),
111                line: Some(r.line),
112            })
113            .collect();
114
115        // 4. Forward link: resolve our unresolved refs against other namespaces.
116        //    Pre-build namespace→SymbolMatch index to avoid O(N*M) scans.
117        let ns_symbol_index: HashMap<String, Vec<linker::SymbolMatch>> = {
118            let graph = self.lock_graph()?;
119            let mut index: HashMap<String, Vec<linker::SymbolMatch>> = HashMap::new();
120            for n in graph.get_all_nodes() {
121                if !n.id.starts_with("sym:") {
122                    continue;
123                }
124                let Some(ref ns) = n.namespace else {
125                    continue;
126                };
127                let vis_str = n
128                    .payload
129                    .get("visibility")
130                    .and_then(|v| v.as_str())
131                    .unwrap_or("private");
132                let visibility = match vis_str {
133                    "public" | "Public" => crate::index::symbol::Visibility::Public,
134                    "crate" | "Crate" => crate::index::symbol::Visibility::Crate,
135                    "protected" | "Protected" => crate::index::symbol::Visibility::Protected,
136                    _ => crate::index::symbol::Visibility::Private,
137                };
138                let kind = n
139                    .payload
140                    .get("symbol_kind")
141                    .and_then(|v| v.as_str())
142                    .unwrap_or("unknown")
143                    .to_string();
144                index
145                    .entry(ns.clone())
146                    .or_default()
147                    .push(linker::SymbolMatch {
148                        qualified_name: n.label.clone(),
149                        visibility,
150                        kind,
151                    });
152            }
153            index
154        };
155
156        let resolve_fn = |target_ns: &str, target_name: &str| -> Vec<linker::SymbolMatch> {
157            let Some(symbols) = ns_symbol_index.get(target_ns) else {
158                return Vec::new();
159            };
160            symbols
161                .iter()
162                .filter(|s| {
163                    let label = &s.qualified_name;
164                    // Exact match
165                    if label == target_name {
166                        return true;
167                    }
168                    // Suffix match with separator check (. or ::)
169                    if label.ends_with(target_name) {
170                        let prefix = &label[..label.len() - target_name.len()];
171                        return prefix.ends_with('.') || prefix.ends_with("::");
172                    }
173                    false
174                })
175                .cloned()
176                .collect()
177        };
178
179        let forward_result =
180            linker::forward_link(namespace, &this_ns_pending, &all_registry, &resolve_fn);
181        for edge in &forward_result.forward_edges {
182            if let Err(e) = self.persist_cross_repo_edge(edge) {
183                tracing::warn!("Failed to persist forward edge: {e}");
184            } else {
185                result.forward_edges_created += 1;
186            }
187        }
188
189        // 5. Backward link: resolve other namespaces' pending refs against our symbols.
190        let backward_result =
191            linker::backward_link(namespace, &package_names, &all_pending, symbols);
192        for edge in &backward_result.backward_edges {
193            if let Err(e) = self.persist_cross_repo_edge(edge) {
194                tracing::warn!("Failed to persist backward edge: {e}");
195            } else {
196                result.backward_edges_created += 1;
197            }
198        }
199
200        // 5b. Clean up resolved refs so they don't accumulate.
201        let all_resolved: Vec<&str> = forward_result
202            .resolved_ref_ids
203            .iter()
204            .chain(backward_result.resolved_ref_ids.iter())
205            .map(|s| s.as_str())
206            .collect();
207        for ref_id in &all_resolved {
208            if let Err(e) = self.storage.delete_unresolved_ref(ref_id) {
209                tracing::warn!("Failed to delete resolved ref {ref_id}: {e}");
210            }
211        }
212
213        // ── Phase 3: API Surface ────────────────────────────────────────────
214        // 6. Detect API endpoints and client calls.
215        let endpoints = api_surface::detect_endpoints(symbols, namespace);
216        result.endpoints_detected = endpoints.len();
217        for ep in &endpoints {
218            if let Err(e) = self.storage.store_api_endpoint(
219                ep.method.as_deref().unwrap_or("ANY"),
220                &ep.path,
221                &ep.handler,
222                namespace,
223            ) {
224                tracing::warn!(
225                    "Failed to store endpoint {} {}: {e}",
226                    ep.method.as_deref().unwrap_or("ANY"),
227                    ep.path
228                );
229            }
230        }
231
232        let client_calls = api_surface::detect_client_calls(references);
233        result.client_calls_detected = client_calls.len();
234        for call in &client_calls {
235            if let Err(e) = self.storage.store_api_client_call(
236                &call.client_library,
237                call.method.as_deref(),
238                &call.caller,
239                namespace,
240            ) {
241                tracing::warn!(
242                    "Failed to store client call to {}: {e}",
243                    call.client_library
244                );
245            }
246        }
247
248        Ok(result)
249    }
250
251    /// Persist a cross-repo edge into the graph_edges table and in-memory graph.
252    fn persist_cross_repo_edge(&self, edge: &CrossRepoEdge) -> Result<(), CodememError> {
253        let now = chrono::Utc::now();
254        let relationship = match edge.relationship.as_str() {
255            "Calls" => RelationshipType::Calls,
256            "Imports" => RelationshipType::Imports,
257            "Inherits" => RelationshipType::Inherits,
258            "Implements" => RelationshipType::Implements,
259            "DependsOn" => RelationshipType::DependsOn,
260            _ => RelationshipType::RelatesTo,
261        };
262
263        let graph_edge = Edge {
264            id: edge.id.clone(),
265            src: edge.source.clone(),
266            dst: edge.target.clone(),
267            relationship,
268            weight: edge.confidence.min(1.0) * 0.7,
269            valid_from: Some(now),
270            valid_to: None,
271            properties: {
272                let mut props = HashMap::new();
273                props.insert(
274                    "src_namespace".to_string(),
275                    serde_json::Value::String(edge.source_namespace.clone()),
276                );
277                props.insert(
278                    "dst_namespace".to_string(),
279                    serde_json::Value::String(edge.target_namespace.clone()),
280                );
281                props.insert("cross_namespace".to_string(), serde_json::Value::Bool(true));
282                props.insert("confidence".to_string(), serde_json::json!(edge.confidence));
283                props
284            },
285            created_at: now,
286        };
287
288        self.storage.insert_graph_edge(&graph_edge)?;
289        let mut graph = self.lock_graph()?;
290        let _ = graph.add_edge(graph_edge);
291        Ok(())
292    }
293
294    // ── Query helpers for tool_get_cross_repo ────────────────────────────
295
296    /// Get all cross-namespace edges touching a given namespace.
297    pub fn get_cross_namespace_edges(&self, namespace: &str) -> Result<Vec<Edge>, CodememError> {
298        self.storage
299            .graph_edges_for_namespace_with_cross(namespace, true)
300    }
301
302    /// Count unresolved refs for a namespace.
303    pub fn count_unresolved_refs(&self, namespace: &str) -> Result<usize, CodememError> {
304        self.storage.count_unresolved_refs(namespace)
305    }
306
307    /// List registered packages for a namespace.
308    pub fn get_registered_packages(
309        &self,
310        namespace: &str,
311    ) -> Result<Vec<RegisteredPackage>, CodememError> {
312        let tuples = self
313            .storage
314            .list_registered_packages_for_namespace(namespace)?;
315        Ok(tuples
316            .into_iter()
317            .map(|(name, ns, manifest)| RegisteredPackage {
318                package_name: name,
319                namespace: ns,
320                version: String::new(),
321                manifest,
322            })
323            .collect())
324    }
325
326    /// List detected API endpoints for a namespace.
327    pub fn get_detected_endpoints(
328        &self,
329        namespace: &str,
330    ) -> Result<Vec<api_surface::DetectedEndpoint>, CodememError> {
331        let tuples = self.storage.list_api_endpoints(namespace)?;
332        Ok(tuples
333            .into_iter()
334            .map(
335                |(method, path, handler, _ns)| api_surface::DetectedEndpoint {
336                    id: format!("ep:{namespace}:{method}:{path}"),
337                    method: if method == "ANY" { None } else { Some(method) },
338                    path,
339                    handler,
340                    file_path: String::new(),
341                    line: 0,
342                },
343            )
344            .collect())
345    }
346}