use crate::eval::{Evaluator, PreparedEvaluator, ResultCount};
use crate::graph::{GraphDecomposition, GraphError, GraphblasVector, LagraphGraph};
use crate::lagraph_sys::LAGraph_Kind;
use crate::lagraph_sys::*;
use crate::rpq::{Endpoint, PathExpr, RpqError, RpqQuery};
use crate::{grb_ok, la_ok};
use rustfst::algorithms::closure::{ClosureType, closure};
use rustfst::algorithms::concat::concat;
use rustfst::algorithms::rm_epsilon::rm_epsilon;
use rustfst::algorithms::union::union;
use rustfst::prelude::*;
use rustfst::semirings::TropicalWeight;
use rustfst::utils::{acceptor, epsilon_machine};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct NfaLabelTransitions {
pub label: String,
pub rows: Vec<GrB_Index>,
pub cols: Vec<GrB_Index>,
}
#[derive(Debug, Clone)]
pub struct Nfa {
pub num_states: usize,
pub start_states: Vec<GrB_Index>,
pub final_states: Vec<GrB_Index>,
pub transitions: Vec<NfaLabelTransitions>,
}
struct SymbolTable {
label_to_id: HashMap<String, Label>,
id_to_label: HashMap<Label, String>,
next_id: Label,
}
impl SymbolTable {
fn new() -> Self {
Self {
label_to_id: HashMap::new(),
id_to_label: HashMap::new(),
next_id: 1,
}
}
fn get_or_insert(&mut self, label: &str) -> Label {
if let Some(&id) = self.label_to_id.get(label) {
id
} else {
let id = self.next_id;
self.next_id += 1;
self.label_to_id.insert(label.to_string(), id);
self.id_to_label.insert(id, label.to_string());
id
}
}
fn get_label(&self, id: Label) -> Option<&str> {
self.id_to_label.get(&id).map(|s| s.as_str())
}
}
fn map_fst_error<E: std::fmt::Display>(operation: &'static str, e: E) -> RpqError {
RpqError::UnsupportedPath(format!("{} failed: {}", operation, e))
}
impl Nfa {
pub fn from_path_expr(path: &PathExpr) -> Result<Self, RpqError> {
let mut symbols = SymbolTable::new();
let mut fst = build_fst(path, &mut symbols)?;
rm_epsilon(&mut fst).map_err(|e| map_fst_error("rm_epsilon", e))?;
extract_nfa(&fst, &symbols)
}
pub fn build_lagraph_matrices(&self) -> Result<Vec<(String, LagraphGraph)>, RpqError> {
let n = self.num_states as GrB_Index;
let mut result = Vec::with_capacity(self.transitions.len());
for trans in &self.transitions {
let vals: Vec<bool> = vec![true; trans.rows.len()];
let lg = LagraphGraph::from_coo(
&trans.rows,
&trans.cols,
&vals,
n,
LAGraph_Kind::LAGraph_ADJACENCY_DIRECTED,
)?;
result.push((trans.label.clone(), lg));
}
Ok(result)
}
}
fn build_fst(
path: &PathExpr,
symbols: &mut SymbolTable,
) -> Result<VectorFst<TropicalWeight>, RpqError> {
match path {
PathExpr::Label(label) => {
let label_id = symbols.get_or_insert(label);
Ok(acceptor(&[label_id], TropicalWeight::one()))
}
PathExpr::Sequence(lhs, rhs) => {
let mut fst_l = build_fst(lhs, symbols)?;
let fst_r = build_fst(rhs, symbols)?;
concat(&mut fst_l, &fst_r).map_err(|e| map_fst_error("concat", e))?;
Ok(fst_l)
}
PathExpr::Alternative(lhs, rhs) => {
let mut fst_l = build_fst(lhs, symbols)?;
let fst_r = build_fst(rhs, symbols)?;
union(&mut fst_l, &fst_r).map_err(|e| map_fst_error("union", e))?;
Ok(fst_l)
}
PathExpr::ZeroOrMore(inner) => {
let mut fst = build_fst(inner, symbols)?;
closure(&mut fst, ClosureType::ClosureStar);
Ok(fst)
}
PathExpr::OneOrMore(inner) => {
let mut fst = build_fst(inner, symbols)?;
closure(&mut fst, ClosureType::ClosurePlus);
Ok(fst)
}
PathExpr::ZeroOrOne(inner) => {
let mut fst_inner = build_fst(inner, symbols)?;
let fst_eps = epsilon_machine::<TropicalWeight, VectorFst<TropicalWeight>>()
.map_err(|e| map_fst_error("epsilon_machine", e))?;
union(&mut fst_inner, &fst_eps).map_err(|e| map_fst_error("union", e))?;
Ok(fst_inner)
}
}
}
fn extract_nfa(fst: &VectorFst<TropicalWeight>, symbols: &SymbolTable) -> Result<Nfa, RpqError> {
let num_states = fst.num_states();
let mut label_transitions: HashMap<String, Vec<(usize, usize)>> = HashMap::new();
for state in fst.states_iter() {
for tr in fst.get_trs(state).unwrap().trs() {
if tr.ilabel == EPS_LABEL {
continue;
}
if let Some(label) = symbols.get_label(tr.ilabel) {
label_transitions
.entry(label.to_string())
.or_default()
.push((state as usize, tr.nextstate as usize));
}
}
}
let start_states: Vec<GrB_Index> = fst
.start()
.map(|s| vec![s as GrB_Index])
.unwrap_or_default();
let final_states: Vec<GrB_Index> = fst
.states_iter()
.filter(|&s| fst.is_final(s).unwrap_or(false))
.map(|s| s as GrB_Index)
.collect();
let transitions: Vec<NfaLabelTransitions> = label_transitions
.into_iter()
.map(|(label, pairs)| {
let mut rows = Vec::with_capacity(pairs.len());
let mut cols = Vec::with_capacity(pairs.len());
for (r, c) in pairs {
rows.push(r as GrB_Index);
cols.push(c as GrB_Index);
}
NfaLabelTransitions { label, rows, cols }
})
.collect();
Ok(Nfa {
num_states,
start_states,
final_states,
transitions,
})
}
#[derive(Debug)]
pub struct NfaRpqResult {
pub reachable: GraphblasVector,
}
impl ResultCount for NfaRpqResult {
fn result_count(&self) -> Result<usize, GraphError> {
Ok(self.reachable.nvals()? as usize)
}
}
pub struct PreparedNfaRpq {
nfa: Nfa,
nfa_matrices: Vec<(String, LagraphGraph)>,
nfa_graph_ptrs: Vec<LAGraph_Graph>,
_data_graphs: Vec<Arc<LagraphGraph>>,
data_graph_ptrs: Vec<LAGraph_Graph>,
source_vertices: Vec<GrB_Index>,
destination_vertex: Option<usize>,
num_nodes: usize,
}
fn filter_reachable_by_destination(
reachable: GraphblasVector,
destination_vertex: Option<usize>,
num_nodes: usize,
) -> Result<GraphblasVector, RpqError> {
let Some(destination_vertex) = destination_vertex else {
return Ok(reachable);
};
let indices = reachable.indices().map_err(RpqError::Graph)?;
let filtered = GraphblasVector::new_bool(num_nodes as GrB_Index)?;
if indices.contains(&(destination_vertex as GrB_Index)) {
unsafe {
grb_ok!(GrB_Vector_setElement_BOOL(
filtered.inner,
true,
destination_vertex as GrB_Index,
))?
};
}
Ok(filtered)
}
impl PreparedEvaluator for PreparedNfaRpq {
type Result = NfaRpqResult;
type Error = RpqError;
fn execute(&mut self) -> Result<NfaRpqResult, RpqError> {
let mut reachable: GrB_Vector = std::ptr::null_mut();
unsafe {
la_ok!(LAGraph_RegularPathQuery(
&mut reachable,
self.nfa_graph_ptrs.as_mut_ptr(),
self.nfa_matrices.len(),
self.nfa.start_states.as_ptr(),
self.nfa.start_states.len(),
self.nfa.final_states.as_ptr(),
self.nfa.final_states.len(),
self.data_graph_ptrs.as_mut_ptr(),
self.source_vertices.as_ptr(),
self.source_vertices.len(),
))?
};
let reachable = filter_reachable_by_destination(
GraphblasVector { inner: reachable },
self.destination_vertex,
self.num_nodes,
)?;
Ok(NfaRpqResult { reachable })
}
}
#[derive(Clone, Copy)]
pub struct NfaRpqEvaluator;
impl Evaluator for NfaRpqEvaluator {
type Query = RpqQuery;
type Result = NfaRpqResult;
type Error = RpqError;
type Prepared = PreparedNfaRpq;
fn prepare<G: GraphDecomposition>(
&self,
query: &RpqQuery,
graph: &G,
) -> Result<PreparedNfaRpq, RpqError> {
let nfa = Nfa::from_path_expr(&query.path)?;
let nfa_matrices = nfa.build_lagraph_matrices()?;
let src_id = resolve_endpoint(&query.subject, graph)?;
let dst_id = resolve_endpoint(&query.object, graph)?;
let n = graph.num_nodes();
let source_vertices: Vec<GrB_Index> = match src_id {
Some(id) => vec![id as GrB_Index],
None => (0..n as GrB_Index).collect(),
};
let nfa_graph_ptrs: Vec<LAGraph_Graph> =
nfa_matrices.iter().map(|(_, lg)| lg.inner).collect();
let mut data_graphs = Vec::with_capacity(nfa_matrices.len());
let mut data_graph_ptrs = Vec::with_capacity(nfa_matrices.len());
for (label, _) in &nfa_matrices {
let lg = graph.get_graph(label)?;
data_graph_ptrs.push(lg.inner);
data_graphs.push(lg);
}
Ok(PreparedNfaRpq {
nfa,
nfa_matrices,
nfa_graph_ptrs,
_data_graphs: data_graphs,
data_graph_ptrs,
source_vertices,
destination_vertex: dst_id,
num_nodes: n,
})
}
}
fn resolve_endpoint<G: GraphDecomposition>(
term: &Endpoint,
graph: &G,
) -> Result<Option<usize>, RpqError> {
match term {
Endpoint::Variable(_) => Ok(None),
Endpoint::Named(id) => graph
.get_node_id(id)
.map(Some)
.ok_or_else(|| RpqError::VertexNotFound(id.clone())),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn label(s: &str) -> PathExpr {
PathExpr::Label(s.to_owned())
}
#[test]
fn test_single_label() {
let nfa = Nfa::from_path_expr(&label("knows")).unwrap();
assert!(nfa.num_states >= 2, "NFA should have at least 2 states");
assert!(!nfa.start_states.is_empty(), "should have start states");
assert!(!nfa.final_states.is_empty(), "should have final states");
assert_eq!(nfa.transitions.len(), 1);
assert_eq!(nfa.transitions[0].label, "knows");
assert!(!nfa.transitions[0].rows.is_empty());
}
#[test]
fn test_sequence() {
let path = PathExpr::Sequence(Box::new(label("a")), Box::new(label("b")));
let nfa = Nfa::from_path_expr(&path).unwrap();
let labels: Vec<&str> = nfa.transitions.iter().map(|t| t.label.as_str()).collect();
assert!(labels.contains(&"a"));
assert!(labels.contains(&"b"));
}
#[test]
fn test_alternative() {
let path = PathExpr::Alternative(Box::new(label("a")), Box::new(label("b")));
let nfa = Nfa::from_path_expr(&path).unwrap();
let labels: Vec<&str> = nfa.transitions.iter().map(|t| t.label.as_str()).collect();
assert!(labels.contains(&"a"));
assert!(labels.contains(&"b"));
}
#[test]
fn test_zero_or_more() {
let path = PathExpr::ZeroOrMore(Box::new(label("knows")));
let nfa = Nfa::from_path_expr(&path).unwrap();
assert!(!nfa.start_states.is_empty());
assert!(!nfa.final_states.is_empty());
let start_set: std::collections::HashSet<GrB_Index> =
nfa.start_states.iter().copied().collect();
let final_set: std::collections::HashSet<GrB_Index> =
nfa.final_states.iter().copied().collect();
assert!(
!start_set.is_disjoint(&final_set),
"start and final states should overlap for zero-or-more, start={:?}, final={:?}",
start_set,
final_set
);
}
}