use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum KgNodeKind {
Repository,
Package,
Module,
File,
Class,
Interface,
Function,
Method,
Field,
Import,
Export,
CallExpression,
TestCase,
Dependency,
}
impl fmt::Display for KgNodeKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
KgNodeKind::Repository => "repository",
KgNodeKind::Package => "package",
KgNodeKind::Module => "module",
KgNodeKind::File => "file",
KgNodeKind::Class => "class",
KgNodeKind::Interface => "interface",
KgNodeKind::Function => "function",
KgNodeKind::Method => "method",
KgNodeKind::Field => "field",
KgNodeKind::Import => "import",
KgNodeKind::Export => "export",
KgNodeKind::CallExpression => "call_expression",
KgNodeKind::TestCase => "test_case",
KgNodeKind::Dependency => "dependency",
};
f.write_str(s)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct KgNode {
pub id: String,
pub kind: KgNodeKind,
pub name: String,
pub qualified_name: String,
pub language: String,
pub file: String,
pub start_line: u32,
pub end_line: u32,
#[serde(default)]
pub doc_comment: Option<String>,
#[serde(default)]
pub is_public: bool,
#[serde(default)]
pub extra: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum KgEdgeKind {
Contains,
Imports,
Exports,
Calls,
Implements,
Extends,
References,
Tests,
DependsOn,
GeneratedFrom,
RuntimeObservationFor,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct KgEdge {
pub from: String,
pub to: String,
pub kind: KgEdgeKind,
#[serde(default = "default_weight")]
pub weight: f32,
}
fn default_weight() -> f32 {
1.0
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct KgGraph {
pub nodes: Vec<KgNode>,
pub edges: Vec<KgEdge>,
}
impl KgGraph {
pub fn new() -> Self {
Self::default()
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn edge_count(&self) -> usize {
self.edges.len()
}
pub fn merge(&mut self, other: KgGraph) {
use std::collections::HashSet;
let mut seen_nodes: HashSet<String> = self.nodes.iter().map(|n| n.id.clone()).collect();
for n in other.nodes {
if seen_nodes.insert(n.id.clone()) {
self.nodes.push(n);
}
}
let mut seen_edges: HashSet<(String, String, KgEdgeKind)> = self
.edges
.iter()
.map(|e| (e.from.clone(), e.to.clone(), e.kind.clone()))
.collect();
for e in other.edges {
let k = (e.from.clone(), e.to.clone(), e.kind.clone());
if seen_edges.insert(k) {
self.edges.push(e);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn n(id: &str) -> KgNode {
KgNode {
id: id.into(),
kind: KgNodeKind::Function,
name: id.into(),
qualified_name: id.into(),
language: "rust".into(),
file: "f.rs".into(),
start_line: 1,
end_line: 2,
doc_comment: None,
is_public: false,
extra: serde_json::Value::Null,
}
}
fn e(from: &str, to: &str) -> KgEdge {
KgEdge {
from: from.into(),
to: to.into(),
kind: KgEdgeKind::Calls,
weight: 1.0,
}
}
#[test]
fn merge_dedups_nodes_by_id() {
let mut a = KgGraph::new();
a.nodes.push(n("a"));
a.nodes.push(n("b"));
let mut b = KgGraph::new();
b.nodes.push(n("b"));
b.nodes.push(n("c"));
a.merge(b);
assert_eq!(a.node_count(), 3);
let ids: Vec<&str> = a.nodes.iter().map(|n| n.id.as_str()).collect();
assert!(ids.contains(&"a"));
assert!(ids.contains(&"b"));
assert!(ids.contains(&"c"));
}
#[test]
fn merge_dedups_edges_by_endpoints_and_kind() {
let mut a = KgGraph::new();
a.edges.push(e("x", "y"));
let mut b = KgGraph::new();
b.edges.push(e("x", "y")); b.edges.push(e("y", "z"));
a.merge(b);
assert_eq!(a.edge_count(), 2);
}
#[test]
fn node_kind_display_is_snake_case() {
assert_eq!(KgNodeKind::Repository.to_string(), "repository");
assert_eq!(KgNodeKind::TestCase.to_string(), "test_case");
assert_eq!(KgNodeKind::CallExpression.to_string(), "call_expression");
assert_eq!(KgNodeKind::Function.to_string(), "function");
}
#[test]
fn graph_round_trips_through_json() {
let mut g = KgGraph::new();
g.nodes.push(n("a"));
g.edges.push(e("a", "a"));
let s = serde_json::to_string(&g).unwrap();
let back: KgGraph = serde_json::from_str(&s).unwrap();
assert_eq!(back.node_count(), 1);
assert_eq!(back.edge_count(), 1);
}
}