Skip to main content

srcgraph_metrics/
association_rules.rs

1//! Apriori-style frequent itemset mining + association-rule generation over
2//! per-class `callSequences` blobs.
3//!
4//! Each class node carries a `callSequences` JSON blob with shape
5//!
6//!   `{"sequences": [{"method": str, "calls": [str], …}, …]}`
7//!
8//! Every sequence with `len(calls) >= 2` becomes one transaction (the
9//! `frozenset` of its `calls`). We then:
10//!
11//! 1. Mine frequent itemsets by Apriori (bottom-up by k), capped at `k = 5` to
12//!    match the Python reference.
13//! 2. Generate rules `A → C` for every non-empty proper subset `A` of each
14//!    frequent itemset, scoring with `confidence = sup(A∪C)/sup(A)` and
15//!    `lift = confidence / (sup(C)/n)`.
16//! 3. Classify each rule as `invariant` (conf ≥ 0.99), `strong` (≥ 0.85), or
17//!    `moderate` (≥ 0.5).
18//!
19//! Mirrors `analysis/association_rules.py` in the visiting tool.
20//!
21//! See `DESIGN.md` (Phase 3) at the workspace root.
22
23use petgraph::Graph;
24use serde::{Deserialize, Serialize};
25use std::collections::{BTreeSet, HashMap};
26
27use srcgraph_core::{ClassNode, EdgeKind};
28
29/// One call-sequence record parsed from a class's `callSequences` blob.
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct CallSequence {
32    pub method: String,
33    pub calls: Vec<String>,
34}
35
36/// A frequent itemset and its support count across the transaction list.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct FrequentItemset {
39    /// Items in deterministic (sorted) order.
40    pub items: Vec<String>,
41    pub support: usize,
42}
43
44/// An association rule `antecedent → consequent` with support, confidence,
45/// and lift.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct AssociationRule {
48    pub antecedent: Vec<String>,
49    pub consequent: Vec<String>,
50    pub support: usize,
51    /// `support / |transactions|`, rounded to 4 decimals.
52    pub support_pct: f64,
53    /// `sup(A∪C) / sup(A)`, rounded to 4 decimals.
54    pub confidence: f64,
55    /// `confidence / (sup(C) / n)`, rounded to 4 decimals. `0.0` when `sup(C)`
56    /// is unknown (consequent wasn't a frequent itemset).
57    pub lift: f64,
58    /// `"invariant"` (conf ≥ 0.99) / `"strong"` (≥ 0.85) / `"moderate"` (≥ 0.5).
59    pub classification: String,
60}
61
62/// Whole-graph association-rule readout.
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct AssociationAnalysis {
65    pub transactions: usize,
66    pub rules: Vec<AssociationRule>,
67    pub num_rules: usize,
68    pub invariants: usize,
69    pub strong: usize,
70    pub moderate: usize,
71    pub itemsets: usize,
72}
73
74/// Parse the `{"sequences": [...]}` blob attached to a single class.
75///
76/// Returns `None` if the blob isn't an object with a `sequences` array; missing
77/// `method`/`calls` default to empty.
78pub fn parse_call_sequences(blob: &serde_json::Value) -> Option<Vec<CallSequence>> {
79    let seqs = blob.get("sequences")?.as_array()?;
80    let mut out = Vec::with_capacity(seqs.len());
81    for s in seqs {
82        let Some(obj) = s.as_object() else { continue };
83        let method = obj.get("method").and_then(|v| v.as_str()).unwrap_or("").to_owned();
84        let calls = obj
85            .get("calls")
86            .and_then(|v| v.as_array())
87            .map(|arr| {
88                arr.iter()
89                    .filter_map(|c| c.as_str().map(|s| s.to_owned()))
90                    .collect()
91            })
92            .unwrap_or_default();
93        out.push(CallSequence { method, calls });
94    }
95    Some(out)
96}
97
98/// Mine frequent itemsets by Apriori. Caps `k` at 5 (matches Python ref) to
99/// keep candidate generation tractable on large unique-item alphabets.
100///
101/// Result is sorted by support descending, then by item-list ascending for
102/// determinism among ties.
103pub fn mine_frequent_itemsets(
104    transactions: &[BTreeSet<String>],
105    min_support: usize,
106) -> Vec<FrequentItemset> {
107    if transactions.is_empty() {
108        return Vec::new();
109    }
110
111    // Frequent 1-itemsets via direct counting.
112    let mut item_counts: HashMap<&str, usize> = HashMap::new();
113    for txn in transactions {
114        for item in txn {
115            *item_counts.entry(item.as_str()).or_insert(0) += 1;
116        }
117    }
118    let mut freq_items: Vec<String> = item_counts
119        .iter()
120        .filter(|(_, &c)| c >= min_support)
121        .map(|(k, _)| (*k).to_owned())
122        .collect();
123    freq_items.sort();
124
125    let mut frequent: HashMap<BTreeSet<String>, usize> = HashMap::new();
126    for item in &freq_items {
127        let mut s = BTreeSet::new();
128        s.insert(item.clone());
129        let c = item_counts[item.as_str()];
130        frequent.insert(s, c);
131    }
132
133    // Iterate k = 2..=5; generate candidates from all k-combinations of
134    // frequent 1-items, then keep those whose (k-1)-subsets are all frequent
135    // (the Apriori property) and whose support meets the threshold.
136    let mut k = 2usize;
137    let mut prev_has_freq = !freq_items.is_empty();
138    while prev_has_freq && k <= 5 {
139        let mut new_count = 0usize;
140        for combo in combinations(&freq_items, k) {
141            let cand: BTreeSet<String> = combo.iter().cloned().collect();
142            // All (k-1)-subsets frequent?
143            let all_sub_freq = cand.iter().all(|item| {
144                let mut sub = cand.clone();
145                sub.remove(item);
146                frequent.contains_key(&sub)
147            });
148            if !all_sub_freq {
149                continue;
150            }
151            let support = transactions
152                .iter()
153                .filter(|txn| cand.iter().all(|x| txn.contains(x)))
154                .count();
155            if support >= min_support {
156                frequent.insert(cand, support);
157                new_count += 1;
158            }
159        }
160        prev_has_freq = new_count > 0;
161        k += 1;
162    }
163
164    let mut out: Vec<FrequentItemset> = frequent
165        .into_iter()
166        .map(|(set, support)| FrequentItemset {
167            items: set.into_iter().collect(),
168            support,
169        })
170        .collect();
171    out.sort_by(|a, b| b.support.cmp(&a.support).then_with(|| a.items.cmp(&b.items)));
172    out
173}
174
175/// Generate association rules from frequent itemsets.
176///
177/// Returns rules sorted by confidence desc, then support desc, then antecedent
178/// ascending for deterministic ordering among ties.
179pub fn generate_rules(
180    itemsets: &[FrequentItemset],
181    transactions: &[BTreeSet<String>],
182    min_confidence: f64,
183) -> Vec<AssociationRule> {
184    let n_txns = transactions.len();
185    if n_txns == 0 || itemsets.is_empty() {
186        return Vec::new();
187    }
188
189    // Support lookup keyed on the itemset as BTreeSet.
190    let support_map: HashMap<BTreeSet<String>, usize> = itemsets
191        .iter()
192        .map(|fi| (fi.items.iter().cloned().collect(), fi.support))
193        .collect();
194
195    let mut rules: Vec<AssociationRule> = Vec::new();
196
197    for fi in itemsets {
198        if fi.items.len() < 2 {
199            continue;
200        }
201        let support = fi.support;
202        // Enumerate every non-empty proper subset as antecedent.
203        for i in 1..fi.items.len() {
204            for ant_vec in combinations(&fi.items, i) {
205                let antecedent: BTreeSet<String> = ant_vec.iter().cloned().collect();
206                let full: BTreeSet<String> = fi.items.iter().cloned().collect();
207                let consequent: BTreeSet<String> = full.difference(&antecedent).cloned().collect();
208
209                let Some(&ant_support) = support_map.get(&antecedent) else {
210                    continue;
211                };
212                if ant_support == 0 {
213                    continue;
214                }
215                let confidence = support as f64 / ant_support as f64;
216                if confidence < min_confidence {
217                    continue;
218                }
219                let cons_support = support_map.get(&consequent).copied().unwrap_or(0);
220                let lift = if cons_support > 0 {
221                    confidence / (cons_support as f64 / n_txns as f64)
222                } else {
223                    0.0
224                };
225                let conf_r = round4(confidence);
226                rules.push(AssociationRule {
227                    antecedent: antecedent.into_iter().collect(),
228                    consequent: consequent.into_iter().collect(),
229                    support,
230                    support_pct: round4(support as f64 / n_txns as f64),
231                    confidence: conf_r,
232                    lift: round4(lift),
233                    classification: classify_rule(conf_r).to_owned(),
234                });
235            }
236        }
237    }
238
239    rules.sort_by(|a, b| {
240        b.confidence
241            .partial_cmp(&a.confidence)
242            .unwrap_or(std::cmp::Ordering::Equal)
243            .then_with(|| b.support.cmp(&a.support))
244            .then_with(|| a.antecedent.cmp(&b.antecedent))
245            .then_with(|| a.consequent.cmp(&b.consequent))
246    });
247    rules
248}
249
250/// Classify by confidence threshold — `invariant` (≥0.99), `strong` (≥0.85),
251/// `moderate` otherwise.
252pub fn classify_rule(confidence: f64) -> &'static str {
253    if confidence >= 0.99 {
254        "invariant"
255    } else if confidence >= 0.85 {
256        "strong"
257    } else {
258        "moderate"
259    }
260}
261
262/// Walk every node's `callSequences` blob and run end-to-end mining.
263///
264/// `min_support` is taken literally; the Python reference uses
265/// `max(2, len(transactions) / 10)` adaptively — callers wanting that pattern
266/// should compute it themselves (transactions count is reported back in the
267/// result so a two-pass adaptive run is possible).
268pub fn compute_association_analysis<N, E>(
269    graph: &Graph<N, E>,
270    min_support: usize,
271    min_confidence: f64,
272) -> AssociationAnalysis
273where
274    N: ClassNode,
275    E: EdgeKind,
276{
277    let mut transactions: Vec<BTreeSet<String>> = Vec::new();
278
279    for nx in graph.node_indices() {
280        let node = &graph[nx];
281        let Some(blob) = node.call_sequences() else {
282            continue;
283        };
284        // The blob can be either the {sequences: [...]} object or a JSON string
285        // when stored verbatim from GraphML — handle both, matching clone_detection.
286        let parsed = if let Some(s) = blob.as_str() {
287            serde_json::from_str::<serde_json::Value>(s)
288                .ok()
289                .and_then(|v| parse_call_sequences(&v))
290        } else {
291            parse_call_sequences(blob)
292        };
293        let Some(seqs) = parsed else {
294            continue;
295        };
296        for seq in seqs {
297            if seq.calls.len() >= 2 {
298                transactions.push(seq.calls.into_iter().collect());
299            }
300        }
301    }
302
303    let itemsets = mine_frequent_itemsets(&transactions, min_support);
304    let rules = generate_rules(&itemsets, &transactions, min_confidence);
305    let invariants = rules.iter().filter(|r| r.classification == "invariant").count();
306    let strong = rules.iter().filter(|r| r.classification == "strong").count();
307    let moderate = rules.iter().filter(|r| r.classification == "moderate").count();
308
309    AssociationAnalysis {
310        transactions: transactions.len(),
311        num_rules: rules.len(),
312        invariants,
313        strong,
314        moderate,
315        itemsets: itemsets.len(),
316        rules,
317    }
318}
319
320// ── helpers ──────────────────────────────────────────────────────────────────
321
322fn round4(x: f64) -> f64 {
323    (x * 10_000.0).round() / 10_000.0
324}
325
326/// All `k`-combinations of `items` in lexicographic index order. Empty when
327/// `k > items.len()` or `k == 0`.
328fn combinations<T: Clone>(items: &[T], k: usize) -> Vec<Vec<T>> {
329    let n = items.len();
330    if k == 0 || k > n {
331        return Vec::new();
332    }
333    let mut out: Vec<Vec<T>> = Vec::new();
334    let mut idx: Vec<usize> = (0..k).collect();
335    loop {
336        out.push(idx.iter().map(|&i| items[i].clone()).collect());
337        // Find rightmost index we can still advance.
338        let mut i = k;
339        while i > 0 {
340            i -= 1;
341            if idx[i] < n - (k - i) {
342                idx[i] += 1;
343                for j in (i + 1)..k {
344                    idx[j] = idx[j - 1] + 1;
345                }
346                break;
347            }
348            if i == 0 {
349                return out;
350            }
351        }
352        // Termination: when no index can advance, exit. The inner loop sets
353        // `i = 0` and either advances [0] or returns above.
354        if idx[0] > n - k {
355            return out;
356        }
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363    use srcgraph_core::{OwnedClassNode, OwnedGraph};
364    use petgraph::Graph;
365    use serde_json::json;
366
367    fn class(id: &str, seqs: Option<serde_json::Value>) -> OwnedClassNode {
368        OwnedClassNode {
369            id: id.to_owned(),
370            name: id.to_owned(),
371            namespace: "test".to_owned(),
372            line_count: 10,
373            method_count: 1,
374            halstead_eta1: 0,
375            halstead_eta2: 0,
376            halstead_n1: 0,
377            halstead_n2: 0,
378            method_connectivity: None,
379            method_fingerprints: None,
380            method_tokens: None,
381            call_sequences: seqs,
382            cyclomatic_complexity: None,
383            path_conditions: None,
384            invariants: None,
385            error_messages: None,
386            magic_numbers: None,
387            dead_code: None,
388            tenant_branches: None,
389            state_transitions: None,
390        }
391    }
392
393    fn simple_transactions() -> Vec<BTreeSet<String>> {
394        vec![
395            ["Validate", "Save", "Notify"].into_iter().map(String::from).collect(),
396            ["Validate", "Save", "Notify"].into_iter().map(String::from).collect(),
397            ["Validate", "Delete", "Notify"].into_iter().map(String::from).collect(),
398            ["Validate", "Charge", "Notify"].into_iter().map(String::from).collect(),
399        ]
400    }
401
402    // ── parse_call_sequences ─────────────────────────────────────────────
403
404    #[test]
405    fn parse_basic_sequences() {
406        let blob = json!({"sequences": [
407            {"method": "CreateOrder", "calls": ["Validate", "Save"]},
408            {"method": "DeleteOrder", "calls": ["Validate", "Delete"]},
409        ]});
410        let seqs = parse_call_sequences(&blob).expect("parses");
411        assert_eq!(seqs.len(), 2);
412        assert_eq!(seqs[0].method, "CreateOrder");
413        assert_eq!(seqs[0].calls, vec!["Validate", "Save"]);
414    }
415
416    #[test]
417    fn parse_missing_key_yields_none() {
418        let blob = json!({"other": []});
419        assert!(parse_call_sequences(&blob).is_none());
420    }
421
422    #[test]
423    fn parse_empty_sequences() {
424        let blob = json!({"sequences": []});
425        assert_eq!(parse_call_sequences(&blob).unwrap().len(), 0);
426    }
427
428    // ── combinations helper ──────────────────────────────────────────────
429
430    #[test]
431    fn combinations_basic() {
432        let items: Vec<&str> = vec!["a", "b", "c", "d"];
433        let cs = combinations(&items, 2);
434        assert_eq!(cs.len(), 6);
435        assert!(cs.contains(&vec!["a", "b"]));
436        assert!(cs.contains(&vec!["c", "d"]));
437    }
438
439    #[test]
440    fn combinations_k_too_big() {
441        let items: Vec<&str> = vec!["a", "b"];
442        assert!(combinations(&items, 3).is_empty());
443    }
444
445    #[test]
446    fn combinations_k_zero() {
447        let items: Vec<&str> = vec!["a"];
448        assert!(combinations(&items, 0).is_empty());
449    }
450
451    // ── mine_frequent_itemsets ───────────────────────────────────────────
452
453    #[test]
454    fn mine_finds_singleton_in_all_txns() {
455        let txns = simple_transactions();
456        let isets = mine_frequent_itemsets(&txns, 2);
457        let v = isets.iter().find(|fi| fi.items == vec!["Validate"]);
458        assert!(v.is_some(), "expected Validate as frequent 1-itemset");
459        assert_eq!(v.unwrap().support, 4);
460    }
461
462    #[test]
463    fn mine_finds_pair() {
464        let txns = simple_transactions();
465        let isets = mine_frequent_itemsets(&txns, 2);
466        let vn = isets
467            .iter()
468            .find(|fi| fi.items == vec!["Notify", "Validate"]);
469        assert!(vn.is_some(), "expected {{Validate, Notify}} pair");
470        assert_eq!(vn.unwrap().support, 4);
471    }
472
473    #[test]
474    fn mine_respects_min_support() {
475        let txns = simple_transactions();
476        let high = mine_frequent_itemsets(&txns, 4);
477        let low = mine_frequent_itemsets(&txns, 2);
478        assert!(high.len() <= low.len());
479        // At min_support=4 only items appearing in every txn survive: Validate, Notify, and their pair.
480        for fi in &high {
481            assert!(fi.support >= 4);
482        }
483    }
484
485    #[test]
486    fn mine_empty_transactions() {
487        assert!(mine_frequent_itemsets(&[], 1).is_empty());
488    }
489
490    #[test]
491    fn mine_sorted_by_support_desc() {
492        let txns = simple_transactions();
493        let isets = mine_frequent_itemsets(&txns, 2);
494        let supports: Vec<usize> = isets.iter().map(|fi| fi.support).collect();
495        let mut sorted = supports.clone();
496        sorted.sort_by(|a, b| b.cmp(a));
497        assert_eq!(supports, sorted);
498    }
499
500    // ── generate_rules ───────────────────────────────────────────────────
501
502    #[test]
503    fn generate_rules_produces_high_confidence_pair() {
504        let txns = simple_transactions();
505        let isets = mine_frequent_itemsets(&txns, 2);
506        let rules = generate_rules(&isets, &txns, 0.5);
507        assert!(!rules.is_empty());
508        // Every txn contains both Validate and Notify, so Validate → Notify is conf 1.0.
509        let vn = rules.iter().find(|r| {
510            r.antecedent == vec!["Validate"] && r.consequent == vec!["Notify"]
511        });
512        assert!(vn.is_some());
513        assert!(vn.unwrap().confidence >= 0.99);
514    }
515
516    #[test]
517    fn generate_rules_sorted_by_confidence_desc() {
518        let txns = simple_transactions();
519        let isets = mine_frequent_itemsets(&txns, 2);
520        let rules = generate_rules(&isets, &txns, 0.5);
521        for w in rules.windows(2) {
522            assert!(w[0].confidence >= w[1].confidence);
523        }
524    }
525
526    #[test]
527    fn generate_rules_classification_applied() {
528        let txns = simple_transactions();
529        let isets = mine_frequent_itemsets(&txns, 2);
530        let rules = generate_rules(&isets, &txns, 0.5);
531        for r in &rules {
532            assert!(["invariant", "strong", "moderate"].contains(&r.classification.as_str()));
533        }
534    }
535
536    #[test]
537    fn generate_rules_empty_inputs() {
538        assert!(generate_rules(&[], &[], 0.5).is_empty());
539        let txns = simple_transactions();
540        assert!(generate_rules(&[], &txns, 0.5).is_empty());
541    }
542
543    #[test]
544    fn classify_thresholds() {
545        assert_eq!(classify_rule(1.0), "invariant");
546        assert_eq!(classify_rule(0.99), "invariant");
547        assert_eq!(classify_rule(0.85), "strong");
548        assert_eq!(classify_rule(0.5), "moderate");
549        assert_eq!(classify_rule(0.49), "moderate");
550    }
551
552    // ── compute_association_analysis ─────────────────────────────────────
553
554    #[test]
555    fn compute_walks_graph_and_counts() {
556        let blob = json!({"sequences": [
557            {"method": "CreateOrder", "calls": ["Validate", "Save", "Notify"]},
558            {"method": "UpdateOrder", "calls": ["Validate", "Save", "Notify"]},
559            {"method": "DeleteOrder", "calls": ["Validate", "Delete", "Notify"]},
560            {"method": "Payment", "calls": ["Validate", "Charge", "Notify", "Log"]},
561        ]});
562        let mut g: OwnedGraph = Graph::new();
563        g.add_node(class("OrderSvc", Some(blob)));
564        // Node without sequences — silently skipped.
565        g.add_node(class("Plain", None));
566
567        let r = compute_association_analysis(&g, 2, 0.5);
568        assert_eq!(r.transactions, 4);
569        assert!(r.num_rules > 0);
570        assert!(r.itemsets > 0);
571        // Validate → Notify is an invariant (every txn has both).
572        assert!(r.invariants >= 1);
573    }
574
575    #[test]
576    fn compute_accepts_string_encoded_blob() {
577        let inner = json!({"sequences": [
578            {"method": "A", "calls": ["X", "Y"]},
579            {"method": "B", "calls": ["X", "Y"]},
580        ]});
581        let mut g: OwnedGraph = Graph::new();
582        g.add_node(class("A", Some(serde_json::Value::String(inner.to_string()))));
583        let r = compute_association_analysis(&g, 2, 0.5);
584        assert_eq!(r.transactions, 2);
585    }
586
587    #[test]
588    fn compute_empty_graph_zero_rules() {
589        let g: OwnedGraph = Graph::new();
590        let r = compute_association_analysis(&g, 2, 0.5);
591        assert_eq!(r.transactions, 0);
592        assert_eq!(r.num_rules, 0);
593        assert_eq!(r.itemsets, 0);
594    }
595
596    #[test]
597    fn compute_skips_short_sequences() {
598        // calls of length < 2 are skipped (matches Python).
599        let blob = json!({"sequences": [
600            {"method": "A", "calls": ["only"]},
601            {"method": "B", "calls": []},
602        ]});
603        let mut g: OwnedGraph = Graph::new();
604        g.add_node(class("X", Some(blob)));
605        let r = compute_association_analysis(&g, 1, 0.5);
606        assert_eq!(r.transactions, 0);
607    }
608}