use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
use crate::Triple;
use crate::storage::Hexastore;
#[derive(Debug, Clone)]
pub struct PathResult {
pub path: Vec<u64>,
pub cost: f64,
pub hops: usize,
}
#[derive(Debug, Clone)]
struct DijkstraEntry {
node_id: u64,
cost: f64,
}
impl PartialEq for DijkstraEntry {
fn eq(&self, other: &Self) -> bool {
self.cost == other.cost
}
}
impl Eq for DijkstraEntry {}
impl PartialOrd for DijkstraEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for DijkstraEntry {
fn cmp(&self, other: &Self) -> Ordering {
other
.cost
.partial_cmp(&self.cost)
.unwrap_or(Ordering::Equal)
}
}
pub fn bfs_shortest_path(
store: &dyn Hexastore,
start_id: u64,
end_id: u64,
predicate_id: Option<u64>,
max_hops: usize,
bidirectional: bool,
) -> Option<PathResult> {
if start_id == end_id {
return Some(PathResult {
path: vec![start_id],
cost: 0.0,
hops: 0,
});
}
let mut visited: HashSet<u64> = HashSet::new();
let mut queue: VecDeque<(u64, Vec<u64>)> = VecDeque::new();
visited.insert(start_id);
queue.push_back((start_id, vec![start_id]));
while let Some((current, path)) = queue.pop_front() {
if path.len() > max_hops {
continue;
}
let neighbors = get_neighbors(store, current, predicate_id, bidirectional);
for neighbor_id in neighbors {
if neighbor_id == end_id {
let mut final_path = path.clone();
final_path.push(neighbor_id);
return Some(PathResult {
hops: final_path.len() - 1,
cost: (final_path.len() - 1) as f64,
path: final_path,
});
}
if !visited.contains(&neighbor_id) {
visited.insert(neighbor_id);
let mut new_path = path.clone();
new_path.push(neighbor_id);
queue.push_back((neighbor_id, new_path));
}
}
}
None
}
pub fn dijkstra_shortest_path<F>(
store: &dyn Hexastore,
start_id: u64,
end_id: u64,
predicate_id: Option<u64>,
weight_fn: F,
max_hops: usize,
) -> Option<PathResult>
where
F: Fn(u64, u64, u64) -> f64,
{
if start_id == end_id {
return Some(PathResult {
path: vec![start_id],
cost: 0.0,
hops: 0,
});
}
let mut distances: HashMap<u64, f64> = HashMap::new();
let mut predecessors: HashMap<u64, u64> = HashMap::new();
let mut hops_count: HashMap<u64, usize> = HashMap::new();
let mut heap: BinaryHeap<DijkstraEntry> = BinaryHeap::new();
distances.insert(start_id, 0.0);
hops_count.insert(start_id, 0);
heap.push(DijkstraEntry {
node_id: start_id,
cost: 0.0,
});
while let Some(DijkstraEntry { node_id, cost }) = heap.pop() {
if let Some(&best) = distances.get(&node_id)
&& cost > best
{
continue;
}
let current_hops = *hops_count.get(&node_id).unwrap_or(&0);
if current_hops >= max_hops {
continue;
}
if node_id == end_id {
return Some(reconstruct_path(&predecessors, start_id, end_id, cost));
}
let triples = get_outgoing_triples(store, node_id, predicate_id);
for triple in triples {
let neighbor_id = triple.object_id;
let edge_weight = weight_fn(triple.subject_id, triple.predicate_id, triple.object_id);
let new_cost = cost + edge_weight;
let is_better = distances
.get(&neighbor_id)
.is_none_or(|¤t| new_cost < current);
if is_better {
distances.insert(neighbor_id, new_cost);
predecessors.insert(neighbor_id, node_id);
hops_count.insert(neighbor_id, current_hops + 1);
heap.push(DijkstraEntry {
node_id: neighbor_id,
cost: new_cost,
});
}
}
}
if let Some(&cost) = distances.get(&end_id) {
return Some(reconstruct_path(&predecessors, start_id, end_id, cost));
}
None
}
fn get_neighbors(
store: &dyn Hexastore,
node_id: u64,
predicate_id: Option<u64>,
bidirectional: bool,
) -> Vec<u64> {
let mut neighbors = Vec::new();
let outgoing = store.query(Some(node_id), predicate_id, None);
for triple in outgoing {
neighbors.push(triple.object_id);
}
if bidirectional {
let incoming = store.query(None, predicate_id, Some(node_id));
for triple in incoming {
neighbors.push(triple.subject_id);
}
}
neighbors
}
fn get_outgoing_triples(
store: &dyn Hexastore,
node_id: u64,
predicate_id: Option<u64>,
) -> Vec<Triple> {
store.query(Some(node_id), predicate_id, None).collect()
}
fn reconstruct_path(
predecessors: &HashMap<u64, u64>,
start_id: u64,
end_id: u64,
cost: f64,
) -> PathResult {
let mut path = vec![end_id];
let mut current = end_id;
while current != start_id {
if let Some(&prev) = predecessors.get(¤t) {
path.push(prev);
current = prev;
} else {
break;
}
}
path.reverse();
PathResult {
hops: path.len() - 1,
cost,
path,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::memory::MemoryHexastore;
fn create_test_graph() -> MemoryHexastore {
let mut store = MemoryHexastore::new();
store.insert(&Triple::new(1, 10, 2)).unwrap(); store.insert(&Triple::new(2, 10, 3)).unwrap(); store.insert(&Triple::new(1, 11, 4)).unwrap(); store.insert(&Triple::new(4, 10, 3)).unwrap();
store
}
#[test]
fn test_bfs_direct_path() {
let store = create_test_graph();
let result = bfs_shortest_path(&store, 1, 2, None, 10, false);
assert!(result.is_some());
let path = result.unwrap();
assert_eq!(path.path, vec![1, 2]);
assert_eq!(path.hops, 1);
}
#[test]
fn test_bfs_two_hop_path() {
let store = create_test_graph();
let result = bfs_shortest_path(&store, 1, 3, None, 10, false);
assert!(result.is_some());
let path = result.unwrap();
assert_eq!(path.hops, 2);
assert!(path.path.len() == 3);
assert_eq!(path.path[0], 1);
assert_eq!(path.path[2], 3);
}
#[test]
fn test_bfs_no_path() {
let store = create_test_graph();
let result = bfs_shortest_path(&store, 3, 1, None, 10, false);
assert!(result.is_none());
}
#[test]
fn test_bfs_max_hops_limit() {
let store = create_test_graph();
let result = bfs_shortest_path(&store, 1, 3, None, 1, false);
assert!(result.is_none());
}
#[test]
fn test_bfs_same_node() {
let store = create_test_graph();
let result = bfs_shortest_path(&store, 1, 1, None, 10, false);
assert!(result.is_some());
let path = result.unwrap();
assert_eq!(path.path, vec![1]);
assert_eq!(path.hops, 0);
}
#[test]
fn test_dijkstra_uniform_weights() {
let store = create_test_graph();
let result = dijkstra_shortest_path(
&store,
1,
3,
None,
|_, _, _| 1.0, 10,
);
assert!(result.is_some());
let path = result.unwrap();
assert_eq!(path.cost, 2.0);
assert_eq!(path.hops, 2);
}
}