1use std::collections::HashMap;
24
25use pyo3::prelude::*;
26
27use lindera::dictionary::{CompressionAlgorithm, Metadata};
28
29use crate::schema::PySchema;
30
31#[pyclass(name = "CompressionAlgorithm")]
35#[derive(Debug, Clone)]
36pub enum PyCompressionAlgorithm {
37 Deflate,
39 Zlib,
41 Gzip,
43 Raw,
45}
46
47#[pymethods]
48impl PyCompressionAlgorithm {
49 fn __str__(&self) -> &str {
50 match self {
51 PyCompressionAlgorithm::Deflate => "deflate",
52 PyCompressionAlgorithm::Zlib => "zlib",
53 PyCompressionAlgorithm::Gzip => "gzip",
54 PyCompressionAlgorithm::Raw => "raw",
55 }
56 }
57
58 fn __repr__(&self) -> String {
59 format!("CompressionAlgorithm.{self:?}")
60 }
61}
62
63impl From<PyCompressionAlgorithm> for CompressionAlgorithm {
64 fn from(alg: PyCompressionAlgorithm) -> Self {
65 match alg {
66 PyCompressionAlgorithm::Deflate => CompressionAlgorithm::Deflate,
67 PyCompressionAlgorithm::Zlib => CompressionAlgorithm::Zlib,
68 PyCompressionAlgorithm::Gzip => CompressionAlgorithm::Gzip,
69 PyCompressionAlgorithm::Raw => CompressionAlgorithm::Raw,
70 }
71 }
72}
73
74impl From<CompressionAlgorithm> for PyCompressionAlgorithm {
75 fn from(alg: CompressionAlgorithm) -> Self {
76 match alg {
77 CompressionAlgorithm::Deflate => PyCompressionAlgorithm::Deflate,
78 CompressionAlgorithm::Zlib => PyCompressionAlgorithm::Zlib,
79 CompressionAlgorithm::Gzip => PyCompressionAlgorithm::Gzip,
80 CompressionAlgorithm::Raw => PyCompressionAlgorithm::Raw,
81 }
82 }
83}
84
85#[pyclass(name = "Metadata")]
104#[derive(Debug, Clone)]
105pub struct PyMetadata {
106 name: String,
107 encoding: String,
108 compress_algorithm: PyCompressionAlgorithm,
109 default_word_cost: i16,
110 default_left_context_id: u16,
111 default_right_context_id: u16,
112 default_field_value: String,
113 flexible_csv: bool,
114 skip_invalid_cost_or_id: bool,
115 normalize_details: bool,
116 dictionary_schema: PySchema,
117 user_dictionary_schema: PySchema,
118}
119
120#[pymethods]
121impl PyMetadata {
122 #[new]
123 #[pyo3(signature = (name=None, encoding=None, compress_algorithm=None, default_word_cost=None, default_left_context_id=None, default_right_context_id=None, default_field_value=None, flexible_csv=None, skip_invalid_cost_or_id=None, normalize_details=None, dictionary_schema=None, user_dictionary_schema=None))]
124 #[allow(clippy::too_many_arguments)]
125 pub fn new(
126 name: Option<String>,
127 encoding: Option<String>,
128 compress_algorithm: Option<PyCompressionAlgorithm>,
129 default_word_cost: Option<i16>,
130 default_left_context_id: Option<u16>,
131 default_right_context_id: Option<u16>,
132 default_field_value: Option<String>,
133 flexible_csv: Option<bool>,
134 skip_invalid_cost_or_id: Option<bool>,
135 normalize_details: Option<bool>,
136 dictionary_schema: Option<PySchema>,
137 user_dictionary_schema: Option<PySchema>,
138 ) -> Self {
139 PyMetadata {
140 name: name.unwrap_or_else(|| "default".to_string()),
141 encoding: encoding.unwrap_or_else(|| "UTF-8".to_string()),
142 compress_algorithm: compress_algorithm.unwrap_or(PyCompressionAlgorithm::Deflate),
143 default_word_cost: default_word_cost.unwrap_or(-10000),
144 default_left_context_id: default_left_context_id.unwrap_or(1288),
145 default_right_context_id: default_right_context_id.unwrap_or(1288),
146 default_field_value: default_field_value.unwrap_or_else(|| "*".to_string()),
147 flexible_csv: flexible_csv.unwrap_or(false),
148 skip_invalid_cost_or_id: skip_invalid_cost_or_id.unwrap_or(false),
149 normalize_details: normalize_details.unwrap_or(false),
150 dictionary_schema: dictionary_schema.unwrap_or_else(PySchema::create_default),
151 user_dictionary_schema: user_dictionary_schema.unwrap_or_else(|| {
152 PySchema::new(vec![
153 "surface".to_string(),
154 "reading".to_string(),
155 "pronunciation".to_string(),
156 ])
157 }),
158 }
159 }
160
161 #[staticmethod]
162 pub fn create_default() -> Self {
163 PyMetadata::new(
164 None, None, None, None, None, None, None, None, None, None, None, None,
165 )
166 }
167
168 #[staticmethod]
169 pub fn from_json_file(path: &str) -> PyResult<Self> {
170 use std::fs;
171
172 let json_str = fs::read_to_string(path).map_err(|e| {
173 pyo3::exceptions::PyIOError::new_err(format!("Failed to read file: {e}"))
174 })?;
175
176 let metadata: Metadata = serde_json::from_str(&json_str).map_err(|e| {
177 pyo3::exceptions::PyValueError::new_err(format!("Failed to parse JSON: {e}"))
178 })?;
179
180 Ok(metadata.into())
181 }
182
183 #[getter]
184 pub fn name(&self) -> &str {
185 &self.name
186 }
187
188 #[setter]
189 pub fn set_name(&mut self, name: String) {
190 self.name = name;
191 }
192
193 #[getter]
194 pub fn encoding(&self) -> &str {
195 &self.encoding
196 }
197
198 #[setter]
199 pub fn set_encoding(&mut self, encoding: String) {
200 self.encoding = encoding;
201 }
202
203 #[getter]
204 pub fn compress_algorithm(&self) -> PyCompressionAlgorithm {
205 self.compress_algorithm.clone()
206 }
207
208 #[setter]
209 pub fn set_compress_algorithm(&mut self, algorithm: PyCompressionAlgorithm) {
210 self.compress_algorithm = algorithm;
211 }
212
213 #[getter]
214 pub fn default_word_cost(&self) -> i16 {
215 self.default_word_cost
216 }
217
218 #[setter]
219 pub fn set_default_word_cost(&mut self, cost: i16) {
220 self.default_word_cost = cost;
221 }
222
223 #[getter]
224 pub fn default_left_context_id(&self) -> u16 {
225 self.default_left_context_id
226 }
227
228 #[setter]
229 pub fn set_default_left_context_id(&mut self, id: u16) {
230 self.default_left_context_id = id;
231 }
232
233 #[getter]
234 pub fn default_right_context_id(&self) -> u16 {
235 self.default_right_context_id
236 }
237
238 #[setter]
239 pub fn set_default_right_context_id(&mut self, id: u16) {
240 self.default_right_context_id = id;
241 }
242
243 #[getter]
244 pub fn default_field_value(&self) -> &str {
245 &self.default_field_value
246 }
247
248 #[setter]
249 pub fn set_default_field_value(&mut self, value: String) {
250 self.default_field_value = value;
251 }
252
253 #[getter]
254 pub fn flexible_csv(&self) -> bool {
255 self.flexible_csv
256 }
257
258 #[setter]
259 pub fn set_flexible_csv(&mut self, value: bool) {
260 self.flexible_csv = value;
261 }
262
263 #[getter]
264 pub fn skip_invalid_cost_or_id(&self) -> bool {
265 self.skip_invalid_cost_or_id
266 }
267
268 #[setter]
269 pub fn set_skip_invalid_cost_or_id(&mut self, value: bool) {
270 self.skip_invalid_cost_or_id = value;
271 }
272
273 #[getter]
274 pub fn normalize_details(&self) -> bool {
275 self.normalize_details
276 }
277
278 #[setter]
279 pub fn set_normalize_details(&mut self, value: bool) {
280 self.normalize_details = value;
281 }
282
283 #[getter]
284 pub fn dictionary_schema(&self) -> PySchema {
285 self.dictionary_schema.clone()
286 }
287
288 #[setter]
289 pub fn set_dictionary_schema(&mut self, schema: PySchema) {
290 self.dictionary_schema = schema;
291 }
292
293 #[getter]
294 pub fn user_dictionary_schema(&self) -> PySchema {
295 self.user_dictionary_schema.clone()
296 }
297
298 #[setter]
299 pub fn set_user_dictionary_schema(&mut self, schema: PySchema) {
300 self.user_dictionary_schema = schema;
301 }
302
303 pub fn to_dict(&self) -> HashMap<String, String> {
304 let mut dict = HashMap::new();
305 dict.insert("name".to_string(), self.name.clone());
306 dict.insert("encoding".to_string(), self.encoding.clone());
307 dict.insert(
308 "compress_algorithm".to_string(),
309 self.compress_algorithm.__str__().to_string(),
310 );
311 dict.insert(
312 "default_word_cost".to_string(),
313 self.default_word_cost.to_string(),
314 );
315 dict.insert(
316 "default_left_context_id".to_string(),
317 self.default_left_context_id.to_string(),
318 );
319 dict.insert(
320 "default_right_context_id".to_string(),
321 self.default_right_context_id.to_string(),
322 );
323 dict.insert(
324 "default_field_value".to_string(),
325 self.default_field_value.clone(),
326 );
327 dict.insert("flexible_csv".to_string(), self.flexible_csv.to_string());
328 dict.insert(
329 "skip_invalid_cost_or_id".to_string(),
330 self.skip_invalid_cost_or_id.to_string(),
331 );
332 dict.insert(
333 "normalize_details".to_string(),
334 self.normalize_details.to_string(),
335 );
336 dict.insert(
337 "dictionary_schema_fields".to_string(),
338 self.dictionary_schema.fields.join(","),
339 );
340 dict.insert(
341 "user_dictionary_schema_fields".to_string(),
342 self.user_dictionary_schema.fields.join(","),
343 );
344 dict
345 }
346
347 fn __str__(&self) -> String {
348 format!(
349 "Metadata(name='{}', encoding='{}', compress_algorithm='{}')",
350 self.name,
351 self.encoding,
352 self.compress_algorithm.__str__()
353 )
354 }
355
356 fn __repr__(&self) -> String {
357 format!(
358 "Metadata(name='{}', encoding='{}', compress_algorithm={:?}, schema_fields={})",
359 self.name,
360 self.encoding,
361 self.compress_algorithm,
362 self.dictionary_schema.field_count()
363 )
364 }
365}
366
367impl From<PyMetadata> for Metadata {
368 fn from(metadata: PyMetadata) -> Self {
369 Metadata::new(
370 metadata.name,
371 metadata.encoding,
372 metadata.compress_algorithm.into(),
373 metadata.default_word_cost,
374 metadata.default_left_context_id,
375 metadata.default_right_context_id,
376 metadata.default_field_value,
377 metadata.flexible_csv,
378 metadata.skip_invalid_cost_or_id,
379 metadata.normalize_details,
380 metadata.dictionary_schema.into(),
381 metadata.user_dictionary_schema.into(),
382 )
383 }
384}
385
386impl From<Metadata> for PyMetadata {
387 fn from(metadata: Metadata) -> Self {
388 PyMetadata {
389 name: metadata.name,
390 encoding: metadata.encoding,
391 compress_algorithm: metadata.compress_algorithm.into(),
392 default_word_cost: metadata.default_word_cost,
393 default_left_context_id: metadata.default_left_context_id,
394 default_right_context_id: metadata.default_right_context_id,
395 default_field_value: metadata.default_field_value,
396 flexible_csv: metadata.flexible_csv,
397 skip_invalid_cost_or_id: metadata.skip_invalid_cost_or_id,
398 normalize_details: metadata.normalize_details,
399 dictionary_schema: metadata.dictionary_schema.into(),
400 user_dictionary_schema: metadata.user_dictionary_schema.into(),
401 }
402 }
403}