lindera/
metadata.rs

1use std::collections::HashMap;
2
3use pyo3::prelude::*;
4
5use lindera::dictionary::{CompressionAlgorithm, Metadata};
6
7use crate::schema::PySchema;
8
9#[pyclass(name = "CompressionAlgorithm")]
10#[derive(Debug, Clone)]
11pub enum PyCompressionAlgorithm {
12    Deflate,
13    Zlib,
14    Gzip,
15    Raw,
16}
17
18#[pymethods]
19impl PyCompressionAlgorithm {
20    fn __str__(&self) -> &str {
21        match self {
22            PyCompressionAlgorithm::Deflate => "deflate",
23            PyCompressionAlgorithm::Zlib => "zlib",
24            PyCompressionAlgorithm::Gzip => "gzip",
25            PyCompressionAlgorithm::Raw => "raw",
26        }
27    }
28
29    fn __repr__(&self) -> String {
30        format!("CompressionAlgorithm.{self:?}")
31    }
32}
33
34impl From<PyCompressionAlgorithm> for CompressionAlgorithm {
35    fn from(alg: PyCompressionAlgorithm) -> Self {
36        match alg {
37            PyCompressionAlgorithm::Deflate => CompressionAlgorithm::Deflate,
38            PyCompressionAlgorithm::Zlib => CompressionAlgorithm::Zlib,
39            PyCompressionAlgorithm::Gzip => CompressionAlgorithm::Gzip,
40            PyCompressionAlgorithm::Raw => CompressionAlgorithm::Raw,
41        }
42    }
43}
44
45impl From<CompressionAlgorithm> for PyCompressionAlgorithm {
46    fn from(alg: CompressionAlgorithm) -> Self {
47        match alg {
48            CompressionAlgorithm::Deflate => PyCompressionAlgorithm::Deflate,
49            CompressionAlgorithm::Zlib => PyCompressionAlgorithm::Zlib,
50            CompressionAlgorithm::Gzip => PyCompressionAlgorithm::Gzip,
51            CompressionAlgorithm::Raw => PyCompressionAlgorithm::Raw,
52        }
53    }
54}
55
56#[pyclass(name = "Metadata")]
57#[derive(Debug, Clone)]
58pub struct PyMetadata {
59    name: String,
60    encoding: String,
61    compress_algorithm: PyCompressionAlgorithm,
62    default_word_cost: i16,
63    default_left_context_id: u16,
64    default_right_context_id: u16,
65    default_field_value: String,
66    flexible_csv: bool,
67    skip_invalid_cost_or_id: bool,
68    normalize_details: bool,
69    dictionary_schema: PySchema,
70    user_dictionary_schema: PySchema,
71}
72
73#[pymethods]
74impl PyMetadata {
75    #[new]
76    #[pyo3(signature = (name=None, encoding=None, compress_algorithm=None, default_word_cost=None, default_left_context_id=None, default_right_context_id=None, default_field_value=None, flexible_csv=None, skip_invalid_cost_or_id=None, normalize_details=None, dictionary_schema=None, user_dictionary_schema=None))]
77    #[allow(clippy::too_many_arguments)]
78    pub fn new(
79        name: Option<String>,
80        encoding: Option<String>,
81        compress_algorithm: Option<PyCompressionAlgorithm>,
82        default_word_cost: Option<i16>,
83        default_left_context_id: Option<u16>,
84        default_right_context_id: Option<u16>,
85        default_field_value: Option<String>,
86        flexible_csv: Option<bool>,
87        skip_invalid_cost_or_id: Option<bool>,
88        normalize_details: Option<bool>,
89        dictionary_schema: Option<PySchema>,
90        user_dictionary_schema: Option<PySchema>,
91    ) -> Self {
92        PyMetadata {
93            name: name.unwrap_or_else(|| "default".to_string()),
94            encoding: encoding.unwrap_or_else(|| "UTF-8".to_string()),
95            compress_algorithm: compress_algorithm.unwrap_or(PyCompressionAlgorithm::Deflate),
96            default_word_cost: default_word_cost.unwrap_or(-10000),
97            default_left_context_id: default_left_context_id.unwrap_or(1288),
98            default_right_context_id: default_right_context_id.unwrap_or(1288),
99            default_field_value: default_field_value.unwrap_or_else(|| "*".to_string()),
100            flexible_csv: flexible_csv.unwrap_or(false),
101            skip_invalid_cost_or_id: skip_invalid_cost_or_id.unwrap_or(false),
102            normalize_details: normalize_details.unwrap_or(false),
103            dictionary_schema: dictionary_schema.unwrap_or_else(PySchema::create_default),
104            user_dictionary_schema: user_dictionary_schema.unwrap_or_else(|| {
105                PySchema::new(vec![
106                    "surface".to_string(),
107                    "reading".to_string(),
108                    "pronunciation".to_string(),
109                ])
110            }),
111        }
112    }
113
114    #[staticmethod]
115    pub fn create_default() -> Self {
116        PyMetadata::new(
117            None, None, None, None, None, None, None, None, None, None, None, None,
118        )
119    }
120
121    #[staticmethod]
122    pub fn from_json_file(path: &str) -> PyResult<Self> {
123        use std::fs;
124
125        let json_str = fs::read_to_string(path).map_err(|e| {
126            pyo3::exceptions::PyIOError::new_err(format!("Failed to read file: {e}"))
127        })?;
128
129        let metadata: Metadata = serde_json::from_str(&json_str).map_err(|e| {
130            pyo3::exceptions::PyValueError::new_err(format!("Failed to parse JSON: {e}"))
131        })?;
132
133        Ok(metadata.into())
134    }
135
136    #[getter]
137    pub fn name(&self) -> &str {
138        &self.name
139    }
140
141    #[setter]
142    pub fn set_name(&mut self, name: String) {
143        self.name = name;
144    }
145
146    #[getter]
147    pub fn encoding(&self) -> &str {
148        &self.encoding
149    }
150
151    #[setter]
152    pub fn set_encoding(&mut self, encoding: String) {
153        self.encoding = encoding;
154    }
155
156    #[getter]
157    pub fn compress_algorithm(&self) -> PyCompressionAlgorithm {
158        self.compress_algorithm.clone()
159    }
160
161    #[setter]
162    pub fn set_compress_algorithm(&mut self, algorithm: PyCompressionAlgorithm) {
163        self.compress_algorithm = algorithm;
164    }
165
166    #[getter]
167    pub fn default_word_cost(&self) -> i16 {
168        self.default_word_cost
169    }
170
171    #[setter]
172    pub fn set_default_word_cost(&mut self, cost: i16) {
173        self.default_word_cost = cost;
174    }
175
176    #[getter]
177    pub fn default_left_context_id(&self) -> u16 {
178        self.default_left_context_id
179    }
180
181    #[setter]
182    pub fn set_default_left_context_id(&mut self, id: u16) {
183        self.default_left_context_id = id;
184    }
185
186    #[getter]
187    pub fn default_right_context_id(&self) -> u16 {
188        self.default_right_context_id
189    }
190
191    #[setter]
192    pub fn set_default_right_context_id(&mut self, id: u16) {
193        self.default_right_context_id = id;
194    }
195
196    #[getter]
197    pub fn default_field_value(&self) -> &str {
198        &self.default_field_value
199    }
200
201    #[setter]
202    pub fn set_default_field_value(&mut self, value: String) {
203        self.default_field_value = value;
204    }
205
206    #[getter]
207    pub fn flexible_csv(&self) -> bool {
208        self.flexible_csv
209    }
210
211    #[setter]
212    pub fn set_flexible_csv(&mut self, value: bool) {
213        self.flexible_csv = value;
214    }
215
216    #[getter]
217    pub fn skip_invalid_cost_or_id(&self) -> bool {
218        self.skip_invalid_cost_or_id
219    }
220
221    #[setter]
222    pub fn set_skip_invalid_cost_or_id(&mut self, value: bool) {
223        self.skip_invalid_cost_or_id = value;
224    }
225
226    #[getter]
227    pub fn normalize_details(&self) -> bool {
228        self.normalize_details
229    }
230
231    #[setter]
232    pub fn set_normalize_details(&mut self, value: bool) {
233        self.normalize_details = value;
234    }
235
236    #[getter]
237    pub fn dictionary_schema(&self) -> PySchema {
238        self.dictionary_schema.clone()
239    }
240
241    #[setter]
242    pub fn set_dictionary_schema(&mut self, schema: PySchema) {
243        self.dictionary_schema = schema;
244    }
245
246    #[getter]
247    pub fn user_dictionary_schema(&self) -> PySchema {
248        self.user_dictionary_schema.clone()
249    }
250
251    #[setter]
252    pub fn set_user_dictionary_schema(&mut self, schema: PySchema) {
253        self.user_dictionary_schema = schema;
254    }
255
256    pub fn to_dict(&self) -> HashMap<String, String> {
257        let mut dict = HashMap::new();
258        dict.insert("name".to_string(), self.name.clone());
259        dict.insert("encoding".to_string(), self.encoding.clone());
260        dict.insert(
261            "compress_algorithm".to_string(),
262            self.compress_algorithm.__str__().to_string(),
263        );
264        dict.insert(
265            "default_word_cost".to_string(),
266            self.default_word_cost.to_string(),
267        );
268        dict.insert(
269            "default_left_context_id".to_string(),
270            self.default_left_context_id.to_string(),
271        );
272        dict.insert(
273            "default_right_context_id".to_string(),
274            self.default_right_context_id.to_string(),
275        );
276        dict.insert(
277            "default_field_value".to_string(),
278            self.default_field_value.clone(),
279        );
280        dict.insert("flexible_csv".to_string(), self.flexible_csv.to_string());
281        dict.insert(
282            "skip_invalid_cost_or_id".to_string(),
283            self.skip_invalid_cost_or_id.to_string(),
284        );
285        dict.insert(
286            "normalize_details".to_string(),
287            self.normalize_details.to_string(),
288        );
289        dict.insert(
290            "dictionary_schema_fields".to_string(),
291            self.dictionary_schema.fields.join(","),
292        );
293        dict.insert(
294            "user_dictionary_schema_fields".to_string(),
295            self.user_dictionary_schema.fields.join(","),
296        );
297        dict
298    }
299
300    fn __str__(&self) -> String {
301        format!(
302            "Metadata(name='{}', encoding='{}', compress_algorithm='{}')",
303            self.name,
304            self.encoding,
305            self.compress_algorithm.__str__()
306        )
307    }
308
309    fn __repr__(&self) -> String {
310        format!(
311            "Metadata(name='{}', encoding='{}', compress_algorithm={:?}, schema_fields={})",
312            self.name,
313            self.encoding,
314            self.compress_algorithm,
315            self.dictionary_schema.field_count()
316        )
317    }
318}
319
320impl From<PyMetadata> for Metadata {
321    fn from(metadata: PyMetadata) -> Self {
322        Metadata::new(
323            metadata.name,
324            metadata.encoding,
325            metadata.compress_algorithm.into(),
326            metadata.default_word_cost,
327            metadata.default_left_context_id,
328            metadata.default_right_context_id,
329            metadata.default_field_value,
330            metadata.flexible_csv,
331            metadata.skip_invalid_cost_or_id,
332            metadata.normalize_details,
333            metadata.dictionary_schema.into(),
334            metadata.user_dictionary_schema.into(),
335        )
336    }
337}
338
339impl From<Metadata> for PyMetadata {
340    fn from(metadata: Metadata) -> Self {
341        PyMetadata {
342            name: metadata.name,
343            encoding: metadata.encoding,
344            compress_algorithm: metadata.compress_algorithm.into(),
345            default_word_cost: metadata.default_word_cost,
346            default_left_context_id: metadata.default_left_context_id,
347            default_right_context_id: metadata.default_right_context_id,
348            default_field_value: metadata.default_field_value,
349            flexible_csv: metadata.flexible_csv,
350            skip_invalid_cost_or_id: metadata.skip_invalid_cost_or_id,
351            normalize_details: metadata.normalize_details,
352            dictionary_schema: metadata.dictionary_schema.into(),
353            user_dictionary_schema: metadata.user_dictionary_schema.into(),
354        }
355    }
356}