use std::collections::{HashMap, HashSet};
use crate::types::{
AcbError, AcbResult, CodeUnit, CodeUnitType, Edge, EdgeType, Language, MAX_EDGES_PER_UNIT,
};
#[derive(Debug, Clone)]
pub struct CodeGraph {
units: Vec<CodeUnit>,
edges: Vec<Edge>,
edges_by_source: HashMap<u64, Vec<usize>>,
edges_by_target: HashMap<u64, Vec<usize>>,
dimension: usize,
languages: HashSet<Language>,
}
impl CodeGraph {
pub fn new(dimension: usize) -> Self {
Self {
units: Vec::new(),
edges: Vec::new(),
edges_by_source: HashMap::new(),
edges_by_target: HashMap::new(),
dimension,
languages: HashSet::new(),
}
}
pub fn with_default_dimension() -> Self {
Self::new(crate::types::DEFAULT_DIMENSION)
}
pub fn add_unit(&mut self, mut unit: CodeUnit) -> u64 {
let id = self.units.len() as u64;
unit.id = id;
self.languages.insert(unit.language);
self.units.push(unit);
id
}
pub fn add_edge(&mut self, edge: Edge) -> AcbResult<()> {
if edge.source_id == edge.target_id {
return Err(AcbError::SelfEdge(edge.source_id));
}
if edge.source_id >= self.units.len() as u64 {
return Err(AcbError::UnitNotFound(edge.source_id));
}
if edge.target_id >= self.units.len() as u64 {
return Err(AcbError::InvalidEdgeTarget(edge.target_id));
}
let source_edge_count = self
.edges_by_source
.get(&edge.source_id)
.map(|v| v.len() as u32)
.unwrap_or(0);
if source_edge_count >= MAX_EDGES_PER_UNIT {
return Err(AcbError::TooManyEdges(source_edge_count));
}
let idx = self.edges.len();
self.edges_by_source
.entry(edge.source_id)
.or_default()
.push(idx);
self.edges_by_target
.entry(edge.target_id)
.or_default()
.push(idx);
self.edges.push(edge);
Ok(())
}
pub fn unit_count(&self) -> usize {
self.units.len()
}
pub fn edge_count(&self) -> usize {
self.edges.len()
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn languages(&self) -> &HashSet<Language> {
&self.languages
}
pub fn get_unit(&self, id: u64) -> Option<&CodeUnit> {
self.units.get(id as usize)
}
pub fn get_unit_mut(&mut self, id: u64) -> Option<&mut CodeUnit> {
self.units.get_mut(id as usize)
}
pub fn units(&self) -> &[CodeUnit] {
&self.units
}
pub fn edges(&self) -> &[Edge] {
&self.edges
}
pub fn edges_from(&self, source_id: u64) -> Vec<&Edge> {
self.edges_by_source
.get(&source_id)
.map(|indices| indices.iter().map(|&i| &self.edges[i]).collect())
.unwrap_or_default()
}
pub fn edges_to(&self, target_id: u64) -> Vec<&Edge> {
self.edges_by_target
.get(&target_id)
.map(|indices| indices.iter().map(|&i| &self.edges[i]).collect())
.unwrap_or_default()
}
pub fn edges_from_of_type(&self, source_id: u64, edge_type: EdgeType) -> Vec<&Edge> {
self.edges_from(source_id)
.into_iter()
.filter(|e| e.edge_type == edge_type)
.collect()
}
pub fn edges_to_of_type(&self, target_id: u64, edge_type: EdgeType) -> Vec<&Edge> {
self.edges_to(target_id)
.into_iter()
.filter(|e| e.edge_type == edge_type)
.collect()
}
pub fn find_units_by_name(&self, prefix: &str) -> Vec<&CodeUnit> {
let prefix_lower = prefix.to_lowercase();
self.units
.iter()
.filter(|u| u.name.to_lowercase().starts_with(&prefix_lower))
.collect()
}
pub fn find_units_by_exact_name(&self, name: &str) -> Vec<&CodeUnit> {
self.units.iter().filter(|u| u.name == name).collect()
}
pub fn find_units_by_type(&self, unit_type: CodeUnitType) -> Vec<&CodeUnit> {
self.units
.iter()
.filter(|u| u.unit_type == unit_type)
.collect()
}
pub fn find_units_by_language(&self, language: Language) -> Vec<&CodeUnit> {
self.units
.iter()
.filter(|u| u.language == language)
.collect()
}
pub fn find_units_by_path(&self, path: &std::path::Path) -> Vec<&CodeUnit> {
self.units.iter().filter(|u| u.file_path == path).collect()
}
pub fn has_edge(&self, source_id: u64, target_id: u64, edge_type: EdgeType) -> bool {
self.edges_from(source_id)
.iter()
.any(|e| e.target_id == target_id && e.edge_type == edge_type)
}
pub fn stats(&self) -> GraphStats {
let mut type_counts: HashMap<CodeUnitType, usize> = HashMap::new();
let mut edge_type_counts: HashMap<EdgeType, usize> = HashMap::new();
let mut lang_counts: HashMap<Language, usize> = HashMap::new();
for unit in &self.units {
*type_counts.entry(unit.unit_type).or_default() += 1;
*lang_counts.entry(unit.language).or_default() += 1;
}
for edge in &self.edges {
*edge_type_counts.entry(edge.edge_type).or_default() += 1;
}
GraphStats {
unit_count: self.units.len(),
edge_count: self.edges.len(),
dimension: self.dimension,
type_counts,
edge_type_counts,
language_counts: lang_counts,
}
}
}
impl Default for CodeGraph {
fn default() -> Self {
Self::with_default_dimension()
}
}
#[derive(Debug, Clone)]
pub struct GraphStats {
pub unit_count: usize,
pub edge_count: usize,
pub dimension: usize,
pub type_counts: HashMap<CodeUnitType, usize>,
pub edge_type_counts: HashMap<EdgeType, usize>,
pub language_counts: HashMap<Language, usize>,
}