Skip to main content

lindera_ruby/
mode.rs

1//! Tokenization modes and penalty configurations.
2//!
3//! This module defines the different tokenization modes available and their
4//! penalty configurations for controlling segmentation behavior.
5
6use magnus::prelude::*;
7use magnus::{Error, Ruby, function, method};
8
9use lindera::mode::{Mode as LinderaMode, Penalty as LinderaPenalty};
10
11/// Tokenization mode.
12///
13/// Determines how text is segmented into tokens.
14#[magnus::wrap(class = "Lindera::Mode", free_immediately, size)]
15#[derive(Debug, Clone, Copy)]
16pub struct RbMode {
17    /// Internal mode variant.
18    inner: RbModeKind,
19}
20
21/// Internal enum for mode kind.
22#[derive(Debug, Clone, Copy)]
23enum RbModeKind {
24    /// Standard tokenization based on dictionary cost.
25    Normal,
26    /// Decompose compound words using penalty-based segmentation.
27    Decompose,
28}
29
30impl RbMode {
31    /// Creates a new `RbMode` from a mode string.
32    ///
33    /// # Arguments
34    ///
35    /// * `mode_str` - Mode string ("normal" or "decompose"). Defaults to "normal" if None.
36    ///
37    /// # Returns
38    ///
39    /// A new `RbMode` instance.
40    fn new(mode_str: Option<String>) -> Result<Self, Error> {
41        let ruby = Ruby::get().expect("Ruby runtime not initialized");
42        let kind = match mode_str.as_deref() {
43            Some("decompose") | Some("Decompose") => RbModeKind::Decompose,
44            Some("normal") | Some("Normal") | None => RbModeKind::Normal,
45            Some(s) => {
46                return Err(Error::new(
47                    ruby.exception_arg_error(),
48                    format!("Invalid mode: {s}. Must be 'normal' or 'decompose'"),
49                ));
50            }
51        };
52        Ok(Self { inner: kind })
53    }
54
55    /// Returns the string representation of the mode.
56    ///
57    /// # Returns
58    ///
59    /// A string slice representing the mode.
60    #[allow(clippy::wrong_self_convention)]
61    fn to_s(&self) -> &str {
62        match self.inner {
63            RbModeKind::Normal => "normal",
64            RbModeKind::Decompose => "decompose",
65        }
66    }
67
68    /// Returns the inspect representation of the mode.
69    ///
70    /// # Returns
71    ///
72    /// A string with the mode inspect representation.
73    fn inspect(&self) -> String {
74        format!("#<Lindera::Mode: {}>", self.to_s())
75    }
76
77    /// Returns the name of the mode.
78    ///
79    /// # Returns
80    ///
81    /// A string slice with the mode name.
82    fn name(&self) -> &str {
83        self.to_s()
84    }
85
86    /// Returns true if the mode is normal.
87    ///
88    /// # Returns
89    ///
90    /// A boolean indicating if this is normal mode.
91    fn is_normal(&self) -> bool {
92        matches!(self.inner, RbModeKind::Normal)
93    }
94
95    /// Returns true if the mode is decompose.
96    ///
97    /// # Returns
98    ///
99    /// A boolean indicating if this is decompose mode.
100    fn is_decompose(&self) -> bool {
101        matches!(self.inner, RbModeKind::Decompose)
102    }
103}
104
105impl From<RbMode> for LinderaMode {
106    fn from(mode: RbMode) -> Self {
107        match mode.inner {
108            RbModeKind::Normal => LinderaMode::Normal,
109            RbModeKind::Decompose => LinderaMode::Decompose(LinderaPenalty::default()),
110        }
111    }
112}
113
114impl From<LinderaMode> for RbMode {
115    fn from(mode: LinderaMode) -> Self {
116        let kind = match mode {
117            LinderaMode::Normal => RbModeKind::Normal,
118            LinderaMode::Decompose(_) => RbModeKind::Decompose,
119        };
120        RbMode { inner: kind }
121    }
122}
123
124/// Penalty configuration for decompose mode.
125///
126/// Controls how aggressively compound words are decomposed based on
127/// character type and length thresholds.
128#[magnus::wrap(class = "Lindera::Penalty", free_immediately, size)]
129#[derive(Debug, Clone, Copy)]
130pub struct RbPenalty {
131    /// Length threshold for kanji penalty.
132    kanji_penalty_length_threshold: usize,
133    /// Penalty value for kanji.
134    kanji_penalty_length_penalty: i32,
135    /// Length threshold for other character penalty.
136    other_penalty_length_threshold: usize,
137    /// Penalty value for other characters.
138    other_penalty_length_penalty: i32,
139}
140
141impl RbPenalty {
142    /// Creates a new `RbPenalty` with optional parameters.
143    ///
144    /// # Arguments
145    ///
146    /// * `kanji_threshold` - Kanji penalty length threshold (default: 2).
147    /// * `kanji_penalty` - Kanji penalty value (default: 3000).
148    /// * `other_threshold` - Other penalty length threshold (default: 7).
149    /// * `other_penalty` - Other penalty value (default: 1700).
150    ///
151    /// # Returns
152    ///
153    /// A new `RbPenalty` instance.
154    fn new(
155        kanji_threshold: Option<usize>,
156        kanji_penalty: Option<i32>,
157        other_threshold: Option<usize>,
158        other_penalty: Option<i32>,
159    ) -> Self {
160        Self {
161            kanji_penalty_length_threshold: kanji_threshold.unwrap_or(2),
162            kanji_penalty_length_penalty: kanji_penalty.unwrap_or(3000),
163            other_penalty_length_threshold: other_threshold.unwrap_or(7),
164            other_penalty_length_penalty: other_penalty.unwrap_or(1700),
165        }
166    }
167
168    /// Returns the kanji penalty length threshold.
169    fn kanji_penalty_length_threshold(&self) -> usize {
170        self.kanji_penalty_length_threshold
171    }
172
173    /// Returns the kanji penalty length penalty.
174    fn kanji_penalty_length_penalty(&self) -> i32 {
175        self.kanji_penalty_length_penalty
176    }
177
178    /// Returns the other penalty length threshold.
179    fn other_penalty_length_threshold(&self) -> usize {
180        self.other_penalty_length_threshold
181    }
182
183    /// Returns the other penalty length penalty.
184    fn other_penalty_length_penalty(&self) -> i32 {
185        self.other_penalty_length_penalty
186    }
187
188    /// Returns a string representation of the penalty configuration.
189    #[allow(clippy::wrong_self_convention)]
190    fn to_s(&self) -> String {
191        format!(
192            "Penalty(kanji_threshold={}, kanji_penalty={}, other_threshold={}, other_penalty={})",
193            self.kanji_penalty_length_threshold,
194            self.kanji_penalty_length_penalty,
195            self.other_penalty_length_threshold,
196            self.other_penalty_length_penalty
197        )
198    }
199
200    /// Returns the inspect representation of the penalty.
201    fn inspect(&self) -> String {
202        format!("#<Lindera::Penalty: {}>", self.to_s())
203    }
204}
205
206impl From<RbPenalty> for LinderaPenalty {
207    fn from(penalty: RbPenalty) -> Self {
208        LinderaPenalty {
209            kanji_penalty_length_threshold: penalty.kanji_penalty_length_threshold,
210            kanji_penalty_length_penalty: penalty.kanji_penalty_length_penalty,
211            other_penalty_length_threshold: penalty.other_penalty_length_threshold,
212            other_penalty_length_penalty: penalty.other_penalty_length_penalty,
213        }
214    }
215}
216
217impl From<LinderaPenalty> for RbPenalty {
218    fn from(penalty: LinderaPenalty) -> Self {
219        RbPenalty {
220            kanji_penalty_length_threshold: penalty.kanji_penalty_length_threshold,
221            kanji_penalty_length_penalty: penalty.kanji_penalty_length_penalty,
222            other_penalty_length_threshold: penalty.other_penalty_length_threshold,
223            other_penalty_length_penalty: penalty.other_penalty_length_penalty,
224        }
225    }
226}
227
228/// Defines Mode and Penalty classes in the given Ruby module.
229///
230/// # Arguments
231///
232/// * `ruby` - Ruby runtime handle.
233/// * `module` - Parent Ruby module.
234///
235/// # Returns
236///
237/// `Ok(())` on success, or a Magnus `Error` on failure.
238pub fn define(ruby: &Ruby, module: &magnus::RModule) -> Result<(), Error> {
239    let mode_class = module.define_class("Mode", ruby.class_object())?;
240    mode_class.define_singleton_method("new", function!(RbMode::new, 1))?;
241    mode_class.define_method("to_s", method!(RbMode::to_s, 0))?;
242    mode_class.define_method("inspect", method!(RbMode::inspect, 0))?;
243    mode_class.define_method("name", method!(RbMode::name, 0))?;
244    mode_class.define_method("normal?", method!(RbMode::is_normal, 0))?;
245    mode_class.define_method("decompose?", method!(RbMode::is_decompose, 0))?;
246
247    let penalty_class = module.define_class("Penalty", ruby.class_object())?;
248    penalty_class.define_singleton_method("new", function!(RbPenalty::new, 4))?;
249    penalty_class.define_method(
250        "kanji_penalty_length_threshold",
251        method!(RbPenalty::kanji_penalty_length_threshold, 0),
252    )?;
253    penalty_class.define_method(
254        "kanji_penalty_length_penalty",
255        method!(RbPenalty::kanji_penalty_length_penalty, 0),
256    )?;
257    penalty_class.define_method(
258        "other_penalty_length_threshold",
259        method!(RbPenalty::other_penalty_length_threshold, 0),
260    )?;
261    penalty_class.define_method(
262        "other_penalty_length_penalty",
263        method!(RbPenalty::other_penalty_length_penalty, 0),
264    )?;
265    penalty_class.define_method("to_s", method!(RbPenalty::to_s, 0))?;
266    penalty_class.define_method("inspect", method!(RbPenalty::inspect, 0))?;
267
268    Ok(())
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn test_rb_mode_normal_to_lindera_mode() {
277        let rb_mode = RbMode {
278            inner: RbModeKind::Normal,
279        };
280        let lindera_mode: LinderaMode = rb_mode.into();
281        assert!(matches!(lindera_mode, LinderaMode::Normal));
282    }
283
284    #[test]
285    fn test_rb_mode_decompose_to_lindera_mode() {
286        let rb_mode = RbMode {
287            inner: RbModeKind::Decompose,
288        };
289        let lindera_mode: LinderaMode = rb_mode.into();
290        assert!(matches!(lindera_mode, LinderaMode::Decompose(_)));
291        if let LinderaMode::Decompose(penalty) = lindera_mode {
292            let default_penalty = LinderaPenalty::default();
293            assert_eq!(
294                penalty.kanji_penalty_length_threshold,
295                default_penalty.kanji_penalty_length_threshold
296            );
297            assert_eq!(
298                penalty.kanji_penalty_length_penalty,
299                default_penalty.kanji_penalty_length_penalty
300            );
301            assert_eq!(
302                penalty.other_penalty_length_threshold,
303                default_penalty.other_penalty_length_threshold
304            );
305            assert_eq!(
306                penalty.other_penalty_length_penalty,
307                default_penalty.other_penalty_length_penalty
308            );
309        }
310    }
311
312    #[test]
313    fn test_lindera_mode_normal_to_rb_mode() {
314        let lindera_mode = LinderaMode::Normal;
315        let rb_mode: RbMode = lindera_mode.into();
316        assert!(matches!(rb_mode.inner, RbModeKind::Normal));
317    }
318
319    #[test]
320    fn test_lindera_mode_decompose_to_rb_mode() {
321        let lindera_mode = LinderaMode::Decompose(LinderaPenalty::default());
322        let rb_mode: RbMode = lindera_mode.into();
323        assert!(matches!(rb_mode.inner, RbModeKind::Decompose));
324    }
325
326    #[test]
327    fn test_rb_penalty_to_lindera_penalty() {
328        let rb_penalty = RbPenalty {
329            kanji_penalty_length_threshold: 3,
330            kanji_penalty_length_penalty: 5000,
331            other_penalty_length_threshold: 10,
332            other_penalty_length_penalty: 2500,
333        };
334        let lindera_penalty: LinderaPenalty = rb_penalty.into();
335        assert_eq!(lindera_penalty.kanji_penalty_length_threshold, 3);
336        assert_eq!(lindera_penalty.kanji_penalty_length_penalty, 5000);
337        assert_eq!(lindera_penalty.other_penalty_length_threshold, 10);
338        assert_eq!(lindera_penalty.other_penalty_length_penalty, 2500);
339    }
340
341    #[test]
342    fn test_lindera_penalty_to_rb_penalty() {
343        let lindera_penalty = LinderaPenalty {
344            kanji_penalty_length_threshold: 4,
345            kanji_penalty_length_penalty: 6000,
346            other_penalty_length_threshold: 8,
347            other_penalty_length_penalty: 1500,
348        };
349        let rb_penalty: RbPenalty = lindera_penalty.into();
350        assert_eq!(rb_penalty.kanji_penalty_length_threshold, 4);
351        assert_eq!(rb_penalty.kanji_penalty_length_penalty, 6000);
352        assert_eq!(rb_penalty.other_penalty_length_threshold, 8);
353        assert_eq!(rb_penalty.other_penalty_length_penalty, 1500);
354    }
355
356    #[test]
357    fn test_rb_penalty_default_values_roundtrip() {
358        let default_lindera = LinderaPenalty::default();
359        let rb_penalty: RbPenalty = default_lindera.clone().into();
360        let roundtripped: LinderaPenalty = rb_penalty.into();
361        assert_eq!(
362            roundtripped.kanji_penalty_length_threshold,
363            default_lindera.kanji_penalty_length_threshold
364        );
365        assert_eq!(
366            roundtripped.kanji_penalty_length_penalty,
367            default_lindera.kanji_penalty_length_penalty
368        );
369        assert_eq!(
370            roundtripped.other_penalty_length_threshold,
371            default_lindera.other_penalty_length_threshold
372        );
373        assert_eq!(
374            roundtripped.other_penalty_length_penalty,
375            default_lindera.other_penalty_length_penalty
376        );
377    }
378
379    #[test]
380    fn test_mode_roundtrip_normal() {
381        let original = LinderaMode::Normal;
382        let rb: RbMode = original.into();
383        let back: LinderaMode = rb.into();
384        assert!(matches!(back, LinderaMode::Normal));
385    }
386
387    #[test]
388    fn test_mode_roundtrip_decompose() {
389        let penalty = LinderaPenalty {
390            kanji_penalty_length_threshold: 5,
391            kanji_penalty_length_penalty: 4000,
392            other_penalty_length_threshold: 9,
393            other_penalty_length_penalty: 2000,
394        };
395        let original = LinderaMode::Decompose(penalty);
396        let rb: RbMode = original.into();
397        let back: LinderaMode = rb.into();
398        // Note: RbMode loses the penalty details, so Decompose uses default penalty.
399        assert!(matches!(back, LinderaMode::Decompose(_)));
400    }
401}