1use std::collections::HashMap;
2
3use pyo3::prelude::*;
4
5use lindera::dictionary::{CompressionAlgorithm, Metadata};
6
7use crate::schema::PySchema;
8
9#[pyclass(name = "CompressionAlgorithm")]
10#[derive(Debug, Clone)]
11pub enum PyCompressionAlgorithm {
12 Deflate,
13 Zlib,
14 Gzip,
15 Raw,
16}
17
18#[pymethods]
19impl PyCompressionAlgorithm {
20 fn __str__(&self) -> &str {
21 match self {
22 PyCompressionAlgorithm::Deflate => "deflate",
23 PyCompressionAlgorithm::Zlib => "zlib",
24 PyCompressionAlgorithm::Gzip => "gzip",
25 PyCompressionAlgorithm::Raw => "raw",
26 }
27 }
28
29 fn __repr__(&self) -> String {
30 format!("CompressionAlgorithm.{self:?}")
31 }
32}
33
34impl From<PyCompressionAlgorithm> for CompressionAlgorithm {
35 fn from(alg: PyCompressionAlgorithm) -> Self {
36 match alg {
37 PyCompressionAlgorithm::Deflate => CompressionAlgorithm::Deflate,
38 PyCompressionAlgorithm::Zlib => CompressionAlgorithm::Zlib,
39 PyCompressionAlgorithm::Gzip => CompressionAlgorithm::Gzip,
40 PyCompressionAlgorithm::Raw => CompressionAlgorithm::Raw,
41 }
42 }
43}
44
45impl From<CompressionAlgorithm> for PyCompressionAlgorithm {
46 fn from(alg: CompressionAlgorithm) -> Self {
47 match alg {
48 CompressionAlgorithm::Deflate => PyCompressionAlgorithm::Deflate,
49 CompressionAlgorithm::Zlib => PyCompressionAlgorithm::Zlib,
50 CompressionAlgorithm::Gzip => PyCompressionAlgorithm::Gzip,
51 CompressionAlgorithm::Raw => PyCompressionAlgorithm::Raw,
52 }
53 }
54}
55
56#[pyclass(name = "Metadata")]
57#[derive(Debug, Clone)]
58pub struct PyMetadata {
59 name: String,
60 encoding: String,
61 compress_algorithm: PyCompressionAlgorithm,
62 default_word_cost: i16,
63 default_left_context_id: u16,
64 default_right_context_id: u16,
65 default_field_value: String,
66 flexible_csv: bool,
67 skip_invalid_cost_or_id: bool,
68 normalize_details: bool,
69 dictionary_schema: PySchema,
70 user_dictionary_schema: PySchema,
71}
72
73#[pymethods]
74impl PyMetadata {
75 #[new]
76 #[pyo3(signature = (name=None, encoding=None, compress_algorithm=None, default_word_cost=None, default_left_context_id=None, default_right_context_id=None, default_field_value=None, flexible_csv=None, skip_invalid_cost_or_id=None, normalize_details=None, dictionary_schema=None, user_dictionary_schema=None))]
77 #[allow(clippy::too_many_arguments)]
78 pub fn new(
79 name: Option<String>,
80 encoding: Option<String>,
81 compress_algorithm: Option<PyCompressionAlgorithm>,
82 default_word_cost: Option<i16>,
83 default_left_context_id: Option<u16>,
84 default_right_context_id: Option<u16>,
85 default_field_value: Option<String>,
86 flexible_csv: Option<bool>,
87 skip_invalid_cost_or_id: Option<bool>,
88 normalize_details: Option<bool>,
89 dictionary_schema: Option<PySchema>,
90 user_dictionary_schema: Option<PySchema>,
91 ) -> Self {
92 PyMetadata {
93 name: name.unwrap_or_else(|| "default".to_string()),
94 encoding: encoding.unwrap_or_else(|| "UTF-8".to_string()),
95 compress_algorithm: compress_algorithm.unwrap_or(PyCompressionAlgorithm::Deflate),
96 default_word_cost: default_word_cost.unwrap_or(-10000),
97 default_left_context_id: default_left_context_id.unwrap_or(1288),
98 default_right_context_id: default_right_context_id.unwrap_or(1288),
99 default_field_value: default_field_value.unwrap_or_else(|| "*".to_string()),
100 flexible_csv: flexible_csv.unwrap_or(false),
101 skip_invalid_cost_or_id: skip_invalid_cost_or_id.unwrap_or(false),
102 normalize_details: normalize_details.unwrap_or(false),
103 dictionary_schema: dictionary_schema.unwrap_or_else(PySchema::create_default),
104 user_dictionary_schema: user_dictionary_schema.unwrap_or_else(|| {
105 PySchema::new(vec![
106 "surface".to_string(),
107 "reading".to_string(),
108 "pronunciation".to_string(),
109 ])
110 }),
111 }
112 }
113
114 #[staticmethod]
115 pub fn create_default() -> Self {
116 PyMetadata::new(
117 None, None, None, None, None, None, None, None, None, None, None, None,
118 )
119 }
120
121 #[staticmethod]
122 pub fn from_json_file(path: &str) -> PyResult<Self> {
123 use std::fs;
124
125 let json_str = fs::read_to_string(path).map_err(|e| {
126 pyo3::exceptions::PyIOError::new_err(format!("Failed to read file: {e}"))
127 })?;
128
129 let metadata: Metadata = serde_json::from_str(&json_str).map_err(|e| {
130 pyo3::exceptions::PyValueError::new_err(format!("Failed to parse JSON: {e}"))
131 })?;
132
133 Ok(metadata.into())
134 }
135
136 #[getter]
137 pub fn name(&self) -> &str {
138 &self.name
139 }
140
141 #[setter]
142 pub fn set_name(&mut self, name: String) {
143 self.name = name;
144 }
145
146 #[getter]
147 pub fn encoding(&self) -> &str {
148 &self.encoding
149 }
150
151 #[setter]
152 pub fn set_encoding(&mut self, encoding: String) {
153 self.encoding = encoding;
154 }
155
156 #[getter]
157 pub fn compress_algorithm(&self) -> PyCompressionAlgorithm {
158 self.compress_algorithm.clone()
159 }
160
161 #[setter]
162 pub fn set_compress_algorithm(&mut self, algorithm: PyCompressionAlgorithm) {
163 self.compress_algorithm = algorithm;
164 }
165
166 #[getter]
167 pub fn default_word_cost(&self) -> i16 {
168 self.default_word_cost
169 }
170
171 #[setter]
172 pub fn set_default_word_cost(&mut self, cost: i16) {
173 self.default_word_cost = cost;
174 }
175
176 #[getter]
177 pub fn default_left_context_id(&self) -> u16 {
178 self.default_left_context_id
179 }
180
181 #[setter]
182 pub fn set_default_left_context_id(&mut self, id: u16) {
183 self.default_left_context_id = id;
184 }
185
186 #[getter]
187 pub fn default_right_context_id(&self) -> u16 {
188 self.default_right_context_id
189 }
190
191 #[setter]
192 pub fn set_default_right_context_id(&mut self, id: u16) {
193 self.default_right_context_id = id;
194 }
195
196 #[getter]
197 pub fn default_field_value(&self) -> &str {
198 &self.default_field_value
199 }
200
201 #[setter]
202 pub fn set_default_field_value(&mut self, value: String) {
203 self.default_field_value = value;
204 }
205
206 #[getter]
207 pub fn flexible_csv(&self) -> bool {
208 self.flexible_csv
209 }
210
211 #[setter]
212 pub fn set_flexible_csv(&mut self, value: bool) {
213 self.flexible_csv = value;
214 }
215
216 #[getter]
217 pub fn skip_invalid_cost_or_id(&self) -> bool {
218 self.skip_invalid_cost_or_id
219 }
220
221 #[setter]
222 pub fn set_skip_invalid_cost_or_id(&mut self, value: bool) {
223 self.skip_invalid_cost_or_id = value;
224 }
225
226 #[getter]
227 pub fn normalize_details(&self) -> bool {
228 self.normalize_details
229 }
230
231 #[setter]
232 pub fn set_normalize_details(&mut self, value: bool) {
233 self.normalize_details = value;
234 }
235
236 #[getter]
237 pub fn dictionary_schema(&self) -> PySchema {
238 self.dictionary_schema.clone()
239 }
240
241 #[setter]
242 pub fn set_dictionary_schema(&mut self, schema: PySchema) {
243 self.dictionary_schema = schema;
244 }
245
246 #[getter]
247 pub fn user_dictionary_schema(&self) -> PySchema {
248 self.user_dictionary_schema.clone()
249 }
250
251 #[setter]
252 pub fn set_user_dictionary_schema(&mut self, schema: PySchema) {
253 self.user_dictionary_schema = schema;
254 }
255
256 pub fn to_dict(&self) -> HashMap<String, String> {
257 let mut dict = HashMap::new();
258 dict.insert("name".to_string(), self.name.clone());
259 dict.insert("encoding".to_string(), self.encoding.clone());
260 dict.insert(
261 "compress_algorithm".to_string(),
262 self.compress_algorithm.__str__().to_string(),
263 );
264 dict.insert(
265 "default_word_cost".to_string(),
266 self.default_word_cost.to_string(),
267 );
268 dict.insert(
269 "default_left_context_id".to_string(),
270 self.default_left_context_id.to_string(),
271 );
272 dict.insert(
273 "default_right_context_id".to_string(),
274 self.default_right_context_id.to_string(),
275 );
276 dict.insert(
277 "default_field_value".to_string(),
278 self.default_field_value.clone(),
279 );
280 dict.insert("flexible_csv".to_string(), self.flexible_csv.to_string());
281 dict.insert(
282 "skip_invalid_cost_or_id".to_string(),
283 self.skip_invalid_cost_or_id.to_string(),
284 );
285 dict.insert(
286 "normalize_details".to_string(),
287 self.normalize_details.to_string(),
288 );
289 dict.insert(
290 "dictionary_schema_fields".to_string(),
291 self.dictionary_schema.fields.join(","),
292 );
293 dict.insert(
294 "user_dictionary_schema_fields".to_string(),
295 self.user_dictionary_schema.fields.join(","),
296 );
297 dict
298 }
299
300 fn __str__(&self) -> String {
301 format!(
302 "Metadata(name='{}', encoding='{}', compress_algorithm='{}')",
303 self.name,
304 self.encoding,
305 self.compress_algorithm.__str__()
306 )
307 }
308
309 fn __repr__(&self) -> String {
310 format!(
311 "Metadata(name='{}', encoding='{}', compress_algorithm={:?}, schema_fields={})",
312 self.name,
313 self.encoding,
314 self.compress_algorithm,
315 self.dictionary_schema.field_count()
316 )
317 }
318}
319
320impl From<PyMetadata> for Metadata {
321 fn from(metadata: PyMetadata) -> Self {
322 Metadata::new(
323 metadata.name,
324 metadata.encoding,
325 metadata.compress_algorithm.into(),
326 metadata.default_word_cost,
327 metadata.default_left_context_id,
328 metadata.default_right_context_id,
329 metadata.default_field_value,
330 metadata.flexible_csv,
331 metadata.skip_invalid_cost_or_id,
332 metadata.normalize_details,
333 metadata.dictionary_schema.into(),
334 metadata.user_dictionary_schema.into(),
335 )
336 }
337}
338
339impl From<Metadata> for PyMetadata {
340 fn from(metadata: Metadata) -> Self {
341 PyMetadata {
342 name: metadata.name,
343 encoding: metadata.encoding,
344 compress_algorithm: metadata.compress_algorithm.into(),
345 default_word_cost: metadata.default_word_cost,
346 default_left_context_id: metadata.default_left_context_id,
347 default_right_context_id: metadata.default_right_context_id,
348 default_field_value: metadata.default_field_value,
349 flexible_csv: metadata.flexible_csv,
350 skip_invalid_cost_or_id: metadata.skip_invalid_cost_or_id,
351 normalize_details: metadata.normalize_details,
352 dictionary_schema: metadata.dictionary_schema.into(),
353 user_dictionary_schema: metadata.user_dictionary_schema.into(),
354 }
355 }
356}