use crate::types::TokenId;
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub enum PostProcessor {
#[default]
None,
Bert {
cls_token: TokenId,
sep_token: TokenId,
},
Prefix {
bos_token: TokenId,
},
Template {
single_prefix: Vec<TokenId>,
single_suffix: Vec<TokenId>,
pair_a_prefix: Vec<TokenId>,
pair_a_suffix: Vec<TokenId>,
pair_b_prefix: Vec<TokenId>,
pair_b_suffix: Vec<TokenId>,
},
}
impl PostProcessor {
pub fn bert(cls_token: TokenId, sep_token: TokenId) -> Self {
Self::Bert { cls_token, sep_token }
}
pub fn prefix(bos_token: TokenId) -> Self {
Self::Prefix { bos_token }
}
pub fn is_none(&self) -> bool {
matches!(self, PostProcessor::None)
}
pub fn is_special_token(&self, id: TokenId) -> bool {
match self {
PostProcessor::None => false,
PostProcessor::Bert { cls_token, sep_token } => id == *cls_token || id == *sep_token,
PostProcessor::Prefix { bos_token } => id == *bos_token,
PostProcessor::Template {
single_prefix, single_suffix,
pair_a_prefix, pair_a_suffix,
pair_b_prefix, pair_b_suffix,
} => {
single_prefix.contains(&id) || single_suffix.contains(&id) ||
pair_a_prefix.contains(&id) || pair_a_suffix.contains(&id) ||
pair_b_prefix.contains(&id) || pair_b_suffix.contains(&id)
}
}
}
pub fn process(&self, tokens: &[TokenId]) -> Vec<TokenId> {
match self {
PostProcessor::None => tokens.to_vec(),
PostProcessor::Bert { cls_token, sep_token } => {
let mut result = Vec::with_capacity(tokens.len() + 2);
result.push(*cls_token);
result.extend_from_slice(tokens);
result.push(*sep_token);
result
}
PostProcessor::Prefix { bos_token } => {
let mut result = Vec::with_capacity(tokens.len() + 1);
result.push(*bos_token);
result.extend_from_slice(tokens);
result
}
PostProcessor::Template {
single_prefix,
single_suffix,
..
} => {
let mut result = Vec::with_capacity(
single_prefix.len() + tokens.len() + single_suffix.len()
);
result.extend_from_slice(single_prefix);
result.extend_from_slice(tokens);
result.extend_from_slice(single_suffix);
result
}
}
}
pub fn process_pair(
&self,
tokens_a: &[TokenId],
tokens_b: &[TokenId],
) -> (Vec<TokenId>, Vec<u8>) {
match self {
PostProcessor::None => {
let mut tokens = Vec::with_capacity(tokens_a.len() + tokens_b.len());
tokens.extend_from_slice(tokens_a);
tokens.extend_from_slice(tokens_b);
let mut type_ids = vec![0u8; tokens_a.len()];
type_ids.extend(vec![1u8; tokens_b.len()]);
(tokens, type_ids)
}
PostProcessor::Bert { cls_token, sep_token } => {
let mut tokens = Vec::with_capacity(tokens_a.len() + tokens_b.len() + 3);
tokens.push(*cls_token);
tokens.extend_from_slice(tokens_a);
tokens.push(*sep_token);
tokens.extend_from_slice(tokens_b);
tokens.push(*sep_token);
let mut type_ids = vec![0u8; 1 + tokens_a.len() + 1];
type_ids.extend(vec![1u8; tokens_b.len() + 1]);
(tokens, type_ids)
}
PostProcessor::Prefix { bos_token } => {
let mut tokens = Vec::with_capacity(tokens_a.len() + tokens_b.len() + 2);
tokens.push(*bos_token);
tokens.extend_from_slice(tokens_a);
tokens.push(*bos_token);
tokens.extend_from_slice(tokens_b);
let mut type_ids = vec![0u8; 1 + tokens_a.len()];
type_ids.extend(vec![1u8; 1 + tokens_b.len()]);
(tokens, type_ids)
}
PostProcessor::Template {
pair_a_prefix,
pair_a_suffix,
pair_b_prefix,
pair_b_suffix,
..
} => {
let total_len = pair_a_prefix.len()
+ tokens_a.len()
+ pair_a_suffix.len()
+ pair_b_prefix.len()
+ tokens_b.len()
+ pair_b_suffix.len();
let mut tokens = Vec::with_capacity(total_len);
tokens.extend_from_slice(pair_a_prefix);
tokens.extend_from_slice(tokens_a);
tokens.extend_from_slice(pair_a_suffix);
tokens.extend_from_slice(pair_b_prefix);
tokens.extend_from_slice(tokens_b);
tokens.extend_from_slice(pair_b_suffix);
let type_0_len = pair_a_prefix.len() + tokens_a.len() + pair_a_suffix.len();
let type_1_len = pair_b_prefix.len() + tokens_b.len() + pair_b_suffix.len();
let mut type_ids = vec![0u8; type_0_len];
type_ids.extend(vec![1u8; type_1_len]);
(tokens, type_ids)
}
}
}
pub fn num_special_tokens_single(&self) -> usize {
match self {
PostProcessor::None => 0,
PostProcessor::Bert { .. } => 2, PostProcessor::Prefix { .. } => 1, PostProcessor::Template { single_prefix, single_suffix, .. } => {
single_prefix.len() + single_suffix.len()
}
}
}
pub fn num_special_tokens_pair(&self) -> usize {
match self {
PostProcessor::None => 0,
PostProcessor::Bert { .. } => 3, PostProcessor::Prefix { .. } => 2, PostProcessor::Template {
pair_a_prefix,
pair_a_suffix,
pair_b_prefix,
pair_b_suffix,
..
} => {
pair_a_prefix.len() + pair_a_suffix.len()
+ pair_b_prefix.len() + pair_b_suffix.len()
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_none_postprocessor() {
let pp = PostProcessor::None;
let tokens = vec![1, 2, 3];
assert_eq!(pp.process(&tokens), vec![1, 2, 3]);
}
#[test]
fn test_bert_single() {
let pp = PostProcessor::bert(101, 102);
let tokens = vec![7592]; assert_eq!(pp.process(&tokens), vec![101, 7592, 102]);
}
#[test]
fn test_bert_pair() {
let pp = PostProcessor::bert(101, 102);
let tokens_a = vec![7592]; let tokens_b = vec![2088];
let (tokens, type_ids) = pp.process_pair(&tokens_a, &tokens_b);
assert_eq!(tokens, vec![101, 7592, 102, 2088, 102]);
assert_eq!(type_ids, vec![0, 0, 0, 1, 1]);
}
#[test]
fn test_prefix_single() {
let pp = PostProcessor::prefix(128000);
let tokens = vec![9906]; assert_eq!(pp.process(&tokens), vec![128000, 9906]);
}
#[test]
fn test_prefix_pair() {
let pp = PostProcessor::prefix(128000);
let tokens_a = vec![9906]; let tokens_b = vec![4435];
let (tokens, type_ids) = pp.process_pair(&tokens_a, &tokens_b);
assert_eq!(tokens, vec![128000, 9906, 128000, 4435]);
assert_eq!(type_ids, vec![0, 0, 1, 1]);
}
#[test]
fn test_num_special_tokens() {
assert_eq!(PostProcessor::None.num_special_tokens_single(), 0);
assert_eq!(PostProcessor::None.num_special_tokens_pair(), 0);
let bert = PostProcessor::bert(101, 102);
assert_eq!(bert.num_special_tokens_single(), 2);
assert_eq!(bert.num_special_tokens_pair(), 3);
let prefix = PostProcessor::prefix(128000);
assert_eq!(prefix.num_special_tokens_single(), 1);
assert_eq!(prefix.num_special_tokens_pair(), 2);
}
}