Skip to main content

entrenar/research/
citation_graph.rs

1//! Citation Graph with upstream aggregation (ENT-025)
2//!
3//! Provides citation graph construction and traversal for
4//! aggregating citations from upstream dependencies.
5
6use crate::research::citation::CitationMetadata;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9
10/// A node in the citation graph
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct CitationNode {
13    /// The citation metadata
14    pub metadata: CitationMetadata,
15    /// Whether this is an upstream dependency
16    pub is_upstream: bool,
17    /// Depth from the root artifact
18    pub depth: usize,
19}
20
21impl CitationNode {
22    /// Create a new citation node
23    pub fn new(metadata: CitationMetadata, is_upstream: bool) -> Self {
24        Self { metadata, is_upstream, depth: 0 }
25    }
26
27    /// Set the depth
28    pub fn with_depth(mut self, depth: usize) -> Self {
29        self.depth = depth;
30        self
31    }
32}
33
34/// An edge in the citation graph
35#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
36pub struct CitationEdge {
37    /// Source artifact ID (the one doing the citing)
38    pub from: String,
39    /// Target artifact ID (the one being cited)
40    pub to: String,
41    /// Edge type
42    pub edge_type: EdgeType,
43}
44
45impl CitationEdge {
46    /// Create a new citation edge
47    pub fn new(from: impl Into<String>, to: impl Into<String>) -> Self {
48        Self { from: from.into(), to: to.into(), edge_type: EdgeType::Cites }
49    }
50
51    /// Set the edge type
52    pub fn with_type(mut self, edge_type: EdgeType) -> Self {
53        self.edge_type = edge_type;
54        self
55    }
56}
57
58/// Type of citation relationship
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub enum EdgeType {
61    /// Standard citation
62    Cites,
63    /// Builds upon or extends
64    Extends,
65    /// Uses as a dependency
66    DependsOn,
67    /// Derived from
68    DerivedFrom,
69}
70
71/// Citation graph for tracking and aggregating citations
72#[derive(Debug, Clone, Default, Serialize, Deserialize)]
73pub struct CitationGraph {
74    /// Nodes indexed by artifact ID
75    pub nodes: HashMap<String, CitationNode>,
76    /// Edges in the graph
77    pub edges: Vec<CitationEdge>,
78}
79
80impl CitationGraph {
81    /// Create a new empty citation graph
82    pub fn new() -> Self {
83        Self { nodes: HashMap::new(), edges: Vec::new() }
84    }
85
86    /// Add a citation node
87    pub fn add_node(&mut self, id: impl Into<String>, node: CitationNode) {
88        self.nodes.insert(id.into(), node);
89    }
90
91    /// Add a citation (creates an edge)
92    pub fn add_citation(&mut self, from: impl Into<String>, to: impl Into<String>) {
93        let edge = CitationEdge::new(from, to);
94        if !self.edges.contains(&edge) {
95            self.edges.push(edge);
96        }
97    }
98
99    /// Add a citation with a specific type
100    pub fn add_citation_typed(
101        &mut self,
102        from: impl Into<String>,
103        to: impl Into<String>,
104        edge_type: EdgeType,
105    ) {
106        let edge = CitationEdge::new(from, to).with_type(edge_type);
107        if !self.edges.contains(&edge) {
108            self.edges.push(edge);
109        }
110    }
111
112    /// Get all citations from a specific artifact
113    pub fn citations_from(&self, artifact_id: &str) -> Vec<&CitationEdge> {
114        self.edges.iter().filter(|e| e.from == artifact_id).collect()
115    }
116
117    /// Get all citations to a specific artifact
118    pub fn citations_to(&self, artifact_id: &str) -> Vec<&CitationEdge> {
119        self.edges.iter().filter(|e| e.to == artifact_id).collect()
120    }
121
122    /// Get upstream citations for an artifact (what it cites)
123    pub fn cite_upstream(&self, artifact_id: &str) -> Vec<&CitationMetadata> {
124        self.citations_from(artifact_id)
125            .iter()
126            .filter_map(|edge| self.nodes.get(&edge.to))
127            .map(|node| &node.metadata)
128            .collect()
129    }
130
131    /// Aggregate all citations transitively (including transitive dependencies)
132    pub fn aggregate_all_citations(&self, root_id: &str) -> Vec<&CitationMetadata> {
133        let mut visited = HashSet::new();
134        let mut result = Vec::new();
135
136        self.aggregate_recursive(root_id, &mut visited, &mut result);
137
138        result
139    }
140
141    /// Recursive helper for citation aggregation
142    fn aggregate_recursive<'a>(
143        &'a self,
144        current_id: &str,
145        visited: &mut HashSet<String>,
146        result: &mut Vec<&'a CitationMetadata>,
147    ) {
148        if visited.contains(current_id) {
149            return;
150        }
151        visited.insert(current_id.to_string());
152
153        for edge in self.citations_from(current_id) {
154            if let Some(node) = self.nodes.get(&edge.to) {
155                if !visited.contains(&edge.to) {
156                    result.push(&node.metadata);
157                    // Recursively get transitive citations
158                    self.aggregate_recursive(&edge.to, visited, result);
159                }
160            }
161        }
162    }
163
164    /// Check for transitive citations (A cites B, B cites C => A transitively cites C)
165    pub fn has_transitive_citation(&self, from: &str, to: &str) -> bool {
166        let mut visited = HashSet::new();
167        self.has_path(from, to, &mut visited)
168    }
169
170    /// Check if there's a path from source to target
171    fn has_path(&self, current: &str, target: &str, visited: &mut HashSet<String>) -> bool {
172        if current == target {
173            return true;
174        }
175        if visited.contains(current) {
176            return false;
177        }
178        visited.insert(current.to_string());
179
180        for edge in self.citations_from(current) {
181            if self.has_path(&edge.to, target, visited) {
182                return true;
183            }
184        }
185
186        false
187    }
188
189    /// Export all citations to BibTeX
190    pub fn to_bibtex_all(&self) -> String {
191        self.nodes.values().map(|node| node.metadata.to_bibtex()).collect::<Vec<_>>().join("\n\n")
192    }
193
194    /// Get the number of nodes
195    pub fn node_count(&self) -> usize {
196        self.nodes.len()
197    }
198
199    /// Get the number of edges
200    pub fn edge_count(&self) -> usize {
201        self.edges.len()
202    }
203
204    /// Get all upstream nodes (is_upstream = true)
205    pub fn upstream_nodes(&self) -> Vec<&CitationNode> {
206        self.nodes.values().filter(|n| n.is_upstream).collect()
207    }
208
209    /// Remove duplicate citations (same from-to pair)
210    pub fn deduplicate(&mut self) {
211        let mut seen = HashSet::new();
212        self.edges.retain(|edge| {
213            let key = (edge.from.clone(), edge.to.clone());
214            seen.insert(key)
215        });
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use crate::research::artifact::{ArtifactType, Author, License, ResearchArtifact};
223
224    fn create_test_citation(id: &str, title: &str, year: u16) -> CitationMetadata {
225        let artifact = ResearchArtifact::new(id, title, ArtifactType::Paper, License::CcBy4)
226            .with_author(Author::new("Test Author"));
227        CitationMetadata::new(artifact, year)
228    }
229
230    #[test]
231    fn test_add_citation() {
232        let mut graph = CitationGraph::new();
233
234        graph.add_citation("paper-a", "paper-b");
235
236        assert_eq!(graph.edge_count(), 1);
237        assert_eq!(graph.edges[0].from, "paper-a");
238        assert_eq!(graph.edges[0].to, "paper-b");
239    }
240
241    #[test]
242    fn test_cite_upstream_aggregation() {
243        let mut graph = CitationGraph::new();
244
245        // Add nodes
246        let citation_b = create_test_citation("paper-b", "Paper B", 2023);
247        let citation_c = create_test_citation("paper-c", "Paper C", 2022);
248
249        graph.add_node("paper-b", CitationNode::new(citation_b, true));
250        graph.add_node("paper-c", CitationNode::new(citation_c, true));
251
252        // Paper A cites B and C
253        graph.add_citation("paper-a", "paper-b");
254        graph.add_citation("paper-a", "paper-c");
255
256        let upstream = graph.cite_upstream("paper-a");
257
258        assert_eq!(upstream.len(), 2);
259    }
260
261    #[test]
262    fn test_transitive_citations() {
263        let mut graph = CitationGraph::new();
264
265        // Add nodes
266        let citation_b = create_test_citation("paper-b", "Paper B", 2023);
267        let citation_c = create_test_citation("paper-c", "Paper C", 2022);
268        let citation_d = create_test_citation("paper-d", "Paper D", 2021);
269
270        graph.add_node("paper-b", CitationNode::new(citation_b, true));
271        graph.add_node("paper-c", CitationNode::new(citation_c, true));
272        graph.add_node("paper-d", CitationNode::new(citation_d, true));
273
274        // A -> B -> C -> D
275        graph.add_citation("paper-a", "paper-b");
276        graph.add_citation("paper-b", "paper-c");
277        graph.add_citation("paper-c", "paper-d");
278
279        // Direct path exists
280        assert!(graph.has_transitive_citation("paper-a", "paper-b"));
281        assert!(graph.has_transitive_citation("paper-b", "paper-c"));
282
283        // Transitive path exists
284        assert!(graph.has_transitive_citation("paper-a", "paper-c"));
285        assert!(graph.has_transitive_citation("paper-a", "paper-d"));
286
287        // No reverse path
288        assert!(!graph.has_transitive_citation("paper-d", "paper-a"));
289    }
290
291    #[test]
292    fn test_no_duplicate_citations() {
293        let mut graph = CitationGraph::new();
294
295        graph.add_citation("paper-a", "paper-b");
296        graph.add_citation("paper-a", "paper-b"); // Duplicate
297
298        assert_eq!(graph.edge_count(), 1);
299    }
300
301    #[test]
302    fn test_graph_to_bibtex_all() {
303        let mut graph = CitationGraph::new();
304
305        let citation_a = create_test_citation("paper-a", "Paper A", 2024);
306        let citation_b = create_test_citation("paper-b", "Paper B", 2023);
307
308        graph.add_node("paper-a", CitationNode::new(citation_a, false));
309        graph.add_node("paper-b", CitationNode::new(citation_b, true));
310
311        let bibtex = graph.to_bibtex_all();
312
313        assert!(bibtex.contains("Paper A"));
314        assert!(bibtex.contains("Paper B"));
315        assert!(bibtex.contains("@article{"));
316    }
317
318    #[test]
319    fn test_aggregate_all_citations() {
320        let mut graph = CitationGraph::new();
321
322        // Build a citation chain: A -> B -> C
323        let citation_b = create_test_citation("paper-b", "Paper B", 2023);
324        let citation_c = create_test_citation("paper-c", "Paper C", 2022);
325
326        graph.add_node("paper-b", CitationNode::new(citation_b, true));
327        graph.add_node("paper-c", CitationNode::new(citation_c, true));
328
329        graph.add_citation("paper-a", "paper-b");
330        graph.add_citation("paper-b", "paper-c");
331
332        let all_citations = graph.aggregate_all_citations("paper-a");
333
334        // Should get both B and C (transitively)
335        assert_eq!(all_citations.len(), 2);
336    }
337
338    #[test]
339    fn test_edge_types() {
340        let mut graph = CitationGraph::new();
341
342        graph.add_citation_typed("paper-a", "paper-b", EdgeType::Extends);
343        graph.add_citation_typed("paper-a", "library-x", EdgeType::DependsOn);
344
345        assert_eq!(graph.edges[0].edge_type, EdgeType::Extends);
346        assert_eq!(graph.edges[1].edge_type, EdgeType::DependsOn);
347    }
348
349    #[test]
350    fn test_citations_to() {
351        let mut graph = CitationGraph::new();
352
353        graph.add_citation("paper-a", "paper-x");
354        graph.add_citation("paper-b", "paper-x");
355        graph.add_citation("paper-c", "paper-x");
356
357        let incoming = graph.citations_to("paper-x");
358        assert_eq!(incoming.len(), 3);
359    }
360
361    #[test]
362    fn test_upstream_nodes() {
363        let mut graph = CitationGraph::new();
364
365        let citation_a = create_test_citation("paper-a", "Paper A", 2024);
366        let citation_b = create_test_citation("paper-b", "Paper B", 2023);
367        let citation_c = create_test_citation("paper-c", "Paper C", 2022);
368
369        graph.add_node("paper-a", CitationNode::new(citation_a, false)); // Not upstream
370        graph.add_node("paper-b", CitationNode::new(citation_b, true)); // Upstream
371        graph.add_node("paper-c", CitationNode::new(citation_c, true)); // Upstream
372
373        let upstream = graph.upstream_nodes();
374        assert_eq!(upstream.len(), 2);
375    }
376
377    #[test]
378    fn test_deduplicate() {
379        let mut graph = CitationGraph::new();
380
381        // Manually add duplicate edges
382        graph.edges.push(CitationEdge::new("a", "b"));
383        graph.edges.push(CitationEdge::new("a", "b"));
384        graph.edges.push(CitationEdge::new("a", "c"));
385
386        assert_eq!(graph.edge_count(), 3);
387
388        graph.deduplicate();
389
390        assert_eq!(graph.edge_count(), 2);
391    }
392
393    #[test]
394    fn test_node_with_depth() {
395        let citation = create_test_citation("paper-a", "Paper A", 2024);
396        let node = CitationNode::new(citation, true).with_depth(3);
397
398        assert_eq!(node.depth, 3);
399        assert!(node.is_upstream);
400    }
401
402    #[test]
403    fn test_cycle_handling() {
404        let mut graph = CitationGraph::new();
405
406        let citation_a = create_test_citation("paper-a", "Paper A", 2024);
407        let citation_b = create_test_citation("paper-b", "Paper B", 2023);
408
409        graph.add_node("paper-a", CitationNode::new(citation_a, false));
410        graph.add_node("paper-b", CitationNode::new(citation_b, true));
411
412        // Create a cycle: A -> B -> A
413        graph.add_citation("paper-a", "paper-b");
414        graph.add_citation("paper-b", "paper-a");
415
416        // Should not infinite loop
417        let all = graph.aggregate_all_citations("paper-a");
418        assert_eq!(all.len(), 1); // Only B, not A again
419    }
420}