rustworkx_core/generators/
star_graph.rs1use petgraph::data::{Build, Create};
14use petgraph::visit::{Data, NodeIndexable};
15
16use super::utils::get_num_nodes;
17use super::InvalidInputError;
18
19pub fn star_graph<G, T, F, H, M>(
60 num_nodes: Option<usize>,
61 weights: Option<Vec<T>>,
62 mut default_node_weight: F,
63 mut default_edge_weight: H,
64 inward: bool,
65 bidirectional: bool,
66) -> Result<G, InvalidInputError>
67where
68 G: Build + Create + Data<NodeWeight = T, EdgeWeight = M> + NodeIndexable,
69 F: FnMut() -> T,
70 H: FnMut() -> M,
71{
72 if weights.is_none() && num_nodes.is_none() {
73 return Err(InvalidInputError {});
74 }
75 let node_len = get_num_nodes(&num_nodes, &weights);
76 let num_edges = if bidirectional {
77 2 * node_len
78 } else {
79 node_len
80 };
81 let mut graph = G::with_capacity(node_len, num_edges);
82 if node_len == 0 {
83 return Ok(graph);
84 }
85
86 match weights {
87 Some(weights) => {
88 for weight in weights {
89 graph.add_node(weight);
90 }
91 }
92 None => {
93 for _ in 0..node_len {
94 graph.add_node(default_node_weight());
95 }
96 }
97 };
98 let zero_index = graph.from_index(0);
99 for a in 1..node_len {
100 let node = graph.from_index(a);
101 if bidirectional {
102 graph.add_edge(node, zero_index, default_edge_weight());
103 graph.add_edge(zero_index, node, default_edge_weight());
104 } else if inward {
105 graph.add_edge(node, zero_index, default_edge_weight());
106 } else {
107 graph.add_edge(zero_index, node, default_edge_weight());
108 }
109 }
110 Ok(graph)
111}
112
113#[cfg(test)]
114mod tests {
115 use crate::generators::star_graph;
116 use crate::generators::InvalidInputError;
117 use crate::petgraph;
118 use crate::petgraph::visit::EdgeRef;
119
120 #[test]
121 fn test_with_weights() {
122 let g: petgraph::graph::UnGraph<usize, ()> =
123 star_graph(None, Some(vec![0, 1, 2, 3]), || 4, || (), false, false).unwrap();
124 assert_eq!(
125 vec![(0, 1), (0, 2), (0, 3)],
126 g.edge_references()
127 .map(|edge| (edge.source().index(), edge.target().index()))
128 .collect::<Vec<(usize, usize)>>(),
129 );
130 assert_eq!(
131 vec![0, 1, 2, 3],
132 g.node_weights().copied().collect::<Vec<usize>>(),
133 );
134 }
135
136 #[test]
137 fn test_with_weights_inward() {
138 let g: petgraph::graph::UnGraph<usize, ()> =
139 star_graph(None, Some(vec![0, 1, 2, 3]), || 4, || (), true, false).unwrap();
140 assert_eq!(
141 vec![(1, 0), (2, 0), (3, 0)],
142 g.edge_references()
143 .map(|edge| (edge.source().index(), edge.target().index()))
144 .collect::<Vec<(usize, usize)>>(),
145 );
146 assert_eq!(
147 vec![0, 1, 2, 3],
148 g.node_weights().copied().collect::<Vec<usize>>(),
149 );
150 }
151
152 #[test]
153 fn test_bidirectional() {
154 let g: petgraph::graph::DiGraph<(), ()> =
155 star_graph(Some(4), None, || (), || (), false, true).unwrap();
156 assert_eq!(
157 vec![(1, 0), (0, 1), (2, 0), (0, 2), (3, 0), (0, 3)],
158 g.edge_references()
159 .map(|edge| (edge.source().index(), edge.target().index()))
160 .collect::<Vec<(usize, usize)>>(),
161 );
162 }
163
164 #[test]
165 fn test_error() {
166 match star_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>(
167 None,
168 None,
169 || (),
170 || (),
171 false,
172 false,
173 ) {
174 Ok(_) => panic!("Returned a non-error"),
175 Err(e) => assert_eq!(e, InvalidInputError),
176 };
177 }
178}