use std::collections::{HashMap, HashSet};
use std::fmt;
use daachorse::{DoubleArrayAhoCorasick, DoubleArrayAhoCorasickBuilder};
use crate::json_structs::AddedTokenConfig;
pub struct AddedTokens {
daac: DoubleArrayAhoCorasick<u32>,
token_lens: Vec<usize>,
start_bytes: Vec<u8>,
max_token_len: usize,
id_to_content: HashMap<u32, String>,
content_to_id: HashMap<String, u32>,
special_ids: HashSet<u32>,
}
#[derive(Debug, PartialEq, Eq)]
pub enum Segment<'a> {
Token(u32),
Text(&'a str),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct AddedTokenInfo<'a> {
pub id: u32,
pub content: &'a str,
pub special: bool,
}
impl AddedTokens {
pub fn from_configs(configs: &[AddedTokenConfig]) -> Result<Option<Self>, String> {
if configs.is_empty() {
return Ok(None);
}
let max_id = configs.iter().map(|c| c.id).max().unwrap_or(0);
let mut token_lens = vec![0usize; (max_id + 1) as usize];
let mut id_to_content = HashMap::with_capacity(configs.len());
let mut special_ids = HashSet::new();
let mut content_to_id = HashMap::with_capacity(configs.len());
let patterns: Vec<(&str, u32)> = configs
.iter()
.map(|c| {
token_lens[c.id as usize] = c.content.len();
id_to_content.insert(c.id, c.content.clone());
content_to_id.insert(c.content.clone(), c.id);
if c.special {
special_ids.insert(c.id);
}
(c.content.as_str(), c.id)
})
.collect();
let daac = DoubleArrayAhoCorasickBuilder::new()
.match_kind(daachorse::MatchKind::LeftmostLongest)
.build_with_values(patterns)
.map_err(|e| format!("error building added-tokens DAAC: {e}"))?;
let mut start_set = [false; 256];
let mut max_token_len = 0;
for c in configs {
if let Some(&b) = c.content.as_bytes().first() {
start_set[b as usize] = true;
}
max_token_len = max_token_len.max(c.content.len());
}
let start_bytes: Vec<u8> = start_set
.iter()
.enumerate()
.filter(|&(_, v)| *v)
.map(|(i, _)| i as u8)
.collect();
Ok(Some(Self {
daac,
token_lens,
start_bytes,
max_token_len,
id_to_content,
content_to_id,
special_ids,
}))
}
pub fn id_to_token(&self, id: u32) -> Option<&str> {
self.id_to_content.get(&id).map(String::as_str)
}
pub fn token_to_id(&self, token: &str) -> Option<u32> {
self.content_to_id.get(token).copied()
}
pub fn is_special(&self, id: u32) -> bool {
self.special_ids.contains(&id)
}
pub fn len(&self) -> usize {
self.id_to_content.len()
}
pub fn is_empty(&self) -> bool {
self.id_to_content.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = AddedTokenInfo<'_>> {
self.id_to_content
.iter()
.map(|(&id, content)| AddedTokenInfo {
id,
content: content.as_str(),
special: self.special_ids.contains(&id),
})
}
pub fn split<'a>(&self, input: &'a str) -> Vec<Segment<'a>> {
match self.start_bytes.len() {
1 => self.split_prefilter(
input,
memchr::memchr_iter(self.start_bytes[0], input.as_bytes()),
),
2 => self.split_prefilter(
input,
memchr::memchr2_iter(self.start_bytes[0], self.start_bytes[1], input.as_bytes()),
),
3 => self.split_prefilter(
input,
memchr::memchr3_iter(
self.start_bytes[0],
self.start_bytes[1],
self.start_bytes[2],
input.as_bytes(),
),
),
_ => self.split_full_scan(input),
}
}
fn split_prefilter<'a>(
&self,
input: &'a str,
candidates: impl Iterator<Item = usize>,
) -> Vec<Segment<'a>> {
let mut segments = Vec::new();
let mut prev_end = 0;
for pos in candidates {
if pos < prev_end {
continue;
}
let mut window_end = (pos + self.max_token_len).min(input.len());
while window_end < input.len() && !input.is_char_boundary(window_end) {
window_end += 1;
}
let window = &input[pos..window_end];
if let Some(m) = self.daac.leftmost_find_iter(window).next()
&& m.start() == 0
{
if pos > prev_end {
segments.push(Segment::Text(&input[prev_end..pos]));
}
segments.push(Segment::Token(m.value()));
prev_end = pos + m.end();
}
}
if prev_end < input.len() {
segments.push(Segment::Text(&input[prev_end..]));
}
if segments.is_empty() && !input.is_empty() {
segments.push(Segment::Text(input));
}
segments
}
fn split_full_scan<'a>(&self, input: &'a str) -> Vec<Segment<'a>> {
let mut segments = Vec::new();
let mut prev_end = 0;
for m in self.daac.leftmost_find_iter(input) {
if m.start() > prev_end {
segments.push(Segment::Text(&input[prev_end..m.start()]));
}
segments.push(Segment::Token(m.value()));
prev_end = m.end();
}
if prev_end < input.len() {
segments.push(Segment::Text(&input[prev_end..]));
}
segments
}
}
impl fmt::Debug for AddedTokens {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let count = self.token_lens.iter().filter(|&&len| len > 0).count();
f.debug_struct("AddedTokens")
.field("count", &count)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_config(id: u32, content: &str) -> AddedTokenConfig {
AddedTokenConfig {
id,
content: content.to_string(),
single_word: false,
lstrip: false,
rstrip: false,
normalized: false,
special: false,
}
}
#[test]
fn empty_configs() {
let result = AddedTokens::from_configs(&[]).unwrap();
assert!(result.is_none());
}
#[test]
fn no_match() {
let configs = vec![make_config(100, "<special>")];
let at = AddedTokens::from_configs(&configs).unwrap().unwrap();
let segs = at.split("hello world");
assert_eq!(segs, vec![Segment::Text("hello world")]);
}
#[test]
fn single_match_at_start() {
let configs = vec![make_config(100, "<s>")];
let at = AddedTokens::from_configs(&configs).unwrap().unwrap();
let segs = at.split("<s>hello");
assert_eq!(segs, vec![Segment::Token(100), Segment::Text("hello")]);
}
#[test]
fn single_match_at_end() {
let configs = vec![make_config(100, "</s>")];
let at = AddedTokens::from_configs(&configs).unwrap().unwrap();
let segs = at.split("hello</s>");
assert_eq!(segs, vec![Segment::Text("hello"), Segment::Token(100)]);
}
#[test]
fn match_in_middle() {
let configs = vec![make_config(42, "<sep>")];
let at = AddedTokens::from_configs(&configs).unwrap().unwrap();
let segs = at.split("hello<sep>world");
assert_eq!(
segs,
vec![
Segment::Text("hello"),
Segment::Token(42),
Segment::Text("world"),
]
);
}
#[test]
fn multiple_matches() {
let configs = vec![make_config(1, "<a>"), make_config(2, "<b>")];
let at = AddedTokens::from_configs(&configs).unwrap().unwrap();
let segs = at.split("x<a>y<b>z");
assert_eq!(
segs,
vec![
Segment::Text("x"),
Segment::Token(1),
Segment::Text("y"),
Segment::Token(2),
Segment::Text("z"),
]
);
}
#[test]
fn adjacent_matches() {
let configs = vec![make_config(1, "<a>"), make_config(2, "<b>")];
let at = AddedTokens::from_configs(&configs).unwrap().unwrap();
let segs = at.split("<a><b>");
assert_eq!(segs, vec![Segment::Token(1), Segment::Token(2)]);
}
#[test]
fn longest_match_wins() {
let configs = vec![make_config(1, "<file>"), make_config(2, "<filename>")];
let at = AddedTokens::from_configs(&configs).unwrap().unwrap();
let segs = at.split("a<filename>b");
assert_eq!(
segs,
vec![Segment::Text("a"), Segment::Token(2), Segment::Text("b"),]
);
}
#[test]
fn entire_input_is_added_token() {
let configs = vec![make_config(99, "hello")];
let at = AddedTokens::from_configs(&configs).unwrap().unwrap();
let segs = at.split("hello");
assert_eq!(segs, vec![Segment::Token(99)]);
}
#[test]
fn empty_input() {
let configs = vec![make_config(1, "<s>")];
let at = AddedTokens::from_configs(&configs).unwrap().unwrap();
let segs = at.split("");
assert!(segs.is_empty());
}
#[test]
fn token_to_id_finds_added_token() {
let configs = vec![make_config(42, "<special>")];
let at = AddedTokens::from_configs(&configs).unwrap().unwrap();
assert_eq!(at.token_to_id("<special>"), Some(42));
}
#[test]
fn token_to_id_returns_none_for_unknown() {
let configs = vec![make_config(1, "<known>")];
let at = AddedTokens::from_configs(&configs).unwrap().unwrap();
assert_eq!(at.token_to_id("<unknown>"), None);
}
#[test]
fn token_to_id_and_id_to_token_are_inverses() {
let configs = vec![
make_config(10, "<bos>"),
make_config(11, "<eos>"),
make_config(12, "<pad>"),
];
let at = AddedTokens::from_configs(&configs).unwrap().unwrap();
for cfg in &configs {
let id = at.token_to_id(&cfg.content).unwrap();
assert_eq!(id, cfg.id);
assert_eq!(at.id_to_token(id), Some(cfg.content.as_str()));
}
}
#[test]
fn unicode_token_content() {
let configs = vec![
make_config(1, "▁"), make_config(2, "Ġ"), make_config(3, "日本語"),
];
let at = AddedTokens::from_configs(&configs).unwrap().unwrap();
assert_eq!(
at.split("▁hello"),
vec![Segment::Token(1), Segment::Text("hello")]
);
assert_eq!(
at.split("Ġworld"),
vec![Segment::Token(2), Segment::Text("world")]
);
assert_eq!(
at.split("日本語text"),
vec![Segment::Token(3), Segment::Text("text")]
);
assert_eq!(at.token_to_id("▁"), Some(1));
assert_eq!(at.token_to_id("Ġ"), Some(2));
assert_eq!(at.token_to_id("日本語"), Some(3));
}
#[test]
fn emoji_token_content() {
let configs = vec![make_config(7, "🌍")];
let at = AddedTokens::from_configs(&configs).unwrap().unwrap();
assert_eq!(
at.split("hello 🌍 world"),
vec![
Segment::Text("hello "),
Segment::Token(7),
Segment::Text(" world"),
]
);
}
#[test]
fn is_special_only_for_marked_tokens() {
let mut special = make_config(1, "<bos>");
special.special = true;
let non_special = make_config(2, "<extra>");
let at = AddedTokens::from_configs(&[special, non_special])
.unwrap()
.unwrap();
assert!(at.is_special(1));
assert!(!at.is_special(2));
assert!(!at.is_special(99)); }
#[test]
fn iter_exposes_id_content_and_special_flag() {
let mut special = make_config(1, "<bos>");
special.special = true;
let plain = make_config(2, "<extra>");
let at = AddedTokens::from_configs(&[special, plain])
.unwrap()
.unwrap();
let mut entries: Vec<_> = at.iter().collect();
entries.sort_by_key(|entry| entry.id);
assert_eq!(
entries,
vec![
AddedTokenInfo {
id: 1,
content: "<bos>",
special: true,
},
AddedTokenInfo {
id: 2,
content: "<extra>",
special: false,
},
]
);
}
#[test]
fn len_returns_token_count() {
let configs = vec![
make_config(1, "<a>"),
make_config(2, "<b>"),
make_config(3, "<c>"),
];
let at = AddedTokens::from_configs(&configs).unwrap().unwrap();
assert_eq!(at.len(), 3);
assert!(!at.is_empty());
}
#[test]
fn three_tokens_with_shared_start_byte() {
let configs = vec![
make_config(1, "<"),
make_config(2, "<s>"),
make_config(3, "<sep>"),
];
let at = AddedTokens::from_configs(&configs).unwrap().unwrap();
let segs = at.split("x<sep>y<s>z<");
assert_eq!(
segs,
vec![
Segment::Text("x"),
Segment::Token(3),
Segment::Text("y"),
Segment::Token(2),
Segment::Text("z"),
Segment::Token(1),
]
);
}
#[test]
fn four_distinct_start_bytes_uses_full_scan() {
let configs = vec![
make_config(1, "<bos>"),
make_config(2, "[SEP]"),
make_config(3, "{pad}"),
make_config(4, "|mask|"),
];
let at = AddedTokens::from_configs(&configs).unwrap().unwrap();
let segs = at.split("<bos>[SEP]{pad}|mask|");
assert_eq!(
segs,
vec![
Segment::Token(1),
Segment::Token(2),
Segment::Token(3),
Segment::Token(4),
]
);
}
#[test]
fn token_surrounded_by_text() {
let configs = vec![make_config(5, "<mid>")];
let at = AddedTokens::from_configs(&configs).unwrap().unwrap();
let segs = at.split("prefix <mid> suffix");
assert_eq!(
segs,
vec![
Segment::Text("prefix "),
Segment::Token(5),
Segment::Text(" suffix"),
]
);
}
#[test]
fn repeated_same_token() {
let configs = vec![make_config(9, "<r>")];
let at = AddedTokens::from_configs(&configs).unwrap().unwrap();
let segs = at.split("<r><r><r>");
assert_eq!(
segs,
vec![Segment::Token(9), Segment::Token(9), Segment::Token(9)]
);
}
}