1use std::collections::HashMap;
23
24use pyo3::prelude::*;
25
26use lindera::dictionary::Metadata;
27
28use crate::schema::PySchema;
29
30#[pyclass(name = "Metadata", from_py_object)]
48#[derive(Debug, Clone)]
49pub struct PyMetadata {
50 name: String,
51 encoding: String,
52 default_word_cost: i16,
53 default_left_context_id: u16,
54 default_right_context_id: u16,
55 default_field_value: String,
56 flexible_csv: bool,
57 skip_invalid_cost_or_id: bool,
58 normalize_details: bool,
59 dictionary_schema: PySchema,
60 user_dictionary_schema: PySchema,
61}
62
63#[pymethods]
64impl PyMetadata {
65 #[new]
66 #[pyo3(signature = (name=None, encoding=None, default_word_cost=None, default_left_context_id=None, default_right_context_id=None, default_field_value=None, flexible_csv=None, skip_invalid_cost_or_id=None, normalize_details=None, dictionary_schema=None, user_dictionary_schema=None))]
67 #[allow(clippy::too_many_arguments)]
68 pub fn new(
69 name: Option<String>,
70 encoding: Option<String>,
71 default_word_cost: Option<i16>,
72 default_left_context_id: Option<u16>,
73 default_right_context_id: Option<u16>,
74 default_field_value: Option<String>,
75 flexible_csv: Option<bool>,
76 skip_invalid_cost_or_id: Option<bool>,
77 normalize_details: Option<bool>,
78 dictionary_schema: Option<PySchema>,
79 user_dictionary_schema: Option<PySchema>,
80 ) -> Self {
81 PyMetadata {
82 name: name.unwrap_or_else(|| "default".to_string()),
83 encoding: encoding.unwrap_or_else(|| "UTF-8".to_string()),
84 default_word_cost: default_word_cost.unwrap_or(-10000),
85 default_left_context_id: default_left_context_id.unwrap_or(1288),
86 default_right_context_id: default_right_context_id.unwrap_or(1288),
87 default_field_value: default_field_value.unwrap_or_else(|| "*".to_string()),
88 flexible_csv: flexible_csv.unwrap_or(false),
89 skip_invalid_cost_or_id: skip_invalid_cost_or_id.unwrap_or(false),
90 normalize_details: normalize_details.unwrap_or(false),
91 dictionary_schema: dictionary_schema.unwrap_or_else(PySchema::create_default),
92 user_dictionary_schema: user_dictionary_schema.unwrap_or_else(|| {
93 PySchema::new(vec![
94 "surface".to_string(),
95 "reading".to_string(),
96 "pronunciation".to_string(),
97 ])
98 }),
99 }
100 }
101
102 #[staticmethod]
103 pub fn create_default() -> Self {
104 PyMetadata::new(
105 None, None, None, None, None, None, None, None, None, None, None,
106 )
107 }
108
109 #[staticmethod]
110 pub fn from_json_file(path: &str) -> PyResult<Self> {
111 use std::fs;
112
113 let json_str = fs::read_to_string(path).map_err(|e| {
114 pyo3::exceptions::PyIOError::new_err(format!("Failed to read file: {e}"))
115 })?;
116
117 let metadata: Metadata = serde_json::from_str(&json_str).map_err(|e| {
118 pyo3::exceptions::PyValueError::new_err(format!("Failed to parse JSON: {e}"))
119 })?;
120
121 Ok(metadata.into())
122 }
123
124 #[getter]
125 pub fn name(&self) -> &str {
126 &self.name
127 }
128
129 #[setter]
130 pub fn set_name(&mut self, name: String) {
131 self.name = name;
132 }
133
134 #[getter]
135 pub fn encoding(&self) -> &str {
136 &self.encoding
137 }
138
139 #[setter]
140 pub fn set_encoding(&mut self, encoding: String) {
141 self.encoding = encoding;
142 }
143
144 #[getter]
145 pub fn default_word_cost(&self) -> i16 {
146 self.default_word_cost
147 }
148
149 #[setter]
150 pub fn set_default_word_cost(&mut self, cost: i16) {
151 self.default_word_cost = cost;
152 }
153
154 #[getter]
155 pub fn default_left_context_id(&self) -> u16 {
156 self.default_left_context_id
157 }
158
159 #[setter]
160 pub fn set_default_left_context_id(&mut self, id: u16) {
161 self.default_left_context_id = id;
162 }
163
164 #[getter]
165 pub fn default_right_context_id(&self) -> u16 {
166 self.default_right_context_id
167 }
168
169 #[setter]
170 pub fn set_default_right_context_id(&mut self, id: u16) {
171 self.default_right_context_id = id;
172 }
173
174 #[getter]
175 pub fn default_field_value(&self) -> &str {
176 &self.default_field_value
177 }
178
179 #[setter]
180 pub fn set_default_field_value(&mut self, value: String) {
181 self.default_field_value = value;
182 }
183
184 #[getter]
185 pub fn flexible_csv(&self) -> bool {
186 self.flexible_csv
187 }
188
189 #[setter]
190 pub fn set_flexible_csv(&mut self, value: bool) {
191 self.flexible_csv = value;
192 }
193
194 #[getter]
195 pub fn skip_invalid_cost_or_id(&self) -> bool {
196 self.skip_invalid_cost_or_id
197 }
198
199 #[setter]
200 pub fn set_skip_invalid_cost_or_id(&mut self, value: bool) {
201 self.skip_invalid_cost_or_id = value;
202 }
203
204 #[getter]
205 pub fn normalize_details(&self) -> bool {
206 self.normalize_details
207 }
208
209 #[setter]
210 pub fn set_normalize_details(&mut self, value: bool) {
211 self.normalize_details = value;
212 }
213
214 #[getter]
215 pub fn dictionary_schema(&self) -> PySchema {
216 self.dictionary_schema.clone()
217 }
218
219 #[setter]
220 pub fn set_dictionary_schema(&mut self, schema: PySchema) {
221 self.dictionary_schema = schema;
222 }
223
224 #[getter]
225 pub fn user_dictionary_schema(&self) -> PySchema {
226 self.user_dictionary_schema.clone()
227 }
228
229 #[setter]
230 pub fn set_user_dictionary_schema(&mut self, schema: PySchema) {
231 self.user_dictionary_schema = schema;
232 }
233
234 pub fn to_dict(&self) -> HashMap<String, String> {
235 let mut dict = HashMap::new();
236 dict.insert("name".to_string(), self.name.clone());
237 dict.insert("encoding".to_string(), self.encoding.clone());
238 dict.insert(
239 "default_word_cost".to_string(),
240 self.default_word_cost.to_string(),
241 );
242 dict.insert(
243 "default_left_context_id".to_string(),
244 self.default_left_context_id.to_string(),
245 );
246 dict.insert(
247 "default_right_context_id".to_string(),
248 self.default_right_context_id.to_string(),
249 );
250 dict.insert(
251 "default_field_value".to_string(),
252 self.default_field_value.clone(),
253 );
254 dict.insert("flexible_csv".to_string(), self.flexible_csv.to_string());
255 dict.insert(
256 "skip_invalid_cost_or_id".to_string(),
257 self.skip_invalid_cost_or_id.to_string(),
258 );
259 dict.insert(
260 "normalize_details".to_string(),
261 self.normalize_details.to_string(),
262 );
263 dict.insert(
264 "dictionary_schema_fields".to_string(),
265 self.dictionary_schema.fields.join(","),
266 );
267 dict.insert(
268 "user_dictionary_schema_fields".to_string(),
269 self.user_dictionary_schema.fields.join(","),
270 );
271 dict
272 }
273
274 fn __str__(&self) -> String {
275 format!(
276 "Metadata(name='{}', encoding='{}')",
277 self.name, self.encoding,
278 )
279 }
280
281 fn __repr__(&self) -> String {
282 format!(
283 "Metadata(name='{}', encoding='{}', schema_fields={})",
284 self.name,
285 self.encoding,
286 self.dictionary_schema.field_count()
287 )
288 }
289}
290
291impl From<PyMetadata> for Metadata {
292 fn from(metadata: PyMetadata) -> Self {
293 Metadata::new(
294 metadata.name,
295 metadata.encoding,
296 metadata.default_word_cost,
297 metadata.default_left_context_id,
298 metadata.default_right_context_id,
299 metadata.default_field_value,
300 metadata.flexible_csv,
301 metadata.skip_invalid_cost_or_id,
302 metadata.normalize_details,
303 metadata.dictionary_schema.into(),
304 metadata.user_dictionary_schema.into(),
305 )
306 }
307}
308
309impl From<Metadata> for PyMetadata {
310 fn from(metadata: Metadata) -> Self {
311 PyMetadata {
312 name: metadata.name,
313 encoding: metadata.encoding,
314 default_word_cost: metadata.default_word_cost,
315 default_left_context_id: metadata.default_left_context_id,
316 default_right_context_id: metadata.default_right_context_id,
317 default_field_value: metadata.default_field_value,
318 flexible_csv: metadata.flexible_csv,
319 skip_invalid_cost_or_id: metadata.skip_invalid_cost_or_id,
320 normalize_details: metadata.normalize_details,
321 dictionary_schema: metadata.dictionary_schema.into(),
322 user_dictionary_schema: metadata.user_dictionary_schema.into(),
323 }
324 }
325}
326
327pub fn register(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
328 let py = parent_module.py();
329 let m = PyModule::new(py, "metadata")?;
330 m.add_class::<PyMetadata>()?;
331 parent_module.add_submodule(&m)?;
332 Ok(())
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use lindera::dictionary::Metadata;
339
340 #[test]
341 fn test_pymetadata_to_metadata() {
342 let py_meta = PyMetadata::new(
343 Some("test_dict".to_string()),
344 Some("EUC-JP".to_string()),
345 Some(-5000),
346 Some(100),
347 Some(200),
348 Some("N/A".to_string()),
349 Some(true),
350 Some(true),
351 Some(true),
352 None,
353 None,
354 );
355 let meta: Metadata = py_meta.into();
356 assert_eq!(meta.name, "test_dict");
357 assert_eq!(meta.encoding, "EUC-JP");
358 assert_eq!(meta.default_word_cost, -5000);
359 assert_eq!(meta.default_left_context_id, 100);
360 assert_eq!(meta.default_right_context_id, 200);
361 assert_eq!(meta.default_field_value, "N/A");
362 assert!(meta.flexible_csv);
363 assert!(meta.skip_invalid_cost_or_id);
364 assert!(meta.normalize_details);
365 }
366
367 #[test]
368 fn test_metadata_to_pymetadata() {
369 let schema = lindera::dictionary::Schema::new(vec![
370 "surface".to_string(),
371 "left_context_id".to_string(),
372 "right_context_id".to_string(),
373 "cost".to_string(),
374 ]);
375 let userdic_schema =
376 lindera::dictionary::Schema::new(vec!["surface".to_string(), "reading".to_string()]);
377 let meta = Metadata::new(
378 "my_dict".to_string(),
379 "UTF-8".to_string(),
380 -10000,
381 1288,
382 1288,
383 "*".to_string(),
384 false,
385 false,
386 false,
387 schema,
388 userdic_schema,
389 );
390 let py_meta: PyMetadata = meta.into();
391 assert_eq!(py_meta.name, "my_dict");
392 assert_eq!(py_meta.encoding, "UTF-8");
393 assert_eq!(py_meta.default_word_cost, -10000);
394 assert_eq!(py_meta.default_left_context_id, 1288);
395 assert_eq!(py_meta.default_right_context_id, 1288);
396 assert_eq!(py_meta.default_field_value, "*");
397 assert!(!py_meta.flexible_csv);
398 assert!(!py_meta.skip_invalid_cost_or_id);
399 assert!(!py_meta.normalize_details);
400 assert_eq!(py_meta.dictionary_schema.fields.len(), 4);
401 assert_eq!(py_meta.user_dictionary_schema.fields.len(), 2);
402 }
403
404 #[test]
405 fn test_pymetadata_default_values() {
406 let py_meta = PyMetadata::create_default();
407 assert_eq!(py_meta.name, "default");
408 assert_eq!(py_meta.encoding, "UTF-8");
409 assert_eq!(py_meta.default_word_cost, -10000);
410 assert_eq!(py_meta.default_left_context_id, 1288);
411 assert_eq!(py_meta.default_right_context_id, 1288);
412 assert_eq!(py_meta.default_field_value, "*");
413 assert!(!py_meta.flexible_csv);
414 assert!(!py_meta.skip_invalid_cost_or_id);
415 assert!(!py_meta.normalize_details);
416 assert_eq!(py_meta.dictionary_schema.field_count(), 13);
417 assert_eq!(py_meta.user_dictionary_schema.fields.len(), 3);
418 }
419
420 #[test]
421 fn test_pymetadata_roundtrip() {
422 let py_meta = PyMetadata::new(
423 Some("roundtrip".to_string()),
424 Some("UTF-8".to_string()),
425 Some(-8000),
426 Some(500),
427 Some(600),
428 Some("?".to_string()),
429 Some(true),
430 Some(false),
431 Some(true),
432 None,
433 None,
434 );
435 let meta: Metadata = py_meta.into();
436 let roundtripped: PyMetadata = meta.into();
437 assert_eq!(roundtripped.name, "roundtrip");
438 assert_eq!(roundtripped.encoding, "UTF-8");
439 assert_eq!(roundtripped.default_word_cost, -8000);
440 assert_eq!(roundtripped.default_left_context_id, 500);
441 assert_eq!(roundtripped.default_right_context_id, 600);
442 assert_eq!(roundtripped.default_field_value, "?");
443 assert!(roundtripped.flexible_csv);
444 assert!(!roundtripped.skip_invalid_cost_or_id);
445 assert!(roundtripped.normalize_details);
446 }
447}