lindera_dictionary/
mode.rs1use std::str::FromStr;
2
3use serde::{Deserialize, Serialize};
4
5use crate::error::{LinderaError, LinderaErrorKind};
6use crate::viterbi::Edge;
7
8#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
9pub struct Penalty {
10 pub kanji_penalty_length_threshold: usize,
11 pub kanji_penalty_length_penalty: i32,
12 pub other_penalty_length_threshold: usize,
13 pub other_penalty_length_penalty: i32,
14}
15
16impl Default for Penalty {
17 fn default() -> Self {
18 Penalty {
19 kanji_penalty_length_threshold: 2,
20 kanji_penalty_length_penalty: 3000,
21 other_penalty_length_threshold: 7,
22 other_penalty_length_penalty: 1700,
23 }
24 }
25}
26
27impl Penalty {
28 #[inline]
29 pub fn penalty(&self, edge: &Edge) -> i32 {
30 let num_chars = edge.num_chars();
31 if num_chars <= self.kanji_penalty_length_threshold {
32 return 0;
33 }
34 if edge.kanji_only {
35 ((num_chars - self.kanji_penalty_length_threshold) as i32)
36 * self.kanji_penalty_length_penalty
37 } else if num_chars > self.other_penalty_length_threshold {
38 ((num_chars - self.other_penalty_length_threshold) as i32)
39 * self.other_penalty_length_penalty
40 } else {
41 0
42 }
43 }
44}
45
46#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
47pub enum Mode {
48 #[serde(rename = "normal")]
49 Normal,
50 #[serde(rename = "decompose")]
51 Decompose(Penalty),
52}
53
54impl Mode {
55 #[inline]
56 pub fn is_search(&self) -> bool {
57 match self {
58 Mode::Normal => false,
59 Mode::Decompose(_penalty) => true,
60 }
61 }
62
63 #[inline]
64 pub fn penalty_cost(&self, edge: &Edge) -> i32 {
65 match self {
66 Mode::Normal => 0i32,
67 Mode::Decompose(penalty) => penalty.penalty(edge),
68 }
69 }
70
71 pub fn as_str(&self) -> &str {
72 match self {
73 Mode::Normal => "normal",
74 Mode::Decompose(_penalty) => "decompose",
75 }
76 }
77}
78
79impl FromStr for Mode {
80 type Err = LinderaError;
81 fn from_str(mode: &str) -> Result<Mode, Self::Err> {
82 match mode {
83 "normal" => Ok(Mode::Normal),
84 "decompose" => Ok(Mode::Decompose(Penalty::default())),
85 _ => Err(LinderaErrorKind::Mode.with_error(anyhow::anyhow!("Invalid mode: {mode}"))),
86 }
87 }
88}