1use serde::{Deserialize, Serialize};
14use uuid::Uuid;
15
16use crate::error::FoldError;
17
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
20pub struct AnchorRef {
21 pub id: Uuid,
23 pub kind: String,
25 pub stable_id: Option<String>,
27}
28
29#[derive(Debug, Clone, Default, Serialize, Deserialize)]
33pub struct AnchorGraph {
34 pub nodes: Vec<AnchorRef>,
36 pub edges: Vec<(Uuid, Uuid, String)>,
38}
39
40impl AnchorGraph {
41 pub fn new() -> Self {
43 Self::default()
44 }
45
46 pub fn add_node(&mut self, anchor: AnchorRef) {
48 self.nodes.push(anchor);
49 }
50
51 pub fn add_edge(&mut self, from: Uuid, to: Uuid, relation: impl Into<String>) {
53 self.edges.push((from, to, relation.into()));
54 }
55
56 pub fn find_node(&self, id: Uuid) -> Option<&AnchorRef> {
58 self.nodes.iter().find(|n| n.id == id)
59 }
60
61 pub fn outgoing(&self, from: Uuid) -> impl Iterator<Item = (Uuid, &str)> {
63 self.edges
64 .iter()
65 .filter(move |(f, _, _)| *f == from)
66 .map(|(_, to, rel)| (*to, rel.as_str()))
67 }
68
69 pub fn incoming(&self, to: Uuid) -> impl Iterator<Item = (Uuid, &str)> {
71 self.edges
72 .iter()
73 .filter(move |(_, t, _)| *t == to)
74 .map(|(from, _, rel)| (*from, rel.as_str()))
75 }
76}
77
78pub trait Anchor {
80 fn trace(
82 &self,
83 graph: &AnchorGraph,
84 start: &AnchorRef,
85 max_depth: usize,
86 ) -> Result<Vec<AnchorRef>, FoldError>;
87
88 fn credit(
90 &self,
91 graph: &AnchorGraph,
92 outcome: &AnchorRef,
93 max_depth: usize,
94 ) -> Result<Vec<(AnchorRef, f32)>, FoldError>;
95}
96
97#[derive(Debug, Clone, Copy, Default)]
102pub struct BfsAnchor;
103
104impl Anchor for BfsAnchor {
105 fn trace(
106 &self,
107 graph: &AnchorGraph,
108 start: &AnchorRef,
109 max_depth: usize,
110 ) -> Result<Vec<AnchorRef>, FoldError> {
111 if graph.find_node(start.id).is_none() {
112 return Err(FoldError::AnchorNotFound(start.id.to_string()));
113 }
114
115 let mut visited = std::collections::HashSet::new();
116 let mut result = Vec::new();
117 let mut queue = std::collections::VecDeque::new();
118
119 visited.insert(start.id);
120 queue.push_back((start.id, 0usize));
121
122 while let Some((current_id, depth)) = queue.pop_front() {
123 if let Some(node) = graph.find_node(current_id) {
124 if current_id != start.id {
125 result.push(node.clone());
126 }
127
128 if depth < max_depth {
129 for (next_id, _rel) in graph.outgoing(current_id) {
130 if visited.insert(next_id) {
131 queue.push_back((next_id, depth + 1));
132 }
133 }
134 }
135 }
136 }
137
138 Ok(result)
139 }
140
141 fn credit(
142 &self,
143 graph: &AnchorGraph,
144 outcome: &AnchorRef,
145 max_depth: usize,
146 ) -> Result<Vec<(AnchorRef, f32)>, FoldError> {
147 if graph.find_node(outcome.id).is_none() {
148 return Err(FoldError::AnchorNotFound(outcome.id.to_string()));
149 }
150
151 let mut visited = std::collections::HashSet::new();
152 let mut result = Vec::new();
153 let mut queue = std::collections::VecDeque::new();
154
155 visited.insert(outcome.id);
156 queue.push_back((outcome.id, 0usize, 1.0f32));
157
158 while let Some((current_id, depth, weight)) = queue.pop_front() {
159 if current_id != outcome.id {
160 if let Some(node) = graph.find_node(current_id) {
161 result.push((node.clone(), weight));
162 }
163 }
164
165 if depth < max_depth {
166 let predecessors: Vec<(Uuid, f32)> = graph
167 .incoming(current_id)
168 .filter(|(id, _)| visited.insert(*id))
169 .map(|(id, _)| (id, weight * 0.5))
170 .collect();
171
172 for (pred_id, pred_weight) in predecessors {
173 queue.push_back((pred_id, depth + 1, pred_weight));
174 }
175 }
176 }
177
178 Ok(result)
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 fn make_ref(id: u128, kind: &str) -> AnchorRef {
187 AnchorRef {
188 id: Uuid::from_u128(id),
189 kind: kind.to_string(),
190 stable_id: None,
191 }
192 }
193
194 #[test]
195 fn test_anchor_ref_fields() {
196 let r = AnchorRef {
197 id: Uuid::new_v4(),
198 kind: "paper".into(),
199 stable_id: Some("doi:10.1234/x".into()),
200 };
201 assert_eq!(r.kind, "paper");
202 assert!(r.stable_id.is_some());
203 }
204
205 #[test]
206 fn test_anchor_graph_add_and_find() {
207 let mut graph = AnchorGraph::new();
208 let a = make_ref(1, "record");
209 let b = make_ref(2, "source");
210 graph.add_node(a.clone());
211 graph.add_node(b.clone());
212 graph.add_edge(a.id, b.id, "derives_from");
213
214 assert!(graph.find_node(a.id).is_some());
215 assert!(graph.find_node(Uuid::nil()).is_none());
216 }
217
218 #[test]
219 fn test_bfs_anchor_trace_not_found() {
220 let graph = AnchorGraph::new();
221 let unknown = make_ref(99, "unknown");
222 let err = BfsAnchor.trace(&graph, &unknown, 5).unwrap_err();
223 assert!(matches!(err, FoldError::AnchorNotFound(_)));
224 }
225
226 #[test]
227 fn test_bfs_anchor_trace_chain() {
228 let mut graph = AnchorGraph::new();
229 let a = make_ref(1, "record");
230 let b = make_ref(2, "source");
231 let c = make_ref(3, "paper");
232 graph.add_node(a.clone());
233 graph.add_node(b.clone());
234 graph.add_node(c.clone());
235 graph.add_edge(a.id, b.id, "derives_from");
236 graph.add_edge(b.id, c.id, "uses");
237
238 let chain = BfsAnchor.trace(&graph, &a, 5).unwrap();
239 assert_eq!(chain.len(), 2);
240 assert!(chain.iter().any(|r| r.id == b.id));
241 assert!(chain.iter().any(|r| r.id == c.id));
242 }
243
244 #[test]
245 fn test_bfs_anchor_trace_max_depth() {
246 let mut graph = AnchorGraph::new();
247 let nodes: Vec<AnchorRef> = (1..=5).map(|i| make_ref(i, "node")).collect();
248 for n in &nodes {
249 graph.add_node(n.clone());
250 }
251 for i in 0..4 {
252 graph.add_edge(nodes[i].id, nodes[i + 1].id, "next");
253 }
254
255 let chain = BfsAnchor.trace(&graph, &nodes[0], 1).unwrap();
257 assert_eq!(chain.len(), 1);
258 assert_eq!(chain[0].id, nodes[1].id);
259 }
260
261 #[test]
262 fn test_bfs_anchor_credit_not_found() {
263 let graph = AnchorGraph::new();
264 let unknown = make_ref(99, "unknown");
265 let err = BfsAnchor.credit(&graph, &unknown, 5).unwrap_err();
266 assert!(matches!(err, FoldError::AnchorNotFound(_)));
267 }
268
269 #[test]
270 fn test_bfs_anchor_credit_basic() {
271 let mut graph = AnchorGraph::new();
272 let source = make_ref(1, "paper");
273 let intermediate = make_ref(2, "record");
274 let outcome = make_ref(3, "composition");
275 graph.add_node(source.clone());
276 graph.add_node(intermediate.clone());
277 graph.add_node(outcome.clone());
278 graph.add_edge(source.id, intermediate.id, "uses");
280 graph.add_edge(intermediate.id, outcome.id, "derives_from");
281
282 let credits = BfsAnchor.credit(&graph, &outcome, 5).unwrap();
283 assert!(!credits.is_empty());
284 let inter_credit = credits.iter().find(|(r, _)| r.id == intermediate.id);
286 assert!(inter_credit.is_some());
287 assert!(inter_credit.unwrap().1 > 0.0);
288 }
289}