Skip to main content

lindera_crf/
lattice.rs

1use core::num::NonZeroU32;
2
3use alloc::vec::Vec;
4
5use crate::errors::{Result, RucrfError};
6
7/// Represents an edge.
8#[derive(Clone, Copy, Debug, Eq, PartialEq)]
9pub struct Edge {
10    target: usize,
11    pub(crate) label: NonZeroU32,
12}
13
14impl Edge {
15    /// Creates a new edge.
16    #[inline(always)]
17    #[must_use]
18    pub const fn new(target: usize, label: NonZeroU32) -> Self {
19        Self { target, label }
20    }
21
22    /// Returns an index of the target node.
23    #[inline(always)]
24    #[must_use]
25    pub const fn target(&self) -> usize {
26        self.target
27    }
28
29    /// Returns a label of this edge.
30    #[inline(always)]
31    #[must_use]
32    pub const fn label(&self) -> NonZeroU32 {
33        self.label
34    }
35}
36
37/// Represents a node.
38#[derive(Clone, Default, Debug)]
39pub struct Node {
40    edges: Vec<Edge>,
41}
42
43impl Node {
44    /// Returns a list of edges.
45    ///
46    /// In training, the first edge is treated as a positive example.
47    #[inline(always)]
48    pub fn edges(&self) -> &[Edge] {
49        &self.edges
50    }
51}
52
53/// Represents a lattice.
54pub struct Lattice {
55    nodes: Vec<Node>,
56}
57
58impl Lattice {
59    /// Creates a new lattice.
60    ///
61    /// # Arguments
62    ///
63    /// * `length` - The length of this lattice.
64    ///
65    /// # Errors
66    ///
67    /// `length` must be >= 1.
68    #[inline(always)]
69    pub fn new(length: usize) -> Result<Self> {
70        if length == 0 {
71            return Err(RucrfError::invalid_argument("length must be >= 1"));
72        }
73        let nodes = vec![Node::default(); length + 1];
74        Ok(Self { nodes })
75    }
76
77    /// Adds a new edge.
78    ///
79    /// In training, the first edge of each position is treated as the positive example.
80    ///
81    /// # Errors
82    ///
83    /// `edge.target()` must be >= `pos` and <= `length`.
84    #[inline(always)]
85    pub fn add_edge(&mut self, pos: usize, edge: Edge) -> Result<()> {
86        if edge.target() <= pos {
87            return Err(RucrfError::invalid_argument("edge.target() must be > pos"));
88        }
89        if edge.target() > self.nodes.len() {
90            return Err(RucrfError::invalid_argument(
91                "edge.target() must be <= length",
92            ));
93        }
94        self.nodes[pos].edges.push(edge);
95        Ok(())
96    }
97
98    /// Returns a list of nodes.
99    #[inline(always)]
100    #[must_use]
101    pub fn nodes(&self) -> &[Node] {
102        &self.nodes
103    }
104}