infigraph-core 1.5.3

AST-powered code analysis framework — parser, graph, diff, and analysis engine
Documentation
use std::collections::HashMap;

use anyhow::Result;

use crate::model::{FileExtraction, RelationKind};

use super::{escape, shortest_id};

const TYPE_KINDS: &[&str] = &["Class", "Interface", "Struct", "Trait", "Enum"];

/// Caller must hold WriteLock.
pub(crate) fn resolve_inherits(
    conn: &kuzu::Connection<'_>,
    extractions: &[FileExtraction],
    symbol_map: &HashMap<String, Vec<(String, String, String)>>,
) -> Result<usize> {
    let mut resolved_pairs: Vec<(String, String)> = Vec::new();

    for ext in extractions {
        let local_symbols: std::collections::HashSet<&str> =
            ext.symbols.iter().map(|s| s.name.as_str()).collect();

        let imported_stems: std::collections::HashSet<String> = ext
            .relations
            .iter()
            .filter(|r| r.kind == RelationKind::Imports)
            .map(|r| {
                let raw = r
                    .target_id
                    .rsplit(['/', '\\', '.'])
                    .next()
                    .unwrap_or(&r.target_id);
                raw.to_lowercase()
            })
            .collect();

        for rel in &ext.relations {
            if rel.kind != RelationKind::Inherits {
                continue;
            }

            let target_name = rel.target_id.rsplit("::").next().unwrap_or(&rel.target_id);

            if local_symbols.contains(target_name) {
                continue;
            }

            if let Some(candidates) = symbol_map.get(target_name) {
                let cross_file: Vec<_> = candidates
                    .iter()
                    .filter(|(_, f, kind)| *f != ext.file && TYPE_KINDS.contains(&kind.as_str()))
                    .collect();

                let resolved_id = if cross_file.len() == 1 {
                    Some(cross_file[0].0.clone())
                } else if cross_file.len() > 1 {
                    let in_scope = shortest_id(cross_file.iter().copied(), |(_, f, _)| {
                        let stem = std::path::Path::new(f)
                            .file_stem()
                            .and_then(|s| s.to_str())
                            .map(|s| s.to_lowercase())
                            .unwrap_or_default();
                        imported_stems.contains(&stem)
                    });
                    let by_kind = in_scope
                        .is_none()
                        .then(|| {
                            shortest_id(cross_file.iter().copied(), |(_, _, k)| k == "Interface")
                        })
                        .flatten();
                    in_scope.or(by_kind).or_else(|| {
                        cross_file
                            .iter()
                            .min_by(|(a, _, _), (b, _, _)| {
                                a.len().cmp(&b.len()).then_with(|| a.cmp(b))
                            })
                            .map(|(id, _, _)| id.clone())
                    })
                } else {
                    None
                };

                if let Some(target_id) = resolved_id {
                    resolved_pairs.push((rel.source_id.clone(), target_id));
                }
            }
        }
    }

    if resolved_pairs.is_empty() {
        return Ok(0);
    }

    let count = resolved_pairs.len();

    let mut known_ids: std::collections::HashSet<&str> = symbol_map
        .values()
        .flat_map(|v| v.iter().map(|(id, _, _)| id.as_str()))
        .collect();
    for ext in extractions {
        for sym in &ext.symbols {
            known_ids.insert(&sym.id);
        }
    }

    let mut file_name_to_ids: HashMap<(String, String), Vec<String>> = HashMap::new();
    for ext in extractions {
        for sym in &ext.symbols {
            file_name_to_ids
                .entry((ext.file.clone(), sym.name.clone()))
                .or_default()
                .push(sym.id.clone());
        }
    }
    for candidates in symbol_map.values() {
        for (id, file, _) in candidates {
            let name = id.rsplit("::").next().unwrap_or(id);
            file_name_to_ids
                .entry((file.clone(), name.to_string()))
                .or_default()
                .push(id.clone());
        }
    }

    let fixed_pairs: Vec<(String, String)> = resolved_pairs
        .iter()
        .flat_map(|(src, tgt)| {
            if known_ids.contains(src.as_str()) {
                vec![(src.clone(), tgt.clone())]
            } else if let Some(sep) = src.rfind("::") {
                let file_part = &src[..sep];
                let name_part = &src[sep + 2..];
                if let Some(ids) =
                    file_name_to_ids.get(&(file_part.to_string(), name_part.to_string()))
                {
                    ids.iter()
                        .filter(|id| known_ids.contains(id.as_str()))
                        .map(|id| (id.clone(), tgt.clone()))
                        .collect::<Vec<_>>()
                } else {
                    vec![(src.clone(), tgt.clone())]
                }
            } else {
                vec![(src.clone(), tgt.clone())]
            }
        })
        .collect();

    let valid_pairs: Vec<&(String, String)> = fixed_pairs
        .iter()
        .filter(|(src, tgt)| known_ids.contains(src.as_str()) && known_ids.contains(tgt.as_str()))
        .collect();

    if valid_pairs.is_empty() {
        return Ok(0);
    }

    let refs: Vec<(&str, &str)> = valid_pairs
        .iter()
        .map(|(a, b)| (a.as_str(), b.as_str()))
        .collect();
    let pq_path = std::env::temp_dir().join("infigraph_resolve_inherits.parquet");
    crate::graph::parquet_loader::write_edge_parquet(&pq_path, &refs)?;
    let copy_result = conn.query(&format!(
        "COPY INHERITS FROM '{}'",
        pq_path.to_string_lossy().replace('\\', "/")
    ));
    if let Err(e) = copy_result {
        eprintln!("[resolve] COPY INHERITS FROM parquet failed ({e}), falling back to UNWIND");
        const CHUNK_SIZE: usize = 500;
        for chunk in refs.chunks(CHUNK_SIZE) {
            let pair_list: Vec<String> = chunk
                .iter()
                .map(|(a, b)| format!("{{a: '{}', b: '{}'}}", escape(a), escape(b)))
                .collect();
            let _ = conn.query(&format!(
                "UNWIND [{}] AS p MATCH (a:Symbol), (b:Symbol) WHERE a.id = p.a AND b.id = p.b CREATE (a)-[:INHERITS]->(b)",
                pair_list.join(", ")
            ));
        }
    }
    let _ = std::fs::remove_file(&pq_path);

    Ok(count)
}