sentencepiece_rs/
normalizer.rs1use crate::darts::DoubleArray;
2use crate::proto::{NormalizerSpec, TrainerSpec};
3use crate::util::{SPACE_SYMBOL, first_char_len};
4use crate::{Error, Result};
5
6#[derive(Clone, Debug)]
11pub struct Normalizer {
12 spec: NormalizerSpec,
13 treat_whitespace_as_suffix: bool,
14 charsmap: Option<PrecompiledCharsMap>,
15 user_symbols: Vec<String>,
16}
17
18#[derive(Clone, Debug)]
19struct PrecompiledCharsMap {
20 trie: DoubleArray,
21 normalized: Vec<u8>,
22}
23
24impl Normalizer {
25 pub(crate) fn new(spec: NormalizerSpec, trainer_spec: &TrainerSpec) -> Result<Self> {
26 let charsmap = if spec.precompiled_charsmap.is_empty() {
27 None
28 } else {
29 Some(PrecompiledCharsMap::decode(&spec.precompiled_charsmap)?)
30 };
31
32 Ok(Self {
33 spec,
34 treat_whitespace_as_suffix: trainer_spec.treat_whitespace_as_suffix,
35 charsmap,
36 user_symbols: Vec::new(),
37 })
38 }
39
40 pub(crate) fn new_denormalizer(spec: NormalizerSpec) -> Result<Self> {
41 let charsmap = if spec.precompiled_charsmap.is_empty() {
42 None
43 } else {
44 Some(PrecompiledCharsMap::decode(&spec.precompiled_charsmap)?)
45 };
46
47 Ok(Self {
48 spec,
49 treat_whitespace_as_suffix: false,
50 charsmap,
51 user_symbols: Vec::new(),
52 })
53 }
54
55 pub(crate) fn set_user_symbols(&mut self, mut user_symbols: Vec<String>) {
56 user_symbols.sort_by_key(|symbol| std::cmp::Reverse(symbol.len()));
57 self.user_symbols = user_symbols;
58 }
59
60 pub fn normalize(&self, input: &str) -> Result<String> {
62 if input.is_empty() {
63 return Ok(String::new());
64 }
65
66 let mut cursor = 0;
67 let bytes = input.as_bytes();
68
69 if self.spec.remove_extra_whitespaces {
70 while cursor < bytes.len() {
71 let (normalized, consumed) = self.normalize_prefix(&input[cursor..]);
72 if normalized.as_slice() != b" " {
73 break;
74 }
75 cursor += consumed;
76 }
77 }
78
79 if cursor == bytes.len() {
80 return Ok(String::new());
81 }
82
83 let mut output = Vec::with_capacity((bytes.len() - cursor) * 3);
84 let add_ws = |output: &mut Vec<u8>| {
85 if self.spec.escape_whitespaces {
86 output.extend_from_slice(SPACE_SYMBOL.as_bytes());
87 } else {
88 output.push(b' ');
89 }
90 };
91
92 if !self.treat_whitespace_as_suffix && self.spec.add_dummy_prefix {
93 add_ws(&mut output);
94 }
95
96 let mut is_prev_space = self.spec.remove_extra_whitespaces;
97 while cursor < bytes.len() {
98 let (mut normalized, consumed) = self.normalize_prefix(&input[cursor..]);
99
100 while is_prev_space && normalized.first() == Some(&b' ') {
101 normalized.remove(0);
102 }
103
104 if !normalized.is_empty() {
105 for byte in normalized.iter().copied() {
106 if self.spec.escape_whitespaces && byte == b' ' {
107 output.extend_from_slice(SPACE_SYMBOL.as_bytes());
108 } else {
109 output.push(byte);
110 }
111 }
112 is_prev_space = normalized.last() == Some(&b' ');
113 }
114
115 cursor += consumed;
116 if !self.spec.remove_extra_whitespaces {
117 is_prev_space = false;
118 }
119 }
120
121 if self.spec.remove_extra_whitespaces {
122 let suffix = if self.spec.escape_whitespaces {
123 SPACE_SYMBOL.as_bytes()
124 } else {
125 b" "
126 };
127 while output.ends_with(suffix) {
128 let new_len = output.len() - suffix.len();
129 output.truncate(new_len);
130 }
131 }
132
133 if self.treat_whitespace_as_suffix && self.spec.add_dummy_prefix {
134 add_ws(&mut output);
135 }
136
137 String::from_utf8(output)
138 .map_err(|_| Error::model_parse("normalization produced invalid UTF-8"))
139 }
140
141 pub(crate) fn add_dummy_prefix(&self) -> bool {
142 self.spec.add_dummy_prefix
143 }
144
145 pub(crate) fn remove_extra_whitespaces(&self) -> bool {
146 self.spec.remove_extra_whitespaces
147 }
148
149 fn normalize_prefix(&self, input: &str) -> (Vec<u8>, usize) {
150 if input.is_empty() {
151 return (Vec::new(), 0);
152 }
153
154 if let Some(symbol) = self
155 .user_symbols
156 .iter()
157 .find(|symbol| input.as_bytes().starts_with(symbol.as_bytes()))
158 {
159 return (symbol.as_bytes().to_vec(), symbol.len());
160 }
161
162 if let Some(charsmap) = &self.charsmap
163 && let Some((offset, length)) = charsmap.longest_match(input.as_bytes())
164 && let Some(normalized) = charsmap.normalized_at(offset)
165 {
166 return (normalized.to_vec(), length);
167 }
168
169 let len = first_char_len(input);
170 (input.as_bytes()[..len].to_vec(), len)
171 }
172}
173
174impl PrecompiledCharsMap {
175 fn decode(blob: &[u8]) -> Result<Self> {
176 if blob.len() <= 4 {
177 return Err(Error::model_parse("normalization rule blob is broken"));
178 }
179
180 let trie_blob_size = u32::from_le_bytes([blob[0], blob[1], blob[2], blob[3]]) as usize;
181 if trie_blob_size >= blob.len() {
182 return Err(Error::model_parse(
183 "normalization trie data exceeds the input blob size",
184 ));
185 }
186 if trie_blob_size < 1024 || (trie_blob_size & 0x3ff) != 0 {
187 return Err(Error::model_parse(
188 "normalization trie data size is not divisible by 1024",
189 ));
190 }
191
192 let trie_start = 4;
193 let trie_end = trie_start + trie_blob_size;
194 let normalized = blob[trie_end..].to_vec();
195 if normalized.is_empty() || normalized.last() != Some(&0) {
196 return Err(Error::model_parse(
197 "normalization data block must be null-terminated",
198 ));
199 }
200
201 Ok(Self {
202 trie: DoubleArray::from_le_blob(&blob[trie_start..trie_end])?,
203 normalized,
204 })
205 }
206
207 fn longest_match(&self, input: &[u8]) -> Option<(usize, usize)> {
208 self.trie
209 .common_prefix_search(input)
210 .into_iter()
211 .max_by_key(|(_, length)| *length)
212 .map(|(offset, length)| (offset as usize, length))
213 }
214
215 fn normalized_at(&self, offset: usize) -> Option<&[u8]> {
216 if offset >= self.normalized.len() {
217 return None;
218 }
219 let tail = &self.normalized[offset..];
220 let end = tail.iter().position(|byte| *byte == 0)?;
221 Some(&tail[..end])
222 }
223}