Skip to main content

infigraph_core/refactor/
mod.rs

1use std::collections::HashMap;
2
3use anyhow::Result;
4use kuzu::Connection;
5use rayon::prelude::*;
6
7use crate::embed;
8use crate::graph::GraphQuery;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum Category {
12    SplitFile,
13    ExtractFunction,
14    MergeDuplicates,
15    RemoveDeadCode,
16    ReduceCoupling,
17    SimplifyLogic,
18}
19
20impl std::fmt::Display for Category {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        match self {
23            Self::SplitFile => write!(f, "split_file"),
24            Self::ExtractFunction => write!(f, "extract_function"),
25            Self::MergeDuplicates => write!(f, "merge_duplicates"),
26            Self::RemoveDeadCode => write!(f, "remove_dead_code"),
27            Self::ReduceCoupling => write!(f, "reduce_coupling"),
28            Self::SimplifyLogic => write!(f, "simplify_logic"),
29        }
30    }
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub enum Impact {
35    High,
36    Medium,
37    Low,
38}
39
40impl Impact {
41    fn score(&self) -> u32 {
42        match self {
43            Self::High => 3,
44            Self::Medium => 2,
45            Self::Low => 1,
46        }
47    }
48}
49
50impl std::fmt::Display for Impact {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        match self {
53            Self::High => write!(f, "high"),
54            Self::Medium => write!(f, "medium"),
55            Self::Low => write!(f, "low"),
56        }
57    }
58}
59
60#[derive(Debug, Clone, PartialEq, Eq)]
61pub enum Effort {
62    High,
63    Medium,
64    Low,
65}
66
67impl std::fmt::Display for Effort {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        match self {
70            Self::High => write!(f, "high"),
71            Self::Medium => write!(f, "medium"),
72            Self::Low => write!(f, "low"),
73        }
74    }
75}
76
77impl Effort {
78    fn score(&self) -> u32 {
79        match self {
80            Self::High => 1,
81            Self::Medium => 2,
82            Self::Low => 3,
83        }
84    }
85}
86
87#[derive(Debug, Clone)]
88pub struct Recommendation {
89    pub category: Category,
90    pub target: String,
91    pub impact: Impact,
92    pub effort: Effort,
93    pub rationale: String,
94}
95
96impl Recommendation {
97    fn priority(&self) -> u32 {
98        self.impact.score() * self.effort.score()
99    }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq)]
103pub enum Focus {
104    All,
105    Complexity,
106    Duplication,
107    Coupling,
108    Size,
109}
110
111impl Focus {
112    pub fn parse(s: &str) -> Self {
113        match s.to_lowercase().as_str() {
114            "complexity" => Self::Complexity,
115            "duplication" | "clones" => Self::Duplication,
116            "coupling" => Self::Coupling,
117            "size" => Self::Size,
118            _ => Self::All,
119        }
120    }
121}
122
123struct SymbolInfo {
124    id: String,
125    name: String,
126    kind: String,
127    file: String,
128    complexity: u32,
129    start_line: u32,
130    end_line: u32,
131}
132
133pub fn analyze(
134    conn: &Connection,
135    embeddings_path: Option<&std::path::Path>,
136    target: Option<&str>,
137    focus: Focus,
138    limit: usize,
139) -> Result<Vec<Recommendation>> {
140    let gq = GraphQuery::new(conn);
141    let symbols = load_symbols(&gq, target)?;
142
143    if symbols.is_empty() {
144        return Ok(vec![]);
145    }
146
147    let mut recommendations = Vec::new();
148
149    let run_all = focus == Focus::All;
150
151    if run_all || focus == Focus::Size {
152        analyze_file_sizes(&symbols, &mut recommendations);
153    }
154
155    if run_all || focus == Focus::Complexity {
156        analyze_complexity(&symbols, &mut recommendations);
157    }
158
159    if run_all || focus == Focus::Coupling {
160        analyze_coupling(&gq, &symbols, &mut recommendations)?;
161    }
162
163    if run_all || focus == Focus::Duplication {
164        analyze_duplication(&gq, &symbols, embeddings_path, &mut recommendations)?;
165    }
166
167    if run_all {
168        analyze_dead_code(&gq, &symbols, &mut recommendations)?;
169    }
170
171    recommendations.sort_by_key(|r| std::cmp::Reverse(r.priority()));
172    recommendations.truncate(limit);
173
174    Ok(recommendations)
175}
176
177fn load_symbols(gq: &GraphQuery, target: Option<&str>) -> Result<Vec<SymbolInfo>> {
178    let query = if let Some(t) = target {
179        format!(
180            "MATCH (s:Symbol) WHERE s.file CONTAINS '{}' RETURN s.id, s.name, s.kind, s.file, s.complexity, s.start_line, s.end_line ORDER BY s.file, s.start_line",
181            t.replace('\'', "\\'")
182        )
183    } else {
184        "MATCH (s:Symbol) WHERE s.kind IN ['Function', 'Method', 'Class', 'Struct', 'Interface', 'Test'] RETURN s.id, s.name, s.kind, s.file, s.complexity, s.start_line, s.end_line ORDER BY s.file, s.start_line".to_string()
185    };
186
187    let rows = gq.raw_query(&query)?;
188    Ok(rows
189        .into_iter()
190        .map(|r| SymbolInfo {
191            id: r[0].clone(),
192            name: r[1].clone(),
193            kind: r[2].clone(),
194            file: r[3].clone(),
195            complexity: r.get(4).and_then(|v| v.parse().ok()).unwrap_or(0),
196            start_line: r.get(5).and_then(|v| v.parse().ok()).unwrap_or(0),
197            end_line: r.get(6).and_then(|v| v.parse().ok()).unwrap_or(0),
198        })
199        .collect())
200}
201
202fn analyze_file_sizes(symbols: &[SymbolInfo], recs: &mut Vec<Recommendation>) {
203    let mut file_stats: HashMap<&str, (usize, u32)> = HashMap::new();
204
205    for sym in symbols {
206        let entry = file_stats.entry(sym.file.as_str()).or_insert((0, 0));
207        entry.0 += 1;
208        if sym.end_line > entry.1 {
209            entry.1 = sym.end_line;
210        }
211    }
212
213    for (file, (symbol_count, max_line)) in &file_stats {
214        if *max_line > 1000 || *symbol_count > 40 {
215            let impact = if *max_line > 2000 || *symbol_count > 80 {
216                Impact::High
217            } else {
218                Impact::Medium
219            };
220            let effort = if *symbol_count > 60 {
221                Effort::High
222            } else {
223                Effort::Medium
224            };
225            recs.push(Recommendation {
226                category: Category::SplitFile,
227                target: file.to_string(),
228                impact,
229                effort,
230                rationale: format!(
231                    "{} lines, {} symbols. Consider splitting into focused modules.",
232                    max_line, symbol_count
233                ),
234            });
235        }
236    }
237}
238
239fn analyze_complexity(symbols: &[SymbolInfo], recs: &mut Vec<Recommendation>) {
240    let threshold = 15u32;
241    let mut hotspots: Vec<&SymbolInfo> = symbols
242        .iter()
243        .filter(|s| {
244            s.complexity >= threshold
245                && (s.kind == "Function" || s.kind == "Method" || s.kind == "Test")
246        })
247        .collect();
248
249    hotspots.sort_by_key(|s| std::cmp::Reverse(s.complexity));
250
251    for sym in hotspots.iter().take(10) {
252        let loc = sym.end_line.saturating_sub(sym.start_line);
253        let (impact, effort) = if sym.complexity > 30 {
254            (Impact::High, Effort::High)
255        } else if sym.complexity > 20 {
256            (Impact::High, Effort::Medium)
257        } else {
258            (Impact::Medium, Effort::Medium)
259        };
260
261        let category = if loc > 80 {
262            Category::ExtractFunction
263        } else {
264            Category::SimplifyLogic
265        };
266
267        recs.push(Recommendation {
268            category,
269            target: format!("{} ({}:{})", sym.name, sym.file, sym.start_line),
270            impact,
271            effort,
272            rationale: format!(
273                "Cyclomatic complexity {}. {} lines. Break into smaller functions or simplify branching.",
274                sym.complexity, loc
275            ),
276        });
277    }
278}
279
280fn analyze_coupling(
281    gq: &GraphQuery,
282    symbols: &[SymbolInfo],
283    recs: &mut Vec<Recommendation>,
284) -> Result<()> {
285    let callable_ids: Vec<&str> = symbols
286        .iter()
287        .filter(|s| s.kind == "Function" || s.kind == "Method")
288        .map(|s| s.id.as_str())
289        .collect();
290
291    if callable_ids.is_empty() {
292        return Ok(());
293    }
294
295    let fan_out_query = "MATCH (s:Symbol)-[:CALLS]->(t:Symbol) WHERE s.kind IN ['Function', 'Method'] RETURN s.id, count(DISTINCT t) ORDER BY count(DISTINCT t) DESC";
296    let fan_out_rows = gq.raw_query(fan_out_query)?;
297
298    let fan_in_query = "MATCH (s:Symbol)<-[:CALLS]-(t:Symbol) WHERE s.kind IN ['Function', 'Method'] RETURN s.id, count(DISTINCT t) ORDER BY count(DISTINCT t) DESC";
299    let fan_in_rows = gq.raw_query(fan_in_query)?;
300
301    let sym_lookup: HashMap<&str, &SymbolInfo> =
302        symbols.iter().map(|s| (s.id.as_str(), s)).collect();
303
304    for row in fan_out_rows.iter().take(20) {
305        let count: u32 = row.get(1).and_then(|v| v.parse().ok()).unwrap_or(0);
306        if count < 15 {
307            continue;
308        }
309        let id = &row[0];
310        if let Some(sym) = sym_lookup.get(id.as_str()) {
311            let impact = if count > 25 {
312                Impact::High
313            } else {
314                Impact::Medium
315            };
316            recs.push(Recommendation {
317                category: Category::ReduceCoupling,
318                target: format!("{} ({}:{})", sym.name, sym.file, sym.start_line),
319                impact,
320                effort: Effort::Medium,
321                rationale: format!(
322                    "Fan-out of {} — calls {} distinct functions. High coupling makes changes risky.",
323                    count, count
324                ),
325            });
326        }
327    }
328
329    for row in fan_in_rows.iter().take(20) {
330        let count: u32 = row.get(1).and_then(|v| v.parse().ok()).unwrap_or(0);
331        if count < 20 {
332            continue;
333        }
334        let id = &row[0];
335        if let Some(sym) = sym_lookup.get(id.as_str()) {
336            recs.push(Recommendation {
337                category: Category::ReduceCoupling,
338                target: format!("{} ({}:{})", sym.name, sym.file, sym.start_line),
339                impact: Impact::High,
340                effort: Effort::High,
341                rationale: format!(
342                    "Fan-in of {} — {} callers depend on this. Changes have wide blast radius. Consider interface extraction.",
343                    count, count
344                ),
345            });
346        }
347    }
348
349    Ok(())
350}
351
352fn analyze_duplication(
353    _gq: &GraphQuery,
354    symbols: &[SymbolInfo],
355    embeddings_path: Option<&std::path::Path>,
356    recs: &mut Vec<Recommendation>,
357) -> Result<()> {
358    let callables: Vec<&SymbolInfo> = symbols
359        .iter()
360        .filter(|s| s.kind == "Function" || s.kind == "Method")
361        .collect();
362
363    if callables.len() < 2 {
364        return Ok(());
365    }
366
367    let embedder = embed::best_embedder();
368
369    let cached: HashMap<String, Vec<f32>> = if let Some(path) = embeddings_path {
370        if path.exists() {
371            embed::load_embeddings_cached(path)?.into_iter().collect()
372        } else {
373            HashMap::new()
374        }
375    } else {
376        HashMap::new()
377    };
378
379    let sym_vecs: Vec<(&SymbolInfo, Vec<f32>)> = callables
380        .par_iter()
381        .map(|sym| {
382            let emb = cached
383                .get(&sym.id)
384                .cloned()
385                .unwrap_or_else(|| embedder.embed(&sym.name).unwrap_or_default());
386            (*sym, emb)
387        })
388        .filter(|(_, emb)| !emb.is_empty())
389        .collect();
390
391    let threshold = 0.90f32;
392    let n = sym_vecs.len();
393    let mut pairs: Vec<(f32, usize, usize)> = Vec::new();
394
395    for i in 0..n {
396        for j in (i + 1)..n {
397            if sym_vecs[i].0.file == sym_vecs[j].0.file {
398                continue;
399            }
400            let sim = embed::cosine_similarity(&sym_vecs[i].1, &sym_vecs[j].1);
401            if sim >= threshold {
402                pairs.push((sim, i, j));
403            }
404        }
405    }
406
407    pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
408
409    for (sim, i, j) in pairs.iter().take(5) {
410        let (impact, effort) = if *sim > 0.95 {
411            (Impact::High, Effort::Low)
412        } else {
413            (Impact::Medium, Effort::Low)
414        };
415        recs.push(Recommendation {
416            category: Category::MergeDuplicates,
417            target: format!("{} ↔ {}", sym_vecs[*i].0.name, sym_vecs[*j].0.name),
418            impact,
419            effort,
420            rationale: format!(
421                "{:.0}% similar. {} ({}) and {} ({}). Extract shared logic.",
422                sim * 100.0,
423                sym_vecs[*i].0.name,
424                sym_vecs[*i].0.file,
425                sym_vecs[*j].0.name,
426                sym_vecs[*j].0.file,
427            ),
428        });
429    }
430
431    Ok(())
432}
433
434fn analyze_dead_code(
435    gq: &GraphQuery,
436    symbols: &[SymbolInfo],
437    recs: &mut Vec<Recommendation>,
438) -> Result<()> {
439    let rows = gq.raw_query(
440        "MATCH (s:Symbol) WHERE s.kind IN ['Function', 'Method'] AND NOT EXISTS { MATCH ()-[:CALLS]->(s) } RETURN s.id, s.name, s.file",
441    )?;
442
443    let entry_points = ["main", "__init__", "setUp", "tearDown", "setup", "teardown"];
444    let target_files: HashMap<&str, bool> =
445        symbols.iter().map(|s| (s.file.as_str(), true)).collect();
446
447    let dead: Vec<&Vec<String>> = rows
448        .iter()
449        .filter(|row| {
450            !entry_points.contains(&row[1].as_str())
451                && !row[1].starts_with("test_")
452                && !row[1].starts_with("Test")
453                && target_files.contains_key(row[2].as_str())
454        })
455        .collect();
456
457    if dead.is_empty() {
458        return Ok(());
459    }
460
461    let mut by_file: HashMap<&str, Vec<&str>> = HashMap::new();
462    for row in &dead {
463        by_file
464            .entry(row[2].as_str())
465            .or_default()
466            .push(row[1].as_str());
467    }
468
469    for (file, names) in &by_file {
470        if names.len() >= 3 {
471            recs.push(Recommendation {
472                category: Category::RemoveDeadCode,
473                target: file.to_string(),
474                impact: Impact::Low,
475                effort: Effort::Low,
476                rationale: format!(
477                    "{} unreachable functions: {}. Safe to remove (verify no dynamic dispatch).",
478                    names.len(),
479                    names.iter().take(5).cloned().collect::<Vec<_>>().join(", ")
480                ),
481            });
482        } else {
483            for name in names {
484                recs.push(Recommendation {
485                    category: Category::RemoveDeadCode,
486                    target: format!("{} ({})", name, file),
487                    impact: Impact::Low,
488                    effort: Effort::Low,
489                    rationale: "Zero callers. Safe to remove (verify no dynamic dispatch)."
490                        .to_string(),
491                });
492            }
493        }
494    }
495
496    Ok(())
497}
498
499pub fn format_recommendations(recs: &[Recommendation], target: Option<&str>) -> String {
500    if recs.is_empty() {
501        return format!(
502            "No refactoring recommendations for {}.",
503            target.unwrap_or("project")
504        );
505    }
506
507    let mut out = format!(
508        "Refactoring Analysis: {}\n{} recommendations, sorted by impact/effort ratio\n\n",
509        target.unwrap_or("project"),
510        recs.len()
511    );
512
513    let mut current_impact = None;
514
515    for (i, rec) in recs.iter().enumerate() {
516        let impact_label = format!("{} IMPACT", rec.impact).to_uppercase();
517        if current_impact.as_ref() != Some(&impact_label) {
518            if i > 0 {
519                out.push('\n');
520            }
521            out.push_str(&impact_label);
522            out.push_str(":\n");
523            current_impact = Some(impact_label);
524        }
525
526        out.push_str(&format!(
527            "{}. [{}] {}\n   Rationale: {}\n   Effort: {} | Impact: {}\n\n",
528            i + 1,
529            rec.category,
530            rec.target,
531            rec.rationale,
532            rec.effort,
533            rec.impact,
534        ));
535    }
536
537    out
538}