use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::error::FoldError;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct AnchorRef {
pub id: Uuid,
pub kind: String,
pub stable_id: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AnchorGraph {
pub nodes: Vec<AnchorRef>,
pub edges: Vec<(Uuid, Uuid, String)>,
}
impl AnchorGraph {
pub fn new() -> Self {
Self::default()
}
pub fn add_node(&mut self, anchor: AnchorRef) {
self.nodes.push(anchor);
}
pub fn add_edge(&mut self, from: Uuid, to: Uuid, relation: impl Into<String>) {
self.edges.push((from, to, relation.into()));
}
pub fn find_node(&self, id: Uuid) -> Option<&AnchorRef> {
self.nodes.iter().find(|n| n.id == id)
}
pub fn outgoing(&self, from: Uuid) -> impl Iterator<Item = (Uuid, &str)> {
self.edges
.iter()
.filter(move |(f, _, _)| *f == from)
.map(|(_, to, rel)| (*to, rel.as_str()))
}
pub fn incoming(&self, to: Uuid) -> impl Iterator<Item = (Uuid, &str)> {
self.edges
.iter()
.filter(move |(_, t, _)| *t == to)
.map(|(from, _, rel)| (*from, rel.as_str()))
}
}
pub trait Anchor {
fn trace(
&self,
graph: &AnchorGraph,
start: &AnchorRef,
max_depth: usize,
) -> Result<Vec<AnchorRef>, FoldError>;
fn credit(
&self,
graph: &AnchorGraph,
outcome: &AnchorRef,
max_depth: usize,
) -> Result<Vec<(AnchorRef, f32)>, FoldError>;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct BfsAnchor;
impl Anchor for BfsAnchor {
fn trace(
&self,
graph: &AnchorGraph,
start: &AnchorRef,
max_depth: usize,
) -> Result<Vec<AnchorRef>, FoldError> {
if graph.find_node(start.id).is_none() {
return Err(FoldError::AnchorNotFound(start.id.to_string()));
}
let mut visited = std::collections::HashSet::new();
let mut result = Vec::new();
let mut queue = std::collections::VecDeque::new();
visited.insert(start.id);
queue.push_back((start.id, 0usize));
while let Some((current_id, depth)) = queue.pop_front() {
if let Some(node) = graph.find_node(current_id) {
if current_id != start.id {
result.push(node.clone());
}
if depth < max_depth {
for (next_id, _rel) in graph.outgoing(current_id) {
if visited.insert(next_id) {
queue.push_back((next_id, depth + 1));
}
}
}
}
}
Ok(result)
}
fn credit(
&self,
graph: &AnchorGraph,
outcome: &AnchorRef,
max_depth: usize,
) -> Result<Vec<(AnchorRef, f32)>, FoldError> {
if graph.find_node(outcome.id).is_none() {
return Err(FoldError::AnchorNotFound(outcome.id.to_string()));
}
let mut visited = std::collections::HashSet::new();
let mut result = Vec::new();
let mut queue = std::collections::VecDeque::new();
visited.insert(outcome.id);
queue.push_back((outcome.id, 0usize, 1.0f32));
while let Some((current_id, depth, weight)) = queue.pop_front() {
if current_id != outcome.id {
if let Some(node) = graph.find_node(current_id) {
result.push((node.clone(), weight));
}
}
if depth < max_depth {
let predecessors: Vec<(Uuid, f32)> = graph
.incoming(current_id)
.filter(|(id, _)| visited.insert(*id))
.map(|(id, _)| (id, weight * 0.5))
.collect();
for (pred_id, pred_weight) in predecessors {
queue.push_back((pred_id, depth + 1, pred_weight));
}
}
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_ref(id: u128, kind: &str) -> AnchorRef {
AnchorRef {
id: Uuid::from_u128(id),
kind: kind.to_string(),
stable_id: None,
}
}
#[test]
fn test_anchor_ref_fields() {
let r = AnchorRef {
id: Uuid::new_v4(),
kind: "paper".into(),
stable_id: Some("doi:10.1234/x".into()),
};
assert_eq!(r.kind, "paper");
assert!(r.stable_id.is_some());
}
#[test]
fn test_anchor_graph_add_and_find() {
let mut graph = AnchorGraph::new();
let a = make_ref(1, "record");
let b = make_ref(2, "source");
graph.add_node(a.clone());
graph.add_node(b.clone());
graph.add_edge(a.id, b.id, "derives_from");
assert!(graph.find_node(a.id).is_some());
assert!(graph.find_node(Uuid::nil()).is_none());
}
#[test]
fn test_bfs_anchor_trace_not_found() {
let graph = AnchorGraph::new();
let unknown = make_ref(99, "unknown");
let err = BfsAnchor.trace(&graph, &unknown, 5).unwrap_err();
assert!(matches!(err, FoldError::AnchorNotFound(_)));
}
#[test]
fn test_bfs_anchor_trace_chain() {
let mut graph = AnchorGraph::new();
let a = make_ref(1, "record");
let b = make_ref(2, "source");
let c = make_ref(3, "paper");
graph.add_node(a.clone());
graph.add_node(b.clone());
graph.add_node(c.clone());
graph.add_edge(a.id, b.id, "derives_from");
graph.add_edge(b.id, c.id, "uses");
let chain = BfsAnchor.trace(&graph, &a, 5).unwrap();
assert_eq!(chain.len(), 2);
assert!(chain.iter().any(|r| r.id == b.id));
assert!(chain.iter().any(|r| r.id == c.id));
}
#[test]
fn test_bfs_anchor_trace_max_depth() {
let mut graph = AnchorGraph::new();
let nodes: Vec<AnchorRef> = (1..=5).map(|i| make_ref(i, "node")).collect();
for n in &nodes {
graph.add_node(n.clone());
}
for i in 0..4 {
graph.add_edge(nodes[i].id, nodes[i + 1].id, "next");
}
let chain = BfsAnchor.trace(&graph, &nodes[0], 1).unwrap();
assert_eq!(chain.len(), 1);
assert_eq!(chain[0].id, nodes[1].id);
}
#[test]
fn test_bfs_anchor_credit_not_found() {
let graph = AnchorGraph::new();
let unknown = make_ref(99, "unknown");
let err = BfsAnchor.credit(&graph, &unknown, 5).unwrap_err();
assert!(matches!(err, FoldError::AnchorNotFound(_)));
}
#[test]
fn test_bfs_anchor_credit_basic() {
let mut graph = AnchorGraph::new();
let source = make_ref(1, "paper");
let intermediate = make_ref(2, "record");
let outcome = make_ref(3, "composition");
graph.add_node(source.clone());
graph.add_node(intermediate.clone());
graph.add_node(outcome.clone());
graph.add_edge(source.id, intermediate.id, "uses");
graph.add_edge(intermediate.id, outcome.id, "derives_from");
let credits = BfsAnchor.credit(&graph, &outcome, 5).unwrap();
assert!(!credits.is_empty());
let inter_credit = credits.iter().find(|(r, _)| r.id == intermediate.id);
assert!(inter_credit.is_some());
assert!(inter_credit.unwrap().1 > 0.0);
}
}