use crate::error::{LlamaError, Result};
use crate::model::LlamaModel;
use crate::token::LlamaToken;
pub trait Tokenizer {
fn encode(&self, text: &str, add_bos: bool, special: bool) -> Result<Vec<LlamaToken>>;
fn decode(&self, tokens: &[LlamaToken], special: bool) -> Result<String>;
}
#[derive(Debug)]
pub struct LlamaTokenizer<'a> {
model: &'a LlamaModel,
}
impl<'a> LlamaTokenizer<'a> {
#[must_use]
pub const fn new(model: &'a LlamaModel) -> Self {
Self { model }
}
}
impl<'a> Tokenizer for LlamaTokenizer<'a> {
fn encode(&self, text: &str, add_bos: bool, special: bool) -> Result<Vec<LlamaToken>> {
self.model.tokenize(text, add_bos, special)
}
fn decode(&self, tokens: &[LlamaToken], special: bool) -> Result<String> {
self.model.detokenize(tokens, special)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct FimTokens {
pub prefix: LlamaToken,
pub suffix: LlamaToken,
pub middle: LlamaToken,
pub eot: Option<LlamaToken>,
}
impl FimTokens {
#[must_use]
pub fn is_supported(&self) -> bool {
self.prefix.0 >= 0 && self.suffix.0 >= 0 && self.middle.0 >= 0
}
pub fn build_prompt(&self, prefix: &str, suffix: &str) -> Result<String> {
if !self.is_supported() {
return Err(LlamaError::Batch("model does not support FIM".into()));
}
Ok(format!("{prefix} <FIM_SUF> {suffix} <FIM_MID>"))
}
}
impl LlamaModel {
#[must_use]
pub fn fim_tokens(&self) -> Option<FimTokens> {
use llama_crab_sys as sys;
let pre = unsafe { sys::llama_token_fim_pre(self.vocab()) };
let suf = unsafe { sys::llama_token_fim_suf(self.vocab()) };
let mid = unsafe { sys::llama_token_fim_mid(self.vocab()) };
let eot_raw = unsafe { sys::llama_token_eot(self.vocab()) };
if pre < 0 || suf < 0 || mid < 0 {
return None;
}
Some(FimTokens {
prefix: LlamaToken(pre),
suffix: LlamaToken(suf),
middle: LlamaToken(mid),
eot: if eot_raw >= 0 { Some(LlamaToken(eot_raw)) } else { None },
})
}
pub fn fim_prompt(&self, prefix: &str, suffix: &str) -> Result<String> {
match self.fim_tokens() {
Some(t) => t.build_prompt(prefix, suffix),
None => Err(LlamaError::Batch("model does not support FIM".into())),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fim_tokens_unsupported() {
assert!(!FimTokens {
prefix: LlamaToken(-1),
suffix: LlamaToken(-1),
middle: LlamaToken(-1),
eot: None,
}
.is_supported());
assert!(FimTokens {
prefix: LlamaToken(100),
suffix: LlamaToken(101),
middle: LlamaToken(102),
eot: None,
}
.is_supported());
}
#[test]
fn fim_build_prompt_unsupported() {
let t = FimTokens {
prefix: LlamaToken(-1),
suffix: LlamaToken(-1),
middle: LlamaToken(-1),
eot: None,
};
assert!(t.build_prompt("a", "b").is_err());
}
}