Skip to main content

infigraph_core/patterns/
mod.rs

1//! Design pattern detection via graph queries.
2//!
3//! Detects common design patterns (Factory, Observer, Singleton, Strategy,
4//! Decorator) by querying the existing call/inheritance graph.
5
6use anyhow::Result;
7use serde::Serialize;
8
9use crate::graph::store::GraphStore;
10use crate::graph::GraphQuery;
11
12// ---------------------------------------------------------------------------
13// Data types
14// ---------------------------------------------------------------------------
15
16/// A single detected design-pattern instance.
17#[derive(Debug, Clone, Serialize)]
18pub struct PatternMatch {
19    /// Pattern name: "Factory", "Observer", "Singleton", "Strategy", "Decorator"
20    pub pattern: String,
21    /// Confidence level: "high", "medium", "low"
22    pub confidence: String,
23    /// Symbols participating in the pattern and their roles.
24    pub participants: Vec<PatternParticipant>,
25    /// Primary file where the pattern is anchored.
26    pub file: String,
27}
28
29/// A symbol participating in a detected pattern.
30#[derive(Debug, Clone, Serialize)]
31pub struct PatternParticipant {
32    /// Role within the pattern (e.g. "Creator", "Product", "Subject").
33    pub role: String,
34    /// Fully-qualified symbol name.
35    pub symbol: String,
36    /// File where the symbol lives.
37    pub file: String,
38}
39
40/// Aggregated pattern-detection report.
41#[derive(Debug, Clone, Serialize)]
42pub struct PatternReport {
43    pub patterns: Vec<PatternMatch>,
44}
45
46// ---------------------------------------------------------------------------
47// Public API
48// ---------------------------------------------------------------------------
49
50/// Run all pattern detectors and return a combined report.
51pub fn detect_all(store: &GraphStore) -> Result<PatternReport> {
52    let conn = store.connection()?;
53    let gq = GraphQuery::new(&conn);
54
55    let mut patterns = Vec::new();
56    patterns.extend(detect_factory(&gq));
57    patterns.extend(detect_singleton(&gq));
58    patterns.extend(detect_observer(&gq));
59    patterns.extend(detect_strategy(&gq));
60    patterns.extend(detect_decorator(&gq));
61
62    Ok(PatternReport { patterns })
63}
64
65/// Run detectors and optionally keep only the given pattern type.
66pub fn detect_filtered(store: &GraphStore, filter: Option<&str>) -> Result<PatternReport> {
67    let mut report = detect_all(store)?;
68    if let Some(name) = filter {
69        let lower = name.to_lowercase();
70        report
71            .patterns
72            .retain(|p| p.pattern.to_lowercase() == lower);
73    }
74    Ok(report)
75}
76
77/// Render report as human-readable text grouped by pattern type.
78pub fn format_report(report: &PatternReport) -> String {
79    if report.patterns.is_empty() {
80        return "No design patterns detected.\n".to_string();
81    }
82
83    let mut out = String::new();
84    let groups = group_by_pattern(&report.patterns);
85
86    for (pattern, matches) in &groups {
87        out.push_str(&format!(
88            "\n=== {} Pattern ({} instance{}) ===\n",
89            pattern,
90            matches.len(),
91            if matches.len() == 1 { "" } else { "s" }
92        ));
93        for (i, m) in matches.iter().enumerate() {
94            out.push_str(&format!("\n  {}. [{}] {}\n", i + 1, m.confidence, m.file));
95            for p in &m.participants {
96                out.push_str(&format!("     {:<14} {} ({})\n", p.role, p.symbol, p.file));
97            }
98        }
99    }
100
101    let total: usize = groups.iter().map(|(_, v)| v.len()).sum();
102    out.push_str(&format!(
103        "\nTotal: {} pattern instance(s) detected.\n",
104        total
105    ));
106    out
107}
108
109/// Render report as pretty-printed JSON.
110pub fn format_json(report: &PatternReport) -> String {
111    serde_json::to_string_pretty(report).unwrap_or_default()
112}
113
114// ---------------------------------------------------------------------------
115// Helpers
116// ---------------------------------------------------------------------------
117
118fn group_by_pattern(matches: &[PatternMatch]) -> Vec<(String, Vec<&PatternMatch>)> {
119    let order = ["Factory", "Singleton", "Observer", "Strategy", "Decorator"];
120    let mut groups: Vec<(String, Vec<&PatternMatch>)> = Vec::new();
121    for name in &order {
122        let items: Vec<&PatternMatch> = matches.iter().filter(|m| m.pattern == *name).collect();
123        if !items.is_empty() {
124            groups.push((name.to_string(), items));
125        }
126    }
127    // Catch any pattern names not in the standard order
128    for m in matches {
129        if !order.contains(&m.pattern.as_str()) {
130            if let Some(g) = groups.iter_mut().find(|(n, _)| *n == m.pattern) {
131                g.1.push(m);
132            } else {
133                groups.push((m.pattern.clone(), vec![m]));
134            }
135        }
136    }
137    groups
138}
139
140fn strip_quotes(s: &str) -> String {
141    s.trim_matches('"').trim_matches('\'').to_string()
142}
143
144// ---------------------------------------------------------------------------
145// Factory Pattern
146// ---------------------------------------------------------------------------
147// A class/struct with methods that create or return instances of subtypes.
148// High confidence when the method name contains create/build/make/factory.
149
150fn detect_factory(gq: &GraphQuery) -> Vec<PatternMatch> {
151    // Find methods that call constructors/classes that have INHERITS edges
152    let query = "\
153        MATCH (creator:Symbol)-[:CALLS]->(product:Symbol) \
154        WHERE creator.kind = 'Method' \
155        AND product.kind IN ['Class', 'Function'] \
156        AND EXISTS { MATCH (product)-[:INHERITS]->(:Symbol) } \
157        RETURN DISTINCT creator.parent, creator.name, creator.file, product.name, product.file";
158
159    let rows = match gq.raw_query(query) {
160        Ok(r) => r,
161        Err(_) => return Vec::new(),
162    };
163
164    let mut results: Vec<PatternMatch> = Vec::new();
165    let mut seen = std::collections::HashSet::new();
166
167    for row in &rows {
168        if row.len() < 5 {
169            continue;
170        }
171        let creator_parent = strip_quotes(&row[0]);
172        let creator_name = strip_quotes(&row[1]);
173        let creator_file = strip_quotes(&row[2]);
174        let product_name = strip_quotes(&row[3]);
175        let product_file = strip_quotes(&row[4]);
176
177        let key = format!("{}::{}", creator_parent, creator_name);
178        if !seen.insert(key) {
179            continue;
180        }
181
182        let name_lower = creator_name.to_lowercase();
183        let confidence = if name_lower.contains("create")
184            || name_lower.contains("build")
185            || name_lower.contains("make")
186            || name_lower.contains("factory")
187            || name_lower.contains("new_")
188        {
189            "high"
190        } else {
191            "medium"
192        };
193
194        results.push(PatternMatch {
195            pattern: "Factory".to_string(),
196            confidence: confidence.to_string(),
197            participants: vec![
198                PatternParticipant {
199                    role: "Creator".to_string(),
200                    symbol: format!("{}::{}", creator_parent, creator_name),
201                    file: creator_file.clone(),
202                },
203                PatternParticipant {
204                    role: "Product".to_string(),
205                    symbol: product_name,
206                    file: product_file,
207                },
208            ],
209            file: creator_file,
210        });
211    }
212    results
213}
214
215// ---------------------------------------------------------------------------
216// Singleton Pattern
217// ---------------------------------------------------------------------------
218// Classes with a static instance-access method (getInstance, instance, shared,
219// get_instance, etc.).
220
221fn detect_singleton(gq: &GraphQuery) -> Vec<PatternMatch> {
222    let singleton_names = [
223        "getInstance",
224        "instance",
225        "shared",
226        "get_instance",
227        "getDefault",
228        "sharedInstance",
229    ];
230
231    let mut results: Vec<PatternMatch> = Vec::new();
232    let mut seen = std::collections::HashSet::new();
233
234    for accessor in &singleton_names {
235        let query = format!(
236            "MATCH (cls:Symbol), (method:Symbol) \
237             WHERE cls.kind = 'Class' \
238             AND method.kind = 'Method' \
239             AND method.parent = cls.name \
240             AND method.name = '{}' \
241             RETURN DISTINCT cls.name, cls.file, method.name",
242            accessor
243        );
244
245        let rows = match gq.raw_query(&query) {
246            Ok(r) => r,
247            Err(_) => continue,
248        };
249
250        for row in &rows {
251            if row.len() < 3 {
252                continue;
253            }
254            let cls_name = strip_quotes(&row[0]);
255            let cls_file = strip_quotes(&row[1]);
256            let method_name = strip_quotes(&row[2]);
257
258            if !seen.insert(cls_name.clone()) {
259                continue;
260            }
261
262            results.push(PatternMatch {
263                pattern: "Singleton".to_string(),
264                confidence: "high".to_string(),
265                participants: vec![
266                    PatternParticipant {
267                        role: "Singleton".to_string(),
268                        symbol: cls_name.clone(),
269                        file: cls_file.clone(),
270                    },
271                    PatternParticipant {
272                        role: "Accessor".to_string(),
273                        symbol: format!("{}::{}", cls_name, method_name),
274                        file: cls_file.clone(),
275                    },
276                ],
277                file: cls_file,
278            });
279        }
280    }
281    results
282}
283
284// ---------------------------------------------------------------------------
285// Observer Pattern
286// ---------------------------------------------------------------------------
287// Subject with subscribe/register + notify/emit methods.
288
289fn detect_observer(gq: &GraphQuery) -> Vec<PatternMatch> {
290    // Step 1: find methods whose names suggest registration of listeners
291    let register_query = "\
292        MATCH (reg:Symbol) \
293        WHERE reg.kind = 'Method' \
294        AND (reg.name CONTAINS 'register' \
295             OR reg.name CONTAINS 'subscribe' \
296             OR reg.name CONTAINS 'add_listener' \
297             OR reg.name CONTAINS 'addEventListener' \
298             OR reg.name CONTAINS 'addObserver' \
299             OR reg.name CONTAINS 'on_') \
300        RETURN DISTINCT reg.parent, reg.name, reg.file";
301
302    let reg_rows = match gq.raw_query(register_query) {
303        Ok(r) => r,
304        Err(_) => return Vec::new(),
305    };
306
307    if reg_rows.is_empty() {
308        return Vec::new();
309    }
310
311    // Build set of parent classes that have register methods
312    let mut register_parents = std::collections::HashMap::<String, (String, String)>::new();
313    for row in &reg_rows {
314        if row.len() < 3 {
315            continue;
316        }
317        let parent = strip_quotes(&row[0]);
318        let method = strip_quotes(&row[1]);
319        let file = strip_quotes(&row[2]);
320        if !parent.is_empty() {
321            register_parents.entry(parent).or_insert((file, method));
322        }
323    }
324
325    // Step 2: check which of those classes also have notify-like methods
326    let notify_query = "\
327        MATCH (n:Symbol) \
328        WHERE n.kind = 'Method' \
329        AND (n.name CONTAINS 'notify' \
330             OR n.name CONTAINS 'emit' \
331             OR n.name CONTAINS 'publish' \
332             OR n.name CONTAINS 'dispatch' \
333             OR n.name CONTAINS 'fire') \
334        RETURN DISTINCT n.parent, n.name, n.file";
335
336    let notify_rows = match gq.raw_query(notify_query) {
337        Ok(r) => r,
338        Err(_) => return Vec::new(),
339    };
340
341    let mut results: Vec<PatternMatch> = Vec::new();
342    let mut seen = std::collections::HashSet::new();
343
344    for row in &notify_rows {
345        if row.len() < 3 {
346            continue;
347        }
348        let parent = strip_quotes(&row[0]);
349        let notify_name = strip_quotes(&row[1]);
350        let file = strip_quotes(&row[2]);
351
352        if let Some((reg_file, reg_method)) = register_parents.get(&parent) {
353            if !seen.insert(parent.clone()) {
354                continue;
355            }
356            results.push(PatternMatch {
357                pattern: "Observer".to_string(),
358                confidence: "high".to_string(),
359                participants: vec![
360                    PatternParticipant {
361                        role: "Subject".to_string(),
362                        symbol: parent.clone(),
363                        file: reg_file.clone(),
364                    },
365                    PatternParticipant {
366                        role: "Register".to_string(),
367                        symbol: format!("{}::{}", parent, reg_method),
368                        file: reg_file.clone(),
369                    },
370                    PatternParticipant {
371                        role: "Notify".to_string(),
372                        symbol: format!("{}::{}", parent, notify_name),
373                        file,
374                    },
375                ],
376                file: reg_file.clone(),
377            });
378        }
379    }
380    results
381}
382
383// ---------------------------------------------------------------------------
384// Strategy Pattern
385// ---------------------------------------------------------------------------
386// An interface/trait with 3+ classes inheriting from it.
387
388fn detect_strategy(gq: &GraphQuery) -> Vec<PatternMatch> {
389    // Kuzu/lbug may not support WITH + aggregation well, so fetch all
390    // INHERITS edges and aggregate in Rust.
391    let query = "\
392        MATCH (impl:Symbol)-[:INHERITS]->(iface:Symbol) \
393        WHERE iface.kind IN ['Class', 'Interface', 'Trait'] \
394        RETURN iface.name, iface.file, impl.name, impl.file";
395
396    let rows = match gq.raw_query(query) {
397        Ok(r) => r,
398        Err(_) => return Vec::new(),
399    };
400
401    // Group implementations by interface
402    let mut iface_impls: std::collections::HashMap<String, (String, Vec<(String, String)>)> =
403        std::collections::HashMap::new();
404
405    for row in &rows {
406        if row.len() < 4 {
407            continue;
408        }
409        let iface_name = strip_quotes(&row[0]);
410        let iface_file = strip_quotes(&row[1]);
411        let impl_name = strip_quotes(&row[2]);
412        let impl_file = strip_quotes(&row[3]);
413
414        let entry = iface_impls
415            .entry(iface_name)
416            .or_insert_with(|| (iface_file, Vec::new()));
417        entry.1.push((impl_name, impl_file));
418    }
419
420    let mut results: Vec<PatternMatch> = Vec::new();
421
422    for (iface_name, (iface_file, impls)) in &iface_impls {
423        if impls.len() < 3 {
424            continue;
425        }
426
427        let confidence = if impls.len() >= 5 { "high" } else { "medium" };
428
429        let mut participants = vec![PatternParticipant {
430            role: "Strategy".to_string(),
431            symbol: iface_name.clone(),
432            file: iface_file.clone(),
433        }];
434
435        for (impl_name, impl_file) in impls {
436            participants.push(PatternParticipant {
437                role: "ConcreteStrategy".to_string(),
438                symbol: impl_name.clone(),
439                file: impl_file.clone(),
440            });
441        }
442
443        results.push(PatternMatch {
444            pattern: "Strategy".to_string(),
445            confidence: confidence.to_string(),
446            participants,
447            file: iface_file.clone(),
448        });
449    }
450
451    results
452}
453
454// ---------------------------------------------------------------------------
455// Decorator / Wrapper Pattern
456// ---------------------------------------------------------------------------
457// A class that inherits from X AND calls methods on the base type.
458
459fn detect_decorator(gq: &GraphQuery) -> Vec<PatternMatch> {
460    let query = "\
461        MATCH (decorator:Symbol)-[:INHERITS]->(base:Symbol) \
462        WHERE decorator.kind = 'Class' \
463        AND base.kind IN ['Class', 'Interface', 'Trait'] \
464        AND EXISTS { \
465            MATCH (decorator)-[:CALLS]->(base_method:Symbol) \
466            WHERE base_method.parent = base.name \
467        } \
468        RETURN DISTINCT decorator.name, decorator.file, base.name, base.file";
469
470    let rows = match gq.raw_query(query) {
471        Ok(r) => r,
472        Err(_) => return Vec::new(),
473    };
474
475    let mut results: Vec<PatternMatch> = Vec::new();
476    let mut seen = std::collections::HashSet::new();
477
478    for row in &rows {
479        if row.len() < 4 {
480            continue;
481        }
482        let dec_name = strip_quotes(&row[0]);
483        let dec_file = strip_quotes(&row[1]);
484        let base_name = strip_quotes(&row[2]);
485        let base_file = strip_quotes(&row[3]);
486
487        if !seen.insert(format!("{}>{}", dec_name, base_name)) {
488            continue;
489        }
490
491        // Higher confidence if the decorator name suggests wrapping
492        let name_lower = dec_name.to_lowercase();
493        let confidence = if name_lower.contains("decorator")
494            || name_lower.contains("wrapper")
495            || name_lower.contains("proxy")
496            || name_lower.contains("adapter")
497        {
498            "high"
499        } else {
500            "medium"
501        };
502
503        results.push(PatternMatch {
504            pattern: "Decorator".to_string(),
505            confidence: confidence.to_string(),
506            participants: vec![
507                PatternParticipant {
508                    role: "Decorator".to_string(),
509                    symbol: dec_name,
510                    file: dec_file.clone(),
511                },
512                PatternParticipant {
513                    role: "Component".to_string(),
514                    symbol: base_name,
515                    file: base_file,
516                },
517            ],
518            file: dec_file,
519        });
520    }
521    results
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527
528    #[test]
529    fn empty_report_formats() {
530        let report = PatternReport { patterns: vec![] };
531        assert_eq!(format_report(&report), "No design patterns detected.\n");
532    }
533
534    #[test]
535    fn json_roundtrip() {
536        let report = PatternReport {
537            patterns: vec![PatternMatch {
538                pattern: "Factory".to_string(),
539                confidence: "high".to_string(),
540                participants: vec![PatternParticipant {
541                    role: "Creator".to_string(),
542                    symbol: "MyFactory::create".to_string(),
543                    file: "src/factory.rs".to_string(),
544                }],
545                file: "src/factory.rs".to_string(),
546            }],
547        };
548        let json = format_json(&report);
549        assert!(json.contains("Factory"));
550        assert!(json.contains("high"));
551    }
552
553    #[test]
554    fn strip_quotes_works() {
555        assert_eq!(strip_quotes("\"hello\""), "hello");
556        assert_eq!(strip_quotes("plain"), "plain");
557    }
558
559    #[test]
560    fn report_groups_by_pattern() {
561        let report = PatternReport {
562            patterns: vec![
563                PatternMatch {
564                    pattern: "Singleton".to_string(),
565                    confidence: "high".to_string(),
566                    participants: vec![],
567                    file: "a.py".to_string(),
568                },
569                PatternMatch {
570                    pattern: "Factory".to_string(),
571                    confidence: "medium".to_string(),
572                    participants: vec![],
573                    file: "b.py".to_string(),
574                },
575                PatternMatch {
576                    pattern: "Singleton".to_string(),
577                    confidence: "high".to_string(),
578                    participants: vec![],
579                    file: "c.py".to_string(),
580                },
581            ],
582        };
583        let text = format_report(&report);
584        // Factory should come before Singleton in the output (canonical order)
585        let factory_pos = text.find("Factory Pattern").unwrap();
586        let singleton_pos = text.find("Singleton Pattern").unwrap();
587        assert!(factory_pos < singleton_pos);
588        assert!(text.contains("Total: 3 pattern instance(s)"));
589    }
590}