1use pyo3::prelude::*;
28
29use lindera::mode::{Mode as LinderaMode, Penalty as LinderaPenalty};
30
31#[pyclass(name = "Mode", from_py_object)]
35#[derive(Debug, Clone, Copy)]
36pub enum PyMode {
37 Normal,
39 Decompose,
41}
42
43#[pymethods]
44impl PyMode {
45 #[new]
46 #[pyo3(signature = (mode_str=None))]
47 pub fn new(mode_str: Option<&str>) -> PyResult<Self> {
48 match mode_str {
49 Some("decompose") | Some("Decompose") => Ok(PyMode::Decompose),
50 Some("normal") | Some("Normal") | None => Ok(PyMode::Normal),
51 Some(s) => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
52 "Invalid mode: {s}. Must be 'normal' or 'decompose'"
53 ))),
54 }
55 }
56
57 fn __str__(&self) -> &str {
58 match self {
59 PyMode::Normal => "normal",
60 PyMode::Decompose => "decompose",
61 }
62 }
63
64 fn __repr__(&self) -> String {
65 format!("Mode.{self:?}")
66 }
67
68 #[getter]
69 pub fn name(&self) -> &str {
70 self.__str__()
71 }
72
73 pub fn is_normal(&self) -> bool {
74 matches!(self, PyMode::Normal)
75 }
76
77 pub fn is_decompose(&self) -> bool {
78 matches!(self, PyMode::Decompose)
79 }
80}
81
82impl From<PyMode> for LinderaMode {
83 fn from(mode: PyMode) -> Self {
84 match mode {
85 PyMode::Normal => LinderaMode::Normal,
86 PyMode::Decompose => LinderaMode::Decompose(LinderaPenalty::default()),
87 }
88 }
89}
90
91impl From<LinderaMode> for PyMode {
92 fn from(mode: LinderaMode) -> Self {
93 match mode {
94 LinderaMode::Normal => PyMode::Normal,
95 LinderaMode::Decompose(_) => PyMode::Decompose,
96 }
97 }
98}
99
100#[pyclass(name = "Penalty", from_py_object)]
116#[derive(Debug, Clone, Copy)]
117pub struct PyPenalty {
118 kanji_penalty_length_threshold: usize,
119 kanji_penalty_length_penalty: i32,
120 other_penalty_length_threshold: usize,
121 other_penalty_length_penalty: i32,
122}
123
124#[pymethods]
125impl PyPenalty {
126 #[new]
127 #[pyo3(signature = (kanji_penalty_length_threshold=None, kanji_penalty_length_penalty=None, other_penalty_length_threshold=None, other_penalty_length_penalty=None))]
128 pub fn new(
129 kanji_penalty_length_threshold: Option<usize>,
130 kanji_penalty_length_penalty: Option<i32>,
131 other_penalty_length_threshold: Option<usize>,
132 other_penalty_length_penalty: Option<i32>,
133 ) -> Self {
134 PyPenalty {
135 kanji_penalty_length_threshold: kanji_penalty_length_threshold.unwrap_or(2),
136 kanji_penalty_length_penalty: kanji_penalty_length_penalty.unwrap_or(3000),
137 other_penalty_length_threshold: other_penalty_length_threshold.unwrap_or(7),
138 other_penalty_length_penalty: other_penalty_length_penalty.unwrap_or(1700),
139 }
140 }
141
142 #[getter]
143 pub fn get_kanji_penalty_length_threshold(&self) -> usize {
144 self.kanji_penalty_length_threshold
145 }
146
147 #[setter]
148 pub fn set_kanji_penalty_length_threshold(&mut self, value: usize) {
149 self.kanji_penalty_length_threshold = value;
150 }
151
152 #[getter]
153 pub fn get_kanji_penalty_length_penalty(&self) -> i32 {
154 self.kanji_penalty_length_penalty
155 }
156
157 #[setter]
158 pub fn set_kanji_penalty_length_penalty(&mut self, value: i32) {
159 self.kanji_penalty_length_penalty = value;
160 }
161
162 #[getter]
163 pub fn get_other_penalty_length_threshold(&self) -> usize {
164 self.other_penalty_length_threshold
165 }
166
167 #[setter]
168 pub fn set_other_penalty_length_threshold(&mut self, value: usize) {
169 self.other_penalty_length_threshold = value;
170 }
171
172 #[getter]
173 pub fn get_other_penalty_length_penalty(&self) -> i32 {
174 self.other_penalty_length_penalty
175 }
176
177 #[setter]
178 pub fn set_other_penalty_length_penalty(&mut self, value: i32) {
179 self.other_penalty_length_penalty = value;
180 }
181
182 fn __str__(&self) -> String {
183 format!(
184 "Penalty(kanji_threshold={}, kanji_penalty={}, other_threshold={}, other_penalty={})",
185 self.kanji_penalty_length_threshold,
186 self.kanji_penalty_length_penalty,
187 self.other_penalty_length_threshold,
188 self.other_penalty_length_penalty
189 )
190 }
191
192 fn __repr__(&self) -> String {
193 self.__str__()
194 }
195}
196
197impl From<PyPenalty> for LinderaPenalty {
198 fn from(penalty: PyPenalty) -> Self {
199 LinderaPenalty {
200 kanji_penalty_length_threshold: penalty.kanji_penalty_length_threshold,
201 kanji_penalty_length_penalty: penalty.kanji_penalty_length_penalty,
202 other_penalty_length_threshold: penalty.other_penalty_length_threshold,
203 other_penalty_length_penalty: penalty.other_penalty_length_penalty,
204 }
205 }
206}
207
208impl From<LinderaPenalty> for PyPenalty {
209 fn from(penalty: LinderaPenalty) -> Self {
210 PyPenalty {
211 kanji_penalty_length_threshold: penalty.kanji_penalty_length_threshold,
212 kanji_penalty_length_penalty: penalty.kanji_penalty_length_penalty,
213 other_penalty_length_threshold: penalty.other_penalty_length_threshold,
214 other_penalty_length_penalty: penalty.other_penalty_length_penalty,
215 }
216 }
217}
218
219pub fn register(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
220 let py = parent_module.py();
221 let m = PyModule::new(py, "mode")?;
222 m.add_class::<PyMode>()?;
223 m.add_class::<PyPenalty>()?;
224 parent_module.add_submodule(&m)?;
225 Ok(())
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use lindera::mode::{Mode as LinderaMode, Penalty as LinderaPenalty};
232
233 #[test]
234 fn test_pymode_normal_to_lindera_mode() {
235 let py_mode = PyMode::Normal;
236 let lindera_mode: LinderaMode = py_mode.into();
237 assert!(matches!(lindera_mode, LinderaMode::Normal));
238 }
239
240 #[test]
241 fn test_pymode_decompose_to_lindera_mode() {
242 let py_mode = PyMode::Decompose;
243 let lindera_mode: LinderaMode = py_mode.into();
244 assert!(matches!(lindera_mode, LinderaMode::Decompose(_)));
245 }
246
247 #[test]
248 fn test_lindera_mode_normal_to_pymode() {
249 let lindera_mode = LinderaMode::Normal;
250 let py_mode: PyMode = lindera_mode.into();
251 assert!(matches!(py_mode, PyMode::Normal));
252 }
253
254 #[test]
255 fn test_lindera_mode_decompose_to_pymode() {
256 let lindera_mode = LinderaMode::Decompose(LinderaPenalty::default());
257 let py_mode: PyMode = lindera_mode.into();
258 assert!(matches!(py_mode, PyMode::Decompose));
259 }
260
261 #[test]
262 fn test_pypenalty_to_lindera_penalty() {
263 let py_penalty = PyPenalty {
264 kanji_penalty_length_threshold: 5,
265 kanji_penalty_length_penalty: 4000,
266 other_penalty_length_threshold: 10,
267 other_penalty_length_penalty: 2000,
268 };
269 let lindera_penalty: LinderaPenalty = py_penalty.into();
270 assert_eq!(lindera_penalty.kanji_penalty_length_threshold, 5);
271 assert_eq!(lindera_penalty.kanji_penalty_length_penalty, 4000);
272 assert_eq!(lindera_penalty.other_penalty_length_threshold, 10);
273 assert_eq!(lindera_penalty.other_penalty_length_penalty, 2000);
274 }
275
276 #[test]
277 fn test_lindera_penalty_to_pypenalty() {
278 let lindera_penalty = LinderaPenalty {
279 kanji_penalty_length_threshold: 3,
280 kanji_penalty_length_penalty: 5000,
281 other_penalty_length_threshold: 8,
282 other_penalty_length_penalty: 1500,
283 };
284 let py_penalty: PyPenalty = lindera_penalty.into();
285 assert_eq!(py_penalty.kanji_penalty_length_threshold, 3);
286 assert_eq!(py_penalty.kanji_penalty_length_penalty, 5000);
287 assert_eq!(py_penalty.other_penalty_length_threshold, 8);
288 assert_eq!(py_penalty.other_penalty_length_penalty, 1500);
289 }
290
291 #[test]
292 fn test_pypenalty_to_lindera_penalty_default_values() {
293 let py_penalty = PyPenalty {
294 kanji_penalty_length_threshold: 2,
295 kanji_penalty_length_penalty: 3000,
296 other_penalty_length_threshold: 7,
297 other_penalty_length_penalty: 1700,
298 };
299 let lindera_penalty: LinderaPenalty = py_penalty.into();
300 let default_penalty = LinderaPenalty::default();
301 assert_eq!(
302 lindera_penalty.kanji_penalty_length_threshold,
303 default_penalty.kanji_penalty_length_threshold
304 );
305 assert_eq!(
306 lindera_penalty.kanji_penalty_length_penalty,
307 default_penalty.kanji_penalty_length_penalty
308 );
309 assert_eq!(
310 lindera_penalty.other_penalty_length_threshold,
311 default_penalty.other_penalty_length_threshold
312 );
313 assert_eq!(
314 lindera_penalty.other_penalty_length_penalty,
315 default_penalty.other_penalty_length_penalty
316 );
317 }
318
319 #[test]
320 fn test_pypenalty_roundtrip() {
321 let original = PyPenalty {
322 kanji_penalty_length_threshold: 4,
323 kanji_penalty_length_penalty: 2500,
324 other_penalty_length_threshold: 6,
325 other_penalty_length_penalty: 1800,
326 };
327 let lindera: LinderaPenalty = original.into();
328 let roundtripped: PyPenalty = lindera.into();
329 assert_eq!(roundtripped.kanji_penalty_length_threshold, 4);
330 assert_eq!(roundtripped.kanji_penalty_length_penalty, 2500);
331 assert_eq!(roundtripped.other_penalty_length_threshold, 6);
332 assert_eq!(roundtripped.other_penalty_length_penalty, 1800);
333 }
334}