use crate::research::citation::CitationMetadata;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CitationNode {
pub metadata: CitationMetadata,
pub is_upstream: bool,
pub depth: usize,
}
impl CitationNode {
pub fn new(metadata: CitationMetadata, is_upstream: bool) -> Self {
Self { metadata, is_upstream, depth: 0 }
}
pub fn with_depth(mut self, depth: usize) -> Self {
self.depth = depth;
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct CitationEdge {
pub from: String,
pub to: String,
pub edge_type: EdgeType,
}
impl CitationEdge {
pub fn new(from: impl Into<String>, to: impl Into<String>) -> Self {
Self { from: from.into(), to: to.into(), edge_type: EdgeType::Cites }
}
pub fn with_type(mut self, edge_type: EdgeType) -> Self {
self.edge_type = edge_type;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum EdgeType {
Cites,
Extends,
DependsOn,
DerivedFrom,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CitationGraph {
pub nodes: HashMap<String, CitationNode>,
pub edges: Vec<CitationEdge>,
}
impl CitationGraph {
pub fn new() -> Self {
Self { nodes: HashMap::new(), edges: Vec::new() }
}
pub fn add_node(&mut self, id: impl Into<String>, node: CitationNode) {
self.nodes.insert(id.into(), node);
}
pub fn add_citation(&mut self, from: impl Into<String>, to: impl Into<String>) {
let edge = CitationEdge::new(from, to);
if !self.edges.contains(&edge) {
self.edges.push(edge);
}
}
pub fn add_citation_typed(
&mut self,
from: impl Into<String>,
to: impl Into<String>,
edge_type: EdgeType,
) {
let edge = CitationEdge::new(from, to).with_type(edge_type);
if !self.edges.contains(&edge) {
self.edges.push(edge);
}
}
pub fn citations_from(&self, artifact_id: &str) -> Vec<&CitationEdge> {
self.edges.iter().filter(|e| e.from == artifact_id).collect()
}
pub fn citations_to(&self, artifact_id: &str) -> Vec<&CitationEdge> {
self.edges.iter().filter(|e| e.to == artifact_id).collect()
}
pub fn cite_upstream(&self, artifact_id: &str) -> Vec<&CitationMetadata> {
self.citations_from(artifact_id)
.iter()
.filter_map(|edge| self.nodes.get(&edge.to))
.map(|node| &node.metadata)
.collect()
}
pub fn aggregate_all_citations(&self, root_id: &str) -> Vec<&CitationMetadata> {
let mut visited = HashSet::new();
let mut result = Vec::new();
self.aggregate_recursive(root_id, &mut visited, &mut result);
result
}
fn aggregate_recursive<'a>(
&'a self,
current_id: &str,
visited: &mut HashSet<String>,
result: &mut Vec<&'a CitationMetadata>,
) {
if visited.contains(current_id) {
return;
}
visited.insert(current_id.to_string());
for edge in self.citations_from(current_id) {
if let Some(node) = self.nodes.get(&edge.to) {
if !visited.contains(&edge.to) {
result.push(&node.metadata);
self.aggregate_recursive(&edge.to, visited, result);
}
}
}
}
pub fn has_transitive_citation(&self, from: &str, to: &str) -> bool {
let mut visited = HashSet::new();
self.has_path(from, to, &mut visited)
}
fn has_path(&self, current: &str, target: &str, visited: &mut HashSet<String>) -> bool {
if current == target {
return true;
}
if visited.contains(current) {
return false;
}
visited.insert(current.to_string());
for edge in self.citations_from(current) {
if self.has_path(&edge.to, target, visited) {
return true;
}
}
false
}
pub fn to_bibtex_all(&self) -> String {
self.nodes.values().map(|node| node.metadata.to_bibtex()).collect::<Vec<_>>().join("\n\n")
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn edge_count(&self) -> usize {
self.edges.len()
}
pub fn upstream_nodes(&self) -> Vec<&CitationNode> {
self.nodes.values().filter(|n| n.is_upstream).collect()
}
pub fn deduplicate(&mut self) {
let mut seen = HashSet::new();
self.edges.retain(|edge| {
let key = (edge.from.clone(), edge.to.clone());
seen.insert(key)
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::research::artifact::{ArtifactType, Author, License, ResearchArtifact};
fn create_test_citation(id: &str, title: &str, year: u16) -> CitationMetadata {
let artifact = ResearchArtifact::new(id, title, ArtifactType::Paper, License::CcBy4)
.with_author(Author::new("Test Author"));
CitationMetadata::new(artifact, year)
}
#[test]
fn test_add_citation() {
let mut graph = CitationGraph::new();
graph.add_citation("paper-a", "paper-b");
assert_eq!(graph.edge_count(), 1);
assert_eq!(graph.edges[0].from, "paper-a");
assert_eq!(graph.edges[0].to, "paper-b");
}
#[test]
fn test_cite_upstream_aggregation() {
let mut graph = CitationGraph::new();
let citation_b = create_test_citation("paper-b", "Paper B", 2023);
let citation_c = create_test_citation("paper-c", "Paper C", 2022);
graph.add_node("paper-b", CitationNode::new(citation_b, true));
graph.add_node("paper-c", CitationNode::new(citation_c, true));
graph.add_citation("paper-a", "paper-b");
graph.add_citation("paper-a", "paper-c");
let upstream = graph.cite_upstream("paper-a");
assert_eq!(upstream.len(), 2);
}
#[test]
fn test_transitive_citations() {
let mut graph = CitationGraph::new();
let citation_b = create_test_citation("paper-b", "Paper B", 2023);
let citation_c = create_test_citation("paper-c", "Paper C", 2022);
let citation_d = create_test_citation("paper-d", "Paper D", 2021);
graph.add_node("paper-b", CitationNode::new(citation_b, true));
graph.add_node("paper-c", CitationNode::new(citation_c, true));
graph.add_node("paper-d", CitationNode::new(citation_d, true));
graph.add_citation("paper-a", "paper-b");
graph.add_citation("paper-b", "paper-c");
graph.add_citation("paper-c", "paper-d");
assert!(graph.has_transitive_citation("paper-a", "paper-b"));
assert!(graph.has_transitive_citation("paper-b", "paper-c"));
assert!(graph.has_transitive_citation("paper-a", "paper-c"));
assert!(graph.has_transitive_citation("paper-a", "paper-d"));
assert!(!graph.has_transitive_citation("paper-d", "paper-a"));
}
#[test]
fn test_no_duplicate_citations() {
let mut graph = CitationGraph::new();
graph.add_citation("paper-a", "paper-b");
graph.add_citation("paper-a", "paper-b");
assert_eq!(graph.edge_count(), 1);
}
#[test]
fn test_graph_to_bibtex_all() {
let mut graph = CitationGraph::new();
let citation_a = create_test_citation("paper-a", "Paper A", 2024);
let citation_b = create_test_citation("paper-b", "Paper B", 2023);
graph.add_node("paper-a", CitationNode::new(citation_a, false));
graph.add_node("paper-b", CitationNode::new(citation_b, true));
let bibtex = graph.to_bibtex_all();
assert!(bibtex.contains("Paper A"));
assert!(bibtex.contains("Paper B"));
assert!(bibtex.contains("@article{"));
}
#[test]
fn test_aggregate_all_citations() {
let mut graph = CitationGraph::new();
let citation_b = create_test_citation("paper-b", "Paper B", 2023);
let citation_c = create_test_citation("paper-c", "Paper C", 2022);
graph.add_node("paper-b", CitationNode::new(citation_b, true));
graph.add_node("paper-c", CitationNode::new(citation_c, true));
graph.add_citation("paper-a", "paper-b");
graph.add_citation("paper-b", "paper-c");
let all_citations = graph.aggregate_all_citations("paper-a");
assert_eq!(all_citations.len(), 2);
}
#[test]
fn test_edge_types() {
let mut graph = CitationGraph::new();
graph.add_citation_typed("paper-a", "paper-b", EdgeType::Extends);
graph.add_citation_typed("paper-a", "library-x", EdgeType::DependsOn);
assert_eq!(graph.edges[0].edge_type, EdgeType::Extends);
assert_eq!(graph.edges[1].edge_type, EdgeType::DependsOn);
}
#[test]
fn test_citations_to() {
let mut graph = CitationGraph::new();
graph.add_citation("paper-a", "paper-x");
graph.add_citation("paper-b", "paper-x");
graph.add_citation("paper-c", "paper-x");
let incoming = graph.citations_to("paper-x");
assert_eq!(incoming.len(), 3);
}
#[test]
fn test_upstream_nodes() {
let mut graph = CitationGraph::new();
let citation_a = create_test_citation("paper-a", "Paper A", 2024);
let citation_b = create_test_citation("paper-b", "Paper B", 2023);
let citation_c = create_test_citation("paper-c", "Paper C", 2022);
graph.add_node("paper-a", CitationNode::new(citation_a, false)); graph.add_node("paper-b", CitationNode::new(citation_b, true)); graph.add_node("paper-c", CitationNode::new(citation_c, true));
let upstream = graph.upstream_nodes();
assert_eq!(upstream.len(), 2);
}
#[test]
fn test_deduplicate() {
let mut graph = CitationGraph::new();
graph.edges.push(CitationEdge::new("a", "b"));
graph.edges.push(CitationEdge::new("a", "b"));
graph.edges.push(CitationEdge::new("a", "c"));
assert_eq!(graph.edge_count(), 3);
graph.deduplicate();
assert_eq!(graph.edge_count(), 2);
}
#[test]
fn test_node_with_depth() {
let citation = create_test_citation("paper-a", "Paper A", 2024);
let node = CitationNode::new(citation, true).with_depth(3);
assert_eq!(node.depth, 3);
assert!(node.is_upstream);
}
#[test]
fn test_cycle_handling() {
let mut graph = CitationGraph::new();
let citation_a = create_test_citation("paper-a", "Paper A", 2024);
let citation_b = create_test_citation("paper-b", "Paper B", 2023);
graph.add_node("paper-a", CitationNode::new(citation_a, false));
graph.add_node("paper-b", CitationNode::new(citation_b, true));
graph.add_citation("paper-a", "paper-b");
graph.add_citation("paper-b", "paper-a");
let all = graph.aggregate_all_citations("paper-a");
assert_eq!(all.len(), 1); }
}