use std::borrow::Cow;
pub type NodeId = u32;
pub const BOS_NODE_ID: NodeId = 0;
pub const INVALID_NODE_ID: NodeId = u32::MAX;
pub const BOS_CONTEXT_ID: u16 = 0;
pub const EOS_CONTEXT_ID: u16 = 0;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum NodeType {
Bos,
Eos,
#[default]
Known,
Unknown,
User,
}
#[derive(Debug, Clone)]
pub struct Node {
pub id: NodeId,
pub surface: Cow<'static, str>,
pub start_pos: usize,
pub end_pos: usize,
pub start_byte: usize,
pub end_byte: usize,
pub left_id: u16,
pub right_id: u16,
pub word_cost: i32,
pub total_cost: i32,
pub prev_node_id: NodeId,
pub node_type: NodeType,
pub feature: Cow<'static, str>,
pub has_space_before: bool,
}
impl Node {
#[must_use]
pub const fn bos() -> Self {
Self {
id: BOS_NODE_ID,
surface: Cow::Borrowed("BOS"),
start_pos: 0,
end_pos: 0,
start_byte: 0,
end_byte: 0,
left_id: BOS_CONTEXT_ID,
right_id: BOS_CONTEXT_ID,
word_cost: 0,
total_cost: 0,
prev_node_id: INVALID_NODE_ID,
node_type: NodeType::Bos,
feature: Cow::Borrowed("BOS/EOS,*,*,*,*,*,*,*"),
has_space_before: false,
}
}
#[must_use]
pub const fn eos(id: NodeId, char_len: usize, byte_len: usize) -> Self {
Self {
id,
surface: Cow::Borrowed("EOS"),
start_pos: char_len,
end_pos: char_len,
start_byte: byte_len,
end_byte: byte_len,
left_id: EOS_CONTEXT_ID,
right_id: EOS_CONTEXT_ID,
word_cost: 0,
total_cost: i32::MAX,
prev_node_id: INVALID_NODE_ID,
node_type: NodeType::Eos,
feature: Cow::Borrowed("BOS/EOS,*,*,*,*,*,*,*"),
has_space_before: false,
}
}
#[inline]
#[must_use]
pub fn is_bos(&self) -> bool {
self.node_type == NodeType::Bos
}
#[inline]
#[must_use]
pub fn is_eos(&self) -> bool {
self.node_type == NodeType::Eos
}
#[inline]
#[must_use]
pub const fn char_len(&self) -> usize {
self.end_pos - self.start_pos
}
#[inline]
#[must_use]
pub const fn byte_len(&self) -> usize {
self.end_byte - self.start_byte
}
}
#[derive(Debug, Clone)]
pub struct NodeBuilder {
surface: String,
start_pos: usize,
end_pos: usize,
start_byte: usize,
end_byte: usize,
left_id: u16,
right_id: u16,
word_cost: i32,
node_type: NodeType,
feature: String,
has_space_before: bool,
}
impl NodeBuilder {
#[must_use]
pub fn new(surface: &str, start_pos: usize, end_pos: usize) -> Self {
Self {
surface: surface.to_string(),
start_pos,
end_pos,
start_byte: 0,
end_byte: 0,
left_id: 0,
right_id: 0,
word_cost: 0,
node_type: NodeType::Known,
feature: String::new(),
has_space_before: false,
}
}
#[must_use]
pub const fn byte_positions(mut self, start: usize, end: usize) -> Self {
self.start_byte = start;
self.end_byte = end;
self
}
#[must_use]
pub const fn left_id(mut self, id: u16) -> Self {
self.left_id = id;
self
}
#[must_use]
pub const fn right_id(mut self, id: u16) -> Self {
self.right_id = id;
self
}
#[must_use]
pub const fn word_cost(mut self, cost: i32) -> Self {
self.word_cost = cost;
self
}
#[must_use]
pub const fn node_type(mut self, node_type: NodeType) -> Self {
self.node_type = node_type;
self
}
#[must_use]
pub fn feature(mut self, feature: &str) -> Self {
self.feature = feature.to_string();
self
}
#[must_use]
pub const fn has_space_before(mut self, value: bool) -> Self {
self.has_space_before = value;
self
}
#[must_use]
pub const fn build(self) -> Self {
self
}
}
#[derive(Debug, Clone)]
pub struct CharPositions {
char_to_byte: Vec<usize>,
total_bytes: usize,
}
impl CharPositions {
#[must_use]
pub fn new(text: &str) -> Self {
let mut char_to_byte = Vec::with_capacity(text.chars().count() + 1);
let mut byte_pos = 0;
for c in text.chars() {
char_to_byte.push(byte_pos);
byte_pos += c.len_utf8();
}
char_to_byte.push(byte_pos);
Self {
char_to_byte,
total_bytes: byte_pos,
}
}
#[inline]
#[must_use]
pub fn char_to_byte(&self, char_pos: usize) -> usize {
self.char_to_byte
.get(char_pos)
.copied()
.unwrap_or(self.total_bytes)
}
#[inline]
#[must_use]
pub fn char_count(&self) -> usize {
if self.char_to_byte.is_empty() {
0
} else {
self.char_to_byte.len() - 1
}
}
#[inline]
#[must_use]
pub fn byte_to_char(&self, byte_pos: usize) -> usize {
self.char_to_byte
.binary_search(&byte_pos)
.unwrap_or_else(|_| self.char_count())
}
#[inline]
#[must_use]
pub const fn byte_count(&self) -> usize {
self.total_bytes
}
}
#[derive(Debug, Clone, Default)]
pub struct SpacePositions {
positions: Vec<usize>,
}
impl SpacePositions {
#[must_use]
pub fn new(text: &str) -> Self {
let mut positions = Vec::new();
let mut char_pos = 0;
let mut prev_is_space = false;
for c in text.chars() {
if prev_is_space && !c.is_whitespace() {
positions.push(char_pos);
}
prev_is_space = c.is_whitespace();
if !c.is_whitespace() {
char_pos += 1;
}
}
Self { positions }
}
#[inline]
#[must_use]
pub fn has_space_before(&self, char_pos: usize) -> bool {
self.positions.binary_search(&char_pos).is_ok()
}
}
#[derive(Debug)]
pub struct Lattice {
original_text: String,
text: String,
char_positions: CharPositions,
space_positions: SpacePositions,
nodes: Vec<Node>,
ends_at: Vec<Vec<NodeId>>,
starts_at: Vec<Vec<NodeId>>,
bos_id: NodeId,
eos_id: NodeId,
}
impl Lattice {
#[must_use]
pub fn new(text: &str) -> Self {
let original_text = text.to_string();
let text_no_space: String = text.chars().filter(|c| !c.is_whitespace()).collect();
let char_positions = CharPositions::new(&text_no_space);
let space_positions = SpacePositions::new(text);
let char_len = char_positions.char_count();
let byte_len = char_positions.byte_count();
let bos = Node::bos();
let bos_id = bos.id;
let eos_id = 1;
let eos = Node::eos(eos_id, char_len, byte_len);
let nodes = vec![bos, eos];
let mut ends_at = vec![Vec::new(); char_len + 1];
let mut starts_at = vec![Vec::new(); char_len + 1];
ends_at[0].push(bos_id);
starts_at[char_len].push(eos_id);
Self {
original_text,
text: text_no_space,
char_positions,
space_positions,
nodes,
ends_at,
starts_at,
bos_id,
eos_id,
}
}
#[inline]
#[must_use]
pub fn text(&self) -> &str {
&self.text
}
#[inline]
#[must_use]
pub fn original_text(&self) -> &str {
&self.original_text
}
#[inline]
#[must_use]
pub fn char_len(&self) -> usize {
self.char_positions.char_count()
}
#[inline]
#[must_use]
pub fn char_pos_from_start_and_byte_len(&self, start_pos: usize, byte_len: usize) -> usize {
let start_byte = self.char_positions.char_to_byte(start_pos);
self.char_positions.byte_to_char(start_byte + byte_len)
}
#[inline]
#[must_use]
pub const fn byte_len(&self) -> usize {
self.char_positions.byte_count()
}
#[inline]
#[must_use]
pub fn node_count(&self) -> usize {
self.nodes.len()
}
#[inline]
#[must_use]
pub fn bos(&self) -> &Node {
&self.nodes[self.bos_id as usize]
}
#[inline]
#[must_use]
pub fn eos(&self) -> &Node {
&self.nodes[self.eos_id as usize]
}
#[inline]
pub fn eos_mut(&mut self) -> &mut Node {
let eos_id = self.eos_id as usize;
&mut self.nodes[eos_id]
}
#[inline]
#[must_use]
pub fn node(&self, id: NodeId) -> Option<&Node> {
self.nodes.get(id as usize)
}
#[inline]
pub fn node_mut(&mut self, id: NodeId) -> Option<&mut Node> {
self.nodes.get_mut(id as usize)
}
#[inline]
pub fn nodes(&self) -> impl Iterator<Item = &Node> {
self.nodes.iter()
}
#[inline]
pub fn nodes_ending_at(&self, pos: usize) -> impl Iterator<Item = &Node> {
self.ends_at
.get(pos)
.map(|ids| ids.iter())
.into_iter()
.flatten()
.filter_map(|&id| self.nodes.get(id as usize))
}
#[inline]
pub fn nodes_starting_at(&self, pos: usize) -> impl Iterator<Item = &Node> {
self.starts_at
.get(pos)
.map(|ids| ids.iter())
.into_iter()
.flatten()
.filter_map(|&id| self.nodes.get(id as usize))
}
#[allow(clippy::cast_possible_truncation)]
pub fn add_node(&mut self, builder: NodeBuilder) -> NodeId {
let id = self.nodes.len() as NodeId;
let start_byte = self.char_positions.char_to_byte(builder.start_pos);
let end_byte = self.char_positions.char_to_byte(builder.end_pos);
let has_space_before =
builder.has_space_before || self.space_positions.has_space_before(builder.start_pos);
let node = Node {
id,
surface: Cow::Owned(builder.surface),
start_pos: builder.start_pos,
end_pos: builder.end_pos,
start_byte,
end_byte,
left_id: builder.left_id,
right_id: builder.right_id,
word_cost: builder.word_cost,
total_cost: i32::MAX, prev_node_id: INVALID_NODE_ID,
node_type: builder.node_type,
feature: Cow::Owned(builder.feature),
has_space_before,
};
if builder.start_pos < self.starts_at.len() {
self.starts_at[builder.start_pos].push(id);
}
if builder.end_pos < self.ends_at.len() {
self.ends_at[builder.end_pos].push(id);
}
self.nodes.push(node);
id
}
#[must_use]
pub fn substring(&self, start: usize, end: usize) -> &str {
let start_byte = self.char_positions.char_to_byte(start);
let end_byte = self.char_positions.char_to_byte(end);
&self.text[start_byte..end_byte]
}
#[inline]
#[must_use]
pub fn has_space_at(&self, char_pos: usize) -> bool {
self.space_positions.has_space_before(char_pos)
}
pub fn clear(&mut self) {
self.nodes.truncate(2);
for v in &mut self.ends_at {
v.clear();
}
for v in &mut self.starts_at {
v.clear();
}
if !self.ends_at.is_empty() {
self.ends_at[0].push(self.bos_id);
}
let char_len = self.char_len();
if char_len < self.starts_at.len() {
self.starts_at[char_len].push(self.eos_id);
}
if let Some(eos) = self.nodes.get_mut(self.eos_id as usize) {
eos.total_cost = i32::MAX;
eos.prev_node_id = INVALID_NODE_ID;
}
}
pub fn reset(&mut self, text: &str) {
self.original_text.clear();
self.original_text.push_str(text);
self.text.clear();
for c in text.chars().filter(|c| !c.is_whitespace()) {
self.text.push(c);
}
self.char_positions = CharPositions::new(&self.text);
self.space_positions = SpacePositions::new(text);
let char_len = self.char_positions.char_count();
let byte_len = self.char_positions.byte_count();
let new_len = char_len + 1;
let old_ends_len = self.ends_at.len();
let old_starts_len = self.starts_at.len();
for v in self.ends_at.iter_mut().take(new_len.min(old_ends_len)) {
v.clear();
}
for v in self.starts_at.iter_mut().take(new_len.min(old_starts_len)) {
v.clear();
}
self.ends_at.truncate(new_len);
self.starts_at.truncate(new_len);
while self.ends_at.len() < new_len {
self.ends_at.push(Vec::new());
}
while self.starts_at.len() < new_len {
self.starts_at.push(Vec::new());
}
self.nodes.truncate(2);
if let Some(eos) = self.nodes.get_mut(self.eos_id as usize) {
eos.start_pos = char_len;
eos.end_pos = char_len;
eos.start_byte = byte_len;
eos.end_byte = byte_len;
eos.total_cost = i32::MAX;
eos.prev_node_id = INVALID_NODE_ID;
}
self.ends_at[0].push(self.bos_id);
self.starts_at[char_len].push(self.eos_id);
}
#[must_use]
pub fn best_path(&self) -> Vec<&Node> {
let mut path = Vec::new();
let mut current_id = self.eos_id;
while current_id != INVALID_NODE_ID {
if let Some(node) = self.nodes.get(current_id as usize) {
if node.node_type != NodeType::Bos && node.node_type != NodeType::Eos {
path.push(node);
}
current_id = node.prev_node_id;
} else {
break;
}
}
path.reverse();
path
}
#[cfg(test)]
#[must_use]
#[allow(clippy::format_push_string, clippy::uninlined_format_args)]
pub fn visualize(&self) -> String {
let mut output = String::new();
output.push_str(&format!("Lattice for: \"{}\"\n", self.text));
output.push_str(&format!("Nodes: {}\n", self.node_count()));
for pos in 0..=self.char_len() {
let ending: Vec<_> = self.nodes_ending_at(pos).collect();
if !ending.is_empty() {
output.push_str(&format!("\nPosition {}: ", pos));
for node in ending {
output.push_str(&format!(
"[{}: {} ({}-{})]",
node.id, node.surface, node.start_pos, node.end_pos
));
}
}
}
output
}
}
#[derive(Debug, Clone, Default)]
pub struct LatticeStats {
pub total_nodes: usize,
pub known_nodes: usize,
pub unknown_nodes: usize,
pub user_nodes: usize,
pub char_length: usize,
}
impl Lattice {
#[must_use]
pub fn stats(&self) -> LatticeStats {
let mut stats = LatticeStats {
total_nodes: self.nodes.len(),
char_length: self.char_len(),
..Default::default()
};
for node in &self.nodes {
match node.node_type {
NodeType::Known => stats.known_nodes += 1,
NodeType::Unknown => stats.unknown_nodes += 1,
NodeType::User => stats.user_nodes += 1,
_ => {}
}
}
stats
}
#[must_use]
pub fn memory_usage(&self) -> usize {
let text_bytes = self.text.len() + self.original_text.len();
let nodes_bytes = self.nodes.capacity() * std::mem::size_of::<Node>();
let index_bytes = self.starts_at.capacity() * std::mem::size_of::<Vec<u32>>()
+ self.ends_at.capacity() * std::mem::size_of::<Vec<u32>>()
+ self
.starts_at
.iter()
.map(|v| v.capacity() * 4)
.sum::<usize>()
+ self.ends_at.iter().map(|v| v.capacity() * 4).sum::<usize>();
let pos_bytes = (self.char_positions.char_count() + 1) * std::mem::size_of::<usize>();
let space_bytes = self.char_len() * std::mem::size_of::<usize>() / 10;
let node_strings: usize = self
.nodes
.iter()
.map(|n| n.surface.len() + n.feature.len())
.sum();
text_bytes + nodes_bytes + index_bytes + pos_bytes + space_bytes + node_strings
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::needless_collect)]
mod tests {
use super::*;
#[test]
fn test_lattice_creation() {
let lattice = Lattice::new("안녕하세요");
assert_eq!(lattice.text(), "안녕하세요");
assert_eq!(lattice.char_len(), 5);
assert_eq!(lattice.node_count(), 2); }
#[test]
fn test_lattice_with_spaces() {
let lattice = Lattice::new("안녕 하세요");
assert_eq!(lattice.text(), "안녕하세요");
assert_eq!(lattice.original_text(), "안녕 하세요");
assert_eq!(lattice.char_len(), 5);
assert!(!lattice.has_space_at(0));
assert!(!lattice.has_space_at(1));
assert!(lattice.has_space_at(2)); }
#[test]
fn test_add_node() {
let mut lattice = Lattice::new("안녕하세요");
let node_id = lattice.add_node(
NodeBuilder::new("안녕", 0, 2)
.left_id(100)
.right_id(100)
.word_cost(1000)
.feature("NNG,*,F,안녕,*,*,*,*"),
);
assert_eq!(node_id, 2); assert_eq!(lattice.node_count(), 3);
let node = lattice.node(node_id).unwrap();
assert_eq!(node.surface.as_ref(), "안녕");
assert_eq!(node.start_pos, 0);
assert_eq!(node.end_pos, 2);
assert_eq!(node.left_id, 100);
assert_eq!(node.word_cost, 1000);
}
#[test]
fn test_nodes_at_position() {
let mut lattice = Lattice::new("안녕하세요");
lattice.add_node(NodeBuilder::new("안녕", 0, 2));
lattice.add_node(NodeBuilder::new("안", 0, 1));
lattice.add_node(NodeBuilder::new("녕하", 1, 3));
let starting_at_0: Vec<_> = lattice.nodes_starting_at(0).collect();
assert_eq!(starting_at_0.len(), 2);
let ending_at_2: Vec<_> = lattice.nodes_ending_at(2).collect();
assert_eq!(ending_at_2.len(), 1); }
#[test]
fn test_char_positions() {
let positions = CharPositions::new("한글test");
assert_eq!(positions.char_count(), 6);
assert_eq!(positions.char_to_byte(0), 0); assert_eq!(positions.char_to_byte(1), 3); assert_eq!(positions.char_to_byte(2), 6); assert_eq!(positions.char_to_byte(3), 7); }
#[test]
fn test_substring() {
let lattice = Lattice::new("안녕하세요");
assert_eq!(lattice.substring(0, 2), "안녕");
assert_eq!(lattice.substring(2, 5), "하세요");
assert_eq!(lattice.substring(0, 5), "안녕하세요");
}
#[test]
fn test_bos_eos() {
let lattice = Lattice::new("테스트");
let bos = lattice.bos();
assert!(bos.is_bos());
assert_eq!(bos.id, BOS_NODE_ID);
let eos = lattice.eos();
assert!(eos.is_eos());
assert_eq!(eos.start_pos, 3);
}
#[test]
fn test_lattice_reset() {
let mut lattice = Lattice::new("안녕");
lattice.add_node(NodeBuilder::new("안녕", 0, 2));
assert_eq!(lattice.node_count(), 3);
lattice.reset("하세요");
assert_eq!(lattice.text(), "하세요");
assert_eq!(lattice.char_len(), 3);
assert_eq!(lattice.node_count(), 2); }
#[test]
fn test_space_before_detection() {
let mut lattice = Lattice::new("아버지가 방에");
let node_id = lattice.add_node(NodeBuilder::new("방에", 4, 6));
let node = lattice.node(node_id).unwrap();
assert!(node.has_space_before);
let node_id2 = lattice.add_node(NodeBuilder::new("아버지가", 0, 4));
let node2 = lattice.node(node_id2).unwrap();
assert!(!node2.has_space_before);
}
}