use std::borrow::Cow;
#[cfg(feature = "zstd")]
use std::io::{Read, Write as IoWrite};
use std::path::Path;
use yada::{builder::DoubleArrayBuilder, DoubleArray};
use crate::error::{DictError, Result};
pub struct Trie<'a> {
da: DoubleArray<Cow<'a, [u8]>>,
}
impl<'a> Trie<'a> {
#[must_use]
pub fn new(bytes: &'a [u8]) -> Self {
Self {
da: DoubleArray::new(Cow::Borrowed(bytes)),
}
}
#[must_use]
pub fn from_vec(bytes: Vec<u8>) -> Trie<'static> {
Trie {
da: DoubleArray::new(Cow::Owned(bytes)),
}
}
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Trie<'static>> {
let bytes = std::fs::read(path.as_ref()).map_err(DictError::Io)?;
Ok(Self::from_vec(bytes))
}
#[cfg(feature = "zstd")]
pub fn from_compressed_file<P: AsRef<Path>>(path: P) -> Result<Trie<'static>> {
let file = std::fs::File::open(path.as_ref()).map_err(DictError::Io)?;
let mut decoder = zstd::Decoder::new(file).map_err(DictError::Io)?;
let mut bytes = Vec::new();
decoder.read_to_end(&mut bytes).map_err(DictError::Io)?;
Ok(Self::from_vec(bytes))
}
#[cfg(not(feature = "zstd"))]
pub fn from_compressed_file<P: AsRef<Path>>(_path: P) -> Result<Trie<'static>> {
Err(DictError::Format(
"zstd feature is not enabled. Use uncompressed files or enable the 'zstd' feature."
.to_string(),
))
}
#[must_use]
pub fn exact_match(&self, key: &str) -> Option<u32> {
self.da.exact_match_search(key.as_bytes())
}
#[must_use]
pub fn exact_match_bytes(&self, key: &[u8]) -> Option<u32> {
self.da.exact_match_search(key)
}
pub fn common_prefix_search<'b>(
&'b self,
text: &'b str,
) -> impl Iterator<Item = (u32, usize)> + 'b {
self.da.common_prefix_search(text.as_bytes())
}
pub fn common_prefix_search_bytes<'b>(
&'b self,
key: &'b [u8],
) -> impl Iterator<Item = (u32, usize)> + 'b {
self.da.common_prefix_search(key)
}
#[must_use]
pub fn common_prefix_search_at(&self, text: &str, start_byte: usize) -> Vec<(u32, usize)> {
if start_byte >= text.len() {
return Vec::new();
}
let suffix = &text[start_byte..];
self.da
.common_prefix_search(suffix.as_bytes())
.map(|(value, len)| (value, start_byte + len))
.collect()
}
}
pub struct TrieBuilder;
impl TrieBuilder {
pub fn build(entries: &[(&str, u32)]) -> Result<Vec<u8>> {
if entries.is_empty() {
return Err(DictError::Format(
"Cannot build Trie from empty entries".to_string(),
));
}
let keyset: Vec<_> = entries.iter().map(|(k, v)| (k.as_bytes(), *v)).collect();
DoubleArrayBuilder::build(&keyset)
.ok_or_else(|| DictError::Format("Failed to build Double-Array Trie".to_string()))
}
pub fn build_bytes(entries: &[(&[u8], u32)]) -> Result<Vec<u8>> {
if entries.is_empty() {
return Err(DictError::Format(
"Cannot build Trie from empty entries".to_string(),
));
}
DoubleArrayBuilder::build(entries)
.ok_or_else(|| DictError::Format("Failed to build Double-Array Trie".to_string()))
}
pub fn build_unsorted(entries: &mut [(&str, u32)]) -> Result<Vec<u8>> {
entries.sort_by(|a, b| a.0.as_bytes().cmp(b.0.as_bytes()));
Self::build(entries)
}
pub fn save_to_file<P: AsRef<Path>>(bytes: &[u8], path: P) -> Result<()> {
std::fs::write(path.as_ref(), bytes).map_err(DictError::Io)
}
#[cfg(feature = "zstd")]
pub fn save_to_compressed_file<P: AsRef<Path>>(
bytes: &[u8],
path: P,
level: i32,
) -> Result<()> {
let file = std::fs::File::create(path.as_ref()).map_err(DictError::Io)?;
let mut encoder = zstd::Encoder::new(file, level).map_err(DictError::Io)?;
encoder.write_all(bytes).map_err(DictError::Io)?;
encoder.finish().map_err(DictError::Io)?;
Ok(())
}
#[cfg(not(feature = "zstd"))]
pub fn save_to_compressed_file<P: AsRef<Path>>(
_bytes: &[u8],
_path: P,
_level: i32,
) -> Result<()> {
Err(DictError::Format(
"zstd feature is not enabled. Use uncompressed files or enable the 'zstd' feature."
.to_string(),
))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct EntryIndex(pub u32);
impl EntryIndex {
#[must_use]
pub const fn new(index: u32) -> Self {
Self(index)
}
#[must_use]
pub const fn value(&self) -> u32 {
self.0
}
}
impl From<u32> for EntryIndex {
fn from(value: u32) -> Self {
Self(value)
}
}
impl From<EntryIndex> for u32 {
fn from(index: EntryIndex) -> Self {
index.0
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PrefixMatch {
pub index: EntryIndex,
pub byte_length: usize,
pub start_byte: usize,
pub end_byte: usize,
}
impl PrefixMatch {
#[must_use]
pub const fn new(index: u32, byte_length: usize, start_byte: usize) -> Self {
Self {
index: EntryIndex(index),
start_byte,
end_byte: start_byte + byte_length,
byte_length,
}
}
}
pub struct DictionarySearcher<'a, E> {
trie: &'a Trie<'a>,
entries: &'a [E],
}
impl<'a, E> DictionarySearcher<'a, E> {
pub const fn new(trie: &'a Trie<'a>, entries: &'a [E]) -> Self {
Self { trie, entries }
}
#[must_use]
pub fn exact_match(&self, key: &str) -> Option<&E> {
let index = self.trie.exact_match(key)?;
self.entries.get(index as usize)
}
#[must_use]
pub fn common_prefix_search(&self, text: &str) -> Vec<(&E, PrefixMatch)> {
self.trie
.common_prefix_search(text)
.filter_map(|(index, byte_len)| {
let entry = self.entries.get(index as usize)?;
let prefix_match = PrefixMatch::new(index, byte_len, 0);
Some((entry, prefix_match))
})
.collect()
}
#[must_use]
pub fn common_prefix_search_at(&self, text: &str, start_byte: usize) -> Vec<(&E, PrefixMatch)> {
self.trie
.common_prefix_search_at(text, start_byte)
.into_iter()
.filter_map(|(index, end_byte)| {
let entry = self.entries.get(index as usize)?;
let byte_len = end_byte - start_byte;
let prefix_match = PrefixMatch::new(index, byte_len, start_byte);
Some((entry, prefix_match))
})
.collect()
}
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used)]
mod tests {
use super::*;
fn build_test_trie() -> Vec<u8> {
let entries = vec![
("가", 0u32),
("가다", 1),
("가방", 2),
("가방에", 3),
("나", 4),
("나다", 5),
];
TrieBuilder::build(&entries).unwrap()
}
#[test]
fn test_exact_match() {
let bytes = build_test_trie();
let trie = Trie::new(&bytes);
assert_eq!(trie.exact_match("가"), Some(0));
assert_eq!(trie.exact_match("가다"), Some(1));
assert_eq!(trie.exact_match("가방"), Some(2));
assert_eq!(trie.exact_match("가방에"), Some(3));
assert_eq!(trie.exact_match("나"), Some(4));
assert_eq!(trie.exact_match("없음"), None);
}
#[test]
fn test_common_prefix_search() {
let bytes = build_test_trie();
let trie = Trie::new(&bytes);
let results: Vec<_> = trie.common_prefix_search("가방에서").collect();
assert_eq!(results.len(), 3);
let values: Vec<_> = results.iter().map(|(v, _)| *v).collect();
assert!(values.contains(&0)); assert!(values.contains(&2)); assert!(values.contains(&3)); }
#[test]
fn test_common_prefix_search_at() {
let bytes = build_test_trie();
let trie = Trie::new(&bytes);
let text = "나가다";
let start = "나".len();
let results = trie.common_prefix_search_at(text, start);
assert_eq!(results.len(), 2);
}
#[test]
fn test_build_unsorted() {
let mut entries = vec![("가방", 2u32), ("가", 0), ("가다", 1)];
let bytes = TrieBuilder::build_unsorted(&mut entries).unwrap();
let trie = Trie::new(&bytes);
assert_eq!(trie.exact_match("가"), Some(0));
assert_eq!(trie.exact_match("가다"), Some(1));
assert_eq!(trie.exact_match("가방"), Some(2));
}
#[test]
fn test_from_vec() {
let bytes = build_test_trie();
let trie = Trie::from_vec(bytes);
assert_eq!(trie.exact_match("가"), Some(0));
}
#[test]
fn test_entry_index() {
let idx = EntryIndex::new(42);
assert_eq!(idx.value(), 42);
assert_eq!(u32::from(idx), 42);
let idx2: EntryIndex = 100u32.into();
assert_eq!(idx2.value(), 100);
}
#[test]
fn test_prefix_match() {
let pm = PrefixMatch::new(5, 6, 10);
assert_eq!(pm.index.value(), 5);
assert_eq!(pm.byte_length, 6);
assert_eq!(pm.start_byte, 10);
assert_eq!(pm.end_byte, 16);
}
#[test]
fn test_dictionary_searcher() {
let bytes = build_test_trie();
let trie = Trie::new(&bytes);
let entries = vec![
"가-entry",
"가다-entry",
"가방-entry",
"가방에-entry",
"나-entry",
"나다-entry",
];
let searcher = DictionarySearcher::new(&trie, &entries);
assert_eq!(searcher.exact_match("가다"), Some(&"가다-entry"));
assert_eq!(searcher.exact_match("없음"), None);
let results = searcher.common_prefix_search("가방에서");
assert_eq!(results.len(), 3);
let found_entries: Vec<_> = results.iter().map(|(e, _)| **e).collect();
assert!(found_entries.contains(&"가-entry"));
assert!(found_entries.contains(&"가방-entry"));
assert!(found_entries.contains(&"가방에-entry"));
}
#[test]
fn test_korean_morphemes() {
let mut entries = vec![
("아버지", 0u32),
("아버지가", 1),
("가", 2),
("가방", 3),
("가방에", 4),
("방", 5),
("방에", 6),
("에", 7),
];
let bytes = TrieBuilder::build_unsorted(&mut entries).expect("should build trie");
let trie = Trie::new(&bytes);
let text = "아버지가방에";
let at_0: Vec<_> = trie.common_prefix_search(text).collect();
assert!(at_0.iter().any(|(v, _)| *v == 0)); assert!(at_0.iter().any(|(v, _)| *v == 1));
let at_9 = trie.common_prefix_search_at(text, 9);
assert!(at_9.iter().any(|(v, _)| *v == 2)); assert!(at_9.iter().any(|(v, _)| *v == 3)); assert!(at_9.iter().any(|(v, _)| *v == 4)); }
#[test]
fn test_empty_trie() {
let entries: Vec<(&str, u32)> = vec![];
let result = TrieBuilder::build(&entries);
assert!(result.is_err());
}
#[test]
fn test_single_entry() {
let entries = vec![("테스트", 42u32)];
let bytes = TrieBuilder::build(&entries).unwrap();
let trie = Trie::new(&bytes);
assert_eq!(trie.exact_match("테스트"), Some(42));
assert_eq!(trie.exact_match("테스"), None);
assert_eq!(trie.exact_match("테스트입니다"), None);
}
}