#[derive(thiserror::Error, Debug, Clone, PartialEq)]
pub enum Error {
#[error("gamma must be positive and finite, got {0}")]
InvalidGamma(f64),
#[error("graph must have at least 2 nodes, got {0}")]
TooFewNodes(usize),
#[error("edge endpoint out of bounds: edge {edge_idx} has ({from}->{to}) for n={n}")]
EdgeOutOfBounds {
edge_idx: usize,
from: usize,
to: usize,
n: usize,
},
#[error("non-finite edge cost at edge {edge_idx}: cost={cost}")]
NonFiniteCost {
edge_idx: usize,
cost: f64,
},
#[error("expected DAG/topological order with from < to; edge {edge_idx} has ({from}->{to})")]
NotDagOrder {
edge_idx: usize,
from: usize,
to: usize,
},
#[error("no path exists from source to sink")]
NoPath,
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Edge {
pub from: usize,
pub to: usize,
pub cost: f64,
}
fn log_sum_exp(xs: &[f64]) -> f64 {
let mut m = f64::NEG_INFINITY;
for &x in xs {
if x > m {
m = x;
}
}
if !m.is_finite() {
return f64::NEG_INFINITY;
}
let mut s = 0.0;
for &x in xs {
s += (x - m).exp();
}
m + s.ln()
}
fn softmin_gamma(gamma: f64, candidates: &[f64], scratch: &mut Vec<f64>) -> f64 {
scratch.clear();
scratch.reserve(candidates.len());
for &a in candidates {
scratch.push(-a / gamma);
}
let lse = log_sum_exp(scratch);
if lse == f64::NEG_INFINITY {
f64::INFINITY
} else {
-gamma * lse
}
}
fn validate(n: usize, edges: &[Edge]) -> Result<()> {
if n < 2 {
return Err(Error::TooFewNodes(n));
}
for (k, e) in edges.iter().enumerate() {
if e.from >= n || e.to >= n {
return Err(Error::EdgeOutOfBounds {
edge_idx: k,
from: e.from,
to: e.to,
n,
});
}
if e.from >= e.to {
return Err(Error::NotDagOrder {
edge_idx: k,
from: e.from,
to: e.to,
});
}
if !e.cost.is_finite() {
return Err(Error::NonFiniteCost {
edge_idx: k,
cost: e.cost,
});
}
}
Ok(())
}
pub fn soft_shortest_path_edge_marginals(
n: usize,
edges: &[Edge],
gamma: f64,
) -> Result<(f64, Vec<f64>)> {
if gamma <= 0.0 || !gamma.is_finite() {
return Err(Error::InvalidGamma(gamma));
}
validate(n, edges)?;
let mut incoming: Vec<Vec<usize>> = vec![Vec::new(); n];
let mut outgoing: Vec<Vec<usize>> = vec![Vec::new(); n];
for (k, e) in edges.iter().enumerate() {
incoming[e.to].push(k);
outgoing[e.from].push(k);
}
let mut fwd = vec![f64::INFINITY; n];
fwd[0] = 0.0;
let mut scratch = Vec::new();
let mut cands = Vec::new();
for v in 1..n {
cands.clear();
for &ek in &incoming[v] {
let e = edges[ek];
let a = fwd[e.from];
if a.is_finite() {
cands.push(a + e.cost);
}
}
fwd[v] = if cands.is_empty() {
f64::INFINITY
} else {
softmin_gamma(gamma, &cands, &mut scratch)
};
}
let value = fwd[n - 1];
if !value.is_finite() {
return Err(Error::NoPath);
}
let mut bwd = vec![f64::INFINITY; n];
bwd[n - 1] = 0.0;
for u_rev in 1..n {
let u = n - 1 - u_rev;
cands.clear();
for &ek in &outgoing[u] {
let e = edges[ek];
let a = bwd[e.to];
if a.is_finite() {
cands.push(e.cost + a);
}
}
bwd[u] = if cands.is_empty() {
f64::INFINITY
} else {
softmin_gamma(gamma, &cands, &mut scratch)
};
}
let mut p = vec![0.0; edges.len()];
for (k, e) in edges.iter().enumerate() {
let a = fwd[e.from];
let b = bwd[e.to];
if a.is_finite() && b.is_finite() {
let z = -((a + e.cost + b - value) / gamma);
p[k] = if z < -745.0 { 0.0 } else { z.exp() };
} else {
p[k] = 0.0;
}
}
Ok((value, p))
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
#[test]
fn diamond_graph_matches_softmax_over_path_costs() {
let n = 4;
let a = 1.0 + 2.0;
let b = 3.0 + 4.0;
let edges = [
Edge {
from: 0,
to: 1,
cost: 1.0,
},
Edge {
from: 1,
to: 3,
cost: 2.0,
},
Edge {
from: 0,
to: 2,
cost: 3.0,
},
Edge {
from: 2,
to: 3,
cost: 4.0,
},
];
let gamma = 0.5;
let (v, p) = soft_shortest_path_edge_marginals(n, &edges, gamma).unwrap();
let pa = (-a / gamma).exp();
let pb = (-b / gamma).exp();
let z = pa + pb;
let p_path_a = pa / z;
let p_path_b = pb / z;
assert!(
(p[0] - p_path_a).abs() < 1e-9,
"p0={} pa={}",
p[0],
p_path_a
);
assert!(
(p[1] - p_path_a).abs() < 1e-9,
"p1={} pa={}",
p[1],
p_path_a
);
assert!(
(p[2] - p_path_b).abs() < 1e-9,
"p2={} pb={}",
p[2],
p_path_b
);
assert!(
(p[3] - p_path_b).abs() < 1e-9,
"p3={} pb={}",
p[3],
p_path_b
);
let v_expected = -gamma * (pa + pb).ln();
assert!(
(v - v_expected).abs() < 1e-9,
"v={} v_expected={}",
v,
v_expected
);
}
proptest! {
#[test]
fn edge_marginals_are_probabilities_on_diamond(
c01 in 0.0f64..10.0,
c13 in 0.0f64..10.0,
c02 in 0.0f64..10.0,
c23 in 0.0f64..10.0,
gamma in 0.05f64..5.0
) {
let n = 4;
let edges = [
Edge { from: 0, to: 1, cost: c01 },
Edge { from: 1, to: 3, cost: c13 },
Edge { from: 0, to: 2, cost: c02 },
Edge { from: 2, to: 3, cost: c23 },
];
let (_v, p) = soft_shortest_path_edge_marginals(n, &edges, gamma).unwrap();
for &pe in &p {
prop_assert!((-1e-12..=1.0 + 1e-12).contains(&pe));
}
let s = p[0] + p[2];
prop_assert!((s - 1.0).abs() < 1e-10, "s={}", s);
}
}
}