use fxhash::{FxHashMap, FxHashSet};
use std::ops::ControlFlow;
use uni_common::core::id::{Eid, Vid};
use super::nfa::{NfaStateId, PathMode, PathSelector};
#[derive(Debug, Clone)]
pub struct PredRec {
pub src_vid: Vid,
pub src_state: NfaStateId,
pub eid: Eid,
pub next: i32,
}
pub struct PredecessorDag {
pred_pool: Vec<PredRec>,
pred_head: FxHashMap<(Vid, NfaStateId, u32), i32>,
first_depth: FxHashMap<(Vid, NfaStateId), u32>,
selector: PathSelector,
}
impl PredecessorDag {
pub fn new(selector: PathSelector) -> Self {
Self {
pred_pool: Vec::new(),
pred_head: FxHashMap::default(),
first_depth: FxHashMap::default(),
selector,
}
}
pub fn is_layered(&self) -> bool {
matches!(self.selector, PathSelector::All | PathSelector::Any)
}
pub fn add_predecessor(
&mut self,
dst: Vid,
dst_state: NfaStateId,
src: Vid,
src_state: NfaStateId,
eid: Eid,
depth: u32,
) {
let first = self.first_depth.entry((dst, dst_state)).or_insert(depth);
if depth < *first {
*first = depth;
}
if !self.is_layered() && depth > *self.first_depth.get(&(dst, dst_state)).unwrap() {
return;
}
let key = (dst, dst_state, depth);
let current_head = self.pred_head.get(&key).copied().unwrap_or(-1);
let new_idx = self.pred_pool.len() as i32;
self.pred_pool.push(PredRec {
src_vid: src,
src_state,
eid,
next: current_head,
});
self.pred_head.insert(key, new_idx);
}
#[expect(
clippy::too_many_arguments,
reason = "path enumeration requires full traversal context"
)]
pub fn enumerate_paths<F>(
&self,
source: Vid,
target: Vid,
accepting_state: NfaStateId,
min_depth: u32,
max_depth: u32,
mode: &PathMode,
yield_path: &mut F,
) where
F: FnMut(&[Vid], &[Eid]) -> ControlFlow<()>,
{
for depth in min_depth..=max_depth {
if depth == 0 {
if source == target && yield_path(&[source], &[]).is_break() {
return;
}
continue;
}
if !self
.pred_head
.contains_key(&(target, accepting_state, depth))
{
continue;
}
let mut nodes = Vec::with_capacity(depth as usize + 1);
let mut edges = Vec::with_capacity(depth as usize);
let mut node_set = FxHashSet::default();
let mut edge_set = FxHashSet::default();
nodes.push(target);
if matches!(mode, PathMode::Acyclic | PathMode::Simple) {
node_set.insert(target);
}
if self
.dfs_backward(
source,
target,
accepting_state,
depth,
&mut nodes,
&mut edges,
&mut node_set,
&mut edge_set,
mode,
yield_path,
)
.is_break()
{
return;
}
}
}
pub fn has_trail_valid_path(
&self,
source: Vid,
target: Vid,
accepting_state: NfaStateId,
min_depth: u32,
max_depth: u32,
) -> bool {
let mut found = false;
self.enumerate_paths(
source,
target,
accepting_state,
min_depth,
max_depth,
&PathMode::Trail,
&mut |_nodes, _edges| {
found = true;
ControlFlow::Break(())
},
);
found
}
#[expect(
clippy::too_many_arguments,
reason = "recursive DFS carries full path state"
)]
fn dfs_backward<F>(
&self,
source: Vid,
current_vid: Vid,
current_state: NfaStateId,
remaining_depth: u32,
nodes: &mut Vec<Vid>,
edges: &mut Vec<Eid>,
node_set: &mut FxHashSet<Vid>,
edge_set: &mut FxHashSet<Eid>,
mode: &PathMode,
yield_path: &mut F,
) -> ControlFlow<()>
where
F: FnMut(&[Vid], &[Eid]) -> ControlFlow<()>,
{
if remaining_depth == 0 {
if current_vid == source {
let fwd_nodes: Vec<Vid> = nodes.iter().rev().copied().collect();
let fwd_edges: Vec<Eid> = edges.iter().rev().copied().collect();
return yield_path(&fwd_nodes, &fwd_edges);
}
return ControlFlow::Continue(());
}
let key = (current_vid, current_state, remaining_depth);
let Some(&head) = self.pred_head.get(&key) else {
return ControlFlow::Continue(());
};
let mut idx = head;
while idx >= 0 {
let pred = &self.pred_pool[idx as usize];
let skip = match mode {
PathMode::Walk => false,
PathMode::Trail => edge_set.contains(&pred.eid),
PathMode::Acyclic => node_set.contains(&pred.src_vid),
PathMode::Simple => {
node_set.contains(&pred.src_vid)
&& !(remaining_depth == 1 && pred.src_vid == source)
}
};
if skip {
idx = pred.next;
continue;
}
nodes.push(pred.src_vid);
edges.push(pred.eid);
if matches!(mode, PathMode::Trail) {
edge_set.insert(pred.eid);
}
if matches!(mode, PathMode::Acyclic | PathMode::Simple) {
node_set.insert(pred.src_vid);
}
let result = self.dfs_backward(
source,
pred.src_vid,
pred.src_state,
remaining_depth - 1,
nodes,
edges,
node_set,
edge_set,
mode,
yield_path,
);
nodes.pop();
edges.pop();
if matches!(mode, PathMode::Trail) {
edge_set.remove(&pred.eid);
}
if matches!(mode, PathMode::Acyclic | PathMode::Simple) {
node_set.remove(&pred.src_vid);
}
if result.is_break() {
return ControlFlow::Break(());
}
idx = pred.next;
}
ControlFlow::Continue(())
}
pub fn pool_len(&self) -> usize {
self.pred_pool.len()
}
pub fn first_depth_of(&self, vid: Vid, state: NfaStateId) -> Option<u32> {
self.first_depth.get(&(vid, state)).copied()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn vid(n: u64) -> Vid {
Vid::new(n)
}
fn eid(n: u64) -> Eid {
Eid::new(n)
}
fn collect_paths(
dag: &PredecessorDag,
source: Vid,
target: Vid,
accepting_state: NfaStateId,
min_depth: u32,
max_depth: u32,
mode: &PathMode,
) -> Vec<(Vec<Vid>, Vec<Eid>)> {
let mut paths = Vec::new();
dag.enumerate_paths(
source,
target,
accepting_state,
min_depth,
max_depth,
mode,
&mut |nodes, edges| {
paths.push((nodes.to_vec(), edges.to_vec()));
ControlFlow::Continue(())
},
);
paths
}
#[test]
fn test_pred_dag_add_single() {
let mut dag = PredecessorDag::new(PathSelector::All);
dag.add_predecessor(vid(2), 1, vid(1), 0, eid(10), 1);
assert_eq!(dag.pool_len(), 1);
assert!(dag.pred_head.contains_key(&(vid(2), 1, 1)));
}
#[test]
fn test_pred_dag_add_chain() {
let mut dag = PredecessorDag::new(PathSelector::All);
dag.add_predecessor(vid(1), 1, vid(0), 0, eid(10), 1);
dag.add_predecessor(vid(2), 2, vid(1), 1, eid(11), 2);
assert_eq!(dag.pool_len(), 2);
assert!(dag.pred_head.contains_key(&(vid(1), 1, 1)));
assert!(dag.pred_head.contains_key(&(vid(2), 2, 2)));
}
#[test]
fn test_pred_dag_multiple_preds() {
let mut dag = PredecessorDag::new(PathSelector::All);
dag.add_predecessor(vid(2), 1, vid(0), 0, eid(10), 1);
dag.add_predecessor(vid(2), 1, vid(1), 0, eid(11), 1);
assert_eq!(dag.pool_len(), 2);
let head = dag.pred_head[&(vid(2), 1, 1)];
assert!(head >= 0);
let first = &dag.pred_pool[head as usize];
assert!(first.next >= 0); }
#[test]
fn test_pred_dag_first_depth() {
let mut dag = PredecessorDag::new(PathSelector::All);
dag.add_predecessor(vid(2), 1, vid(0), 0, eid(10), 3);
assert_eq!(dag.first_depth_of(vid(2), 1), Some(3));
dag.add_predecessor(vid(2), 1, vid(1), 0, eid(11), 2);
assert_eq!(dag.first_depth_of(vid(2), 1), Some(2));
dag.add_predecessor(vid(2), 1, vid(3), 0, eid(12), 5);
assert_eq!(dag.first_depth_of(vid(2), 1), Some(2));
}
#[test]
fn test_pred_dag_layered_stores_all() {
let mut dag = PredecessorDag::new(PathSelector::All);
assert!(dag.is_layered());
dag.add_predecessor(vid(2), 1, vid(0), 0, eid(10), 2);
dag.add_predecessor(vid(2), 1, vid(1), 0, eid(11), 3);
assert_eq!(dag.pool_len(), 2);
assert!(dag.pred_head.contains_key(&(vid(2), 1, 2)));
assert!(dag.pred_head.contains_key(&(vid(2), 1, 3)));
}
#[test]
fn test_pred_dag_shortest_skips() {
let mut dag = PredecessorDag::new(PathSelector::AnyShortest);
assert!(!dag.is_layered());
dag.add_predecessor(vid(2), 1, vid(0), 0, eid(10), 2);
assert_eq!(dag.pool_len(), 1);
dag.add_predecessor(vid(2), 1, vid(1), 0, eid(11), 3);
assert_eq!(dag.pool_len(), 1);
dag.add_predecessor(vid(2), 1, vid(3), 0, eid(12), 2);
assert_eq!(dag.pool_len(), 2);
}
#[test]
fn test_pred_dag_selector_switch() {
let build = |selector: PathSelector| -> usize {
let mut dag = PredecessorDag::new(selector);
dag.add_predecessor(vid(2), 1, vid(0), 0, eid(10), 2);
dag.add_predecessor(vid(2), 1, vid(1), 0, eid(11), 3);
dag.add_predecessor(vid(2), 1, vid(3), 0, eid(12), 4);
dag.pool_len()
};
assert_eq!(build(PathSelector::All), 3); assert_eq!(build(PathSelector::AnyShortest), 1); }
#[test]
fn test_pred_dag_linear_walk() {
let mut dag = PredecessorDag::new(PathSelector::All);
dag.add_predecessor(vid(1), 1, vid(0), 0, eid(10), 1);
dag.add_predecessor(vid(2), 2, vid(1), 1, eid(11), 2);
let paths = collect_paths(&dag, vid(0), vid(2), 2, 2, 2, &PathMode::Walk);
assert_eq!(paths.len(), 1);
assert_eq!(paths[0].0, vec![vid(0), vid(1), vid(2)]);
assert_eq!(paths[0].1, vec![eid(10), eid(11)]);
}
#[test]
fn test_pred_dag_diamond_walk() {
let mut dag = PredecessorDag::new(PathSelector::All);
dag.add_predecessor(vid(1), 1, vid(0), 0, eid(10), 1);
dag.add_predecessor(vid(2), 1, vid(0), 0, eid(11), 1);
dag.add_predecessor(vid(3), 2, vid(1), 1, eid(12), 2);
dag.add_predecessor(vid(3), 2, vid(2), 1, eid(13), 2);
let paths = collect_paths(&dag, vid(0), vid(3), 2, 2, 2, &PathMode::Walk);
assert_eq!(paths.len(), 2);
let mut sorted: Vec<_> = paths.iter().map(|(n, _)| n.clone()).collect();
sorted.sort();
assert!(sorted.contains(&vec![vid(0), vid(1), vid(3)]));
assert!(sorted.contains(&vec![vid(0), vid(2), vid(3)]));
}
#[test]
fn test_pred_dag_multiple_depths() {
let mut dag = PredecessorDag::new(PathSelector::All);
dag.add_predecessor(vid(2), 1, vid(0), 0, eid(10), 1);
dag.add_predecessor(vid(1), 1, vid(0), 0, eid(11), 1);
dag.add_predecessor(vid(2), 2, vid(1), 1, eid(12), 2);
let paths1 = collect_paths(&dag, vid(0), vid(2), 1, 1, 1, &PathMode::Walk);
assert_eq!(paths1.len(), 1);
assert_eq!(paths1[0].0, vec![vid(0), vid(2)]);
let paths2 = collect_paths(&dag, vid(0), vid(2), 2, 2, 2, &PathMode::Walk);
assert_eq!(paths2.len(), 1);
assert_eq!(paths2[0].0, vec![vid(0), vid(1), vid(2)]);
assert_eq!(paths1.len() + paths2.len(), 2);
}
#[test]
fn test_pred_dag_fan_out() {
let mut dag = PredecessorDag::new(PathSelector::All);
dag.add_predecessor(vid(1), 1, vid(0), 0, eid(10), 1);
dag.add_predecessor(vid(2), 1, vid(0), 0, eid(11), 1);
dag.add_predecessor(vid(3), 1, vid(0), 0, eid(12), 1);
dag.add_predecessor(vid(4), 2, vid(1), 1, eid(13), 2);
dag.add_predecessor(vid(4), 2, vid(2), 1, eid(14), 2);
dag.add_predecessor(vid(4), 2, vid(3), 1, eid(15), 2);
let paths = collect_paths(&dag, vid(0), vid(4), 2, 2, 2, &PathMode::Walk);
assert_eq!(paths.len(), 3);
}
#[test]
fn test_pred_dag_trail_no_repeat() {
let mut dag = PredecessorDag::new(PathSelector::All);
dag.add_predecessor(vid(1), 1, vid(0), 0, eid(1), 1);
dag.add_predecessor(vid(0), 2, vid(1), 1, eid(2), 2);
dag.add_predecessor(vid(1), 3, vid(0), 2, eid(1), 3);
let walk_paths = collect_paths(&dag, vid(0), vid(1), 3, 3, 3, &PathMode::Walk);
assert_eq!(walk_paths.len(), 1);
assert_eq!(walk_paths[0].1, vec![eid(1), eid(2), eid(1)]);
let trail_paths = collect_paths(&dag, vid(0), vid(1), 3, 3, 3, &PathMode::Trail);
assert_eq!(trail_paths.len(), 0);
}
#[test]
fn test_pred_dag_trail_allows_node_repeat() {
let mut dag = PredecessorDag::new(PathSelector::All);
dag.add_predecessor(vid(1), 1, vid(0), 0, eid(1), 1);
dag.add_predecessor(vid(2), 2, vid(1), 1, eid(2), 2);
dag.add_predecessor(vid(1), 3, vid(2), 2, eid(3), 3);
let paths = collect_paths(&dag, vid(0), vid(1), 3, 3, 3, &PathMode::Trail);
assert_eq!(paths.len(), 1);
assert_eq!(paths[0].0, vec![vid(0), vid(1), vid(2), vid(1)]);
assert_eq!(paths[0].1, vec![eid(1), eid(2), eid(3)]);
}
#[test]
fn test_pred_dag_trail_diamond() {
let mut dag = PredecessorDag::new(PathSelector::All);
dag.add_predecessor(vid(1), 1, vid(0), 0, eid(10), 1);
dag.add_predecessor(vid(2), 1, vid(0), 0, eid(11), 1);
dag.add_predecessor(vid(3), 2, vid(1), 1, eid(12), 2);
dag.add_predecessor(vid(3), 2, vid(2), 1, eid(13), 2);
let paths = collect_paths(&dag, vid(0), vid(3), 2, 2, 2, &PathMode::Trail);
assert_eq!(paths.len(), 2);
}
#[test]
fn test_pred_dag_trail_cycle_2_hop() {
let mut dag = PredecessorDag::new(PathSelector::All);
dag.add_predecessor(vid(1), 1, vid(0), 0, eid(1), 1);
dag.add_predecessor(vid(0), 2, vid(1), 1, eid(2), 2);
let paths = collect_paths(&dag, vid(0), vid(0), 2, 2, 2, &PathMode::Trail);
assert_eq!(paths.len(), 1);
assert_eq!(paths[0].0, vec![vid(0), vid(1), vid(0)]);
assert_eq!(paths[0].1, vec![eid(1), eid(2)]);
}
#[test]
fn test_pred_dag_acyclic_filter() {
let mut dag = PredecessorDag::new(PathSelector::All);
dag.add_predecessor(vid(1), 1, vid(0), 0, eid(1), 1);
dag.add_predecessor(vid(2), 2, vid(1), 1, eid(2), 2);
dag.add_predecessor(vid(0), 3, vid(2), 2, eid(3), 3);
let walk_paths = collect_paths(&dag, vid(0), vid(0), 3, 3, 3, &PathMode::Walk);
assert_eq!(walk_paths.len(), 1);
let acyclic_paths = collect_paths(&dag, vid(0), vid(0), 3, 3, 3, &PathMode::Acyclic);
assert_eq!(acyclic_paths.len(), 0);
}
#[test]
fn test_pred_dag_acyclic_diamond() {
let mut dag = PredecessorDag::new(PathSelector::All);
dag.add_predecessor(vid(1), 1, vid(0), 0, eid(10), 1);
dag.add_predecessor(vid(2), 1, vid(0), 0, eid(11), 1);
dag.add_predecessor(vid(3), 2, vid(1), 1, eid(12), 2);
dag.add_predecessor(vid(3), 2, vid(2), 1, eid(13), 2);
let paths = collect_paths(&dag, vid(0), vid(3), 2, 2, 2, &PathMode::Acyclic);
assert_eq!(paths.len(), 2);
}
#[test]
fn test_has_trail_valid_true() {
let mut dag = PredecessorDag::new(PathSelector::All);
dag.add_predecessor(vid(1), 1, vid(0), 0, eid(10), 1);
dag.add_predecessor(vid(2), 2, vid(1), 1, eid(11), 2);
assert!(dag.has_trail_valid_path(vid(0), vid(2), 2, 2, 2));
}
#[test]
fn test_has_trail_valid_false() {
let mut dag = PredecessorDag::new(PathSelector::All);
dag.add_predecessor(vid(1), 1, vid(0), 0, eid(1), 1);
dag.add_predecessor(vid(0), 2, vid(1), 1, eid(2), 2);
dag.add_predecessor(vid(1), 3, vid(0), 2, eid(1), 3);
assert!(!dag.has_trail_valid_path(vid(0), vid(1), 3, 3, 3));
}
#[test]
fn test_has_trail_valid_one_of_many() {
let mut dag = PredecessorDag::new(PathSelector::All);
dag.add_predecessor(vid(1), 1, vid(0), 0, eid(1), 1);
dag.add_predecessor(vid(2), 2, vid(1), 1, eid(2), 2);
dag.add_predecessor(vid(3), 1, vid(0), 0, eid(3), 1);
dag.add_predecessor(vid(2), 2, vid(3), 1, eid(4), 2);
assert!(dag.has_trail_valid_path(vid(0), vid(2), 2, 2, 2));
}
#[test]
fn test_pred_dag_early_stop() {
let mut dag = PredecessorDag::new(PathSelector::All);
for i in 1..=10u64 {
dag.add_predecessor(Vid::new(i), 1, vid(0), 0, Eid::new(i), 1);
dag.add_predecessor(vid(99), 2, Vid::new(i), 1, Eid::new(100 + i), 2);
}
let mut count = 0;
dag.enumerate_paths(
vid(0),
vid(99),
2,
2,
2,
&PathMode::Walk,
&mut |_nodes, _edges| {
count += 1;
if count >= 3 {
ControlFlow::Break(())
} else {
ControlFlow::Continue(())
}
},
);
assert_eq!(count, 3); }
#[test]
fn test_pred_dag_empty_enumerate() {
let dag = PredecessorDag::new(PathSelector::All);
let paths = collect_paths(&dag, vid(0), vid(1), 0, 1, 5, &PathMode::Walk);
assert!(paths.is_empty());
}
#[test]
fn test_pred_dag_zero_length() {
let dag = PredecessorDag::new(PathSelector::All);
let paths = collect_paths(&dag, vid(5), vid(5), 0, 0, 0, &PathMode::Walk);
assert_eq!(paths.len(), 1);
assert_eq!(paths[0].0, vec![vid(5)]);
assert!(paths[0].1.is_empty());
}
#[test]
fn test_pred_dag_path_order() {
let mut dag = PredecessorDag::new(PathSelector::All);
dag.add_predecessor(vid(1), 1, vid(0), 0, eid(10), 1);
dag.add_predecessor(vid(2), 2, vid(1), 1, eid(20), 2);
dag.add_predecessor(vid(3), 3, vid(2), 2, eid(30), 3);
let paths = collect_paths(&dag, vid(0), vid(3), 3, 3, 3, &PathMode::Walk);
assert_eq!(paths.len(), 1);
assert_eq!(paths[0].0, vec![vid(0), vid(1), vid(2), vid(3)]);
assert_eq!(paths[0].1, vec![eid(10), eid(20), eid(30)]);
}
#[test]
fn test_pred_dag_eid_in_path() {
let mut dag = PredecessorDag::new(PathSelector::All);
dag.add_predecessor(vid(1), 1, vid(0), 0, eid(100), 1);
dag.add_predecessor(vid(2), 2, vid(1), 1, eid(200), 2);
dag.add_predecessor(vid(3), 3, vid(2), 2, eid(300), 3);
let paths = collect_paths(&dag, vid(0), vid(3), 3, 3, 3, &PathMode::Walk);
assert_eq!(paths.len(), 1);
let (nodes, edges) = &paths[0];
assert_eq!(nodes.len(), edges.len() + 1);
assert_eq!(edges[0], eid(100)); assert_eq!(edges[1], eid(200)); assert_eq!(edges[2], eid(300)); }
}