use std::collections::VecDeque;
use ahash::HashSet;
use mentedb_core::edge::EdgeType;
use mentedb_core::types::MemoryId;
use crate::csr::CsrGraph;
pub fn find_contradictions(graph: &CsrGraph, id: MemoryId) -> Vec<MemoryId> {
let mut contradictions = HashSet::default();
let mut visited = HashSet::default();
let mut queue = VecDeque::new();
visited.insert(id);
for (neighbor, edge) in graph.outgoing(id) {
if edge.edge_type == EdgeType::Contradicts {
contradictions.insert(neighbor);
visited.insert(neighbor);
}
}
for (neighbor, edge) in graph.incoming(id) {
if edge.edge_type == EdgeType::Contradicts {
contradictions.insert(neighbor);
visited.insert(neighbor);
}
}
queue.push_back(id);
while let Some(node) = queue.pop_front() {
for (supporter, edge) in graph.incoming(node) {
if edge.edge_type == EdgeType::Supports && visited.insert(supporter) {
for (target, e2) in graph.outgoing(supporter) {
if e2.edge_type == EdgeType::Contradicts && target != id {
contradictions.insert(target);
}
}
}
}
for (supported, edge) in graph.outgoing(node) {
if edge.edge_type == EdgeType::Supports && visited.insert(supported) {
for (target, e2) in graph.outgoing(supported) {
if e2.edge_type == EdgeType::Contradicts {
contradictions.insert(target);
}
}
}
}
}
contradictions.into_iter().collect()
}
pub fn find_superseded(graph: &CsrGraph, id: MemoryId) -> Vec<MemoryId> {
let mut result = Vec::new();
let mut visited = HashSet::default();
let mut queue = VecDeque::new();
visited.insert(id);
queue.push_back(id);
while let Some(node) = queue.pop_front() {
for (neighbor, edge) in graph.outgoing(node) {
if edge.edge_type == EdgeType::Supersedes && visited.insert(neighbor) {
result.push(neighbor);
queue.push_back(neighbor);
}
}
}
result
}
pub fn detect_cycles(graph: &CsrGraph, edge_types: &[EdgeType]) -> Vec<Vec<MemoryId>> {
let filter: HashSet<EdgeType> = edge_types.iter().copied().collect();
let mut cycles = Vec::new();
let mut globally_visited = HashSet::default();
for &start_id in graph.node_ids() {
if globally_visited.contains(&start_id) {
continue;
}
let mut stack: Vec<(MemoryId, Vec<MemoryId>)> = vec![(start_id, vec![start_id])];
let mut in_stack = HashSet::default();
in_stack.insert(start_id);
let mut local_visited = HashSet::default();
local_visited.insert(start_id);
while let Some((node, path)) = stack.pop() {
in_stack.clear();
for &p in &path {
in_stack.insert(p);
}
for (neighbor, edge) in graph.outgoing(node) {
if !filter.contains(&edge.edge_type) {
continue;
}
if in_stack.contains(&neighbor) {
if let Some(pos) = path.iter().position(|&n| n == neighbor) {
let cycle: Vec<MemoryId> = path[pos..].to_vec();
if !cycles.iter().any(|c: &Vec<MemoryId>| {
c.len() == cycle.len() && cycle.iter().all(|n| c.contains(n))
}) {
cycles.push(cycle);
}
}
} else if local_visited.insert(neighbor) {
let mut new_path = path.clone();
new_path.push(neighbor);
stack.push((neighbor, new_path));
}
}
globally_visited.insert(node);
}
}
cycles
}
#[cfg(test)]
mod tests {
use super::*;
use mentedb_core::edge::MemoryEdge;
fn make_edge(src: MemoryId, tgt: MemoryId, etype: EdgeType, weight: f32) -> MemoryEdge {
MemoryEdge {
source: src,
target: tgt,
edge_type: etype,
weight,
created_at: 1000,
valid_from: None,
valid_until: None,
label: None,
}
}
#[test]
fn test_direct_contradictions() {
let mut g = CsrGraph::new();
let a = MemoryId::new();
let b = MemoryId::new();
let c = MemoryId::new();
g.add_edge(&make_edge(a, b, EdgeType::Contradicts, 1.0));
g.add_edge(&make_edge(c, a, EdgeType::Contradicts, 1.0));
let contras = find_contradictions(&g, a);
assert!(contras.contains(&b));
assert!(contras.contains(&c));
}
#[test]
fn test_transitive_contradictions() {
let mut g = CsrGraph::new();
let a = MemoryId::new();
let b = MemoryId::new();
let c = MemoryId::new();
g.add_edge(&make_edge(b, a, EdgeType::Supports, 1.0));
g.add_edge(&make_edge(b, c, EdgeType::Contradicts, 1.0));
let contras = find_contradictions(&g, a);
assert!(contras.contains(&c));
}
#[test]
fn test_find_superseded() {
let mut g = CsrGraph::new();
let a = MemoryId::new();
let b = MemoryId::new();
let c = MemoryId::new();
g.add_edge(&make_edge(a, b, EdgeType::Supersedes, 1.0));
g.add_edge(&make_edge(b, c, EdgeType::Supersedes, 1.0));
let superseded = find_superseded(&g, a);
assert_eq!(superseded.len(), 2);
assert!(superseded.contains(&b));
assert!(superseded.contains(&c));
}
#[test]
fn test_detect_cycle() {
let mut g = CsrGraph::new();
let a = MemoryId::new();
let b = MemoryId::new();
let c = MemoryId::new();
g.add_edge(&make_edge(a, b, EdgeType::Caused, 1.0));
g.add_edge(&make_edge(b, c, EdgeType::Caused, 1.0));
g.add_edge(&make_edge(c, a, EdgeType::Caused, 1.0));
let cycles = detect_cycles(&g, &[EdgeType::Caused]);
assert!(!cycles.is_empty());
assert_eq!(cycles[0].len(), 3);
}
#[test]
fn test_no_cycle() {
let mut g = CsrGraph::new();
let a = MemoryId::new();
let b = MemoryId::new();
let c = MemoryId::new();
g.add_edge(&make_edge(a, b, EdgeType::Caused, 1.0));
g.add_edge(&make_edge(b, c, EdgeType::Caused, 1.0));
let cycles = detect_cycles(&g, &[EdgeType::Caused]);
assert!(cycles.is_empty());
}
}