spokes/
algorithms.rs

1//! Algorithms for use with networks.
2
3use std::{
4    cmp::Reverse,
5    collections::{HashMap, HashSet, VecDeque},
6    fmt::Debug,
7    hash::Hash,
8    io::Write,
9    mem,
10    ops::Add,
11};
12
13use num::{Unsigned, Zero};
14use priority_queue::PriorityQueue;
15
16use crate::{arc_storage::ArcStorage, search::Direction, ArcInfo, Network};
17
18mod preflowpush;
19pub use self::preflowpush::preflow_push;
20
21/// An error value when the underlying [`Network`] is cyclic.
22#[derive(Clone, Copy, Debug, PartialEq, Eq)]
23pub struct Cyclic;
24
25/// An error value when the underlying [`Network`] has a negative cycle.
26#[derive(Clone, Copy, Debug, PartialEq, Eq)]
27pub struct NegativeCycle;
28
29/// Determine a [Topological Sorting](https://en.wikipedia.org/wiki/Topological_sorting) of a
30/// network.
31///
32/// # Errors
33/// If the network is cyclic, then the is no valid Topological Sorting, so an `Err(Cyclic)` is
34/// returned.
35///
36/// # Example
37///
38/// The following shows how to generate a topological sorting for a network where one is possible.
39///
40/// ```rust
41/// use spokes::{Network, algorithms::{topological_sorting, valid_topological_sorting}, ArcStorage};
42///
43/// let mut network: Network<usize, (), (), _> = Network::new();
44/// network.add_nodes((0..6).map(|i| (i, ())));
45/// network.add_arcs([
46///     (5, 0),
47///     (5, 2),
48///     (4, 0),
49///     (4, 1),
50///     (2, 3),
51///     (3, 1),
52/// ]);
53///
54///
55/// let sorting = topological_sorting(&network);
56/// assert!(sorting.is_ok());
57/// assert!(valid_topological_sorting(&network, &(sorting.unwrap())));
58/// ```
59/// Here is a case where the network is cyclic and, therefore, has no topological sorting.
60///
61/// ```rust
62/// use spokes::{Network, algorithms::{topological_sorting, Cyclic}, ArcStorage};
63///
64/// let mut network: Network<usize, (), (), _> = Network::new();
65/// network.add_nodes((0..3).map(|i| (i, ())));
66/// network.add_arcs([
67///     (0, 1),
68///     (1, 2),
69///     (2, 1),
70/// ]);
71///
72/// assert_eq!(topological_sorting(&network), Err(Cyclic));
73/// ```
74pub fn topological_sorting<I, NA, AA, AS>(
75    network: &Network<I, NA, AA, AS>,
76) -> Result<Vec<I>, Cyclic>
77where
78    AS: ArcStorage<I, AA>,
79    I: Hash + Eq + Copy + std::fmt::Debug,
80{
81    // determine indegree for all nodes
82    let mut indegrees: HashMap<I, usize> = network
83        .iter_nodes()
84        .filter_map(|(node, _)| {
85            let indegree = network.reverse_arcs(node).count();
86            if indegree == 0 {
87                None
88            } else {
89                Some((*node, indegree))
90            }
91        })
92        .collect();
93
94    let mut indegree_zero: Vec<I> = network
95        .iter_nodes()
96        .filter_map(|(node, _)| {
97            if indegrees.contains_key(node) {
98                None
99            } else {
100                Some(node)
101            }
102        })
103        .copied()
104        .collect();
105
106    let mut sorting = Vec::with_capacity(network.n_nodes());
107
108    while !indegree_zero.is_empty() {
109        let this_degree = mem::take(&mut indegree_zero);
110
111        for node in this_degree {
112            for neigh in network.forward_arcs(&node) {
113                *indegrees.get_mut(&neigh.head).expect("Should be present") -= 1;
114                if *indegrees.get(&neigh.head).expect("Should be present") == 0 {
115                    indegree_zero.push(neigh.head);
116                    indegrees.remove(&neigh.head);
117                }
118            }
119            sorting.push(node);
120        }
121    }
122
123    if indegrees.is_empty() {
124        Ok(sorting)
125    } else {
126        Err(Cyclic)
127    }
128}
129
130/// Check if a given topological sorting is valid for a network.
131///
132/// # Example
133/// ```rust
134///
135/// use spokes::{Network, algorithms::valid_topological_sorting, ArcStorage};
136///
137/// let mut network: Network<usize, (), (), _> = Network::new();
138/// network.add_nodes((0..6).map(|i| (i, ())));
139/// network.add_arcs([
140///     (5, 0),
141///     (5, 2),
142///     (4, 0),
143///     (4, 1),
144///     (2, 3),
145///     (3, 1),
146/// ]);
147///
148/// assert!(valid_topological_sorting(&network, &[4, 5, 0, 2, 3, 1]));
149/// ```
150pub fn valid_topological_sorting<I, NA, AA, AS>(
151    network: &Network<I, NA, AA, AS>,
152    ordering: &[I],
153) -> bool
154where
155    AS: ArcStorage<I, AA>,
156    I: Hash + Eq + Copy + std::fmt::Debug,
157{
158    let labels: HashMap<I, usize> = ordering.iter().enumerate().map(|(i, n)| (*n, i)).collect();
159
160    network.arc_iter().all(|arc| {
161        let a = labels.get(&arc.tail);
162        let b = labels.get(&arc.head);
163
164        a.and_then(|a| b.map(|b| a < b)).unwrap_or(false)
165    })
166}
167
168/// Representation of Finite or Infinite distances in a network
169#[derive(Clone, Copy, Debug, PartialEq, Eq)]
170pub enum Distance<T> {
171    /// A finite distance to the root node.
172    Finite(T),
173    /// Value for a node who is not connected to the root node.
174    Infinite,
175}
176
177impl<T> Distance<T> {
178    /// Return the finite value if the value is finite
179    pub fn finite_value(&self) -> Option<&T> {
180        match self {
181            Distance::Finite(x) => Some(x),
182            Distance::Infinite => None,
183        }
184    }
185
186    /// Map the interior value of the distance type
187    pub fn map<U, F: Fn(T) -> U>(self, f: F) -> Distance<U> {
188        match self {
189            Distance::Finite(x) => Distance::Finite(f(x)),
190            Distance::Infinite => Distance::Infinite,
191        }
192    }
193}
194
195impl<T: Add<Output = T>> Add for Distance<T> {
196    type Output = Self;
197
198    fn add(self, rhs: Self) -> Self::Output {
199        match (self, rhs) {
200            (Distance::Finite(a), Distance::Finite(b)) => Distance::Finite(a + b),
201            _ => Distance::Infinite,
202        }
203    }
204}
205
206impl<T: Zero> Zero for Distance<T> {
207    fn zero() -> Self {
208        Distance::Finite(T::zero())
209    }
210
211    fn is_zero(&self) -> bool {
212        if let Distance::Finite(x) = self {
213            x.is_zero()
214        } else {
215            false
216        }
217    }
218}
219
220impl<T: Add<T, Output = T>> Add<T> for Distance<T> {
221    type Output = Distance<T>;
222
223    fn add(self, rhs: T) -> Self::Output {
224        match self {
225            Distance::Finite(d) => Distance::Finite(d + rhs),
226            Distance::Infinite => Distance::Infinite,
227        }
228    }
229}
230
231impl<T: PartialOrd> PartialOrd for Distance<T> {
232    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
233        match (self, other) {
234            (Distance::Finite(a), Distance::Finite(b)) => a.partial_cmp(b),
235            (Distance::Finite(_), Distance::Infinite) => Some(std::cmp::Ordering::Less),
236            (Distance::Infinite, Distance::Finite(_)) => Some(std::cmp::Ordering::Greater),
237            (Distance::Infinite, Distance::Infinite) => Some(std::cmp::Ordering::Equal),
238        }
239    }
240}
241
242impl<T: Ord + Eq> Ord for Distance<T> {
243    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
244        match (self, other) {
245            (Distance::Finite(a), Distance::Finite(b)) => a.cmp(b),
246            (Distance::Finite(_), Distance::Infinite) => std::cmp::Ordering::Less,
247            (Distance::Infinite, Distance::Finite(_)) => std::cmp::Ordering::Greater,
248            (Distance::Infinite, Distance::Infinite) => std::cmp::Ordering::Equal,
249        }
250    }
251}
252
253impl<T: std::fmt::Display> std::fmt::Display for Distance<T> {
254    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255        match self {
256            Distance::Finite(x) => write!(f, "{}", x),
257            Distance::Infinite => write!(f, "INFINITE"),
258        }
259    }
260}
261
262impl<T: std::ops::Add<Output = T> + Copy> std::ops::Add for &Distance<T> {
263    type Output = Distance<T>;
264
265    fn add(self, rhs: Self) -> Self::Output {
266        if let (Distance::Finite(x), Distance::Finite(y)) = (self, rhs) {
267            Distance::Finite(*x + *y)
268        } else {
269            Distance::Infinite
270        }
271    }
272}
273
274impl<T: std::ops::AddAssign + Copy> std::ops::AddAssign for Distance<T> {
275    fn add_assign(&mut self, rhs: Self) {
276        match (self, rhs) {
277            (Distance::Finite(x), Distance::Finite(y)) => {
278                *x += y;
279            }
280            (s, _) => {
281                *s = Distance::Infinite;
282            }
283        }
284    }
285}
286
287impl<T: std::ops::SubAssign + Copy> std::ops::SubAssign for Distance<T> {
288    fn sub_assign(&mut self, rhs: Self) {
289        match (self, rhs) {
290            (Distance::Finite(x), Distance::Finite(y)) => {
291                *x -= y;
292            }
293            (_, Distance::Infinite) => panic!("X - INFINITE is not well defined."),
294            (s, _) => {
295                *s = Distance::Infinite;
296            }
297        }
298    }
299}
300
301/// Create a network with arcs pointing to the shortest path to `source` and distances stored in
302/// the node's attributes.
303///
304/// # Example
305///
306/// Using the example from Wikipedia's [Dijkstra's Algorithm](https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm)
307/// the following network's shortest path tree is correctly generated.
308///
309/// ```raw
310///       ┌───┐  9    ┌─────┐
311///   ┌── │ 5 │ ───── │  0  │ ─┐
312///   │   └───┘       └─────┘  │
313///   │     │           │      │
314///   │     │           │ 7    │
315///   │     │           │      │
316///   │     │         ┌─────┐  │
317///   │ 2   │    ┌─── │  1  │  │ 14
318///   │     │    │    └─────┘  │
319///   │     │    │      │      │
320///   │     │    │      │ 10   │
321///   │     │    │      │      │
322///   │     │    │    ┌─────┐  │
323///   └─────┼────┼─── │  2  │ ─┘
324///         │    │    └─────┘
325///         │    │      │
326///         │    │ 15   │ 11
327///         │    │      │
328///         │    │    ┌─────┐
329///         │    └─── │  3  │
330///         │         └─────┘
331///         │           │
332///         │           │ 6
333///         │           │
334///         │   9     ┌─────┐
335///         └─────────│  4  │
336///                   └─────┘
337/// ```
338///
339/// ```rust
340/// use spokes::{Network, algorithms::{dijkstra_shortest_path, Distance}, ArcStorage};
341///
342/// let mut network: Network<usize, (), u16> = Network::new();
343/// network.add_nodes((0..6).map(|i| (i, ())));
344/// network.add_arcs([
345///     (0, 1, 7), (1, 0, 7),
346///     (0, 5, 14), (5, 0, 14),
347///     (0, 2, 9), (2, 0, 9),
348///     (1, 2, 10), (2, 1, 10),
349///     (1, 3, 15), (3, 1, 15),
350///     (2, 5, 2), (5, 2, 2),
351///     (2, 3, 11), (3, 2, 11),
352///     (3, 4, 6), (4, 3, 6),
353///     (4, 5, 9), (5, 4, 9),
354/// ]);
355///
356/// let shortest_path_tree = dijkstra_shortest_path(&network, 0);
357///
358/// assert_eq!(shortest_path_tree.node(&4), Some(&Distance::Finite(20)));
359///
360/// let mut expected_network: Network<usize, Distance<u16>, ()> = Network::new();
361///
362/// expected_network.add_nodes([
363///     (0, Distance::Finite(0)),
364///     (1, Distance::Finite(7)),
365///     (2, Distance::Finite(9)),
366///     (3, Distance::Finite(20)),
367///     (4, Distance::Finite(20)),
368///     (5, Distance::Finite(11)),
369/// ]);
370///
371/// expected_network.add_arcs([
372///     (1, 0),
373///     (2, 0),
374///     (3, 2),
375///     (5, 2),
376///     (4, 5),
377/// ]);
378///
379/// assert_eq!(shortest_path_tree, expected_network);
380/// ```
381pub fn dijkstra_shortest_path<I, NA, T, AS>(
382    network: &Network<I, NA, T, AS>,
383    source: I,
384) -> Network<I, Distance<T>, ()>
385where
386    AS: ArcStorage<I, T>,
387    I: Hash + Eq + Copy + std::fmt::Debug,
388    T: Ord + Zero + Copy + Debug + Unsigned,
389{
390    let mut pred = Network::with_capacity(network.n_nodes(), network.m_arcs());
391    pred.add_nodes(network.iter_nodes().map(|(n, _)| (*n, Distance::Infinite)));
392
393    *pred.node_entry(source).or_insert(Distance::Infinite) = Distance::Finite(T::zero());
394
395    let mut next_nodes: PriorityQueue<I, Reverse<T>> = PriorityQueue::new();
396    next_nodes.push(source, Reverse(T::zero()));
397
398    while let Some((position, Reverse(cost))) = next_nodes.pop() {
399        for arc in network.forward_arcs(&position) {
400            let alt = cost + arc.attributes;
401            let altd = Distance::Finite(cost + arc.attributes);
402            let head_dist = pred.node_entry(arc.head).or_insert(Distance::Infinite);
403
404            if *head_dist > altd {
405                *head_dist = altd;
406
407                // Replace predecessor arcs
408                let arcs_to_remove: Vec<(I, I)> = pred
409                    .forward_arcs(&arc.head)
410                    .map(|a| (a.tail, a.head))
411                    .collect();
412                pred.remove_arcs(arcs_to_remove);
413
414                pred.add_arc(ArcInfo::new(arc.head, arc.tail, ()));
415
416                if next_nodes
417                    .change_priority(&arc.head, Reverse(alt))
418                    .is_none()
419                {
420                    next_nodes.push(arc.head, Reverse(alt));
421                }
422            }
423        }
424    }
425
426    pred
427}
428
429/// Bellman Ford algorith for shortest paths.
430///
431/// See [Wikipedia](https://en.wikipedia.org/wiki/Bellman%E2%80%93Ford_algorithm) for details of the algorithm.
432/// The algorithm actually being used is the [Shortest Path Faster Algorithm](https://en.wikipedia.org/wiki/Shortest_Path_Faster_Algorithm) as it works on the same networks as the Bellman-Ford algorithm.
433///
434/// # Errors
435/// If the network contains a negative cycle, then a [`NegativeCycle`] error will be returned.
436///
437/// # Example
438/// ```rust
439/// use spokes::{Network, algorithms::{bellman_ford_shortest_path, Distance}, ArcStorage};
440///
441/// let mut network: Network<usize, (), u16> = Network::new();
442/// network.add_nodes((0..6).map(|i| (i, ())));
443/// network.add_arcs([
444///     (0, 1, 7), (1, 0, 7),
445///     (0, 5, 14), (5, 0, 14),
446///     (0, 2, 9), (2, 0, 9),
447///     (1, 2, 10), (2, 1, 10),
448///     (1, 3, 15), (3, 1, 15),
449///     (2, 5, 2), (5, 2, 2),
450///     (2, 3, 11), (3, 2, 11),
451///     (3, 4, 6), (4, 3, 6),
452///     (4, 5, 9), (5, 4, 9),
453/// ]);
454///
455/// let shortest_path_tree = bellman_ford_shortest_path(&network, 0).unwrap();
456///
457/// assert_eq!(shortest_path_tree.node(&4), Some(&Distance::Finite(20)));
458///
459/// let mut expected_network: Network<usize, Distance<u16>, ()> = Network::new();
460///
461/// expected_network.add_nodes([
462///     (0, Distance::Finite(0)),
463///     (1, Distance::Finite(7)),
464///     (2, Distance::Finite(9)),
465///     (3, Distance::Finite(20)),
466///     (4, Distance::Finite(20)),
467///     (5, Distance::Finite(11)),
468/// ]);
469///
470/// expected_network.add_arcs([
471///     (1, 0),
472///     (2, 0),
473///     (3, 2),
474///     (5, 2),
475///     (4, 5),
476/// ]);
477///
478/// assert_eq!(shortest_path_tree, expected_network);
479/// ```
480pub fn bellman_ford_shortest_path<I, NA, T, AS>(
481    network: &Network<I, NA, T, AS>,
482    source: I,
483) -> Result<Network<I, Distance<T>, ()>, NegativeCycle>
484where
485    AS: ArcStorage<I, T>,
486    I: Hash + Eq + Copy + std::fmt::Debug,
487    T: Ord + Zero + Copy + Debug,
488{
489    let n = network.n_nodes();
490    let mut pred = Network::with_capacity(network.n_nodes(), network.m_arcs());
491    pred.add_nodes(network.iter_nodes().map(|(n, _)| (*n, Distance::Infinite)));
492
493    *pred.node_entry(source).or_insert(Distance::Infinite) = Distance::Finite(T::zero());
494
495    // Queue for next node to visit
496    let mut queue = VecDeque::new();
497
498    // Hashset to track if a node is in the queue
499    let mut in_queue = HashSet::new();
500
501    // Visit counts
502    let mut visit_counts: HashMap<I, usize> = HashMap::new();
503
504    // Add the source to the
505    queue.push_back(source);
506    in_queue.insert(source);
507
508    while let Some(u) = queue.pop_front() {
509        in_queue.remove(&u);
510
511        if let Some(Distance::Finite(du)) = pred.node(&u).copied() {
512            for arc in network.forward_arcs(&u) {
513                let alt = Distance::Finite(du + arc.attributes);
514                let head_dist = pred.node_entry(arc.head).or_insert(Distance::Infinite);
515
516                if *head_dist > alt {
517                    *head_dist = alt;
518
519                    // Replace predecessor arcs
520                    let arcs_to_remove: Vec<(I, I)> = pred
521                        .forward_arcs(&arc.head)
522                        .map(|a| (a.tail, a.head))
523                        .collect();
524                    pred.remove_arcs(arcs_to_remove);
525                    pred.add_arc(ArcInfo::new(arc.head, arc.tail, ()));
526
527                    if !in_queue.contains(&arc.head) {
528                        queue.push_back(arc.head);
529                        in_queue.insert(arc.head);
530
531                        let mut visits = *visit_counts.entry(arc.head).or_default();
532                        visits += 1;
533
534                        if visits > n {
535                            return Err(NegativeCycle);
536                        }
537                    }
538                }
539            }
540        }
541    }
542
543    Ok(pred)
544}
545
546/// Compute the all-pairs shortest paths for a network.
547///
548/// # Example
549/// ```rust
550/// use spokes::{Network, algorithms::{floyd_worshall_all_pairs_minimum_path, Distance}, ArcStorage};
551///
552/// let mut network: Network<usize, (), u16> = Network::new();
553/// network.add_nodes((0..6).map(|i| (i, ())));
554/// network.add_arcs([
555///     (0, 1, 7), (1, 0, 7),
556///     (0, 5, 14), (5, 0, 14),
557///     (0, 2, 9), (2, 0, 9),
558///     (1, 2, 10), (2, 1, 10),
559///     (1, 3, 15), (3, 1, 15),
560///     (2, 5, 2), (5, 2, 2),
561///     (2, 3, 11), (3, 2, 11),
562///     (3, 4, 6), (4, 3, 6),
563///     (4, 5, 9), (5, 4, 9),
564/// ]);
565///
566/// let shortest_path_trees = floyd_worshall_all_pairs_minimum_path(&network);
567///
568/// assert_eq!(shortest_path_trees[&0].node(&4), Some(&Distance::Finite(20)));
569///
570/// let mut expected_network: Network<usize, Distance<u16>, ()> = Network::new();
571///
572/// expected_network.add_nodes([
573///     (0, Distance::Finite(0)),
574///     (1, Distance::Finite(7)),
575///     (2, Distance::Finite(9)),
576///     (3, Distance::Finite(20)),
577///     (4, Distance::Finite(20)),
578///     (5, Distance::Finite(11)),
579/// ]);
580///
581/// expected_network.add_arcs([
582///     (1, 0),
583///     (2, 0),
584///     (3, 2),
585///     (5, 2),
586///     (4, 5),
587/// ]);
588///
589/// assert_eq!(shortest_path_trees[&0], expected_network);
590/// ```
591pub fn floyd_worshall_all_pairs_minimum_path<I, NA, T, AS>(
592    network: &Network<I, NA, T, AS>,
593) -> HashMap<I, Network<I, Distance<T>, ()>>
594where
595    AS: ArcStorage<I, T>,
596    I: Hash + Eq + Copy + std::fmt::Debug,
597    T: Ord + Zero + Copy + Debug,
598{
599    let mut pred = Network::with_capacity(network.n_nodes(), network.n_nodes());
600    pred.add_nodes(network.iter_nodes().map(|(n, _)| (*n, Distance::Infinite)));
601
602    let mut preds: HashMap<I, Network<I, Distance<T>, ()>> = network
603        .iter_nodes()
604        .map(|(&id, _)| {
605            let mut this_pred = pred.clone();
606            *this_pred.node_entry(id).or_insert(Distance::Infinite) = Distance::Finite(T::zero());
607            (id, this_pred)
608        })
609        .collect();
610
611    for arc in network.arc_iter() {
612        *preds
613            .get_mut(&arc.tail)
614            .expect("The key should map to a network")
615            .node_mut(&arc.head)
616            .expect("Node should exist") = Distance::Finite(arc.attributes);
617
618        preds
619            .get_mut(&arc.tail)
620            .expect("The key should map to a network")
621            .add_arc((arc.head, arc.tail));
622    }
623
624    for k in network.iter_nodes().map(|(i, _)| i) {
625        for i in network.iter_nodes().map(|(i, _)| i) {
626            for j in network.iter_nodes().map(|(i, _)| i) {
627                let dij = preds[i].node(j).expect("Node should exist");
628                let alt = preds[i].node(k).expect("Node should exist")
629                    + preds[k].node(j).expect("Node should exist");
630
631                if dij > &alt {
632                    *preds
633                        .get_mut(i)
634                        .expect("Key should map to network")
635                        .node_mut(j)
636                        .expect("Node should exist") = alt;
637
638                    // remove all existing arcs for predecessor graph for this node
639                    let arcs_to_remove: Vec<(I, I)> =
640                        preds[i].forward_arcs(j).map(|a| (a.tail, a.head)).collect();
641                    preds
642                        .get_mut(i)
643                        .expect("Key should map to network")
644                        .remove_arcs(arcs_to_remove);
645
646                    let arc_to_add = preds[k]
647                        .forward_arcs(j)
648                        .next()
649                        .expect("One arc from j should exist")
650                        .clone();
651
652                    preds
653                        .get_mut(i)
654                        .expect("Key should map to network")
655                        .add_arc(arc_to_add);
656                }
657            }
658        }
659    }
660
661    preds
662}
663
664/// Determine the distances to nodes for arcs of unit length and annotate a network with those
665/// values.
666///
667/// # Example
668/// ```rust, ignore
669/// use spokes::{Network, algorithms::{distance_unit_arcs, Distance}, ArcStorage};
670///
671/// let mut network: Network<usize, (), usize> = Network::new();
672///
673/// network.add_nodes((0..=11).map(|i| (i, ())));
674///
675/// network.add_arcs([
676///     (0, 1, 2),
677///     (0, 2, 2),
678///     (0, 3, 2),
679///     (0, 4, 2),
680///     (0, 5, 2),
681///     (1, 6, 1),
682///     (2, 7, 1),
683///     (3, 8, 1),
684///     (4, 9, 1),
685///     (5, 10, 1),
686///     (6, 11, 2),
687///     (7, 11, 2),
688///     (8, 11, 2),
689///     (9, 11, 2),
690///     (10, 11, 2),
691/// ]);
692///
693/// let distances = distance_unit_arcs(&network, &11);
694///
695/// assert_eq!(distances.node(&11), Some(&Distance::Finite(0)));
696/// assert_eq!(distances.node(&6), Some(&Distance::Finite(1)));
697/// assert_eq!(distances.node(&7), Some(&Distance::Finite(1)));
698/// assert_eq!(distances.node(&8), Some(&Distance::Finite(1)));
699/// assert_eq!(distances.node(&9), Some(&Distance::Finite(1)));
700/// assert_eq!(distances.node(&10), Some(&Distance::Finite(1)));
701/// assert_eq!(distances.node(&1), Some(&Distance::Finite(2)));
702/// assert_eq!(distances.node(&2), Some(&Distance::Finite(2)));
703/// assert_eq!(distances.node(&3), Some(&Distance::Finite(2)));
704/// assert_eq!(distances.node(&4), Some(&Distance::Finite(2)));
705/// assert_eq!(distances.node(&5), Some(&Distance::Finite(2)));
706/// assert_eq!(distances.node(&0), Some(&Distance::Finite(3)));
707/// ```
708fn distance_unit_arcs<I, NA, T, AS>(
709    network: &Network<I, NA, T, AS>,
710    root_id: &I,
711) -> Network<I, Distance<usize>, T, AS>
712where
713    AS: ArcStorage<I, T> + Clone,
714    I: Hash + Eq + Copy + Debug,
715    T: Debug,
716{
717    let mut distances: HashMap<I, Distance<usize>> = HashMap::with_capacity(network.n_nodes());
718    distances.insert(*root_id, Distance::Finite(0));
719
720    // TODO: This is currently done in O(m+n) but could just be O(n) instead if we track the node
721    // that precedes each node in the bfs
722
723    for node in network.bfs(root_id, Direction::Reverse).skip(1) {
724        let distance_to_node = network
725            .forward_arcs(node)
726            .map(|arc| distances.get(&arc.head).unwrap_or(&Distance::Infinite))
727            .min()
728            .copied()
729            .unwrap_or(Distance::Infinite)
730            + 1;
731        distances.insert(*node, distance_to_node);
732    }
733
734    let nodes = network
735        .nodes()
736        .iter()
737        .map(|(id, _)| {
738            (
739                *id,
740                distances.get(id).copied().unwrap_or(Distance::Infinite),
741            )
742        })
743        .collect();
744
745    Network::from_parts(nodes, network.arc_store().clone())
746}
747
748/// Determine the maximal flow from the source to sink
749///
750/// # Example
751/// ```rust, ignore
752/// use spokes::{Network, algorithms::{shortest_augmenting_path_flow, Distance}, ArcStorage};
753///
754/// let mut network: Network<usize, (), usize> = Network::new();
755///
756/// network.add_nodes((0..=11).map(|i| (i, ())));
757///
758/// network.add_arcs([
759///     (0, 1, 2),
760///     (0, 2, 2),
761///     (0, 3, 2),
762///     (0, 4, 2),
763///     (0, 5, 2),
764///     (1, 6, 1),
765///     (2, 7, 1),
766///     (3, 8, 1),
767///     (4, 9, 1),
768///     (5, 10, 1),
769///     (6, 11, 2),
770///     (7, 11, 2),
771///     (8, 11, 2),
772///     (9, 11, 2),
773///     (10, 11, 2),
774/// ]);
775///
776/// let max_flow_residual = shortest_augmenting_path_flow(&network, &0, &11);
777///
778/// let max_flow: usize = max_flow_residual.reverse_arcs(&0).map(|a| a.attributes).sum();
779/// dbg!(max_flow);
780/// dbg!(max_flow_residual);
781///
782/// panic!();
783/// ```
784/// # Panics
785/// TODO: Shouldn't happen
786///
787///
788/// XXX: Currently broken! Don't use.
789fn shortest_augmenting_path_flow<I, NA, T, AS>(
790    network: &Network<I, NA, T, AS>,
791    source_id: &I,
792    sink_id: &I,
793) -> Network<I, Distance<usize>, T, AS>
794where
795    AS: ArcStorage<I, T> + Clone + std::fmt::Debug,
796    I: Hash + Eq + Copy + std::fmt::Debug + std::fmt::Display,
797    T: Ord + Unsigned + Zero + Copy + Debug + std::ops::SubAssign + std::ops::AddAssign,
798    NA: Clone,
799{
800    eprintln!("Starting");
801    std::io::stderr().flush().unwrap();
802    let mut flow: Network<I, Distance<usize>, T, AS> = distance_unit_arcs(network, sink_id);
803    let n = Distance::Finite(flow.n_nodes());
804    let mut cur: I = *source_id;
805    let mut pred: HashMap<I, I> = HashMap::new();
806
807    dbg!(&flow);
808
809    while flow
810        .node(source_id)
811        .expect("source should be in the network")
812        < &n
813    {
814        if let Some(arc) = admissible_arc(&cur, &flow).cloned() {
815            // Search for an admissible arc from the current position and advance
816            println!("Advancing from {} to {}", cur, arc.head);
817            pred.insert(arc.head, cur);
818            cur = arc.head;
819
820            if &cur == sink_id {
821                let path = pred_to_path(&pred, sink_id);
822                eprintln!("Augmenting along {:?}", path);
823                augment(&path, network, &mut flow);
824            }
825        } else {
826            // Retreat
827            eprintln!("Retreating from {}", cur);
828            let new_distance = flow
829                .forward_arcs(&cur)
830                .filter_map(|arc| {
831                    if arc.attributes > T::zero() {
832                        Some(*flow.node(&arc.head).expect("Should be present") + 1)
833                    } else {
834                        None
835                    }
836                })
837                .min()
838                .expect("There should be a minimum arc");
839            *flow.node_entry(cur).or_insert(Distance::Infinite) = new_distance;
840
841            if &cur != source_id {
842                cur = *pred.get(&cur).expect("previous map should exist");
843            }
844        }
845    }
846
847    flow
848}
849
850#[inline]
851fn pred_to_path<'a, I: Hash + Eq>(pred: &'a HashMap<I, I>, start: &'a I) -> Vec<(&'a I, &'a I)> {
852    let mut cur = start;
853    let mut path = Vec::new();
854
855    while let Some(next) = pred.get(cur) {
856        path.push((cur, next));
857        cur = next;
858    }
859
860    path.reverse();
861    path
862}
863
864#[inline]
865fn augment<I, T, NA, AS>(
866    path: &[(&I, &I)],
867    network: &Network<I, NA, T, AS>,
868    flow: &mut Network<I, Distance<usize>, T, AS>,
869) where
870    AS: ArcStorage<I, T>,
871    I: Hash + Eq + Copy,
872    T: Ord + Zero + Copy + std::ops::SubAssign + std::ops::AddAssign,
873{
874    let delta = path
875        .iter()
876        .map(|(&tail, &head)| network.arc(tail, head).expect("Arc should exist"))
877        .min()
878        .expect("Minimum should exist");
879
880    for (&tail, &head) in path {
881        let forward_arc = if let Some(arc) = flow.arc_mut(tail, head) {
882            arc
883        } else {
884            flow.add_arc(ArcInfo::new(tail, head, T::zero()));
885            flow.arc_mut(tail, head)
886                .expect("Should exist, as it was just inserted")
887        };
888        *forward_arc -= *delta;
889
890        let reverse_arc = if let Some(arc) = flow.arc_mut(head, tail) {
891            arc
892        } else {
893            flow.add_arc(ArcInfo::new(head, tail, T::zero()));
894            flow.arc_mut(head, tail)
895                .expect("Should exist, as it was just inserted")
896        };
897        *reverse_arc += *delta;
898    }
899}
900
901#[inline]
902fn next_admissible_arc<'a, I, NA, AA, AS, F, T>(
903    network: &'a Network<I, NA, AA, AS>,
904    f: F,
905    node: &'a I,
906) -> Option<&'a ArcInfo<I, AA>>
907where
908    AS: ArcStorage<I, AA>,
909    I: Hash + Eq + Copy,
910    F: Fn(&NA) -> Distance<T>,
911    T: std::ops::Add<Output = T> + num::One + PartialEq + Copy,
912{
913    let this_distance = f(network.node(node).expect("Should exist"));
914    for arc in network.forward_arcs(node) {
915        let head_dist = f(network.node(&arc.head).expect("Should exist"));
916
917        if this_distance == head_dist + T::one() {
918            return Some(arc);
919        }
920    }
921    None
922}
923
924#[inline]
925fn admissible_arc<'a, I, T, AS>(
926    node: &'a I,
927    network: &'a Network<I, Distance<usize>, T, AS>,
928) -> Option<&'a ArcInfo<I, T>>
929where
930    AS: ArcStorage<I, T>,
931    I: Hash + Eq + Copy,
932{
933    let this_distance = network.node(node).unwrap_or(&Distance::Infinite);
934    for arc in network.forward_arcs(node) {
935        let head_dist = network.node(node).unwrap_or(&Distance::Infinite);
936
937        if *this_distance == *head_dist + 1 {
938            return Some(arc);
939        }
940    }
941    None
942}
943
944#[cfg(test)]
945mod tests {
946    mod distance {
947        use super::super::Distance;
948
949        #[test]
950        fn partial_ord() {
951            assert!(Distance::Infinite > Distance::Finite(1_usize));
952        }
953
954        #[test]
955        fn add() {
956            assert_eq!(Distance::Finite(10) + 1, Distance::Finite(11));
957            assert_eq!(
958                Distance::Finite(10) + Distance::Finite(1),
959                Distance::Finite(11)
960            );
961            assert_eq!(
962                Distance::<usize>::Infinite + Distance::Finite(1),
963                Distance::Infinite
964            );
965        }
966    }
967
968    mod distance_unit_arcs {
969        use crate::{
970            algorithms::{distance_unit_arcs, Distance},
971            ArcStorage, Network,
972        };
973
974        #[test]
975        #[ignore]
976        fn cycle() {
977            let mut network: Network<usize, (), ()> = Network::new();
978
979            network.add_nodes((0..=3).map(|i| (i, ())));
980
981            network.add_arcs([(0, 1), (1, 2), (2, 3), (3, 0)]);
982
983            let distances = distance_unit_arcs(&network, &0);
984
985            assert_eq!(distances.node(&0), Some(&Distance::Finite(0)));
986            assert_eq!(distances.node(&1), Some(&Distance::Finite(1)));
987            assert_eq!(distances.node(&2), Some(&Distance::Finite(2)));
988            assert_eq!(distances.node(&3), Some(&Distance::Finite(3)));
989        }
990    }
991
992    mod shortest_augmenting_path_flow {
993        use crate::{algorithms::shortest_augmenting_path_flow, ArcStorage, Network};
994
995        #[test]
996        #[ignore]
997        fn simple() {
998            let mut network: Network<usize, (), usize> = Network::new();
999
1000            network.add_nodes((0..=11).map(|i| (i, ())));
1001
1002            network.add_arcs([
1003                (0, 1, 2),
1004                (0, 2, 2),
1005                (0, 3, 2),
1006                (0, 4, 2),
1007                (0, 5, 2),
1008                (1, 6, 1),
1009                (2, 7, 1),
1010                (3, 8, 1),
1011                (4, 9, 1),
1012                (5, 10, 1),
1013                (6, 11, 2),
1014                (7, 11, 2),
1015                (8, 11, 2),
1016                (9, 11, 2),
1017                (10, 11, 2),
1018            ]);
1019
1020            let max_flow_residual = shortest_augmenting_path_flow(&network, &0, &11);
1021
1022            let max_flow: usize = max_flow_residual
1023                .reverse_arcs(&0)
1024                .map(|a| a.attributes)
1025                .sum();
1026            dbg!(max_flow);
1027            dbg!(max_flow_residual);
1028
1029            panic!();
1030        }
1031    }
1032}