use crate::types::TokenId;
#[derive(Clone)]
pub struct VocabDecoder {
data: Vec<u8>,
offsets: Vec<u32>,
}
impl VocabDecoder {
pub fn from_parts(data: Vec<u8>, offsets: Vec<u32>) -> Self {
Self { data, offsets }
}
pub fn as_parts(&self) -> (&[u8], &[u32]) {
(&self.data, &self.offsets)
}
pub fn new(token_bytes: Vec<Vec<u8>>) -> Self {
let total_size: usize = token_bytes.iter().map(|t| t.len()).sum();
let mut data = Vec::with_capacity(total_size);
let mut offsets = Vec::with_capacity(token_bytes.len() + 1);
for token in &token_bytes {
offsets.push(data.len() as u32);
data.extend_from_slice(token);
}
offsets.push(data.len() as u32);
Self { data, offsets }
}
#[inline]
pub fn vocab_size(&self) -> usize {
self.offsets.len() - 1
}
#[inline]
pub fn token_to_bytes(&self, token: TokenId) -> &[u8] {
let start = self.offsets[token as usize] as usize;
let end = self.offsets[token as usize + 1] as usize;
&self.data[start..end]
}
#[inline]
pub fn token_len(&self, token: TokenId) -> usize {
let start = self.offsets[token as usize];
let end = self.offsets[token as usize + 1];
(end - start) as usize
}
const PARALLEL_THRESHOLD: usize = 50_000;
pub fn decode(&self, tokens: &[TokenId]) -> Vec<u8> {
if tokens.is_empty() {
return Vec::new();
}
let num_cpus = std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(1);
if tokens.len() >= Self::PARALLEL_THRESHOLD && num_cpus > 1 {
self.decode_parallel(tokens, num_cpus)
} else {
self.decode_sequential(tokens)
}
}
fn decode_sequential(&self, tokens: &[TokenId]) -> Vec<u8> {
let total_size: usize = tokens
.iter()
.map(|&t| self.token_len(t))
.sum();
let mut result = Vec::with_capacity(total_size);
for &token in tokens {
result.extend_from_slice(self.token_to_bytes(token));
}
result
}
fn decode_parallel(&self, tokens: &[TokenId], num_threads: usize) -> Vec<u8> {
use std::thread;
let chunk_size = (tokens.len() + num_threads - 1) / num_threads;
let chunks: Vec<&[TokenId]> = tokens.chunks(chunk_size).collect();
let results: Vec<Vec<u8>> = thread::scope(|s| {
let handles: Vec<_> = chunks
.into_iter()
.map(|chunk| {
let data = &self.data;
let offsets = &self.offsets;
s.spawn(move || {
let size: usize = chunk
.iter()
.map(|&t| {
let idx = t as usize;
(offsets[idx + 1] - offsets[idx]) as usize
})
.sum();
let mut result = Vec::with_capacity(size);
for &token in chunk {
let t = token as usize;
let start = offsets[t] as usize;
let end = offsets[t + 1] as usize;
result.extend_from_slice(&data[start..end]);
}
result
})
})
.collect();
handles.into_iter().map(|h| h.join().unwrap()).collect()
});
let total_size: usize = results.iter().map(|v| v.len()).sum();
let mut output = Vec::with_capacity(total_size);
for chunk in results {
output.extend_from_slice(&chunk);
}
output
}
pub fn decode_to_string(&self, tokens: &[TokenId]) -> Option<String> {
String::from_utf8(self.decode(tokens)).ok()
}
pub fn token_bytes(&self) -> Vec<Vec<u8>> {
(0..self.vocab_size())
.map(|i| self.token_to_bytes(i as TokenId).to_vec())
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decode_empty() {
let decoder = VocabDecoder::new(vec![vec![b'a'], vec![b'b']]);
assert_eq!(decoder.decode(&[]), Vec::<u8>::new());
}
#[test]
fn test_decode_single() {
let decoder = VocabDecoder::new(vec![vec![b'a'], vec![b'b'], vec![b'c']]);
assert_eq!(decoder.decode(&[0]), b"a");
assert_eq!(decoder.decode(&[1]), b"b");
assert_eq!(decoder.decode(&[2]), b"c");
}
#[test]
fn test_decode_multiple() {
let decoder = VocabDecoder::new(vec![
vec![b'H', b'e', b'l', b'l', b'o'],
vec![b' '],
vec![b'w', b'o', b'r', b'l', b'd'],
]);
assert_eq!(decoder.decode(&[0, 1, 2]), b"Hello world");
}
#[test]
fn test_decode_to_string() {
let decoder = VocabDecoder::new(vec![vec![b'a'], vec![b'b'], vec![b'c']]);
assert_eq!(decoder.decode_to_string(&[0, 1, 2]), Some("abc".to_string()));
}
#[test]
fn test_vocab_size() {
let decoder = VocabDecoder::new(vec![vec![b'a'], vec![b'b'], vec![b'c']]);
assert_eq!(decoder.vocab_size(), 3);
}
#[test]
fn test_token_to_bytes() {
let decoder = VocabDecoder::new(vec![vec![b'a', b'b'], vec![b'c', b'd', b'e']]);
assert_eq!(decoder.token_to_bytes(0), b"ab");
assert_eq!(decoder.token_to_bytes(1), b"cde");
}
#[test]
fn test_token_len() {
let decoder = VocabDecoder::new(vec![vec![b'a'], vec![b'a', b'b'], vec![b'a', b'b', b'c']]);
assert_eq!(decoder.token_len(0), 1);
assert_eq!(decoder.token_len(1), 2);
assert_eq!(decoder.token_len(2), 3);
}
#[test]
fn test_parallel_decode_matches_sequential() {
let token_bytes: Vec<Vec<u8>> = (0u8..=255).map(|b| vec![b]).collect();
let decoder = VocabDecoder::new(token_bytes);
let tokens: Vec<TokenId> = (0..100_000).map(|i| (i % 256) as TokenId).collect();
let sequential = decoder.decode_sequential(&tokens);
let parallel = decoder.decode_parallel(&tokens, 4);
assert_eq!(sequential.len(), parallel.len(), "Length mismatch");
assert_eq!(sequential, parallel, "Content mismatch");
}
#[test]
fn test_parallel_decode_chunk_boundaries() {
let token_bytes = vec![
vec![b'A'], vec![b'B', b'B'], vec![b'C', b'C', b'C'], ];
let decoder = VocabDecoder::new(token_bytes);
let tokens: Vec<TokenId> = (0..60_000).map(|i| (i % 3) as TokenId).collect();
let sequential = decoder.decode_sequential(&tokens);
let parallel = decoder.decode_parallel(&tokens, 4);
assert_eq!(sequential, parallel);
}
}