1use crate::types::{EdgeWeight, NodeId, NodeWeight};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13#[derive(Clone, Debug, Serialize, Deserialize)]
20pub struct Graph {
21 num_nodes: u32,
23 num_edges: u32,
25 adjacency: Vec<Vec<NodeId>>,
27 node_weights: Option<Vec<NodeWeight>>,
29 edge_weights: Option<HashMap<(NodeId, NodeId), EdgeWeight>>,
31 is_directed: bool,
33}
34
35impl Graph {
36 pub fn new(num_nodes: u32, is_directed: bool) -> Self {
38 Graph {
39 num_nodes,
40 num_edges: 0,
41 adjacency: vec![Vec::new(); num_nodes as usize],
42 node_weights: None,
43 edge_weights: None,
44 is_directed,
45 }
46 }
47
48 #[inline]
50 pub fn num_nodes(&self) -> u32 {
51 self.num_nodes
52 }
53
54 #[inline]
56 pub fn num_edges(&self) -> u32 {
57 self.num_edges
58 }
59
60 #[inline]
62 pub fn is_directed(&self) -> bool {
63 self.is_directed
64 }
65
66 pub fn add_edge(&mut self, from: NodeId, to: NodeId, weight: Option<EdgeWeight>) {
73 assert!(from.0 < self.num_nodes, "Node {} out of bounds", from.0);
74 assert!(to.0 < self.num_nodes, "Node {} out of bounds", to.0);
75
76 if !self.adjacency[from.as_usize()].contains(&to) {
78 self.adjacency[from.as_usize()].push(to);
79 self.num_edges += 1;
80 }
81
82 if let Some(w) = weight {
83 if self.edge_weights.is_none() {
84 self.edge_weights = Some(HashMap::new());
85 }
86 self.edge_weights
87 .as_mut()
88 .unwrap()
89 .insert((from, to), w);
90 }
91
92 if !self.is_directed && from != to {
94 if !self.adjacency[to.as_usize()].contains(&from) {
95 self.adjacency[to.as_usize()].push(from);
96 }
98 if let Some(w) = weight {
99 self.edge_weights
100 .as_mut()
101 .unwrap()
102 .insert((to, from), w);
103 }
104 }
105 }
106
107 #[inline]
109 pub fn neighbors(&self, node: NodeId) -> &[NodeId] {
110 &self.adjacency[node.as_usize()]
111 }
112
113 pub fn edge_weight(&self, from: NodeId, to: NodeId) -> Option<EdgeWeight> {
115 self.edge_weights
116 .as_ref()
117 .and_then(|weights| weights.get(&(from, to)).copied())
118 .or(Some(EdgeWeight::default()))
119 }
120
121 pub fn set_node_weights(&mut self, weights: Vec<NodeWeight>) {
123 assert_eq!(
124 weights.len() as u32, self.num_nodes,
125 "Node weights length must match num_nodes"
126 );
127 self.node_weights = Some(weights);
128 }
129
130 pub fn node_weight(&self, node: NodeId) -> Option<NodeWeight> {
132 self.node_weights
133 .as_ref()
134 .map(|weights| weights[node.as_usize()])
135 }
136
137 pub fn nodes(&self) -> impl Iterator<Item = NodeId> {
139 (0..self.num_nodes).map(NodeId::new)
140 }
141
142 pub fn edges(&self) -> impl Iterator<Item = (NodeId, NodeId)> + '_ {
144 self.adjacency
145 .iter()
146 .enumerate()
147 .flat_map(|(from_idx, neighbors)| {
148 let from = NodeId(from_idx as u32);
149 neighbors.iter().map(move |&to| (from, to))
150 })
151 }
152}
153
154#[derive(Default)]
156pub struct GraphBuilder {
157 num_nodes: u32,
158 is_directed: bool,
159 edges: Vec<(NodeId, NodeId, Option<EdgeWeight>)>,
160}
161
162impl GraphBuilder {
163 pub fn new(num_nodes: u32) -> Self {
165 GraphBuilder {
166 num_nodes,
167 is_directed: false,
168 edges: Vec::new(),
169 }
170 }
171
172 pub fn directed(mut self, directed: bool) -> Self {
174 self.is_directed = directed;
175 self
176 }
177
178 pub fn add_edge(mut self, from: NodeId, to: NodeId) -> Self {
180 self.edges.push((from, to, None));
181 self
182 }
183
184 pub fn add_weighted_edge(mut self, from: NodeId, to: NodeId, weight: f64) -> Self {
186 self.edges.push((from, to, Some(EdgeWeight::new(weight))));
187 self
188 }
189
190 pub fn build(self) -> Graph {
192 let mut graph = Graph::new(self.num_nodes, self.is_directed);
193 for (from, to, weight) in self.edges {
194 graph.add_edge(from, to, weight);
195 }
196 graph
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn test_graph_creation() {
206 let graph = Graph::new(5, true);
207 assert_eq!(graph.num_nodes(), 5);
208 assert_eq!(graph.num_edges(), 0);
209 assert!(graph.is_directed());
210 }
211
212 #[test]
213 fn test_add_edge() {
214 let mut graph = Graph::new(5, true);
215 graph.add_edge(NodeId(0), NodeId(1), None);
216 assert_eq!(graph.num_edges(), 1);
217 assert!(graph.neighbors(NodeId(0)).contains(&NodeId(1)));
218 }
219
220 #[test]
221 fn test_undirected_graph() {
222 let mut graph = Graph::new(5, false);
223 graph.add_edge(NodeId(0), NodeId(1), None);
224 assert_eq!(graph.num_edges(), 1);
225 assert!(graph.neighbors(NodeId(0)).contains(&NodeId(1)));
226 assert!(graph.neighbors(NodeId(1)).contains(&NodeId(0)));
227 }
228
229 #[test]
230 fn test_builder_pattern() {
231 let graph = GraphBuilder::new(4)
232 .directed(true)
233 .add_edge(NodeId(0), NodeId(1))
234 .add_weighted_edge(NodeId(1), NodeId(2), 2.5)
235 .build();
236
237 assert_eq!(graph.num_nodes(), 4);
238 assert_eq!(graph.num_edges(), 2);
239 assert_eq!(graph.edge_weight(NodeId(1), NodeId(2)), Some(EdgeWeight(2.5)));
240 }
241
242 #[test]
243 fn test_node_weights() {
244 let mut graph = Graph::new(3, false);
245 graph.set_node_weights(vec![
246 NodeWeight(1.0),
247 NodeWeight(2.0),
248 NodeWeight(3.0),
249 ]);
250
251 assert_eq!(graph.node_weight(NodeId(0)), Some(NodeWeight(1.0)));
252 assert_eq!(graph.node_weight(NodeId(1)), Some(NodeWeight(2.0)));
253 }
254}