Skip to main content

ckg_storage/store/
resolve.rs

1//! Cross-file call resolution and test-edge detection.
2//!
3//! `resolve_cross_file_calls`: rewrites unresolved `Calls.dst` bare-names to
4//! canonical Symbol ids using a three-tier match (exact name → qname suffix →
5//! leaf segment), all requiring a unique match to avoid ambiguous rewrites.
6//!
7//! `detect_test_edges`: emits `Tests` relation edges from `test_*` functions
8//! to their stripped-prefix target symbol when the match is unique.
9
10use std::collections::BTreeMap;
11
12use ckg_core::Result;
13use cozo::{DataValue, ScriptMutability};
14
15use super::map_err;
16use super::Storage;
17
18/// Strip **Rust** generic / lifetime parameters from a qname.
19///
20/// Examples: `Ctx<'_>::push_symbol` → `Ctx::push_symbol`,
21/// `HashMap<K, V>::insert` → `HashMap::insert`. Bracket-balanced so
22/// nested generics also collapse.
23///
24/// **Scope (L3):** This function is only applied to Rust-origin qnames.
25/// Python / Ruby / JS qnames don't use `<>`-delimited generics in this
26/// sense, so calling it on those languages is a no-op in practice.
27fn strip_generics(s: &str) -> String {
28    let mut out = String::with_capacity(s.len());
29    let mut depth: usize = 0;
30    for c in s.chars() {
31        match c {
32            '<' => depth += 1,
33            '>' if depth > 0 => depth -= 1,
34            _ if depth == 0 => out.push(c),
35            _ => {}
36        }
37    }
38    out
39}
40
41/// Return the trailing identifier segment of a Rust / Python / Ruby
42/// dotted-or-scoped path. `commands::index::run` → `run`,
43/// `pkg.module.fn` → `fn`. If no separator is present, returns the input
44/// unchanged.
45fn leaf_of(s: &str) -> &str {
46    s.rsplit("::")
47        .next()
48        .unwrap_or(s)
49        .rsplit('.')
50        .next()
51        .unwrap_or(s)
52}
53
54impl Storage {
55    /// Cross-file resolution pass: for every `Calls` edge whose `dst` is a
56    /// bare name (not an existing symbol id) AND the name matches *exactly
57    /// one* `Symbol.name` in the same repo, rewrite `dst` to that id.
58    ///
59    /// Returns the number of edges rewritten. Confidence on rewritten edges
60    /// stays at 0.5 (still ambiguous compared to in-file resolution which
61    /// gets 1.0 at extract time).
62    pub fn resolve_cross_file_calls(&self) -> Result<usize> {
63        // Step 1: load only unresolved Calls (dst not in Symbol.id) — server-side filter.
64        let calls = self
65            .db
66            .run_script(
67                "?[src, dst, confidence] := *Calls{src, dst, confidence}, \
68                 not *Symbol{id: dst}",
69                BTreeMap::new(),
70                ScriptMutability::Immutable,
71            )
72            .map_err(map_err)?;
73
74        // Step 2: derive distinct needles (full dst + leaf segment).
75        let mut needles: std::collections::HashSet<String> = std::collections::HashSet::new();
76        for r in &calls.rows {
77            if let Some(DataValue::Str(dst)) = r.get(1) {
78                let d = dst.as_str();
79                needles.insert(d.to_string());
80                let leaf = leaf_of(d);
81                if leaf != d {
82                    needles.insert(leaf.to_string());
83                }
84            }
85        }
86
87        // Step 3: targeted Symbol load via bound `needle_set[n]` relation.
88        let mut by_name: std::collections::HashMap<String, Vec<String>> =
89            std::collections::HashMap::new();
90        let mut qnames: Vec<(String, String)> = Vec::new();
91        if !needles.is_empty() {
92            let needle_rows: Vec<DataValue> = needles
93                .iter()
94                .map(|n| DataValue::List(vec![DataValue::from(n.as_str())]))
95                .collect();
96            let mut params = BTreeMap::new();
97            params.insert("needle_rows".into(), DataValue::List(needle_rows));
98            let rows = self
99                .db
100                .run_script(
101                    "needle_set[n] <- $needle_rows\n\
102                     ?[name, qname, id] := *Symbol{name, qname, id}, needle_set[name]",
103                    params,
104                    ScriptMutability::Immutable,
105                )
106                .map_err(map_err)?;
107            qnames.reserve(rows.rows.len());
108            for r in rows.rows {
109                if let (
110                    Some(DataValue::Str(n)),
111                    Some(DataValue::Str(q)),
112                    Some(DataValue::Str(i)),
113                ) = (r.first(), r.get(1), r.get(2))
114                {
115                    by_name
116                        .entry(n.to_string())
117                        .or_default()
118                        .push(i.to_string());
119                    qnames.push((strip_generics(q.as_str()), i.to_string()));
120                }
121            }
122        }
123
124        let mut rewrites: Vec<(String, String, String, f64)> = Vec::new();
125        for r in calls.rows {
126            let (Some(DataValue::Str(src)), Some(DataValue::Str(dst)), Some(DataValue::Num(c))) =
127                (r.first(), r.get(1), r.get(2))
128            else {
129                continue;
130            };
131            let dst_s = dst.to_string();
132
133            // Three-tier resolution, all requiring a UNIQUE match:
134            //   1. Full path matches a Symbol.name verbatim.
135            //   2. Some Symbol.qname ends with the dst path — covers
136            //      `commands::index::run` even when leaf `run` collides.
137            //   3. Leaf segment matches a Symbol.name uniquely.
138            let candidate_id = by_name
139                .get(&dst_s)
140                .filter(|v| v.len() == 1)
141                .map(|v| v[0].clone())
142                .or_else(|| {
143                    // Progressive-suffix qname match. For a path call
144                    // `a::b::c::d`, try suffix matches starting from the
145                    // longest. First suffix yielding a unique candidate wins.
146                    let segments: Vec<&str> =
147                        dst_s.split(['.', ':']).filter(|s| !s.is_empty()).collect();
148                    if segments.len() < 2 {
149                        return None;
150                    }
151                    let max_skip = segments.len().saturating_sub(2);
152                    for skip in 0..=max_skip {
153                        let suffix_segs = &segments[skip..];
154                        let suffix_dot = suffix_segs.join(".");
155                        let suffix_colon = suffix_segs.join("::");
156                        // CR-M-2: pre-compute the four `endswith` needles OUTSIDE
157                        // the qnames filter. Inside, original code allocated
158                        // `format!("::{n}")` per qname — O(unresolved × qnames).
159                        let dot_dot = format!(".{suffix_dot}");
160                        let dot_colon = format!(".{suffix_colon}");
161                        let colon_dot = format!("::{suffix_dot}");
162                        let colon_colon = format!("::{suffix_colon}");
163                        let mut hits: Vec<&String> = qnames
164                            .iter()
165                            .filter(|(q, _)| {
166                                q.as_str() == suffix_dot.as_str()
167                                    || q.as_str() == suffix_colon.as_str()
168                                    || q.ends_with(&dot_dot)
169                                    || q.ends_with(&dot_colon)
170                                    || q.ends_with(&colon_dot)
171                                    || q.ends_with(&colon_colon)
172                            })
173                            .map(|(_, id)| id)
174                            .collect();
175                        hits.sort();
176                        hits.dedup();
177                        if hits.len() == 1 {
178                            return Some(hits[0].clone());
179                        }
180                        // hits.len() > 1 → ambiguous; try narrower suffix (skip+1).
181                    }
182                    None
183                })
184                .or_else(|| {
185                    let leaf = leaf_of(&dst_s);
186                    if leaf == dst_s {
187                        return None;
188                    }
189                    by_name
190                        .get(leaf)
191                        .filter(|v| v.len() == 1)
192                        .map(|v| v[0].clone())
193                });
194
195            if let Some(target) = candidate_id {
196                let conf = match c {
197                    cozo::Num::Float(f) => *f,
198                    cozo::Num::Int(i) => *i as f64,
199                };
200                rewrites.push((src.to_string(), dst_s, target.clone(), conf));
201            }
202        }
203
204        if rewrites.is_empty() {
205            return Ok(0);
206        }
207
208        // Apply: delete old, insert rewritten. Cozo has no UPDATE so we use :rm + :put.
209        // RESOLVE-C3: two separate `{...}` blocks in one imperative program so
210        // both are committed in a single SessionTx.commit_tx() — atomic.
211        // A single `{ stmt1 ; stmt2 }` block with two `?` heads is rejected by
212        // Cozo 0.7.13 with "cannot have multiple definitions since it contains
213        // non-Horn clauses". Two blocks avoid this restriction.
214        let n = rewrites.len();
215        for chunk in rewrites.chunks(500) {
216            let rm_rows: Vec<DataValue> = chunk
217                .iter()
218                .map(|(s, d, _, _)| {
219                    DataValue::List(vec![
220                        DataValue::from(s.as_str()),
221                        DataValue::from(d.as_str()),
222                    ])
223                })
224                .collect();
225            let put_rows: Vec<DataValue> = chunk
226                .iter()
227                .map(|(s, _, new_d, c)| {
228                    DataValue::List(vec![
229                        DataValue::from(s.as_str()),
230                        DataValue::from(new_d.as_str()),
231                        DataValue::from(*c),
232                    ])
233                })
234                .collect();
235            // Pack both row sets into a single params map and execute two
236            // separate blocks so each `?` head is unambiguous.
237            let mut p = BTreeMap::new();
238            p.insert("rm_rows".into(), DataValue::List(rm_rows));
239            p.insert("put_rows".into(), DataValue::List(put_rows));
240            self.db
241                .run_script(
242                    "{ ?[src, dst] <- $rm_rows :rm Calls {src, dst} } \
243                     { ?[src, dst, confidence] <- $put_rows :put Calls {src, dst => confidence} }",
244                    p,
245                    ScriptMutability::Mutable,
246                )
247                .map_err(map_err)?;
248        }
249        Ok(n)
250    }
251
252    /// Detect test functions and emit `Tests` edges to their candidate target.
253    ///
254    /// Rule: any `Symbol` of kind=function/method whose `name` starts with
255    /// `test_` is a test. Strip the prefix to get the bare target name.
256    /// If exactly one other symbol shares that bare name, emit a `Tests` edge.
257    ///
258    /// Returns the number of `Tests` edges written.
259    pub fn detect_test_edges(&self) -> Result<usize> {
260        let rows = self
261            .db
262            .run_script(
263                "?[id, name, kind] := *Symbol{id, name, kind}, \
264                 (kind = \"function\" or kind = \"method\")",
265                BTreeMap::new(),
266                ScriptMutability::Immutable,
267            )
268            .map_err(map_err)?;
269
270        // Build name → ids index; collect tests.
271        let mut by_name: std::collections::HashMap<String, Vec<String>> =
272            std::collections::HashMap::new();
273        let mut tests: Vec<(String, String)> = Vec::new(); // (test_id, target_name)
274        for r in rows.rows {
275            let (Some(DataValue::Str(id)), Some(DataValue::Str(name)), _) =
276                (r.first(), r.get(1), r.get(2))
277            else {
278                continue;
279            };
280            let id_s = id.to_string();
281            let n = name.to_string();
282            by_name.entry(n.clone()).or_default().push(id_s.clone());
283            if let Some(stripped) = n.strip_prefix("test_") {
284                tests.push((id_s, stripped.to_string()));
285            }
286        }
287
288        let mut written: Vec<(String, String)> = Vec::new();
289        for (test_id, target) in tests {
290            match by_name.get(&target) {
291                Some(candidates) if candidates.len() == 1 => {
292                    written.push((test_id, candidates[0].clone()));
293                }
294                Some(candidates) => {
295                    tracing::trace!(
296                        "test edge skipped (ambiguous): {test_id} → {target} matched {} candidates",
297                        candidates.len()
298                    );
299                }
300                None => {
301                    tracing::trace!(
302                        "test edge skipped (no target): {test_id} prefix-stripped to {target} \
303                         but no Symbol of that name exists"
304                    );
305                }
306            }
307        }
308
309        if written.is_empty() {
310            return Ok(0);
311        }
312        let n = written.len();
313        let rows: Vec<DataValue> = written
314            .into_iter()
315            .map(|(s, d)| {
316                DataValue::List(vec![
317                    DataValue::from(s.as_str()),
318                    DataValue::from(d.as_str()),
319                ])
320            })
321            .collect();
322        let mut params = BTreeMap::new();
323        params.insert("rows".into(), DataValue::List(rows));
324        self.db
325            .run_script(
326                "?[src, dst] <- $rows :put Tests {src, dst}",
327                params,
328                ScriptMutability::Mutable,
329            )
330            .map_err(map_err)?;
331        Ok(n)
332    }
333}