1use std::collections::hash_map::Entry;
2use std::collections::HashMap;
3
4use derivative::Derivative;
5
6use crate::node::{Node, Vec3};
7use crate::union_find::UnionFind;
8
9#[derive(Derivative)]
10#[derivative(Clone, PartialEq, Eq, Hash)]
11pub struct Edge {
12 index: usize,
13 pub source: usize,
14 pub destination: usize,
15 #[derivative(PartialEq = "ignore")]
16 #[derivative(Hash = "ignore")]
17 pub weight: f32,
18}
19
20impl Edge {
21 pub fn from(index: usize, source: usize, destination: usize, weight: f32) -> Edge {
22 return Edge {
23 index,
24 source,
25 destination,
26 weight,
27 };
28 }
29}
30
31pub struct Graph {
32 pub edges_lookup: HashMap<usize, Edge>,
33 pub nodes_lookup: HashMap<usize, Node>,
34 pub node_position_lookup: Option<HashMap<usize, Vec3>>,
35 pub edges: Vec<Edge>,
36 pub node_count: usize,
37}
38
39impl Graph {
40 pub fn from(edges: Vec<Edge>) -> Graph {
41 let mut nodes: HashMap<usize, Node> = HashMap::new();
42 let edge_map = edges.iter().map(|edge| {
43 match nodes.entry(edge.source) {
44 Entry::Vacant(entry) => { entry.insert(Node::from(edge.source, vec![edge.clone()])); }
45 Entry::Occupied(mut entry) => { entry.get_mut().edges.push(edge.clone()); }
46 }
47
48 match nodes.entry(edge.destination) {
49 Entry::Vacant(entry) => { entry.insert(Node::from(edge.destination, vec![])); }
50 Entry::Occupied(mut _entry) => {}
51 }
52
53 return (edge.index, edge.clone());
54 }).collect();
55
56 let node_size: usize = nodes.keys().len();
57
58 Graph {
59 nodes_lookup: nodes,
60 edges_lookup: edge_map,
61 node_position_lookup: None,
62 edges,
63 node_count: node_size,
64 }
65 }
66
67 pub fn from_adjacency_matrix(matrix: &[&[f32]]) -> Graph {
68 let mut vec: Vec<Edge> = Vec::new();
69 for (row, array) in matrix.iter().enumerate() {
70 for (col, weight) in array.iter().enumerate() {
71 if !weight.eq(&(0.0 as f32)) {
72 vec.push(Edge::from(row * array.len() + col, row, col, weight.clone()));
73 }
74 }
75 }
76
77 return Graph::from(vec);
78 }
79
80 pub fn sorted_by_weight_asc(&self) -> Vec<Edge> {
81 let mut sorted_edges = self.edges.clone();
82 sorted_edges.sort_by(|edge1, edge2|
83 edge1.weight.total_cmp(&edge2.weight));
84 return sorted_edges;
85 }
86
87 pub fn offer_positions(&mut self, node_positions: HashMap<usize, Vec3>) {
88 self.node_position_lookup = Some(node_positions);
89 }
90
91 pub fn verify_positions(&self) {
92 return match &self.node_position_lookup {
93 None => panic!("You must offer node positions to the graph before using this\
94 heuristic. Make sure to provide a Vec3 for every node id."),
95 _ => {}
96 };
97 }
98
99 pub fn position_is_set(&self) -> bool {
100 return self.node_position_lookup.is_some();
101 }
102
103 pub fn get_position(&self, node_id: &usize) -> &Vec3 {
104 match &self.node_position_lookup {
105 None => panic!("You must offer node positions to the graph before using this heuristic."),
106 Some(positions) => {
107 return match positions.get(node_id) {
108 None => panic!("Node position missing for given node id: {node_id}"),
109 Some(position) => position
110 };
111 }
112 };
113 }
114}
115
116pub fn minimum_spanning(graph: Graph) -> Graph {
117 let edges = graph.sorted_by_weight_asc();
118 let mut union_find = UnionFind::from(graph.node_count);
119 let mut min_edges = Vec::new();
120
121 for edge in edges {
122 if !union_find.connected(edge.source, edge.destination) {
123 union_find.unify(edge.source, edge.destination);
124 min_edges.push(edge);
125 }
126 }
127
128 return Graph::from(min_edges);
129}
130
131
132#[test]
133fn mst_should_return_graph() {
134 let edge = Edge::from(0, 0, 1, 0.5);
135 let graph = Graph::from(Vec::from([edge]));
136 let min_graph = minimum_spanning(graph);
137
138 assert_eq!(1, min_graph.edges_lookup.keys().count());
139 assert_eq!(2, min_graph.nodes_lookup.keys().count());
140}
141
142#[test]
143fn mst_should_return_graph_with_source_node_having_one_edge() {
144 let edge = Edge::from(0, 0, 1, 0.5);
145 let graph = Graph::from(Vec::from([edge]));
146 let min_graph = minimum_spanning(graph);
147
148 let source_node = min_graph.nodes_lookup.get(&0).unwrap();
149 assert_eq!(1, source_node.edges.to_vec().len());
150 assert!(min_graph.nodes_lookup.contains_key(&0));
151 assert!(min_graph.nodes_lookup.contains_key(&1));
152}
153
154#[test]
155fn mst_should_return_minimum_spanning_tree() {
156 let edge1 = Edge::from(0, 1, 2, 0.0);
157 let edge2 = Edge::from(1, 2, 3, 0.1428571429);
158 let edge3 = Edge::from(2, 1, 0, 0.2857142857);
159 let edge4 = Edge::from(3, 3, 4, 0.2857142857);
160 let edge5 = Edge::from(4, 1, 3, 0.4285714286);
161 let edge6 = Edge::from(5, 0, 3, 0.8571428571);
162 let edge7 = Edge::from(6, 0, 4, 1.0);
163
164
165 let graph = Graph::from(Vec::from([edge1, edge2, edge3, edge4, edge5, edge6, edge7]));
166 let min_graph = minimum_spanning(graph);
167
168 let mut total_cost: f32 = 0.0;
169 for edge in min_graph.edges {
170 total_cost += edge.weight;
171 }
172
173 assert_eq!(0.7142857143, total_cost);
174}
175
176#[test]
177fn edge_from_should_construct_edge() {
178 let edge = Edge::from(0, 2, 3, 0.5);
179
180 assert_eq!(0, edge.index);
181 assert_eq!(2, edge.source);
182 assert_eq!(3, edge.destination);
183 assert_eq!(0.5, edge.weight);
184}
185
186#[test]
187fn sorted_by_weight_asc_should_return_sorted_vec() {
188 let edge3 = Edge::from(2, 2, 3, 0.3);
189 let edge4 = Edge::from(3, 2, 3, 0.7);
190 let edge1 = Edge::from(0, 2, 3, 0.5);
191 let edge2 = Edge::from(1, 2, 3, 0.2);
192
193 let graph = Graph::from(Vec::from([edge1, edge2, edge3, edge4]));
194 let sorted_edges = graph.sorted_by_weight_asc();
195
196 assert_eq!(0.2, sorted_edges[0].weight);
197 assert_eq!(0.3, sorted_edges[1].weight);
198 assert_eq!(0.5, sorted_edges[2].weight);
199 assert_eq!(0.7, sorted_edges[3].weight);
200}
201
202#[test]
203fn create_graph_from_adjacency_matrix() {
204 let matrix: &[&[f32]] = &[
205 &[0.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 8.0, 0.0],
206 &[4.0, 0.0, 8.0, 0.0, 0.0, 0.0, 0.0, 11.0, 0.0],
207 &[0.0, 8.0, 0.0, 7.0, 0.0, 4.0, 0.0, 0.0, 2.0],
208 &[0.0, 0.0, 7.0, 0.0, 9.0, 14.0, 0.0, 0.0, 0.0],
209 &[0.0, 0.0, 0.0, 9.0, 0.0, 10.0, 0.0, 0.0, 0.0],
210 &[0.0, 0.0, 4.0, 14.0, 10.0, 0.0, 2.0, 0.0, 0.0],
211 &[0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 0.0, 1.0, 6.0],
212 &[8.0, 11.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 7.0],
213 &[0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 6.0, 7.0, 0.0]
214 ];
215
216 let graph = Graph::from_adjacency_matrix(matrix);
217
218 assert_eq!(28, graph.edges.len());
219 assert_eq!(2, graph.nodes_lookup.get(&0).unwrap().edges.len());
220 assert_eq!(3, graph.nodes_lookup.get(&8).unwrap().edges.len());
221 assert_eq!(2.0, graph.nodes_lookup.get(&8).unwrap().edges[0].weight);
222}
223
224#[test]
225fn create_initial_graph_should_not_have_node_positions() {
226 let edge = Edge::from(0, 2, 3, 0.5);
227 let graph = Graph::from(Vec::from([edge]));
228
229 assert!(graph.node_position_lookup.is_none());
230}
231
232#[test]
233fn offer_node_positions_should_set_node_positions() {
234 let edge = Edge::from(0, 2, 3, 0.5);
235 let mut graph = Graph::from(Vec::from([edge.clone()]));
236
237 let mut node_positions: HashMap<usize, Vec3> = HashMap::new();
238 node_positions.insert((&edge).source.clone(), Vec3::from(0.3, 0.2, 0.0));
239 node_positions.insert((&edge).destination.clone(), Vec3::from(0.1, 0.5, 0.0));
240
241 graph.offer_positions(node_positions);
242
243 assert!(graph.node_position_lookup.is_some());
244
245 let position_lookup = graph.node_position_lookup.unwrap();
246 assert_eq!(0.3, position_lookup.get(&(&edge).source).unwrap().x);
247 assert_eq!(0.1, position_lookup.get(&(&edge).destination).unwrap().x);
248}