use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
use super::location::LocationId;
#[derive(Debug, Clone, Copy)]
pub struct EdgeCost {
pub time_per_gb: f64,
pub priority: u32,
}
impl EdgeCost {
pub fn new(time_per_gb: f64, priority: u32) -> Result<Self, super::error::DomainError> {
if !time_per_gb.is_finite() || time_per_gb < 0.0 {
return Err(super::error::DomainError::Validation {
field: "time_per_gb".to_string(),
reason: format!("must be finite and non-negative, got {time_per_gb}"),
});
}
Ok(Self {
time_per_gb,
priority,
})
}
fn scalar(&self) -> f64 {
self.time_per_gb + (self.priority as f64) * 0.001
}
}
impl Default for EdgeCost {
fn default() -> Self {
Self {
time_per_gb: 1.0,
priority: 100,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct RouteGraph {
adj: HashMap<LocationId, HashMap<LocationId, EdgeCost>>,
}
impl RouteGraph {
pub fn new() -> Self {
Self::default()
}
#[cfg(test)]
pub fn add(&mut self, src: LocationId, dest: LocationId) {
self.add_with_cost(src, dest, EdgeCost::default());
}
pub fn add_with_cost(&mut self, src: LocationId, dest: LocationId, cost: EdgeCost) {
if src != dest {
self.adj.entry(src).or_default().insert(dest, cost);
}
}
#[cfg(test)]
pub fn remove(&mut self, src: &LocationId, dest: &LocationId) {
if let Some(dests) = self.adj.get_mut(src) {
dests.remove(dest);
if dests.is_empty() {
self.adj.remove(src);
}
}
}
#[cfg(test)]
pub fn has(&self, src: &LocationId, dest: &LocationId) -> bool {
self.adj
.get(src)
.is_some_and(|dests| dests.contains_key(dest))
}
#[cfg(test)]
#[allow(dead_code)]
pub fn edge_cost(&self, src: &LocationId, dest: &LocationId) -> Option<&EdgeCost> {
self.adj.get(src).and_then(|dests| dests.get(dest))
}
pub fn direct_from(&self, origin: &LocationId) -> impl Iterator<Item = &LocationId> {
self.adj
.get(origin)
.into_iter()
.flat_map(|dests| dests.keys())
}
#[cfg(test)]
pub fn reachable_from(&self, origin: &LocationId) -> HashSet<LocationId> {
self.bfs_from(origin).into_iter().collect()
}
pub fn all_destinations(&self) -> HashSet<LocationId> {
self.adj
.values()
.flat_map(|dests| dests.keys().cloned())
.collect()
}
pub fn destinations_ordered_from(&self, origin: &LocationId) -> Vec<LocationId> {
self.bfs_from(origin)
}
pub fn optimal_tree(
&self,
origin: &LocationId,
required_dests: &HashSet<LocationId>,
) -> Vec<(LocationId, LocationId)> {
if required_dests.is_empty() {
return Vec::new();
}
let (dist, prev) = self.dijkstra(origin);
let mut tree_edges: HashSet<(LocationId, LocationId)> = HashSet::new();
for dest in required_dests {
let mut current = dest.clone();
while let Some(predecessor) = prev.get(¤t) {
tree_edges.insert((predecessor.clone(), current.clone()));
current = predecessor.clone();
}
}
let mut result = Vec::with_capacity(tree_edges.len());
let mut visited = HashSet::new();
visited.insert(origin.clone());
let mut queue = VecDeque::new();
queue.push_back(origin.clone());
while let Some(node) = queue.pop_front() {
let outgoing: Vec<_> = tree_edges
.iter()
.filter(|(src, _)| src == &node)
.cloned()
.collect();
let mut outgoing = outgoing;
outgoing.sort_by(|(_, d1), (_, d2)| {
let c1 = dist.get(d1).copied().unwrap_or(f64::INFINITY);
let c2 = dist.get(d2).copied().unwrap_or(f64::INFINITY);
c1.partial_cmp(&c2).unwrap_or(Ordering::Equal)
});
for (src, dest) in outgoing {
if visited.insert(dest.clone()) {
result.push((src, dest.clone()));
queue.push_back(dest);
}
}
}
result
}
pub fn optimal_tree_multi_source(
&self,
sources: &HashSet<LocationId>,
required_dests: &HashSet<LocationId>,
) -> Vec<(LocationId, LocationId)> {
if sources.is_empty() || required_dests.is_empty() {
return Vec::new();
}
let actual_dests: HashSet<LocationId> = required_dests
.iter()
.filter(|d| !sources.contains(d))
.cloned()
.collect();
if actual_dests.is_empty() {
return Vec::new();
}
let (dist, prev) = self.dijkstra_multi_source(sources);
let mut tree_edges: HashSet<(LocationId, LocationId)> = HashSet::new();
for dest in &actual_dests {
let mut current = dest.clone();
while let Some(predecessor) = prev.get(¤t) {
tree_edges.insert((predecessor.clone(), current.clone()));
current = predecessor.clone();
}
}
let mut result = Vec::with_capacity(tree_edges.len());
let mut visited: HashSet<LocationId> = sources.clone();
let mut queue: VecDeque<LocationId> = sources.iter().cloned().collect();
while let Some(node) = queue.pop_front() {
let mut outgoing: Vec<_> = tree_edges
.iter()
.filter(|(src, _)| src == &node)
.cloned()
.collect();
outgoing.sort_by(|(_, d1), (_, d2)| {
let c1 = dist.get(d1).copied().unwrap_or(f64::INFINITY);
let c2 = dist.get(d2).copied().unwrap_or(f64::INFINITY);
c1.partial_cmp(&c2).unwrap_or(Ordering::Equal)
});
for (src, dest) in outgoing {
if visited.insert(dest.clone()) {
result.push((src, dest.clone()));
queue.push_back(dest);
}
}
}
result
}
fn dijkstra(
&self,
origin: &LocationId,
) -> (HashMap<LocationId, f64>, HashMap<LocationId, LocationId>) {
let sources = HashSet::from([origin.clone()]);
self.dijkstra_multi_source(&sources)
}
fn dijkstra_multi_source(
&self,
sources: &HashSet<LocationId>,
) -> (HashMap<LocationId, f64>, HashMap<LocationId, LocationId>) {
let mut dist: HashMap<LocationId, f64> = HashMap::new();
let mut prev: HashMap<LocationId, LocationId> = HashMap::new();
let mut heap = BinaryHeap::new();
for source in sources {
dist.insert(source.clone(), 0.0);
heap.push(DijkstraEntry {
cost: 0.0,
node: source.clone(),
});
}
while let Some(DijkstraEntry { cost, node }) = heap.pop() {
if let Some(&best) = dist.get(&node) {
if cost > best {
continue;
}
}
if let Some(neighbors) = self.adj.get(&node) {
for (next, edge_cost) in neighbors {
let next_cost = cost + edge_cost.scalar();
let is_better = dist
.get(next)
.is_none_or(|¤t_best| next_cost < current_best);
if is_better {
dist.insert(next.clone(), next_cost);
prev.insert(next.clone(), node.clone());
heap.push(DijkstraEntry {
cost: next_cost,
node: next.clone(),
});
}
}
}
}
(dist, prev)
}
fn bfs_from(&self, origin: &LocationId) -> Vec<LocationId> {
let mut result = Vec::new();
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
for d in self.direct_from(origin) {
if visited.insert(d.clone()) {
queue.push_back(d.clone());
result.push(d.clone());
}
}
while let Some(current) = queue.pop_front() {
for d in self.direct_from(¤t) {
if d != origin && visited.insert(d.clone()) {
queue.push_back(d.clone());
result.push(d.clone());
}
}
}
result
}
pub fn all_edges(&self) -> Vec<(LocationId, LocationId)> {
let mut edges = Vec::new();
for (src, dests) in &self.adj {
for dest in dests.keys() {
edges.push((src.clone(), dest.clone()));
}
}
edges
}
#[cfg(test)]
pub fn edge_count(&self) -> usize {
self.adj.values().map(|dests| dests.len()).sum()
}
}
impl super::plan::Topology for RouteGraph {
fn optimal_tree(
&self,
origin: &LocationId,
required_dests: &HashSet<LocationId>,
) -> Vec<(LocationId, LocationId)> {
self.optimal_tree(origin, required_dests)
}
fn optimal_tree_multi_source(
&self,
sources: &HashSet<LocationId>,
required_dests: &HashSet<LocationId>,
) -> Vec<(LocationId, LocationId)> {
self.optimal_tree_multi_source(sources, required_dests)
}
}
#[derive(Debug, Clone)]
struct DijkstraEntry {
cost: f64,
node: LocationId,
}
impl PartialEq for DijkstraEntry {
fn eq(&self, other: &Self) -> bool {
self.cost == other.cost && self.node == other.node
}
}
impl Eq for DijkstraEntry {}
impl PartialOrd for DijkstraEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for DijkstraEntry {
fn cmp(&self, other: &Self) -> Ordering {
other
.cost
.partial_cmp(&self.cost)
.unwrap_or(Ordering::Equal)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn loc(s: &str) -> LocationId {
LocationId::new(s).unwrap()
}
#[test]
fn add_and_has() {
let mut g = RouteGraph::new();
g.add(loc("local"), loc("cloud"));
assert!(g.has(&loc("local"), &loc("cloud")));
assert!(!g.has(&loc("cloud"), &loc("local")));
}
#[test]
fn self_loop_ignored() {
let mut g = RouteGraph::new();
g.add(loc("local"), loc("local"));
assert_eq!(g.edge_count(), 0);
}
#[test]
fn remove_edge() {
let mut g = RouteGraph::new();
g.add(loc("local"), loc("cloud"));
g.remove(&loc("local"), &loc("cloud"));
assert!(!g.has(&loc("local"), &loc("cloud")));
assert_eq!(g.edge_count(), 0);
}
#[test]
fn direct_from() {
let mut g = RouteGraph::new();
g.add(loc("local"), loc("cloud"));
g.add(loc("local"), loc("pod"));
g.add(loc("cloud"), loc("local"));
let direct: HashSet<LocationId> = g.direct_from(&loc("local")).cloned().collect();
assert_eq!(direct.len(), 2);
assert!(direct.contains(&loc("cloud")));
assert!(direct.contains(&loc("pod")));
}
#[test]
fn reachable_direct_only() {
let mut g = RouteGraph::new();
g.add(loc("local"), loc("cloud"));
let r = g.reachable_from(&loc("local"));
assert_eq!(r, HashSet::from([loc("cloud")]));
}
#[test]
fn reachable_multi_hop() {
let mut g = RouteGraph::new();
g.add(loc("local"), loc("cloud"));
g.add(loc("cloud"), loc("pod"));
let r = g.reachable_from(&loc("local"));
assert_eq!(r, HashSet::from([loc("cloud"), loc("pod")]));
}
#[test]
fn reachable_excludes_origin() {
let mut g = RouteGraph::new();
g.add(loc("local"), loc("cloud"));
g.add(loc("cloud"), loc("local"));
let r = g.reachable_from(&loc("local"));
assert_eq!(r, HashSet::from([loc("cloud")]));
assert!(!r.contains(&loc("local")));
}
#[test]
fn reachable_isolated_node() {
let mut g = RouteGraph::new();
g.add(loc("local"), loc("cloud"));
g.add(loc("pod"), loc("cloud"));
let r = g.reachable_from(&loc("local"));
assert_eq!(r, HashSet::from([loc("cloud")]));
assert!(!r.contains(&loc("pod")));
}
#[test]
fn reachable_diamond() {
let mut g = RouteGraph::new();
g.add(loc("local"), loc("cloud"));
g.add(loc("local"), loc("nas"));
g.add(loc("cloud"), loc("pod"));
g.add(loc("nas"), loc("pod"));
let r = g.reachable_from(&loc("local"));
assert_eq!(r, HashSet::from([loc("cloud"), loc("nas"), loc("pod")]));
}
#[test]
fn reachable_empty_graph() {
let g = RouteGraph::new();
let r = g.reachable_from(&loc("local"));
assert!(r.is_empty());
}
#[test]
fn all_destinations() {
let mut g = RouteGraph::new();
g.add(loc("local"), loc("cloud"));
g.add(loc("pod"), loc("cloud"));
let dests = g.all_destinations();
assert_eq!(dests, HashSet::from([loc("cloud")]));
}
#[test]
fn destinations_ordered_chain() {
let mut g = RouteGraph::new();
g.add(loc("local"), loc("cloud"));
g.add(loc("cloud"), loc("pod"));
let ordered = g.destinations_ordered_from(&loc("local"));
assert_eq!(ordered.len(), 2);
assert_eq!(ordered[0], loc("cloud"));
assert_eq!(ordered[1], loc("pod"));
}
#[test]
fn destinations_ordered_diamond() {
let mut g = RouteGraph::new();
g.add(loc("local"), loc("cloud"));
g.add(loc("local"), loc("nas"));
g.add(loc("cloud"), loc("pod"));
g.add(loc("nas"), loc("pod"));
let ordered = g.destinations_ordered_from(&loc("local"));
assert_eq!(ordered.len(), 3);
assert_eq!(ordered[2], loc("pod"));
assert!(ordered[..2].contains(&loc("cloud")));
assert!(ordered[..2].contains(&loc("nas")));
}
#[test]
fn destinations_ordered_empty_graph() {
let g = RouteGraph::new();
let ordered = g.destinations_ordered_from(&loc("local"));
assert!(ordered.is_empty());
}
#[test]
fn optimal_tree_chain_prefers_single_path() {
let mut g = RouteGraph::new();
g.add_with_cost(loc("local"), loc("pod"), EdgeCost::new(1.0, 10).unwrap());
g.add_with_cost(loc("pod"), loc("cloud"), EdgeCost::new(2.0, 10).unwrap());
g.add_with_cost(loc("local"), loc("cloud"), EdgeCost::new(10.0, 10).unwrap());
let required = HashSet::from([loc("pod"), loc("cloud")]);
let tree = g.optimal_tree(&loc("local"), &required);
assert_eq!(tree.len(), 2);
assert_eq!(tree[0], (loc("local"), loc("pod")));
assert_eq!(tree[1], (loc("pod"), loc("cloud")));
}
#[test]
fn optimal_tree_direct_cheaper() {
let mut g = RouteGraph::new();
g.add_with_cost(loc("local"), loc("pod"), EdgeCost::new(10.0, 10).unwrap());
g.add_with_cost(loc("pod"), loc("cloud"), EdgeCost::new(10.0, 10).unwrap());
g.add_with_cost(loc("local"), loc("cloud"), EdgeCost::new(1.0, 10).unwrap());
let required = HashSet::from([loc("pod"), loc("cloud")]);
let tree = g.optimal_tree(&loc("local"), &required);
assert_eq!(tree.len(), 2);
let edges: HashSet<_> = tree.into_iter().collect();
assert!(edges.contains(&(loc("local"), loc("pod"))));
assert!(edges.contains(&(loc("local"), loc("cloud"))));
}
#[test]
fn optimal_tree_single_dest() {
let mut g = RouteGraph::new();
g.add_with_cost(loc("local"), loc("cloud"), EdgeCost::new(5.0, 10).unwrap());
let required = HashSet::from([loc("cloud")]);
let tree = g.optimal_tree(&loc("local"), &required);
assert_eq!(tree.len(), 1);
assert_eq!(tree[0], (loc("local"), loc("cloud")));
}
#[test]
fn optimal_tree_empty_dests() {
let mut g = RouteGraph::new();
g.add(loc("local"), loc("cloud"));
let tree = g.optimal_tree(&loc("local"), &HashSet::new());
assert!(tree.is_empty());
}
#[test]
fn optimal_tree_unreachable_dest_skipped() {
let mut g = RouteGraph::new();
g.add(loc("local"), loc("cloud"));
let required = HashSet::from([loc("cloud"), loc("pod")]);
let tree = g.optimal_tree(&loc("local"), &required);
assert_eq!(tree.len(), 1);
assert_eq!(tree[0], (loc("local"), loc("cloud")));
}
#[test]
fn optimal_tree_dependency_order() {
let mut g = RouteGraph::new();
g.add_with_cost(loc("local"), loc("pod"), EdgeCost::new(1.0, 10).unwrap());
g.add_with_cost(loc("pod"), loc("cloud"), EdgeCost::new(1.0, 10).unwrap());
let required = HashSet::from([loc("pod"), loc("cloud")]);
let tree = g.optimal_tree(&loc("local"), &required);
assert_eq!(tree.len(), 2);
assert_eq!(tree[0], (loc("local"), loc("pod")));
assert_eq!(tree[1], (loc("pod"), loc("cloud")));
}
#[test]
fn optimal_tree_diamond_deduplicates() {
let mut g = RouteGraph::new();
g.add_with_cost(loc("local"), loc("cloud"), EdgeCost::new(1.0, 10).unwrap());
g.add_with_cost(loc("local"), loc("nas"), EdgeCost::new(1.0, 10).unwrap());
g.add_with_cost(loc("cloud"), loc("pod"), EdgeCost::new(1.0, 10).unwrap());
g.add_with_cost(loc("nas"), loc("pod"), EdgeCost::new(5.0, 10).unwrap());
let required = HashSet::from([loc("cloud"), loc("nas"), loc("pod")]);
let tree = g.optimal_tree(&loc("local"), &required);
assert_eq!(tree.len(), 3);
let edges: HashSet<_> = tree.into_iter().collect();
assert!(edges.contains(&(loc("local"), loc("cloud"))));
assert!(edges.contains(&(loc("local"), loc("nas"))));
assert!(edges.contains(&(loc("cloud"), loc("pod"))));
assert!(!edges.contains(&(loc("nas"), loc("pod"))));
}
#[test]
fn multi_source_picks_cheaper_relay() {
let mut g = RouteGraph::new();
g.add_with_cost(loc("local"), loc("pod"), EdgeCost::new(1.0, 10).unwrap());
g.add_with_cost(loc("pod"), loc("cloud"), EdgeCost::new(2.0, 10).unwrap());
g.add_with_cost(loc("local"), loc("cloud"), EdgeCost::new(5.0, 10).unwrap());
g.add_with_cost(loc("cloud"), loc("local"), EdgeCost::new(5.0, 10).unwrap());
g.add_with_cost(loc("cloud"), loc("pod"), EdgeCost::new(2.0, 10).unwrap());
let sources = HashSet::from([loc("local"), loc("pod")]);
let targets = HashSet::from([loc("cloud")]);
let tree = g.optimal_tree_multi_source(&sources, &targets);
assert_eq!(tree.len(), 1);
assert_eq!(tree[0], (loc("pod"), loc("cloud")));
}
#[test]
fn multi_source_single_source_fallback() {
let mut g = RouteGraph::new();
g.add_with_cost(loc("local"), loc("pod"), EdgeCost::new(1.0, 10).unwrap());
g.add_with_cost(loc("pod"), loc("cloud"), EdgeCost::new(2.0, 10).unwrap());
let sources = HashSet::from([loc("local")]);
let targets = HashSet::from([loc("pod"), loc("cloud")]);
let tree = g.optimal_tree_multi_source(&sources, &targets);
assert_eq!(tree.len(), 2);
assert_eq!(tree[0], (loc("local"), loc("pod")));
assert_eq!(tree[1], (loc("pod"), loc("cloud")));
}
#[test]
fn multi_source_empty_sources() {
let mut g = RouteGraph::new();
g.add(loc("local"), loc("cloud"));
let tree = g.optimal_tree_multi_source(&HashSet::new(), &HashSet::from([loc("cloud")]));
assert!(tree.is_empty());
}
#[test]
fn multi_source_empty_targets() {
let mut g = RouteGraph::new();
g.add(loc("local"), loc("cloud"));
let tree = g.optimal_tree_multi_source(&HashSet::from([loc("local")]), &HashSet::new());
assert!(tree.is_empty());
}
#[test]
fn multi_source_target_already_in_sources() {
let mut g = RouteGraph::new();
g.add_with_cost(loc("local"), loc("cloud"), EdgeCost::new(5.0, 10).unwrap());
let sources = HashSet::from([loc("local"), loc("cloud")]);
let targets = HashSet::from([loc("cloud")]);
let tree = g.optimal_tree_multi_source(&sources, &targets);
assert!(
tree.is_empty(),
"target already has data, no transfer needed"
);
}
#[test]
fn multi_source_multiple_targets_different_best_sources() {
let mut g = RouteGraph::new();
g.add_with_cost(loc("local"), loc("nas"), EdgeCost::new(1.0, 10).unwrap());
g.add_with_cost(loc("pod"), loc("cloud"), EdgeCost::new(2.0, 10).unwrap());
g.add_with_cost(loc("local"), loc("cloud"), EdgeCost::new(10.0, 10).unwrap());
g.add_with_cost(loc("pod"), loc("nas"), EdgeCost::new(10.0, 10).unwrap());
let sources = HashSet::from([loc("local"), loc("pod")]);
let targets = HashSet::from([loc("nas"), loc("cloud")]);
let tree = g.optimal_tree_multi_source(&sources, &targets);
assert_eq!(tree.len(), 2);
let edges: HashSet<_> = tree.into_iter().collect();
assert!(edges.contains(&(loc("local"), loc("nas"))));
assert!(edges.contains(&(loc("pod"), loc("cloud"))));
}
}