use crate::{Graph, Vertex};
use std::collections::HashMap;
use std::marker::PhantomData;
use crate::error::{GraphError, ErrorKind};
use std::ops::Add;
use std::cmp::min;
pub struct Distance<'a, T> {
dist: Vec<HashMap<usize, T>>,
phantom: PhantomData<&'a T>,
}
impl<'a, T> Distance<'a, T> {
pub fn get_distance(&self, from: &Vertex<T>, to: &Vertex<T>) -> Option<&T> {
self.dist[from.id].get(&to.id)
}
}
pub fn floid<T>(graph: &Graph<T>) -> Result<Distance<T>, GraphError> where T: Default + Copy + Ord + PartialEq + Add<Output = T> {
let mut distance = Distance{
dist: vec![HashMap::new(); graph.size()],
phantom: PhantomData
};
for (idx, vertex) in graph.adj.iter().enumerate() {
distance.dist[idx].insert(idx, T::default());
for edge in vertex.edges.iter() {
distance.dist[idx].insert(edge.to, edge.weight);
}
}
for i in 0..graph.size() {
for j in 0 .. graph.size() {
for k in 0..graph.size() {
if let (Some(&first), Some(&second)) = (distance.dist[j].get(&i), distance.dist[i].get(&k)) {
let weight = distance.dist[j].entry(k).or_insert(first + second);
*weight = min(*weight, first + second);
if j == k && *weight < T::default() {
return Err(GraphError::Regular(ErrorKind::ExistsCycleNegativeWeightInVertexId(j)));
}
};
}
}
}
Ok(distance)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_floid() {
let mut graph = Graph::new(4);
graph.add_edge(1, 2, 1).unwrap();
graph.add_edge(1, 3, 6).unwrap();
graph.add_edge(2, 3, 4).unwrap();
graph.add_edge(2, 4, 1).unwrap();
graph.add_edge(4, 3, 1).unwrap();
let dist = floid(&graph).unwrap();
assert_eq!(*dist.get_distance(graph.get_vertex(2).unwrap(), graph.get_vertex(4).unwrap()).unwrap(), 1);
assert_eq!(*dist.get_distance(graph.get_vertex(1).unwrap(), graph.get_vertex(3).unwrap()).unwrap(), 3);
}
#[test]
#[should_panic]
fn test_floid_exists_cycle_negative_weight() {
let mut graph = Graph::new(4);
graph.add_edge(1, 2, 1).unwrap();
graph.add_edge(2, 3, -5).unwrap();
graph.add_edge(3, 1, -7).unwrap();
floid(&graph).unwrap();
}
}