use super::{Symbol, SymbolRelation};
use super::SymbolRelationType;
use dashmap::DashMap;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
#[derive(Clone)]
pub struct RelationshipGraph {
relationships: Arc<DashMap<String, Vec<SymbolRelation>>>,
reverse_relationships: Arc<DashMap<String, Vec<SymbolRelation>>>,
symbols: Arc<DashMap<String, Symbol>>,
}
impl RelationshipGraph {
pub fn new() -> Self {
Self {
relationships: Arc::new(DashMap::new()),
reverse_relationships: Arc::new(DashMap::new()),
symbols: Arc::new(DashMap::new()),
}
}
pub fn add_symbol(&self, symbol: Symbol) {
self.symbols.insert(symbol.id.clone(), symbol);
}
pub fn add_relationship(
&self,
from_symbol: &str,
to_symbol: &str,
relation_type: SymbolRelationType,
confidence: f64,
) {
let relation = SymbolRelation {
symbol_id: to_symbol.to_string(),
relation_type: relation_type.clone(),
confidence,
};
self.relationships
.entry(from_symbol.to_string())
.or_insert_with(Vec::new)
.push(relation.clone());
let reverse_type = relation_type.reverse();
let reverse_relation = SymbolRelation {
symbol_id: from_symbol.to_string(),
relation_type: reverse_type,
confidence,
};
self.reverse_relationships
.entry(to_symbol.to_string())
.or_insert_with(Vec::new)
.push(reverse_relation);
}
pub fn get_relationships(&self, symbol_id: &str) -> Vec<SymbolRelation> {
self.relationships
.get(symbol_id)
.map(|relations| relations.value().clone())
.unwrap_or_default()
}
pub fn get_relationships_by_type(
&self,
symbol_id: &str,
relation_type: SymbolRelationType,
) -> Vec<SymbolRelation> {
self.get_relationships(symbol_id)
.into_iter()
.filter(|r| r.relation_type == relation_type)
.collect()
}
pub fn get_reverse_relationships(&self, symbol_id: &str) -> Vec<SymbolRelation> {
self.reverse_relationships
.get(symbol_id)
.map(|relations| relations.value().clone())
.unwrap_or_default()
}
pub fn find_children(&self, symbol_id: &str) -> Vec<Symbol> {
self.get_reverse_relationships(symbol_id)
.into_iter()
.filter(|r| r.relation_type == SymbolRelationType::Inherits)
.filter_map(|r| self.symbols.get(&r.symbol_id).map(|s| s.value().clone()))
.collect()
}
pub fn find_parents(&self, symbol_id: &str) -> Vec<Symbol> {
self.get_relationships(symbol_id)
.into_iter()
.filter(|r| r.relation_type == SymbolRelationType::Inherits)
.filter_map(|r| self.symbols.get(&r.symbol_id).map(|s| s.value().clone()))
.collect()
}
pub fn find_callees(&self, symbol_id: &str) -> Vec<Symbol> {
self.get_relationships(symbol_id)
.into_iter()
.filter(|r| r.relation_type == SymbolRelationType::Calls)
.filter_map(|r| self.symbols.get(&r.symbol_id).map(|s| s.value().clone()))
.collect()
}
pub fn find_callers(&self, symbol_id: &str) -> Vec<Symbol> {
self.get_reverse_relationships(symbol_id)
.into_iter()
.filter(|r| r.relation_type == SymbolRelationType::CalledBy)
.filter_map(|r| self.symbols.get(&r.symbol_id).map(|s| s.value().clone()))
.collect()
}
pub fn find_hierarchy(&self, symbol_id: &str) -> Vec<Symbol> {
let mut hierarchy = Vec::new();
let mut visited = HashSet::new();
self.collect_hierarchy(symbol_id, &mut hierarchy, &mut visited);
hierarchy
}
fn collect_hierarchy(
&self,
symbol_id: &str,
hierarchy: &mut Vec<Symbol>,
visited: &mut HashSet<String>,
) {
if visited.contains(symbol_id) {
return;
}
visited.insert(symbol_id.to_string());
if let Some(symbol) = self.symbols.get(symbol_id) {
hierarchy.push(symbol.value().clone());
for parent in self.find_parents(symbol_id) {
self.collect_hierarchy(&parent.id, hierarchy, visited);
}
for child in self.find_children(symbol_id) {
self.collect_hierarchy(&child.id, hierarchy, visited);
}
}
}
pub fn find_related(
&self,
symbol_id: &str,
relation_types: &[SymbolRelationType],
) -> Vec<Symbol> {
let mut related = Vec::new();
for relation in self.get_relationships(symbol_id) {
if relation_types.contains(&relation.relation_type) {
if let Some(symbol) = self.symbols.get(&relation.symbol_id) {
related.push(symbol.value().clone());
}
}
}
related
}
pub fn get_symbol(&self, id: &str) -> Option<Symbol> {
self.symbols.get(id).map(|s| s.value().clone())
}
pub fn all_symbols(&self) -> Vec<Symbol> {
self.symbols.iter().map(|e| e.value().clone()).collect()
}
pub fn clear(&self) {
self.relationships.clear();
self.reverse_relationships.clear();
self.symbols.clear();
}
pub fn build_from_symbols(&self, symbols: &[Symbol]) {
for symbol in symbols {
self.add_symbol(symbol.clone());
}
for symbol in symbols {
self.analyze_symbol_relationships(symbol);
}
}
fn analyze_symbol_relationships(&self, symbol: &Symbol) {
if let Some(ref parent) = symbol.parent {
self.add_relationship(
&symbol.id,
parent,
SymbolRelationType::ContainedIn,
1.0,
);
}
if symbol.kind == super::SymbolKind::Class || symbol.kind == super::SymbolKind::Struct {
if symbol.signature.contains("extends") || symbol.signature.contains(":") {
}
}
if symbol.kind == super::SymbolKind::Function || symbol.kind == super::SymbolKind::Method {
self.analyze_function_calls(symbol);
}
}
fn analyze_function_calls(&self, _symbol: &Symbol) {
}
}
impl Default for RelationshipGraph {
fn default() -> Self {
Self::new()
}
}
impl SymbolRelationType {
pub fn reverse(&self) -> SymbolRelationType {
match self {
SymbolRelationType::Inherits => SymbolRelationType::Inherits,
SymbolRelationType::CalledBy => SymbolRelationType::Calls,
SymbolRelationType::Implements => SymbolRelationType::CalledBy, SymbolRelationType::Calls => SymbolRelationType::CalledBy,
SymbolRelationType::Contains => SymbolRelationType::ContainedIn,
SymbolRelationType::ContainedIn => SymbolRelationType::Contains,
SymbolRelationType::References => SymbolRelationType::ReferencedBy,
SymbolRelationType::ReferencedBy => SymbolRelationType::References,
SymbolRelationType::Overrides => SymbolRelationType::OverriddenBy,
SymbolRelationType::OverriddenBy => SymbolRelationType::Overrides,
SymbolRelationType::Instantiates => SymbolRelationType::InstantiatedBy,
SymbolRelationType::InstantiatedBy => SymbolRelationType::Instantiates,
SymbolRelationType::ParameterOf => SymbolRelationType::FieldOf, SymbolRelationType::ReturnTypeOf => SymbolRelationType::MethodOf, SymbolRelationType::FieldOf => SymbolRelationType::Contains, SymbolRelationType::MethodOf => SymbolRelationType::Contains, SymbolRelationType::Imports => SymbolRelationType::ImportedBy,
SymbolRelationType::ImportedBy => SymbolRelationType::Imports,
SymbolRelationType::Exports => SymbolRelationType::ExportedBy,
SymbolRelationType::ExportedBy => SymbolRelationType::Exports,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_relationship_graph() {
let graph = RelationshipGraph::new();
let parent = Symbol {
id: "parent".to_string(),
name: "ParentClass".to_string(),
kind: super::super::SymbolKind::Class,
file_path: "test.rs".to_string(),
line: 10,
column: 0,
end_line: 15,
signature: "".to_string(),
documentation: None,
visibility: super::super::SymbolVisibility::Public,
parent: None,
type_info: None,
generics: vec![],
annotations: vec![],
attributes: vec![],
metadata: HashMap::new(),
};
let child = Symbol {
id: "child".to_string(),
name: "ChildClass".to_string(),
kind: super::super::SymbolKind::Class,
file_path: "test.rs".to_string(),
line: 20,
column: 0,
end_line: 25,
signature: "".to_string(),
documentation: None,
visibility: super::super::SymbolVisibility::Public,
parent: None,
type_info: None,
generics: vec![],
annotations: vec![],
attributes: vec![],
metadata: HashMap::new(),
};
graph.add_symbol(parent.clone());
graph.add_symbol(child.clone());
graph.add_relationship(&child.id, &parent.id, SymbolRelationType::Inherits, 1.0);
let parents = graph.find_parents(&child.id);
assert_eq!(parents.len(), 1);
assert_eq!(parents[0].name, "ParentClass");
let children = graph.find_children(&parent.id);
assert_eq!(children.len(), 1);
assert_eq!(children[0].name, "ChildClass");
}
#[test]
fn test_call_relationships() {
let graph = RelationshipGraph::new();
let caller = Symbol {
id: "caller".to_string(),
name: "caller_func".to_string(),
kind: super::super::SymbolKind::Function,
file_path: "test.rs".to_string(),
line: 10,
column: 0,
end_line: 15,
signature: "".to_string(),
documentation: None,
visibility: super::super::SymbolVisibility::Public,
parent: None,
type_info: None,
generics: vec![],
annotations: vec![],
attributes: vec![],
metadata: HashMap::new(),
};
let callee = Symbol {
id: "callee".to_string(),
name: "callee_func".to_string(),
kind: super::super::SymbolKind::Function,
file_path: "test.rs".to_string(),
line: 20,
column: 0,
end_line: 25,
signature: "".to_string(),
documentation: None,
visibility: super::super::SymbolVisibility::Public,
parent: None,
type_info: None,
generics: vec![],
annotations: vec![],
attributes: vec![],
metadata: HashMap::new(),
};
graph.add_symbol(caller.clone());
graph.add_symbol(callee.clone());
graph.add_relationship(&caller.id, &callee.id, SymbolRelationType::Calls, 1.0);
let callees = graph.find_callees(&caller.id);
assert_eq!(callees.len(), 1);
assert_eq!(callees[0].name, "callee_func");
let callers = graph.find_callers(&callee.id);
assert_eq!(callers.len(), 1);
assert_eq!(callers[0].name, "caller_func");
}
}