use crate::dict::DictionaryEntry;
use std::collections::HashMap;
use std::sync::RwLock;
use yada::DoubleArray;
use yada::builder::DoubleArrayBuilder;
type SurfaceIndex = HashMap<String, usize>;
pub struct OverlayDictionary {
entries: RwLock<HashMap<String, Vec<OverlayEntry>>>,
trie: RwLock<Option<DoubleArray<Vec<u8>>>>,
surface_index: RwLock<SurfaceIndex>,
sorted_surfaces: RwLock<Vec<String>>,
trie_dirty: RwLock<bool>,
}
#[derive(Debug, Clone)]
pub struct OverlayEntry {
pub left_id: u16,
pub right_id: u16,
pub wcost: i16,
pub feature: String,
}
impl OverlayEntry {
pub fn new(feature: &str, wcost: i16) -> Self {
Self {
left_id: 1285, right_id: 1285,
wcost,
feature: feature.to_string(),
}
}
pub fn with_context(left_id: u16, right_id: u16, wcost: i16, feature: &str) -> Self {
Self {
left_id,
right_id,
wcost,
feature: feature.to_string(),
}
}
}
impl Default for OverlayDictionary {
fn default() -> Self {
Self::new()
}
}
impl OverlayDictionary {
pub fn new() -> Self {
Self {
entries: RwLock::new(HashMap::new()),
trie: RwLock::new(None),
surface_index: RwLock::new(HashMap::new()),
sorted_surfaces: RwLock::new(Vec::new()),
trie_dirty: RwLock::new(false),
}
}
pub fn add_word(&self, surface: &str, entry: OverlayEntry) {
{
let mut entries = self.entries.write().unwrap();
entries.entry(surface.to_string()).or_default().push(entry);
}
*self.trie_dirty.write().unwrap() = true;
}
pub fn add_simple(&self, surface: &str, reading: &str, pronunciation: &str, wcost: i16) {
let feature = format!(
"名詞,固有名詞,一般,*,*,*,{},{},{}",
surface, reading, pronunciation
);
self.add_word(surface, OverlayEntry::new(&feature, wcost));
}
pub fn remove_word(&self, surface: &str) -> bool {
let removed = {
let mut entries = self.entries.write().unwrap();
entries.remove(surface).is_some()
};
if removed {
*self.trie_dirty.write().unwrap() = true;
}
removed
}
fn rebuild_trie(&self) {
let entries = self.entries.read().unwrap();
if entries.is_empty() {
*self.trie.write().unwrap() = None;
*self.surface_index.write().unwrap() = HashMap::new();
*self.sorted_surfaces.write().unwrap() = Vec::new();
*self.trie_dirty.write().unwrap() = false;
return;
}
let mut surfaces: Vec<String> = entries.keys().cloned().collect();
surfaces.sort();
let keyset: Vec<(&[u8], u32)> = surfaces
.iter()
.enumerate()
.map(|(i, s)| (s.as_bytes(), i as u32))
.collect();
let new_index: SurfaceIndex = surfaces
.iter()
.enumerate()
.map(|(i, s)| (s.clone(), i))
.collect();
if let Some(da_bytes) = DoubleArrayBuilder::build(&keyset) {
*self.trie.write().unwrap() = Some(DoubleArray::new(da_bytes));
*self.surface_index.write().unwrap() = new_index;
*self.sorted_surfaces.write().unwrap() = surfaces;
}
*self.trie_dirty.write().unwrap() = false;
}
pub fn lookup(&self, key: &str) -> Vec<DictionaryEntry> {
{
let dirty = *self.trie_dirty.read().unwrap();
if dirty {
drop(self.trie_dirty.read());
self.rebuild_trie();
}
}
let entries = self.entries.read().unwrap();
if entries.is_empty() {
return Vec::new();
}
let trie_guard = self.trie.read().unwrap();
let sorted_surfaces = self.sorted_surfaces.read().unwrap();
let mut results = Vec::new();
if let Some(ref trie) = *trie_guard {
let key_bytes = key.as_bytes();
for (value, length) in trie.common_prefix_search(key_bytes) {
if let Some(surface) = sorted_surfaces.get(value as usize) {
if let Some(entry_list) = entries.get(surface) {
for entry in entry_list {
results.push(DictionaryEntry {
length,
word_id: u32::MAX, left_id: entry.left_id,
right_id: entry.right_id,
pos_id: entry.left_id,
wcost: entry.wcost,
feature: entry.feature.clone(),
});
}
}
}
}
} else {
for (surface, surface_entries) in entries.iter() {
if key.starts_with(surface) {
for entry in surface_entries {
results.push(DictionaryEntry {
length: surface.len(),
word_id: u32::MAX, left_id: entry.left_id,
right_id: entry.right_id,
pos_id: entry.left_id,
wcost: entry.wcost,
feature: entry.feature.clone(),
});
}
}
}
}
results
}
pub fn len(&self) -> usize {
let entries = self.entries.read().unwrap();
entries.values().map(Vec::len).sum()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn clear(&self) {
let mut entries = self.entries.write().unwrap();
entries.clear();
*self.trie.write().unwrap() = None;
*self.surface_index.write().unwrap() = HashMap::new();
*self.sorted_surfaces.write().unwrap() = Vec::new();
*self.trie_dirty.write().unwrap() = false;
}
pub fn surfaces(&self) -> Vec<String> {
let entries = self.entries.read().unwrap();
entries.keys().cloned().collect()
}
}
impl std::fmt::Debug for OverlayDictionary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OverlayDictionary")
.field("count", &self.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_and_lookup() {
let overlay = OverlayDictionary::new();
overlay.add_simple(
"ChatGPT",
"チャットジーピーティー",
"チャットジーピーティー",
5000,
);
let results = overlay.lookup("ChatGPT");
assert_eq!(results.len(), 1);
assert_eq!(results[0].length, "ChatGPT".len());
assert_eq!(results[0].wcost, 5000);
assert!(results[0].feature.contains("ChatGPT"));
}
#[test]
fn test_remove_word() {
let overlay = OverlayDictionary::new();
overlay.add_simple("テスト", "テスト", "テスト", 5000);
assert_eq!(overlay.len(), 1);
assert!(overlay.remove_word("テスト"));
assert_eq!(overlay.len(), 0);
assert!(!overlay.remove_word("テスト"));
}
#[test]
fn test_prefix_lookup() {
let overlay = OverlayDictionary::new();
overlay.add_simple("東京", "トウキョウ", "トーキョー", 5000);
let results = overlay.lookup("東京都");
assert_eq!(results.len(), 1);
assert_eq!(results[0].length, "東京".len());
}
#[test]
fn test_multiple_entries() {
let overlay = OverlayDictionary::new();
overlay.add_simple("日本", "ニホン", "ニホン", 5000);
overlay.add_simple("日本語", "ニホンゴ", "ニホンゴ", 5000);
let results = overlay.lookup("日本語学校");
assert_eq!(results.len(), 2);
}
#[test]
fn test_thread_safety() {
use std::sync::Arc;
use std::thread;
let overlay = Arc::new(OverlayDictionary::new());
let handles: Vec<_> = (0..10)
.map(|i| {
let overlay = Arc::clone(&overlay);
thread::spawn(move || {
let surface = format!("word{}", i);
overlay.add_simple(&surface, "ワード", "ワード", 5000);
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
assert_eq!(overlay.len(), 10);
}
#[test]
fn test_trie_rebuild() {
let overlay = OverlayDictionary::new();
overlay.add_simple("Apple", "アップル", "アップル", 5000);
overlay.add_simple("Banana", "バナナ", "バナナ", 5000);
overlay.add_simple("Cherry", "チェリー", "チェリー", 5000);
let results = overlay.lookup("Apple");
assert_eq!(results.len(), 1);
let results = overlay.lookup("Banana");
assert_eq!(results.len(), 1);
}
}