use super::types::{Edge, EdgeId, Node, NodeId};
use crate::backend::{LatticeBackend, VocabId};
use crate::semiring::Semiring;
#[derive(Clone, Debug)]
pub struct Lattice<W: Semiring, B: LatticeBackend> {
nodes: Vec<Node>,
edges: Vec<Edge<W>>,
start: NodeId,
end: NodeId,
backend: B,
topo_order: Option<Vec<NodeId>>,
}
impl<W: Semiring, B: LatticeBackend> Lattice<W, B> {
pub(crate) fn new(
nodes: Vec<Node>,
edges: Vec<Edge<W>>,
start: NodeId,
end: NodeId,
backend: B,
) -> Self {
Self {
nodes,
edges,
start,
end,
backend,
topo_order: None,
}
}
#[inline]
pub fn start(&self) -> NodeId {
self.start
}
#[inline]
pub fn end(&self) -> NodeId {
self.end
}
#[inline]
pub fn num_nodes(&self) -> usize {
self.nodes.len()
}
#[inline]
pub fn num_edges(&self) -> usize {
self.edges.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.edges.is_empty()
}
#[inline]
pub fn node(&self, id: NodeId) -> Option<&Node> {
self.nodes.get(id.0 as usize)
}
#[inline]
pub fn edge(&self, id: EdgeId) -> Option<&Edge<W>> {
self.edges.get(id.0 as usize)
}
#[inline]
pub fn nodes(&self) -> &[Node] {
&self.nodes
}
#[inline]
pub fn edges(&self) -> &[Edge<W>] {
&self.edges
}
pub fn outgoing_edges(&self, node: NodeId) -> impl Iterator<Item = &Edge<W>> {
self.nodes
.get(node.0 as usize)
.into_iter()
.flat_map(|n| n.outgoing.iter())
.filter_map(|&eid| self.edges.get(eid.0 as usize))
}
pub fn incoming_edges(&self, node: NodeId) -> impl Iterator<Item = &Edge<W>> {
self.nodes
.get(node.0 as usize)
.into_iter()
.flat_map(|n| n.incoming.iter())
.filter_map(|&eid| self.edges.get(eid.0 as usize))
}
#[inline]
pub fn word(&self, id: VocabId) -> Option<&str> {
self.backend.lookup(id)
}
#[inline]
pub fn edge_word(&self, edge: &Edge<W>) -> Option<&str> {
self.word(edge.label)
}
#[inline]
pub fn backend(&self) -> &B {
&self.backend
}
#[inline]
pub fn backend_mut(&mut self) -> &mut B {
&mut self.backend
}
pub fn topological_order(&mut self) -> Option<&[NodeId]> {
if self.topo_order.is_none() {
self.topo_order = super::algorithms::topological_sort(&self.nodes, &self.edges);
}
self.topo_order.as_deref()
}
pub fn is_acyclic(&self) -> bool {
super::algorithms::is_acyclic(&self.nodes, &self.edges)
}
pub fn path_count(&mut self) -> Option<usize> {
super::algorithms::count_paths(self)
}
pub fn shrink_to_fit(&mut self) {
self.nodes.shrink_to_fit();
self.edges.shrink_to_fit();
for node in &mut self.nodes {
node.outgoing.shrink_to_fit();
node.incoming.shrink_to_fit();
}
}
pub fn node_ids(&self) -> impl Iterator<Item = NodeId> + '_ {
(0..self.nodes.len() as u32).map(NodeId)
}
pub fn edge_ids(&self) -> impl Iterator<Item = EdgeId> + '_ {
(0..self.edges.len() as u32).map(EdgeId)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::HashMapBackend;
use crate::lattice::builder::LatticeBuilder;
use crate::lattice::types::EdgeMetadata;
use crate::semiring::TropicalWeight;
fn sample_lattice() -> Lattice<TropicalWeight, HashMapBackend> {
let backend = HashMapBackend::new();
let mut builder = LatticeBuilder::new(backend);
builder.add_correction(
0,
1,
"the",
TropicalWeight::new(0.5),
EdgeMetadata::correction(1),
);
builder.add_correction(
0,
1,
"teh",
TropicalWeight::new(0.0),
EdgeMetadata::original(),
);
builder.add_correction(
1,
2,
"quick",
TropicalWeight::new(0.5),
EdgeMetadata::correction(1),
);
builder.add_correction(
1,
2,
"quik",
TropicalWeight::new(0.0),
EdgeMetadata::original(),
);
builder.build(2)
}
#[test]
fn test_lattice_structure() {
let lattice = sample_lattice();
assert_eq!(lattice.num_nodes(), 3); assert_eq!(lattice.num_edges(), 4); assert_eq!(lattice.start(), NodeId::new(0));
assert_eq!(lattice.end(), NodeId::new(2));
}
#[test]
fn test_node_access() {
let lattice = sample_lattice();
let start = lattice.node(NodeId::new(0)).expect("start node exists");
assert_eq!(start.out_degree(), 2);
assert_eq!(start.in_degree(), 0);
let middle = lattice.node(NodeId::new(1)).expect("middle node exists");
assert_eq!(middle.out_degree(), 2);
assert_eq!(middle.in_degree(), 2);
let end = lattice.node(NodeId::new(2)).expect("end node exists");
assert_eq!(end.out_degree(), 0);
assert_eq!(end.in_degree(), 2);
}
#[test]
fn test_word_lookup() {
let lattice = sample_lattice();
for edge in lattice.edges() {
let word = lattice.word(edge.label);
assert!(word.is_some());
}
}
#[test]
fn test_outgoing_edges() {
let lattice = sample_lattice();
let edges: Vec<_> = lattice.outgoing_edges(NodeId::new(0)).collect();
assert_eq!(edges.len(), 2);
let words: Vec<_> = edges.iter().filter_map(|e| lattice.word(e.label)).collect();
assert!(words.contains(&"the"));
assert!(words.contains(&"teh"));
}
#[test]
fn test_is_acyclic() {
let lattice = sample_lattice();
assert!(lattice.is_acyclic());
}
#[test]
fn test_topological_order() {
let mut lattice = sample_lattice();
let order = lattice.topological_order().expect("acyclic lattice");
assert_eq!(order.len(), 3);
let start_pos = order
.iter()
.position(|&n| n == NodeId::new(0))
.expect("lattice/lattice.rs: required value was None/Err");
let middle_pos = order
.iter()
.position(|&n| n == NodeId::new(1))
.expect("lattice/lattice.rs: required value was None/Err");
let end_pos = order
.iter()
.position(|&n| n == NodeId::new(2))
.expect("lattice/lattice.rs: required value was None/Err");
assert!(start_pos < middle_pos);
assert!(middle_pos < end_pos);
}
#[test]
fn test_path_count() {
let mut lattice = sample_lattice();
let count = lattice.path_count();
assert_eq!(count, Some(4));
}
#[test]
fn test_empty_lattice() {
let backend = HashMapBackend::new();
let builder: LatticeBuilder<TropicalWeight, _> = LatticeBuilder::new(backend);
let lattice = builder.build(0);
assert!(lattice.is_empty());
assert_eq!(lattice.num_nodes(), 1); }
#[test]
fn test_node_ids_iterator() {
let lattice = sample_lattice();
let ids: Vec<_> = lattice.node_ids().collect();
assert_eq!(ids, vec![NodeId::new(0), NodeId::new(1), NodeId::new(2)]);
}
#[test]
fn test_edge_ids_iterator() {
let lattice = sample_lattice();
let ids: Vec<_> = lattice.edge_ids().collect();
assert_eq!(ids.len(), 4);
}
}