rustworkx_core/generators/
star_graph.rs

1// Licensed under the Apache License, Version 2.0 (the "License"); you may
2// not use this file except in compliance with the License. You may obtain
3// a copy of the License at
4//
5//     http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10// License for the specific language governing permissions and limitations
11// under the License.
12
13use petgraph::data::{Build, Create};
14use petgraph::visit::{Data, NodeIndexable};
15
16use super::utils::get_num_nodes;
17use super::InvalidInputError;
18
19/// Generate a star graph
20///
21/// Arguments:
22///
23/// * `num_nodes` - The number of nodes to create a star graph for. Either this or
24///   `weights` must be specified. If both this and `weights` are specified, weights
25///   will take priority and this argument will be ignored
26/// * `weights` - A `Vec` of node weight objects.
27/// * `default_node_weight` - A callable that will return the weight to use
28///   for newly created nodes. This is ignored if `weights` is specified.
29/// * `default_edge_weight` - A callable that will return the weight object
30///   to use for newly created edges.
31/// * `inward` - If set `true` the nodes will be directed towards the
32///   center node. This parameter is ignored if `bidirectional` is set to
33///   `true`.
34/// * `bidirectional` - Whether edges are added bidirectionally. If set to
35///   `true` then for any edge `(u, v)` an edge `(v, u)` will also be added.
36///   If the graph is undirected this will result in a parallel edge.
37///
38/// # Example
39/// ```rust
40/// use rustworkx_core::petgraph;
41/// use rustworkx_core::generators::star_graph;
42/// use rustworkx_core::petgraph::visit::EdgeRef;
43///
44/// let g: petgraph::graph::UnGraph<(), ()> = star_graph(
45///     Some(4),
46///     None,
47///     || {()},
48///     || {()},
49///     false,
50///     false
51/// ).unwrap();
52/// assert_eq!(
53///     vec![(0, 1), (0, 2), (0, 3)],
54///     g.edge_references()
55///         .map(|edge| (edge.source().index(), edge.target().index()))
56///         .collect::<Vec<(usize, usize)>>(),
57/// )
58/// ```
59pub fn star_graph<G, T, F, H, M>(
60    num_nodes: Option<usize>,
61    weights: Option<Vec<T>>,
62    mut default_node_weight: F,
63    mut default_edge_weight: H,
64    inward: bool,
65    bidirectional: bool,
66) -> Result<G, InvalidInputError>
67where
68    G: Build + Create + Data<NodeWeight = T, EdgeWeight = M> + NodeIndexable,
69    F: FnMut() -> T,
70    H: FnMut() -> M,
71{
72    if weights.is_none() && num_nodes.is_none() {
73        return Err(InvalidInputError {});
74    }
75    let node_len = get_num_nodes(&num_nodes, &weights);
76    let num_edges = if bidirectional {
77        2 * node_len
78    } else {
79        node_len
80    };
81    let mut graph = G::with_capacity(node_len, num_edges);
82    if node_len == 0 {
83        return Ok(graph);
84    }
85
86    match weights {
87        Some(weights) => {
88            for weight in weights {
89                graph.add_node(weight);
90            }
91        }
92        None => {
93            for _ in 0..node_len {
94                graph.add_node(default_node_weight());
95            }
96        }
97    };
98    let zero_index = graph.from_index(0);
99    for a in 1..node_len {
100        let node = graph.from_index(a);
101        if bidirectional {
102            graph.add_edge(node, zero_index, default_edge_weight());
103            graph.add_edge(zero_index, node, default_edge_weight());
104        } else if inward {
105            graph.add_edge(node, zero_index, default_edge_weight());
106        } else {
107            graph.add_edge(zero_index, node, default_edge_weight());
108        }
109    }
110    Ok(graph)
111}
112
113#[cfg(test)]
114mod tests {
115    use crate::generators::star_graph;
116    use crate::generators::InvalidInputError;
117    use crate::petgraph;
118    use crate::petgraph::visit::EdgeRef;
119
120    #[test]
121    fn test_with_weights() {
122        let g: petgraph::graph::UnGraph<usize, ()> =
123            star_graph(None, Some(vec![0, 1, 2, 3]), || 4, || (), false, false).unwrap();
124        assert_eq!(
125            vec![(0, 1), (0, 2), (0, 3)],
126            g.edge_references()
127                .map(|edge| (edge.source().index(), edge.target().index()))
128                .collect::<Vec<(usize, usize)>>(),
129        );
130        assert_eq!(
131            vec![0, 1, 2, 3],
132            g.node_weights().copied().collect::<Vec<usize>>(),
133        );
134    }
135
136    #[test]
137    fn test_with_weights_inward() {
138        let g: petgraph::graph::UnGraph<usize, ()> =
139            star_graph(None, Some(vec![0, 1, 2, 3]), || 4, || (), true, false).unwrap();
140        assert_eq!(
141            vec![(1, 0), (2, 0), (3, 0)],
142            g.edge_references()
143                .map(|edge| (edge.source().index(), edge.target().index()))
144                .collect::<Vec<(usize, usize)>>(),
145        );
146        assert_eq!(
147            vec![0, 1, 2, 3],
148            g.node_weights().copied().collect::<Vec<usize>>(),
149        );
150    }
151
152    #[test]
153    fn test_bidirectional() {
154        let g: petgraph::graph::DiGraph<(), ()> =
155            star_graph(Some(4), None, || (), || (), false, true).unwrap();
156        assert_eq!(
157            vec![(1, 0), (0, 1), (2, 0), (0, 2), (3, 0), (0, 3)],
158            g.edge_references()
159                .map(|edge| (edge.source().index(), edge.target().index()))
160                .collect::<Vec<(usize, usize)>>(),
161        );
162    }
163
164    #[test]
165    fn test_error() {
166        match star_graph::<petgraph::graph::DiGraph<(), ()>, (), _, _, ()>(
167            None,
168            None,
169            || (),
170            || (),
171            false,
172            false,
173        ) {
174            Ok(_) => panic!("Returned a non-error"),
175            Err(e) => assert_eq!(e, InvalidInputError),
176        };
177    }
178}