1use regex::Regex;
21use serde::{Deserialize, Serialize};
22use std::sync::OnceLock;
23
24use crate::byte_encoder::METASPACE;
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
30#[serde(tag = "op", rename_all = "snake_case")]
31pub enum PreTokOp {
32 LiteralsCi { patterns: Vec<String> },
34 Literals { patterns: Vec<String> },
37 Letters {
42 #[serde(default, skip_serializing_if = "Option::is_none")]
43 lead_other: Option<bool>,
44 #[serde(default, skip_serializing_if = "Option::is_none")]
45 lead_space: Option<bool>,
46 },
47 Numbers {
50 #[serde(default, skip_serializing_if = "Option::is_none")]
51 max_run: Option<u32>,
52 #[serde(default, skip_serializing_if = "Option::is_none")]
53 lead_space: Option<bool>,
54 },
55 PunctRun {
58 #[serde(default, skip_serializing_if = "Option::is_none")]
59 lead_space: Option<bool>,
60 #[serde(default, skip_serializing_if = "Option::is_none")]
61 trailing_newlines: Option<bool>,
62 #[serde(default, skip_serializing_if = "Option::is_none")]
66 trailing_chars: Option<String>,
67 },
68 LettersCased {
73 kind: CasedKind,
74 #[serde(default, skip_serializing_if = "Option::is_none")]
75 lead_other: Option<bool>,
76 #[serde(default, skip_serializing_if = "Option::is_none")]
77 trailing_ci: Option<Vec<String>>,
78 },
79 NewlineBlock {},
81 TrailingWs {},
83 WsRun {},
85 MetaspaceSplit {
87 #[serde(default, skip_serializing_if = "Option::is_none")]
88 prefix_first: Option<bool>,
89 },
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
94#[serde(rename_all = "snake_case")]
95pub enum CasedKind {
96 Title,
98 Upper,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct PreTokProgram {
107 pub version: u32,
108 pub ops: Vec<PreTokOp>,
109}
110
111fn re_letter() -> &'static Regex {
114 static R: OnceLock<Regex> = OnceLock::new();
115 R.get_or_init(|| Regex::new(r"\p{L}").unwrap())
116}
117fn re_number() -> &'static Regex {
118 static R: OnceLock<Regex> = OnceLock::new();
119 R.get_or_init(|| Regex::new(r"\p{N}").unwrap())
120}
121fn re_ws() -> &'static Regex {
122 static R: OnceLock<Regex> = OnceLock::new();
123 R.get_or_init(|| Regex::new(r"\s").unwrap())
124}
125fn re_letter_upper() -> &'static Regex {
126 static R: OnceLock<Regex> = OnceLock::new();
127 R.get_or_init(|| Regex::new(r"[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]").unwrap())
128}
129fn re_letter_lower() -> &'static Regex {
130 static R: OnceLock<Regex> = OnceLock::new();
131 R.get_or_init(|| Regex::new(r"[\p{Ll}\p{Lm}\p{Lo}\p{M}]").unwrap())
132}
133
134fn is_letter(cp: char) -> bool {
135 let mut buf = [0u8; 4];
136 re_letter().is_match(cp.encode_utf8(&mut buf))
137}
138fn is_number(cp: char) -> bool {
139 let mut buf = [0u8; 4];
140 re_number().is_match(cp.encode_utf8(&mut buf))
141}
142fn is_ws(cp: char) -> bool {
143 let mut buf = [0u8; 4];
144 re_ws().is_match(cp.encode_utf8(&mut buf))
145}
146fn is_letter_upper(cp: char) -> bool {
147 let mut buf = [0u8; 4];
148 re_letter_upper().is_match(cp.encode_utf8(&mut buf))
149}
150fn is_letter_lower(cp: char) -> bool {
151 let mut buf = [0u8; 4];
152 re_letter_lower().is_match(cp.encode_utf8(&mut buf))
153}
154
155fn match_literals_ci(patterns: &[String], text: &str, i: usize) -> usize {
160 let rest = &text[i..];
161 let rest_bytes = rest.as_bytes();
162 let mut best = 0;
163 for p in patterns {
164 if p.len() <= best || rest.len() < p.len() {
165 continue;
166 }
167 let p_bytes = p.as_bytes();
171 let mut ok = true;
172 for k in 0..p.len() {
173 let a = rest_bytes[k];
174 let b = p_bytes[k];
175 if a == b { continue; }
176 if a.is_ascii_uppercase() && a + 32 == b { continue; }
177 if a.is_ascii_lowercase() && a - 32 == b { continue; }
178 ok = false;
179 break;
180 }
181 if ok {
182 best = p.len();
183 }
184 }
185 best
186}
187
188fn match_literals(patterns: &[String], text: &str, i: usize) -> usize {
189 let rest = &text[i..];
190 let bytes = rest.as_bytes();
191 let mut best = 0;
192 for p in patterns {
193 if p.len() <= best || rest.len() < p.len() {
194 continue;
195 }
196 if bytes[..p.len()] == p.as_bytes()[..] {
201 best = p.len();
202 }
203 }
204 best
205}
206
207fn match_letters(lead_other: bool, lead_space: bool, text: &str, i: usize) -> usize {
208 let rest = &text[i..];
209 let mut chars = rest.char_indices().peekable();
210 let mut p = 0usize;
211 if lead_other {
212 if let Some(&(_off, c)) = chars.peek() {
214 if c != '\r' && c != '\n' && !is_letter(c) && !is_number(c) {
215 p = c.len_utf8();
216 chars.next();
217 }
218 }
219 } else if lead_space {
220 if let Some(&(_off, c)) = chars.peek() {
222 if c == ' ' {
223 p = c.len_utf8();
224 chars.next();
225 }
226 }
227 }
228 let run_start = p;
230 while let Some(&(_off, c)) = chars.peek() {
231 if !is_letter(c) {
232 break;
233 }
234 p += c.len_utf8();
235 chars.next();
236 }
237 if p == run_start {
238 0
239 } else {
240 p
241 }
242}
243
244fn match_numbers(max_run: u32, lead_space: bool, text: &str, i: usize) -> usize {
245 let max = if max_run == 0 { u32::MAX } else { max_run };
246 let mut p = 0usize;
247 let bytes = text.as_bytes();
248 if lead_space && i + p < bytes.len() && bytes[i + p] == b' ' {
249 p += 1;
250 }
251 let run_start = p;
252 let mut count = 0u32;
253 for c in text[i + p..].chars() {
254 if count >= max || !is_number(c) {
255 break;
256 }
257 p += c.len_utf8();
258 count += 1;
259 }
260 if p == run_start { 0 } else { p }
261}
262
263fn match_punct_run(
264 lead_space: bool,
265 trailing_newlines: bool,
266 trailing_chars: Option<&str>,
267 text: &str,
268 i: usize,
269) -> usize {
270 let bytes = text.as_bytes();
271 let mut p = i;
272 if lead_space && p < bytes.len() && bytes[p] == b' ' {
273 p += 1;
274 }
275 let run_start = p;
277 for c in text[p..].chars() {
278 if is_ws(c) || is_letter(c) || is_number(c) {
279 break;
280 }
281 p += c.len_utf8();
282 }
283 if p == run_start {
284 return 0;
285 }
286 if let Some(chars) = trailing_chars {
289 loop {
290 let Some(c) = text[p..].chars().next() else { break };
291 if !chars.contains(c) {
292 break;
293 }
294 p += c.len_utf8();
295 }
296 } else if trailing_newlines {
297 while p < bytes.len() && (bytes[p] == b'\n' || bytes[p] == b'\r') {
298 p += 1;
299 }
300 }
301 p - i
302}
303
304fn match_letters_cased(
305 kind: CasedKind,
306 lead_other: bool,
307 trailing_ci: Option<&[String]>,
308 text: &str,
309 i: usize,
310) -> usize {
311 let mut p = i;
312 if lead_other {
313 if let Some(c) = text[p..].chars().next() {
314 if c != '\r' && c != '\n' && !is_letter(c) && !is_number(c) {
315 p += c.len_utf8();
316 }
317 }
318 }
319
320 let mut checkpoints: Vec<usize> = vec![p];
324 while let Some(c) = text[p..].chars().next() {
325 if !is_letter_upper(c) {
326 break;
327 }
328 p += c.len_utf8();
329 checkpoints.push(p);
330 }
331
332 let (min_prefix, min_suffix): (usize, usize) = match kind {
333 CasedKind::Upper => (1, 0),
334 CasedKind::Title => (0, 1),
335 };
336
337 for k in (0..checkpoints.len()).rev() {
339 if k < min_prefix {
340 break;
341 }
342 let mut q = checkpoints[k];
343 let mut suffix_count = 0usize;
344 while let Some(c) = text[q..].chars().next() {
345 if !is_letter_lower(c) {
346 break;
347 }
348 q += c.len_utf8();
349 suffix_count += 1;
350 }
351 if suffix_count < min_suffix {
352 continue;
353 }
354
355 if let Some(patterns) = trailing_ci {
357 let rest = &text[q..];
358 let rest_bytes = rest.as_bytes();
359 let mut best = 0usize;
360 for pat in patterns {
361 if pat.len() <= best || rest.len() < pat.len() {
362 continue;
363 }
364 let p_bytes = pat.as_bytes();
365 let mut ok = true;
366 for k in 0..pat.len() {
367 let a = rest_bytes[k];
368 let b = p_bytes[k];
369 if a == b {
370 continue;
371 }
372 if a.is_ascii_uppercase() && a + 32 == b {
373 continue;
374 }
375 if a.is_ascii_lowercase() && a - 32 == b {
376 continue;
377 }
378 ok = false;
379 break;
380 }
381 if ok {
382 best = pat.len();
383 }
384 }
385 q += best;
386 }
387
388 return q - i;
389 }
390 0
391}
392
393fn match_newline_block(text: &str, i: usize) -> usize {
394 let mut p = 0usize;
397 for c in text[i..].chars() {
398 if !is_ws(c) {
399 break;
400 }
401 p += c.len_utf8();
402 }
403 let bytes = text.as_bytes();
404 let mut first_nl: Option<usize> = None;
406 for q in i..(i + p) {
407 if bytes[q] == b'\n' || bytes[q] == b'\r' {
408 first_nl = Some(q);
409 break;
410 }
411 }
412 let Some(first_nl) = first_nl else { return 0 };
413 let mut q = i + p;
415 while q > first_nl {
416 let c = bytes[q - 1];
417 if c == b'\n' || c == b'\r' {
418 break;
419 }
420 q -= 1;
421 }
422 q - i
423}
424
425fn match_trailing_ws(text: &str, i: usize) -> usize {
426 let mut p = i;
429 for c in text[i..].chars() {
430 if !is_ws(c) {
431 break;
432 }
433 p += c.len_utf8();
434 }
435 if p == i {
436 return 0;
437 }
438 if p == text.len() {
439 return p - i;
440 }
441 let mut q = i;
443 let mut last_start = i;
444 while q < p {
445 last_start = q;
446 let c = text[q..].chars().next().unwrap();
447 q += c.len_utf8();
448 }
449 last_start - i
450}
451
452fn match_ws_run(text: &str, i: usize) -> usize {
453 let mut p = 0usize;
454 for c in text[i..].chars() {
455 if !is_ws(c) {
456 break;
457 }
458 p += c.len_utf8();
459 }
460 p
461}
462
463pub fn run_pretok_program(program: &PreTokProgram, text: &str) -> Vec<String> {
468 if program.ops.len() == 1 {
470 if let PreTokOp::MetaspaceSplit { prefix_first } = &program.ops[0] {
471 return run_metaspace(prefix_first.unwrap_or(false), text);
472 }
473 }
474
475 let mut out: Vec<String> = Vec::new();
476 let bytes = text.as_bytes();
477 let n = bytes.len();
478 let mut i = 0usize;
479 'outer: while i < n {
480 for op in &program.ops {
481 let span = match op {
482 PreTokOp::LiteralsCi { patterns } => match_literals_ci(patterns, text, i),
483 PreTokOp::Literals { patterns } => match_literals(patterns, text, i),
484 PreTokOp::Letters {
485 lead_other,
486 lead_space,
487 } => match_letters(
488 lead_other.unwrap_or(false),
489 lead_space.unwrap_or(false),
490 text,
491 i,
492 ),
493 PreTokOp::Numbers {
494 max_run,
495 lead_space,
496 } => match_numbers(
497 max_run.unwrap_or(0),
498 lead_space.unwrap_or(false),
499 text,
500 i,
501 ),
502 PreTokOp::PunctRun {
503 lead_space,
504 trailing_newlines,
505 trailing_chars,
506 } => match_punct_run(
507 lead_space.unwrap_or(false),
508 trailing_newlines.unwrap_or(false),
509 trailing_chars.as_deref(),
510 text,
511 i,
512 ),
513 PreTokOp::LettersCased {
514 kind,
515 lead_other,
516 trailing_ci,
517 } => match_letters_cased(
518 *kind,
519 lead_other.unwrap_or(false),
520 trailing_ci.as_deref(),
521 text,
522 i,
523 ),
524 PreTokOp::NewlineBlock {} => match_newline_block(text, i),
525 PreTokOp::TrailingWs {} => match_trailing_ws(text, i),
526 PreTokOp::WsRun {} => match_ws_run(text, i),
527 PreTokOp::MetaspaceSplit { .. } => 0, };
529 if span > 0 {
530 out.push(text[i..i + span].to_string());
531 i += span;
532 continue 'outer;
533 }
534 }
535 let c = text[i..].chars().next().unwrap();
537 out.push(c.to_string());
538 i += c.len_utf8();
539 }
540 out
541}
542
543fn run_metaspace(prefix_first: bool, text: &str) -> Vec<String> {
544 let mut out: Vec<String> = Vec::new();
545 let mut buf = String::new();
546 let mut prev_horiz_ws = false;
549 for c in text.chars() {
550 if c == ' ' || c == '\t' {
551 if !prev_horiz_ws {
552 buf.push(' ');
553 prev_horiz_ws = true;
554 }
555 } else {
556 buf.push(c);
557 prev_horiz_ws = false;
558 }
559 }
560 let mut is_first = true;
561 let mut piece = String::new();
562 for c in buf.chars() {
563 if c.is_whitespace() {
564 if !piece.is_empty() {
565 if prefix_first && is_first {
566 out.push(std::mem::take(&mut piece));
567 } else {
568 let mut s = String::with_capacity(piece.len() + 3);
569 s.push(METASPACE);
570 s.push_str(&piece);
571 out.push(s);
572 piece.clear();
573 }
574 is_first = false;
575 }
576 if c == ' ' {
577 is_first = false;
578 }
579 } else {
580 piece.push(c);
581 }
582 }
583 if !piece.is_empty() {
584 if prefix_first && is_first {
585 out.push(piece);
586 } else {
587 let mut s = String::with_capacity(piece.len() + 3);
588 s.push(METASPACE);
589 s.push_str(&piece);
590 out.push(s);
591 }
592 }
593 out
594}
595
596#[cfg(test)]
597mod tests {
598 use super::*;
599
600 fn qwen_program() -> PreTokProgram {
601 PreTokProgram {
602 version: 1,
603 ops: vec![
604 PreTokOp::LiteralsCi {
605 patterns: vec![
606 "'s".into(),
607 "'t".into(),
608 "'re".into(),
609 "'ve".into(),
610 "'m".into(),
611 "'ll".into(),
612 "'d".into(),
613 ],
614 },
615 PreTokOp::Letters {
616 lead_other: Some(true),
617 lead_space: None,
618 },
619 PreTokOp::Numbers {
620 max_run: None,
621 lead_space: None,
622 },
623 PreTokOp::PunctRun {
624 lead_space: Some(true),
625 trailing_newlines: Some(true),
626 trailing_chars: None,
627 },
628 PreTokOp::NewlineBlock {},
629 PreTokOp::TrailingWs {},
630 PreTokOp::WsRun {},
631 ],
632 }
633 }
634
635 #[test]
636 fn qwen_program_splits_basic_text() {
637 let p = qwen_program();
638 let out = run_pretok_program(&p, "Hello, world!");
639 assert_eq!(out, vec!["Hello", ",", " world", "!"]);
640 }
641
642 #[test]
643 fn qwen_program_handles_contractions() {
644 let p = qwen_program();
645 let out = run_pretok_program(&p, "it's");
646 assert_eq!(out, vec!["it", "'s"]);
647 }
648
649 #[test]
650 fn qwen_program_unbounded_digits() {
651 let p = qwen_program();
652 let out = run_pretok_program(&p, "abc 12345 def");
654 assert_eq!(out, vec!["abc", " ", "12345", " def"]);
655 }
656}