use rustc_hash::FxHashMap;
#[derive(Debug, Clone, Copy)]
struct Node {
prev: usize,
next: usize,
rank: u32,
start: usize,
len: usize,
}
pub fn byte_pair_encode(piece: &[u8], encoder: &FxHashMap<Vec<u8>, u32>) -> Vec<u32> {
if piece.is_empty() {
return vec![];
}
if piece.len() == 1 {
return encoder.get(piece).copied().map_or(vec![], |r| vec![r]);
}
if let Some(&rank) = encoder.get(piece) {
return vec![rank];
}
let mut nodes: Vec<Node> = Vec::with_capacity(piece.len());
for i in 0..piece.len() {
nodes.push(Node {
prev: if i == 0 { usize::MAX } else { i - 1 },
next: if i == piece.len() - 1 {
usize::MAX
} else {
i + 1
},
rank: u32::MAX,
start: i,
len: 1,
});
}
let get_rank = |left_idx: usize, right_idx: usize, nodes: &[Node]| -> u32 {
if left_idx == usize::MAX || right_idx == usize::MAX {
return u32::MAX;
}
let left = &nodes[left_idx];
let right = &nodes[right_idx];
let start = left.start;
let len = left.len + right.len;
let slice = &piece[start..start + len];
encoder.get(slice).copied().unwrap_or(u32::MAX)
};
for i in 0..nodes.len() - 1 {
nodes[i].rank = get_rank(i, nodes[i].next, &nodes);
}
loop {
let mut min_rank = u32::MAX;
let mut min_idx = usize::MAX;
let mut curr = 0;
while nodes[curr].prev != usize::MAX {
curr = nodes[curr].prev;
}
while curr != usize::MAX {
let r = nodes[curr].rank;
if r < min_rank {
min_rank = r;
min_idx = curr;
}
curr = nodes[curr].next;
}
if min_rank == u32::MAX {
break;
}
let next_idx = nodes[min_idx].next;
nodes[min_idx].len += nodes[next_idx].len;
let new_next = nodes[next_idx].next;
nodes[min_idx].next = new_next;
if new_next != usize::MAX {
nodes[new_next].prev = min_idx;
}
if nodes[min_idx].prev != usize::MAX {
let prev = nodes[min_idx].prev;
nodes[prev].rank = get_rank(prev, min_idx, &nodes);
}
nodes[min_idx].rank = get_rank(min_idx, nodes[min_idx].next, &nodes);
}
let mut result = Vec::new();
let mut curr = 0;
while nodes[curr].prev != usize::MAX {
curr = nodes[curr].prev;
}
while curr != usize::MAX {
let node = &nodes[curr];
let slice = &piece[node.start..node.start + node.len];
if let Some(&rank) = encoder.get(slice) {
result.push(rank);
} else {
for &byte in slice {
if let Some(&rank) = encoder.get(&[byte][..]) {
result.push(rank);
}
}
}
curr = nodes[curr].next;
}
result
}
#[cfg(test)]
mod tests {
use super::*;
fn make_encoder() -> FxHashMap<Vec<u8>, u32> {
let mut encoder = FxHashMap::default();
encoder.insert(b"a".to_vec(), 0);
encoder.insert(b"b".to_vec(), 1);
encoder.insert(b"c".to_vec(), 2);
encoder.insert(b"ab".to_vec(), 3);
encoder.insert(b"bc".to_vec(), 4);
encoder.insert(b"abc".to_vec(), 5);
encoder
}
#[test]
fn test_single_byte() {
let encoder = make_encoder();
assert_eq!(byte_pair_encode(b"a", &encoder), vec![0]);
}
#[test]
fn test_simple_merge() {
let encoder = make_encoder();
assert_eq!(byte_pair_encode(b"ab", &encoder), vec![3]);
}
#[test]
fn test_chain_merge() {
let encoder = make_encoder();
assert_eq!(byte_pair_encode(b"abc", &encoder), vec![5]);
}
#[test]
fn test_empty() {
let encoder = make_encoder();
let empty: Vec<u32> = vec![];
assert_eq!(byte_pair_encode(b"", &encoder), empty);
}
#[test]
fn test_no_merge_possible() {
let encoder = make_encoder();
assert_eq!(byte_pair_encode(b"ac", &encoder), vec![0, 2]);
}
}