use std::borrow::Cow;
use std::collections::BTreeMap;
use std::fs::File;
use std::io::Write;
use std::io::{self, Read};
use std::path::{Path, PathBuf};
use std::str::FromStr;
use anyhow::anyhow;
use byteorder::{LittleEndian, WriteBytesExt};
use csv::StringRecord;
use daachorse::DoubleArrayAhoCorasickBuilder;
use derive_builder::Builder;
use encoding_rs::{Encoding, UTF_8};
use encoding_rs_io::DecodeReaderBytesBuilder;
use glob::glob;
use log::debug;
use crate::LinderaResult;
use crate::dictionary::schema::Schema;
use crate::error::LinderaErrorKind;
use crate::util::write_data;
use crate::viterbi::WordEntry;
#[derive(Builder)]
#[builder(name = PrefixDictionaryBuilderOptions)]
#[builder(build_fn(name = "builder"))]
pub struct PrefixDictionaryBuilder {
#[builder(default = "true")]
flexible_csv: bool,
#[builder(default = "\"UTF-8\".into()", setter(into))]
encoding: Cow<'static, str>,
#[builder(default = "false")]
normalize_details: bool,
#[builder(default = "false")]
skip_invalid_cost_or_id: bool,
#[builder(default = "Schema::default()")]
schema: Schema,
}
impl PrefixDictionaryBuilder {
pub fn new(schema: Schema) -> Self {
Self {
flexible_csv: true,
encoding: "UTF-8".into(),
normalize_details: false,
skip_invalid_cost_or_id: false,
schema,
}
}
pub fn build(&self, input_dir: &Path, output_dir: &Path) -> LinderaResult<()> {
let rows = self.load_csv_data(input_dir)?;
let word_entry_map = self.build_word_entry_map(&rows)?;
self.write_dictionary_files(output_dir, &rows, &word_entry_map)?;
Ok(())
}
fn load_csv_data(&self, input_dir: &Path) -> LinderaResult<Vec<StringRecord>> {
let filenames = self.collect_csv_files(input_dir)?;
let encoding = self.get_encoding()?;
let mut rows = self.read_csv_files(&filenames, encoding)?;
if self.normalize_details {
rows.sort_by_key(|row| normalize(&row[0]));
} else {
rows.sort_by(|a, b| a[0].cmp(&b[0]))
}
Ok(rows)
}
fn collect_csv_files(&self, input_dir: &Path) -> LinderaResult<Vec<PathBuf>> {
let pattern = if let Some(path) = input_dir.to_str() {
format!("{path}/*.csv")
} else {
return Err(LinderaErrorKind::Io
.with_error(anyhow::anyhow!("Failed to convert path to &str."))
.add_context(format!(
"Input directory path contains invalid characters: {input_dir:?}"
)));
};
let mut filenames: Vec<PathBuf> = Vec::new();
for entry in glob(&pattern).map_err(|err| {
LinderaErrorKind::Io
.with_error(anyhow::anyhow!(err))
.add_context(format!("Failed to glob CSV files with pattern: {pattern}"))
})? {
match entry {
Ok(path) => {
if let Some(filename) = path.file_name() {
filenames.push(Path::new(input_dir).join(filename));
} else {
return Err(LinderaErrorKind::Io
.with_error(anyhow::anyhow!("failed to get filename"))
.add_context(format!("Invalid filename in path: {path:?}")));
};
}
Err(err) => {
return Err(LinderaErrorKind::Content
.with_error(anyhow!(err))
.add_context(format!(
"Failed to process glob entry with pattern: {pattern}"
)));
}
}
}
Ok(filenames)
}
fn get_encoding(&self) -> LinderaResult<&'static Encoding> {
let encoding = Encoding::for_label_no_replacement(self.encoding.as_bytes());
encoding.ok_or_else(|| {
LinderaErrorKind::Decode
.with_error(anyhow!("Invalid encoding: {}", self.encoding))
.add_context("Failed to get encoding for CSV file reading")
})
}
fn read_csv_files(
&self,
filenames: &[PathBuf],
encoding: &'static Encoding,
) -> LinderaResult<Vec<StringRecord>> {
let mut rows: Vec<StringRecord> = vec![];
for filename in filenames {
debug!("reading {filename:?}");
let file = File::open(filename).map_err(|err| {
LinderaErrorKind::Io
.with_error(anyhow::anyhow!(err))
.add_context(format!("Failed to open CSV file: {filename:?}"))
})?;
let reader: Box<dyn Read> = if encoding == UTF_8 {
Box::new(file)
} else {
Box::new(
DecodeReaderBytesBuilder::new()
.encoding(Some(encoding))
.build(file),
)
};
let mut rdr = csv::ReaderBuilder::new()
.has_headers(false)
.flexible(self.flexible_csv)
.from_reader(reader);
for result in rdr.records() {
let record = result.map_err(|err| {
LinderaErrorKind::Content
.with_error(anyhow!(err))
.add_context(format!("Failed to parse CSV record in file: {filename:?}"))
})?;
rows.push(record);
}
}
Ok(rows)
}
fn build_word_entry_map(
&self,
rows: &[StringRecord],
) -> LinderaResult<BTreeMap<String, Vec<WordEntry>>> {
let mut word_entry_map: BTreeMap<String, Vec<WordEntry>> = BTreeMap::new();
for (row_id, row) in rows.iter().enumerate() {
let word_cost = self.parse_word_cost(row)?;
let left_id = self.parse_left_id(row)?;
let right_id = self.parse_right_id(row)?;
if word_cost.is_none() || left_id.is_none() || right_id.is_none() {
continue;
}
let key = if self.normalize_details {
if let Some(surface) = self.get_field_value(row, "surface")? {
normalize(&surface)
} else {
continue;
}
} else if let Some(surface) = self.get_field_value(row, "surface")? {
surface
} else {
continue;
};
word_entry_map.entry(key).or_default().push(WordEntry {
word_id: crate::viterbi::WordId::new(
crate::viterbi::LexType::System,
row_id as u32,
),
word_cost: word_cost.unwrap(),
left_id: left_id.unwrap(),
right_id: right_id.unwrap(),
});
}
Ok(word_entry_map)
}
fn get_field_value(
&self,
row: &StringRecord,
field_name: &str,
) -> LinderaResult<Option<String>> {
if let Some(index) = self.schema.get_field_index(field_name) {
if index >= row.len() {
return Ok(None);
}
let value = row[index].trim();
Ok(if value.is_empty() {
None
} else {
Some(value.to_string())
})
} else {
Ok(None)
}
}
fn parse_word_cost(&self, row: &StringRecord) -> LinderaResult<Option<i16>> {
let cost_str = self.get_field_value(row, "cost")?;
match cost_str {
Some(s) => match i16::from_str(&s) {
Ok(cost) => Ok(Some(cost)),
Err(_) => {
if self.skip_invalid_cost_or_id {
Ok(None)
} else {
Err(LinderaErrorKind::Content
.with_error(anyhow!("Invalid cost value: {s}")))
}
}
},
None => Ok(None),
}
}
fn parse_left_id(&self, row: &StringRecord) -> LinderaResult<Option<u16>> {
let left_id_str = self.get_field_value(row, "left_context_id")?;
match left_id_str {
Some(s) => match u16::from_str(&s) {
Ok(id) => Ok(Some(id)),
Err(_) => {
if self.skip_invalid_cost_or_id {
Ok(None)
} else {
Err(LinderaErrorKind::Content
.with_error(anyhow!("Invalid left context ID: {s}")))
}
}
},
None => Ok(None),
}
}
fn parse_right_id(&self, row: &StringRecord) -> LinderaResult<Option<u16>> {
let right_id_str = self.get_field_value(row, "right_context_id")?;
match right_id_str {
Some(s) => match u16::from_str(&s) {
Ok(id) => Ok(Some(id)),
Err(_) => {
if self.skip_invalid_cost_or_id {
Ok(None)
} else {
Err(LinderaErrorKind::Content
.with_error(anyhow!("Invalid right context ID: {s}")))
}
}
},
None => Ok(None),
}
}
fn write_dictionary_files(
&self,
output_dir: &Path,
rows: &[StringRecord],
word_entry_map: &BTreeMap<String, Vec<WordEntry>>,
) -> LinderaResult<()> {
self.write_words_files(output_dir, rows)?;
self.write_double_array_file(output_dir, word_entry_map)?;
self.write_values_file(output_dir, word_entry_map)?;
Ok(())
}
fn write_words_files(&self, output_dir: &Path, rows: &[StringRecord]) -> LinderaResult<()> {
let mut dict_words_buffer = Vec::new();
let mut dict_wordsidx_buffer = Vec::new();
for row in rows.iter() {
let offset = dict_words_buffer.len();
dict_wordsidx_buffer
.write_u32::<LittleEndian>(offset as u32)
.map_err(|err| {
LinderaErrorKind::Io
.with_error(anyhow::anyhow!(err))
.add_context("Failed to write word index offset to dict.wordsidx buffer")
})?;
let joined_details = if self.normalize_details {
row.iter()
.skip(4)
.map(normalize)
.collect::<Vec<String>>()
.join("\0")
} else {
row.iter().skip(4).collect::<Vec<&str>>().join("\0")
};
let joined_details_len = u32::try_from(joined_details.len()).map_err(|err| {
LinderaErrorKind::Serialize
.with_error(anyhow::anyhow!(err))
.add_context(format!(
"Word details length too large: {} bytes",
joined_details.len()
))
})?;
dict_words_buffer
.write_u32::<LittleEndian>(joined_details_len)
.map_err(|err| {
LinderaErrorKind::Serialize
.with_error(anyhow::anyhow!(err))
.add_context("Failed to write word details length to dict.words buffer")
})?;
dict_words_buffer
.write_all(joined_details.as_bytes())
.map_err(|err| {
LinderaErrorKind::Serialize
.with_error(anyhow::anyhow!(err))
.add_context("Failed to write word details to dict.words buffer")
})?;
}
let dict_words_path = output_dir.join(Path::new("dict.words"));
let mut dict_words_writer =
io::BufWriter::new(File::create(&dict_words_path).map_err(|err| {
LinderaErrorKind::Io
.with_error(anyhow::anyhow!(err))
.add_context(format!(
"Failed to create dict.words file: {dict_words_path:?}"
))
})?);
write_data(&dict_words_buffer, &mut dict_words_writer)?;
dict_words_writer.flush().map_err(|err| {
LinderaErrorKind::Io
.with_error(anyhow::anyhow!(err))
.add_context(format!(
"Failed to flush dict.words file: {dict_words_path:?}"
))
})?;
let dict_wordsidx_path = output_dir.join(Path::new("dict.wordsidx"));
let mut dict_wordsidx_writer =
io::BufWriter::new(File::create(&dict_wordsidx_path).map_err(|err| {
LinderaErrorKind::Io
.with_error(anyhow::anyhow!(err))
.add_context(format!(
"Failed to create dict.wordsidx file: {dict_wordsidx_path:?}"
))
})?);
write_data(&dict_wordsidx_buffer, &mut dict_wordsidx_writer)?;
dict_wordsidx_writer.flush().map_err(|err| {
LinderaErrorKind::Io
.with_error(anyhow::anyhow!(err))
.add_context(format!(
"Failed to flush dict.wordsidx file: {dict_wordsidx_path:?}"
))
})?;
Ok(())
}
fn write_double_array_file(
&self,
output_dir: &Path,
word_entry_map: &BTreeMap<String, Vec<WordEntry>>,
) -> LinderaResult<()> {
let mut id = 0u32;
let mut keyset: Vec<(&[u8], u32)> = vec![];
for (key, word_entries) in word_entry_map {
let len = word_entries.len() as u32;
let val = (id << 8) | len; keyset.push((key.as_bytes(), val));
id += len;
}
let keyset_len = keyset.len();
let dict_da = DoubleArrayAhoCorasickBuilder::new()
.build_with_values(keyset)
.map_err(|err| {
LinderaErrorKind::Build
.with_error(anyhow::anyhow!(err))
.add_context(format!(
"Failed to build DoubleArray with {} keys for prefix dictionary",
keyset_len
))
})?;
let dict_da_buffer = dict_da.serialize();
let dict_da_path = output_dir.join(Path::new("dict.da"));
let mut dict_da_writer =
io::BufWriter::new(File::create(&dict_da_path).map_err(|err| {
LinderaErrorKind::Io
.with_error(anyhow::anyhow!(err))
.add_context(format!("Failed to create dict.da file: {dict_da_path:?}"))
})?);
write_data(&dict_da_buffer, &mut dict_da_writer)?;
Ok(())
}
fn write_values_file(
&self,
output_dir: &Path,
word_entry_map: &BTreeMap<String, Vec<WordEntry>>,
) -> LinderaResult<()> {
let mut dict_vals_buffer = Vec::new();
for word_entries in word_entry_map.values() {
for word_entry in word_entries {
word_entry.serialize(&mut dict_vals_buffer).map_err(|err| {
LinderaErrorKind::Serialize
.with_error(anyhow::anyhow!(err))
.add_context(format!(
"Failed to serialize word entry (id: {})",
word_entry.word_id.id
))
})?;
}
}
let dict_vals_path = output_dir.join(Path::new("dict.vals"));
let mut dict_vals_writer =
io::BufWriter::new(File::create(&dict_vals_path).map_err(|err| {
LinderaErrorKind::Io
.with_error(anyhow::anyhow!(err))
.add_context(format!(
"Failed to create dict.vals file: {dict_vals_path:?}"
))
})?);
write_data(&dict_vals_buffer, &mut dict_vals_writer)?;
dict_vals_writer.flush().map_err(|err| {
LinderaErrorKind::Io
.with_error(anyhow::anyhow!(err))
.add_context(format!(
"Failed to flush dict.vals file: {dict_vals_path:?}"
))
})?;
Ok(())
}
}
fn normalize(text: &str) -> String {
text.to_string().replace('―', "—").replace('~', "〜")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dictionary::schema::Schema;
use csv::StringRecord;
#[test]
fn test_new_with_schema() {
let schema = Schema::default();
let builder = PrefixDictionaryBuilder::new(schema.clone());
assert!(builder.flexible_csv);
assert_eq!(builder.encoding, "UTF-8");
assert!(!builder.normalize_details);
assert!(!builder.skip_invalid_cost_or_id);
}
#[test]
fn test_get_common_field_value_empty() {
let schema = Schema::default();
let builder = PrefixDictionaryBuilder::new(schema);
let record = StringRecord::from(vec![
"", "123", "456", "789", ]);
let surface = builder.get_field_value(&record, "surface").unwrap();
assert_eq!(surface, None);
}
#[test]
fn test_get_common_field_value_out_of_bounds() {
let schema = Schema::default();
let builder = PrefixDictionaryBuilder::new(schema);
let record = StringRecord::from(vec![
"surface_form", ]);
let left_id = builder.get_field_value(&record, "left_context_id").unwrap();
assert_eq!(left_id, None);
}
#[test]
fn test_parse_word_cost() {
let schema = Schema::default();
let builder = PrefixDictionaryBuilder::new(schema);
let record = StringRecord::from(vec![
"surface_form", "123", "456", "789", ]);
let cost = builder.parse_word_cost(&record).unwrap();
assert_eq!(cost, Some(789));
}
#[test]
fn test_parse_word_cost_invalid() {
let schema = Schema::default();
let builder = PrefixDictionaryBuilder::new(schema);
let record = StringRecord::from(vec![
"surface_form", "123", "456", "invalid", ]);
let result = builder.parse_word_cost(&record);
assert!(result.is_err());
}
#[test]
fn test_parse_word_cost_skip_invalid() {
let schema = Schema::default();
let mut builder = PrefixDictionaryBuilder::new(schema);
builder.skip_invalid_cost_or_id = true;
let record = StringRecord::from(vec![
"surface_form", "123", "456", "invalid", ]);
let cost = builder.parse_word_cost(&record).unwrap();
assert_eq!(cost, None);
}
#[test]
fn test_parse_left_id() {
let schema = Schema::default();
let builder = PrefixDictionaryBuilder::new(schema);
let record = StringRecord::from(vec![
"surface_form", "123", "456", "789", ]);
let left_id = builder.parse_left_id(&record).unwrap();
assert_eq!(left_id, Some(123));
}
#[test]
fn test_parse_right_id() {
let schema = Schema::default();
let builder = PrefixDictionaryBuilder::new(schema);
let record = StringRecord::from(vec![
"surface_form", "123", "456", "789", ]);
let right_id = builder.parse_right_id(&record).unwrap();
assert_eq!(right_id, Some(456));
}
#[test]
fn test_normalize_function() {
assert_eq!(normalize("test―text"), "test—text");
assert_eq!(normalize("test~text"), "test〜text");
assert_eq!(normalize("test―text~more"), "test—text〜more");
assert_eq!(normalize("normal text"), "normal text");
}
#[test]
fn test_get_encoding() {
let schema = Schema::default();
let builder = PrefixDictionaryBuilder::new(schema);
let encoding = builder.get_encoding().unwrap();
assert_eq!(encoding.name(), "UTF-8");
}
#[test]
fn test_get_encoding_invalid() {
let schema = Schema::default();
let mut builder = PrefixDictionaryBuilder::new(schema);
builder.encoding = "INVALID-ENCODING".into();
let result = builder.get_encoding();
assert!(result.is_err());
}
#[test]
fn test_get_common_field_value() {
let schema = Schema::default();
let builder = PrefixDictionaryBuilder::new(schema);
let record = StringRecord::from(vec![
"word", "123", "456", "789", "名詞", ]);
assert_eq!(
builder.get_field_value(&record, "surface").unwrap(),
Some("word".to_string())
);
assert_eq!(
builder.get_field_value(&record, "left_context_id").unwrap(),
Some("123".to_string())
);
assert_eq!(
builder
.get_field_value(&record, "right_context_id")
.unwrap(),
Some("456".to_string())
);
assert_eq!(
builder.get_field_value(&record, "cost").unwrap(),
Some("789".to_string())
);
let short_record = StringRecord::from(vec!["word", "123"]);
assert_eq!(
builder.get_field_value(&short_record, "cost").unwrap(),
None
);
}
}