use std::collections::HashMap;
use magnus::prelude::*;
use magnus::{Error, Ruby, function, method};
use lindera::dictionary::Metadata;
use crate::schema::RbSchema;
#[magnus::wrap(class = "Lindera::Metadata", free_immediately, size)]
#[derive(Debug, Clone)]
pub struct RbMetadata {
name: String,
encoding: String,
default_word_cost: i16,
default_left_context_id: u16,
default_right_context_id: u16,
default_field_value: String,
flexible_csv: bool,
skip_invalid_cost_or_id: bool,
normalize_details: bool,
dictionary_schema: RbSchema,
user_dictionary_schema: RbSchema,
}
impl RbMetadata {
#[allow(clippy::too_many_arguments)]
fn new(
name: Option<String>,
encoding: Option<String>,
default_word_cost: Option<i16>,
default_left_context_id: Option<u16>,
default_right_context_id: Option<u16>,
default_field_value: Option<String>,
flexible_csv: Option<bool>,
skip_invalid_cost_or_id: Option<bool>,
normalize_details: Option<bool>,
) -> Self {
RbMetadata {
name: name.unwrap_or_else(|| "default".to_string()),
encoding: encoding.unwrap_or_else(|| "UTF-8".to_string()),
default_word_cost: default_word_cost.unwrap_or(-10000),
default_left_context_id: default_left_context_id.unwrap_or(1288),
default_right_context_id: default_right_context_id.unwrap_or(1288),
default_field_value: default_field_value.unwrap_or_else(|| "*".to_string()),
flexible_csv: flexible_csv.unwrap_or(false),
skip_invalid_cost_or_id: skip_invalid_cost_or_id.unwrap_or(false),
normalize_details: normalize_details.unwrap_or(false),
dictionary_schema: RbSchema::create_default_internal(),
user_dictionary_schema: RbSchema::new_internal(vec![
"surface".to_string(),
"reading".to_string(),
"pronunciation".to_string(),
]),
}
}
fn create_default() -> Self {
RbMetadata::new(None, None, None, None, None, None, None, None, None)
}
fn from_json_file(path: String) -> Result<Self, Error> {
let ruby = Ruby::get().expect("Ruby runtime not initialized");
let json_str = std::fs::read_to_string(&path).map_err(|e| {
Error::new(
ruby.exception_io_error(),
format!("Failed to read file: {e}"),
)
})?;
let metadata: Metadata = serde_json::from_str(&json_str).map_err(|e| {
Error::new(
ruby.exception_arg_error(),
format!("Failed to parse JSON: {e}"),
)
})?;
Ok(metadata.into())
}
fn name(&self) -> String {
self.name.clone()
}
fn encoding(&self) -> String {
self.encoding.clone()
}
fn default_word_cost(&self) -> i16 {
self.default_word_cost
}
fn default_left_context_id(&self) -> u16 {
self.default_left_context_id
}
fn default_right_context_id(&self) -> u16 {
self.default_right_context_id
}
fn default_field_value(&self) -> String {
self.default_field_value.clone()
}
fn flexible_csv(&self) -> bool {
self.flexible_csv
}
fn skip_invalid_cost_or_id(&self) -> bool {
self.skip_invalid_cost_or_id
}
fn normalize_details(&self) -> bool {
self.normalize_details
}
fn to_hash(&self) -> HashMap<String, String> {
let mut dict = HashMap::new();
dict.insert("name".to_string(), self.name.clone());
dict.insert("encoding".to_string(), self.encoding.clone());
dict.insert(
"default_word_cost".to_string(),
self.default_word_cost.to_string(),
);
dict.insert(
"default_left_context_id".to_string(),
self.default_left_context_id.to_string(),
);
dict.insert(
"default_right_context_id".to_string(),
self.default_right_context_id.to_string(),
);
dict.insert(
"default_field_value".to_string(),
self.default_field_value.clone(),
);
dict.insert("flexible_csv".to_string(), self.flexible_csv.to_string());
dict.insert(
"skip_invalid_cost_or_id".to_string(),
self.skip_invalid_cost_or_id.to_string(),
);
dict.insert(
"normalize_details".to_string(),
self.normalize_details.to_string(),
);
dict.insert(
"dictionary_schema_fields".to_string(),
self.dictionary_schema.fields.join(","),
);
dict.insert(
"user_dictionary_schema_fields".to_string(),
self.user_dictionary_schema.fields.join(","),
);
dict
}
fn to_s(&self) -> String {
format!(
"Metadata(name='{}', encoding='{}')",
self.name, self.encoding,
)
}
fn inspect(&self) -> String {
format!(
"#<Lindera::Metadata: name='{}', encoding='{}', schema_fields={}>",
self.name,
self.encoding,
self.dictionary_schema.fields.len()
)
}
}
impl From<RbMetadata> for Metadata {
fn from(metadata: RbMetadata) -> Self {
Metadata::new(
metadata.name,
metadata.encoding,
metadata.default_word_cost,
metadata.default_left_context_id,
metadata.default_right_context_id,
metadata.default_field_value,
metadata.flexible_csv,
metadata.skip_invalid_cost_or_id,
metadata.normalize_details,
metadata.dictionary_schema.into(),
metadata.user_dictionary_schema.into(),
)
}
}
impl From<Metadata> for RbMetadata {
fn from(metadata: Metadata) -> Self {
RbMetadata {
name: metadata.name,
encoding: metadata.encoding,
default_word_cost: metadata.default_word_cost,
default_left_context_id: metadata.default_left_context_id,
default_right_context_id: metadata.default_right_context_id,
default_field_value: metadata.default_field_value,
flexible_csv: metadata.flexible_csv,
skip_invalid_cost_or_id: metadata.skip_invalid_cost_or_id,
normalize_details: metadata.normalize_details,
dictionary_schema: metadata.dictionary_schema.into(),
user_dictionary_schema: metadata.user_dictionary_schema.into(),
}
}
}
pub fn define(ruby: &Ruby, module: &magnus::RModule) -> Result<(), Error> {
let metadata_class = module.define_class("Metadata", ruby.class_object())?;
metadata_class.define_singleton_method("new", function!(RbMetadata::new, 9))?;
metadata_class
.define_singleton_method("create_default", function!(RbMetadata::create_default, 0))?;
metadata_class
.define_singleton_method("from_json_file", function!(RbMetadata::from_json_file, 1))?;
metadata_class.define_method("name", method!(RbMetadata::name, 0))?;
metadata_class.define_method("encoding", method!(RbMetadata::encoding, 0))?;
metadata_class.define_method(
"default_word_cost",
method!(RbMetadata::default_word_cost, 0),
)?;
metadata_class.define_method(
"default_left_context_id",
method!(RbMetadata::default_left_context_id, 0),
)?;
metadata_class.define_method(
"default_right_context_id",
method!(RbMetadata::default_right_context_id, 0),
)?;
metadata_class.define_method(
"default_field_value",
method!(RbMetadata::default_field_value, 0),
)?;
metadata_class.define_method("flexible_csv", method!(RbMetadata::flexible_csv, 0))?;
metadata_class.define_method(
"skip_invalid_cost_or_id",
method!(RbMetadata::skip_invalid_cost_or_id, 0),
)?;
metadata_class.define_method(
"normalize_details",
method!(RbMetadata::normalize_details, 0),
)?;
metadata_class.define_method("to_hash", method!(RbMetadata::to_hash, 0))?;
metadata_class.define_method("to_h", method!(RbMetadata::to_hash, 0))?;
metadata_class.define_method("to_s", method!(RbMetadata::to_s, 0))?;
metadata_class.define_method("inspect", method!(RbMetadata::inspect, 0))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rb_metadata_to_lindera_metadata() {
let rb_metadata = RbMetadata {
name: "test_dict".to_string(),
encoding: "EUC-JP".to_string(),
default_word_cost: -5000,
default_left_context_id: 100,
default_right_context_id: 200,
default_field_value: "N/A".to_string(),
flexible_csv: true,
skip_invalid_cost_or_id: true,
normalize_details: true,
dictionary_schema: RbSchema::new_internal(vec![
"surface".to_string(),
"cost".to_string(),
]),
user_dictionary_schema: RbSchema::new_internal(vec!["surface".to_string()]),
};
let lindera_metadata: Metadata = rb_metadata.into();
assert_eq!(lindera_metadata.name, "test_dict");
assert_eq!(lindera_metadata.encoding, "EUC-JP");
assert_eq!(lindera_metadata.default_word_cost, -5000);
assert_eq!(lindera_metadata.default_left_context_id, 100);
assert_eq!(lindera_metadata.default_right_context_id, 200);
assert_eq!(lindera_metadata.default_field_value, "N/A");
assert!(lindera_metadata.flexible_csv);
assert!(lindera_metadata.skip_invalid_cost_or_id);
assert!(lindera_metadata.normalize_details);
assert_eq!(lindera_metadata.dictionary_schema.get_all_fields().len(), 2);
assert_eq!(
lindera_metadata
.user_dictionary_schema
.get_all_fields()
.len(),
1
);
}
#[test]
fn test_lindera_metadata_to_rb_metadata() {
let dict_schema =
lindera::dictionary::Schema::new(vec!["surface".to_string(), "cost".to_string()]);
let user_schema =
lindera::dictionary::Schema::new(vec!["surface".to_string(), "reading".to_string()]);
let lindera_metadata = Metadata::new(
"my_dict".to_string(),
"UTF-8".to_string(),
-8000,
500,
600,
"?".to_string(),
false,
true,
false,
dict_schema,
user_schema,
);
let rb_metadata: RbMetadata = lindera_metadata.into();
assert_eq!(rb_metadata.name, "my_dict");
assert_eq!(rb_metadata.encoding, "UTF-8");
assert_eq!(rb_metadata.default_word_cost, -8000);
assert_eq!(rb_metadata.default_left_context_id, 500);
assert_eq!(rb_metadata.default_right_context_id, 600);
assert_eq!(rb_metadata.default_field_value, "?");
assert!(!rb_metadata.flexible_csv);
assert!(rb_metadata.skip_invalid_cost_or_id);
assert!(!rb_metadata.normalize_details);
assert_eq!(rb_metadata.dictionary_schema.fields.len(), 2);
assert_eq!(rb_metadata.user_dictionary_schema.fields.len(), 2);
}
#[test]
fn test_rb_metadata_roundtrip() {
let rb_metadata = RbMetadata {
name: "roundtrip".to_string(),
encoding: "UTF-8".to_string(),
default_word_cost: -10000,
default_left_context_id: 1288,
default_right_context_id: 1288,
default_field_value: "*".to_string(),
flexible_csv: false,
skip_invalid_cost_or_id: false,
normalize_details: false,
dictionary_schema: RbSchema::create_default_internal(),
user_dictionary_schema: RbSchema::new_internal(vec![
"surface".to_string(),
"reading".to_string(),
"pronunciation".to_string(),
]),
};
let lindera: Metadata = rb_metadata.into();
let back: RbMetadata = lindera.into();
assert_eq!(back.name, "roundtrip");
assert_eq!(back.encoding, "UTF-8");
assert_eq!(back.default_word_cost, -10000);
assert_eq!(back.default_left_context_id, 1288);
assert_eq!(back.default_right_context_id, 1288);
assert_eq!(back.default_field_value, "*");
assert!(!back.flexible_csv);
assert!(!back.skip_invalid_cost_or_id);
assert!(!back.normalize_details);
assert_eq!(back.dictionary_schema.fields.len(), 13);
assert_eq!(back.user_dictionary_schema.fields.len(), 3);
}
}