use std::str::FromStr;
use log::warn;
use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
use serde::{Deserialize, Serialize};
use crate::LinderaResult;
use crate::dictionary::character_definition::CategoryId;
use crate::error::LinderaErrorKind;
use crate::viterbi::WordEntry;
#[derive(Serialize, Deserialize, Clone, Archive, RkyvSerialize, RkyvDeserialize)]
pub struct UnknownDictionary {
pub category_references: Vec<Vec<u32>>,
pub costs: Vec<WordEntry>,
pub words_idx_data: Vec<u32>,
pub words_data: Vec<u8>,
}
impl UnknownDictionary {
pub fn load(unknown_data: &[u8]) -> LinderaResult<UnknownDictionary> {
let mut aligned = rkyv::util::AlignedVec::<16>::new();
aligned.extend_from_slice(unknown_data);
rkyv::from_bytes::<UnknownDictionary, rkyv::rancor::Error>(&aligned).map_err(|err| {
LinderaErrorKind::Deserialize.with_error(anyhow::anyhow!(err.to_string()))
})
}
pub fn word_entry(&self, word_id: u32) -> WordEntry {
self.costs[word_id as usize]
}
pub fn lookup_word_ids(&self, category_id: CategoryId) -> &[u32] {
&self.category_references[category_id.0][..]
}
pub fn word_details(&self, word_id: u32) -> Option<Vec<&str>> {
let idx = word_id as usize;
if idx >= self.words_idx_data.len() {
return None;
}
let offset = self.words_idx_data[idx] as usize;
if offset + 4 > self.words_data.len() {
return None;
}
let len = u32::from_le_bytes(self.words_data[offset..offset + 4].try_into().ok()?) as usize;
if offset + 4 + len > self.words_data.len() {
return None;
}
let text = std::str::from_utf8(&self.words_data[offset + 4..offset + 4 + len]).ok()?;
Some(text.split('\0').collect())
}
pub fn gen_unk_words<F>(
&self,
sentence: &str,
start_pos: usize,
has_matched: bool,
max_grouping_len: Option<usize>,
mut callback: F,
) where
F: FnMut(UnkWord),
{
let chars: Vec<char> = sentence.chars().collect();
let max_len = max_grouping_len.unwrap_or(10);
let actual_max_len = if has_matched { 1 } else { max_len.min(3) };
for length in 1..=actual_max_len {
if start_pos + length > chars.len() {
break;
}
let end_pos = start_pos + length;
let first_char = chars[start_pos];
let char_type = classify_char_type(first_char);
let unk_word = UnkWord {
word_idx: WordIdx::new(char_type as u32),
end_char: end_pos,
};
callback(unk_word);
}
}
pub fn compatible_unk_index(
&self,
sentence: &str,
start: usize,
_end: usize,
feature: &str,
) -> Option<WordIdx> {
let chars: Vec<char> = sentence.chars().collect();
if start >= chars.len() {
return None;
}
let first_char = chars[start];
let char_type = classify_char_type(first_char);
if feature.starts_with(&format!("名詞,{}", get_type_name(char_type))) {
Some(WordIdx::new(char_type as u32))
} else {
None
}
}
}
#[derive(Debug, Clone)]
pub struct UnkWord {
pub word_idx: WordIdx,
pub end_char: usize,
}
impl UnkWord {
pub fn word_idx(&self) -> WordIdx {
self.word_idx
}
pub fn end_char(&self) -> usize {
self.end_char
}
}
#[derive(Debug, Clone, Copy)]
pub struct WordIdx {
pub word_id: u32,
}
impl WordIdx {
pub fn new(word_id: u32) -> Self {
Self { word_id }
}
}
fn classify_char_type(ch: char) -> usize {
if ch.is_ascii_digit() {
5 } else if ch.is_ascii_alphabetic() {
4 } else if is_kanji(ch) {
3 } else if is_katakana(ch) {
2 } else if is_hiragana(ch) {
1 } else {
0 }
}
fn get_type_name(char_type: usize) -> &'static str {
match char_type {
1 => "一般",
2 => "一般",
3 => "一般",
4 => "固有名詞",
5 => "数",
_ => "一般",
}
}
fn is_hiragana(ch: char) -> bool {
matches!(ch, '\u{3041}'..='\u{3096}')
}
fn is_katakana(ch: char) -> bool {
matches!(ch, '\u{30A1}'..='\u{30F6}' | '\u{30F7}'..='\u{30FA}' | '\u{31F0}'..='\u{31FF}')
}
fn is_kanji(ch: char) -> bool {
matches!(ch, '\u{4E00}'..='\u{9FAF}' | '\u{3400}'..='\u{4DBF}')
}
#[derive(Debug)]
pub struct UnknownDictionaryEntry {
pub surface: String,
pub left_id: u32,
pub right_id: u32,
pub word_cost: i32,
}
fn parse_dictionary_entry(
fields: &[&str],
expected_fields_len: usize,
) -> LinderaResult<UnknownDictionaryEntry> {
if fields.len() != expected_fields_len {
return Err(LinderaErrorKind::Content.with_error(anyhow::anyhow!(
"Invalid number of fields. Expect {}, got {}",
expected_fields_len,
fields.len()
)));
}
let surface = fields[0];
let left_id = u32::from_str(fields[1])
.map_err(|err| LinderaErrorKind::Parse.with_error(anyhow::anyhow!(err)))?;
let right_id = u32::from_str(fields[2])
.map_err(|err| LinderaErrorKind::Parse.with_error(anyhow::anyhow!(err)))?;
let word_cost = i32::from_str(fields[3])
.map_err(|err| LinderaErrorKind::Parse.with_error(anyhow::anyhow!(err)))?;
Ok(UnknownDictionaryEntry {
surface: surface.to_string(),
left_id,
right_id,
word_cost,
})
}
fn get_entry_id_matching_surface(
entries: &[UnknownDictionaryEntry],
target_surface: &str,
) -> Vec<u32> {
entries
.iter()
.enumerate()
.filter_map(|(entry_id, entry)| {
if entry.surface == *target_surface {
Some(entry_id as u32)
} else {
None
}
})
.collect()
}
fn make_category_references(
categories: &[String],
entries: &[UnknownDictionaryEntry],
) -> Vec<Vec<u32>> {
categories
.iter()
.map(|category| get_entry_id_matching_surface(entries, category))
.collect()
}
fn make_costs_array(entries: &[UnknownDictionaryEntry]) -> Vec<WordEntry> {
entries
.iter()
.enumerate()
.map(|(i, e)| {
if e.left_id != e.right_id {
warn!("left id and right id are not same: {e:?}");
}
WordEntry {
word_id: crate::viterbi::WordId::new(crate::viterbi::LexType::Unknown, i as u32),
left_id: e.left_id as u16,
right_id: e.right_id as u16,
word_cost: e.word_cost as i16,
}
})
.collect()
}
pub fn parse_unk(categories: &[String], file_content: &str) -> LinderaResult<UnknownDictionary> {
let mut unknown_dict_entries = Vec::new();
let mut words_idx_data = Vec::new();
let mut words_data: Vec<u8> = Vec::new();
for line in file_content.lines() {
let fields: Vec<&str> = line.split(',').collect::<Vec<&str>>();
let entry = parse_dictionary_entry(&fields[..], fields.len())?;
unknown_dict_entries.push(entry);
let offset = words_data.len() as u32;
words_idx_data.push(offset);
let details = if fields.len() > 4 {
fields[4..].join("\0")
} else {
String::new()
};
let details_bytes = details.as_bytes();
let len = details_bytes.len() as u32;
words_data.extend_from_slice(&len.to_le_bytes());
words_data.extend_from_slice(details_bytes);
}
let category_references = make_category_references(categories, &unknown_dict_entries[..]);
let costs = make_costs_array(&unknown_dict_entries[..]);
Ok(UnknownDictionary {
category_references,
costs,
words_idx_data,
words_data,
})
}
impl ArchivedUnknownDictionary {
pub fn word_entry(&self, word_id: u32) -> WordEntry {
let archived_entry = &self.costs[word_id as usize];
rkyv::deserialize::<WordEntry, rkyv::rancor::Error>(archived_entry).unwrap()
}
pub fn lookup_word_ids(&self, category_id: CategoryId) -> &[rkyv::rend::u32_le] {
self.category_references[category_id.0].as_slice()
}
pub fn word_details(&self, word_id: u32) -> Option<Vec<&str>> {
let idx = word_id as usize;
if idx >= self.words_idx_data.len() {
return None;
}
let offset = u32::from(self.words_idx_data[idx]) as usize;
if offset + 4 > self.words_data.len() {
return None;
}
let len_bytes: [u8; 4] = self.words_data[offset..offset + 4].try_into().ok()?;
let len = u32::from_le_bytes(len_bytes) as usize;
if offset + 4 + len > self.words_data.len() {
return None;
}
let text = std::str::from_utf8(&self.words_data[offset + 4..offset + 4 + len]).ok()?;
Some(text.split('\0').collect())
}
}