use rustc_hash::FxHashSet;
use super::lattice::Lattice;
use super::types::{Edge, Node, NodeId};
use crate::backend::LatticeBackend;
use crate::semiring::Semiring;
pub fn topological_sort<W: Semiring>(nodes: &[Node], edges: &[Edge<W>]) -> Option<Vec<NodeId>> {
if nodes.is_empty() {
return Some(Vec::new());
}
let n = nodes.len();
let edge_targets: Vec<NodeId> = edges.iter().map(|e| e.target).collect();
let mut in_degree: Vec<usize> = nodes.iter().map(|node| node.incoming.len()).collect();
let mut queue: Vec<NodeId> = Vec::with_capacity(n);
let mut result: Vec<NodeId> = Vec::with_capacity(n);
for node in nodes {
if node.incoming.is_empty() {
queue.push(node.id);
}
}
while let Some(node_id) = queue.pop() {
result.push(node_id);
if let Some(node) = nodes.get(node_id.0 as usize) {
for &edge_id in &node.outgoing {
let target = edge_targets[edge_id.0 as usize];
let idx = target.0 as usize;
in_degree[idx] -= 1;
if in_degree[idx] == 0 {
queue.push(target);
}
}
}
}
if result.len() == n {
Some(result)
} else {
None
}
}
pub fn is_acyclic(nodes: &[Node], edges: &[Edge<impl Semiring>]) -> bool {
if nodes.is_empty() {
return true;
}
let mut adj: Vec<Vec<NodeId>> = vec![Vec::new(); nodes.len()];
for edge in edges {
let src = edge.source.0 as usize;
if src < adj.len() {
adj[src].push(edge.target);
}
}
let mut color: Vec<u8> = vec![0; nodes.len()];
fn dfs(node: usize, adj: &[Vec<NodeId>], color: &mut [u8]) -> bool {
color[node] = 1;
for &neighbor in &adj[node] {
let idx = neighbor.0 as usize;
if idx >= color.len() {
continue;
}
match color[idx] {
1 => return false, 0 => {
if !dfs(idx, adj, color) {
return false;
}
}
_ => {} }
}
color[node] = 2; true
}
for i in 0..nodes.len() {
if color[i] == 0 && !dfs(i, &adj, &mut color) {
return false;
}
}
true
}
pub fn count_paths<W: Semiring, B: LatticeBackend>(lattice: &mut Lattice<W, B>) -> Option<usize> {
let topo_order = lattice.topological_order()?.to_vec();
if topo_order.is_empty() {
return Some(0);
}
let n = lattice.num_nodes();
let mut path_count: Vec<usize> = vec![0; n];
path_count[lattice.start().0 as usize] = 1;
for node_id in topo_order {
let current_count = path_count[node_id.0 as usize];
if current_count == 0 {
continue;
}
let outgoing: Vec<_> = lattice.outgoing_edges(node_id).map(|e| e.target).collect();
for target in outgoing {
let target_idx = target.0 as usize;
path_count[target_idx] = path_count[target_idx].checked_add(current_count)?;
}
}
Some(path_count[lattice.end().0 as usize])
}
pub fn reachable_nodes<W: Semiring, B: LatticeBackend>(
lattice: &Lattice<W, B>,
start: NodeId,
) -> FxHashSet<NodeId> {
let mut visited = FxHashSet::default();
let mut queue = vec![start];
while let Some(node_id) = queue.pop() {
if !visited.insert(node_id) {
continue;
}
for edge in lattice.outgoing_edges(node_id) {
if !visited.contains(&edge.target) {
queue.push(edge.target);
}
}
}
visited
}
pub fn path_exists<W: Semiring, B: LatticeBackend>(
lattice: &Lattice<W, B>,
source: NodeId,
target: NodeId,
) -> bool {
if source == target {
return true;
}
let reachable = reachable_nodes(lattice, source);
reachable.contains(&target)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::HashMapBackend;
use crate::lattice::builder::LatticeBuilder;
use crate::lattice::types::EdgeMetadata;
use crate::semiring::TropicalWeight;
fn linear_lattice(n: usize) -> Lattice<TropicalWeight, HashMapBackend> {
let backend = HashMapBackend::new();
let mut builder = LatticeBuilder::new(backend);
for i in 0..n {
builder.add_correction(
i,
i + 1,
&format!("word{}", i),
TropicalWeight::new(1.0),
EdgeMetadata::default(),
);
}
builder.build(n)
}
fn diamond_lattice() -> Lattice<TropicalWeight, HashMapBackend> {
let backend = HashMapBackend::new();
let mut builder = LatticeBuilder::new(backend);
builder.add_correction(0, 1, "a", TropicalWeight::new(1.0), EdgeMetadata::default());
builder.add_correction(0, 2, "b", TropicalWeight::new(1.0), EdgeMetadata::default());
builder.add_correction(1, 3, "c", TropicalWeight::new(1.0), EdgeMetadata::default());
builder.add_correction(2, 3, "d", TropicalWeight::new(1.0), EdgeMetadata::default());
builder.build(3)
}
#[test]
fn test_topological_sort_linear() {
let lattice = linear_lattice(3);
let order = topological_sort(lattice.nodes(), lattice.edges())
.expect("lattice/algorithms.rs: required value was None/Err");
assert_eq!(order.len(), 4);
for i in 0..order.len() - 1 {
let pos_i = lattice
.node(order[i])
.expect("lattice/algorithms.rs: required value was None/Err")
.position;
let pos_j = lattice
.node(order[i + 1])
.expect("lattice/algorithms.rs: required value was None/Err")
.position;
assert!(pos_i <= pos_j);
}
}
#[test]
fn test_topological_sort_diamond() {
let lattice = diamond_lattice();
let order = topological_sort(lattice.nodes(), lattice.edges())
.expect("lattice/algorithms.rs: required value was None/Err");
assert_eq!(order.len(), 4);
let start_pos = order
.iter()
.position(|&n| n == lattice.start())
.expect("lattice/algorithms.rs: required value was None/Err");
let end_pos = order
.iter()
.position(|&n| n == lattice.end())
.expect("lattice/algorithms.rs: required value was None/Err");
assert_eq!(start_pos, 0);
assert_eq!(end_pos, 3);
}
#[test]
fn test_topological_sort_empty() {
let empty_edges: &[Edge<TropicalWeight>] = &[];
let order = topological_sort(&[], empty_edges);
assert_eq!(order, Some(vec![]));
}
#[test]
fn test_is_acyclic_linear() {
let lattice = linear_lattice(3);
assert!(is_acyclic(lattice.nodes(), lattice.edges()));
}
#[test]
fn test_is_acyclic_diamond() {
let lattice = diamond_lattice();
assert!(is_acyclic(lattice.nodes(), lattice.edges()));
}
#[test]
fn test_count_paths_linear() {
let mut lattice = linear_lattice(3);
assert_eq!(count_paths(&mut lattice), Some(1));
}
#[test]
fn test_count_paths_diamond() {
let mut lattice = diamond_lattice();
assert_eq!(count_paths(&mut lattice), Some(2)); }
#[test]
fn test_count_paths_multi_edge() {
let backend = HashMapBackend::new();
let mut builder: LatticeBuilder<TropicalWeight, _> = LatticeBuilder::new(backend);
builder.add_correction(0, 1, "a", TropicalWeight::new(1.0), EdgeMetadata::default());
builder.add_correction(0, 1, "b", TropicalWeight::new(1.0), EdgeMetadata::default());
builder.add_correction(1, 2, "c", TropicalWeight::new(1.0), EdgeMetadata::default());
builder.add_correction(1, 2, "d", TropicalWeight::new(1.0), EdgeMetadata::default());
builder.add_correction(1, 2, "e", TropicalWeight::new(1.0), EdgeMetadata::default());
let mut lattice = builder.build(2);
assert_eq!(count_paths(&mut lattice), Some(6)); }
#[test]
fn test_reachable_nodes() {
let lattice = diamond_lattice();
let reachable = reachable_nodes(&lattice, lattice.start());
assert_eq!(reachable.len(), 4); assert!(reachable.contains(&lattice.start()));
assert!(reachable.contains(&lattice.end()));
}
#[test]
fn test_reachable_nodes_from_end() {
let lattice = diamond_lattice();
let reachable = reachable_nodes(&lattice, lattice.end());
assert_eq!(reachable.len(), 1); assert!(reachable.contains(&lattice.end()));
}
#[test]
fn test_path_exists() {
let lattice = diamond_lattice();
assert!(path_exists(&lattice, lattice.start(), lattice.end()));
assert!(path_exists(&lattice, lattice.start(), lattice.start()));
assert!(!path_exists(&lattice, lattice.end(), lattice.start()));
}
#[test]
fn test_count_paths_empty() {
let backend = HashMapBackend::new();
let builder: LatticeBuilder<TropicalWeight, _> = LatticeBuilder::new(backend);
let mut lattice = builder.build(0);
assert_eq!(count_paths(&mut lattice), Some(1));
}
}