use crate::{error::TokenizerResult, utils::render_template};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ChatMessage<'a> {
pub role: &'a str,
pub content: &'a str,
}
impl<'a> ChatMessage<'a> {
pub fn new(role: &'a str, content: &'a str) -> Self {
Self { role, content }
}
pub fn user(content: &'a str) -> Self {
Self::new("user", content)
}
pub fn assistant(content: &'a str) -> Self {
Self::new("assistant", content)
}
pub fn system(content: &'a str) -> Self {
Self::new("system", content)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum ChatTemplateKind {
ChatML,
Llama3,
Mistral,
Gemma,
Qwen,
}
impl ChatTemplateKind {
pub fn template(&self) -> &'static str {
match self {
Self::ChatML => CHATML_TEMPLATE,
Self::Llama3 => LLAMA3_TEMPLATE,
Self::Mistral => MISTRAL_TEMPLATE,
Self::Gemma => GEMMA_TEMPLATE,
Self::Qwen => QWEN_TEMPLATE,
}
}
pub fn render(&self, messages: &[ChatMessage<'_>]) -> String {
let pairs: Vec<(&str, &str)> = messages.iter().map(|m| (m.role, m.content)).collect();
render_template(self.template(), &pairs)
}
pub fn render_with_generation_prompt(&self, messages: &[ChatMessage<'_>]) -> String {
let mut out = self.render(messages);
out.push_str(self.generation_prompt());
out
}
pub fn generation_prompt(&self) -> &'static str {
match self {
Self::ChatML | Self::Qwen => "<|im_start|>assistant\n",
Self::Llama3 => "<|start_header_id|>assistant<|end_header_id|>\n\n",
Self::Mistral => "",
Self::Gemma => "<start_of_turn>model\n",
}
}
pub fn encode(
&self,
tokenizer: &crate::OxiTokenizer,
messages: &[ChatMessage<'_>],
) -> TokenizerResult<Vec<u32>> {
tokenizer.encode(&self.render(messages))
}
pub fn all() -> &'static [ChatTemplateKind] {
&[
Self::ChatML,
Self::Llama3,
Self::Mistral,
Self::Gemma,
Self::Qwen,
]
}
pub fn infer_from_name(name: &str) -> Option<Self> {
let n = name.to_ascii_lowercase();
if n.contains("llama-3") || n.contains("llama3") {
Some(Self::Llama3)
} else if n.contains("mistral") {
Some(Self::Mistral)
} else if n.contains("gemma") {
Some(Self::Gemma)
} else if n.contains("qwen") {
Some(Self::Qwen)
} else if n.contains("chatml") {
Some(Self::ChatML)
} else {
None
}
}
}
const CHATML_TEMPLATE: &str =
"{% for message in messages %}<|im_start|>{{ role }}\n{{ content }}<|im_end|>\n{% endfor %}";
const LLAMA3_TEMPLATE: &str = concat!(
"<|begin_of_text|>",
"{% for message in messages %}",
"<|start_header_id|>{{ role }}<|end_header_id|>\n\n",
"{{ content }}<|eot_id|>",
"{% endfor %}"
);
const MISTRAL_TEMPLATE: &str = concat!(
"{% for message in messages %}",
"{% if role == \"user\" %}<s>[INST] {{ content }} [/INST]{% else %} {{ content }}</s>{% endif %}",
"{% endfor %}"
);
const GEMMA_TEMPLATE: &str = concat!(
"{% for message in messages %}",
"<start_of_turn>{{ role }}\n{{ content }}<end_of_turn>\n",
"{% endfor %}"
);
const QWEN_TEMPLATE: &str =
"{% for message in messages %}<|im_start|>{{ role }}\n{{ content }}<|im_end|>\n{% endfor %}";
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn all_kinds_yield_a_template() {
for k in ChatTemplateKind::all() {
assert!(!k.template().is_empty(), "template for {k:?} empty");
}
}
#[test]
fn chatml_renders_basic() {
let out = ChatTemplateKind::ChatML.render(&[ChatMessage::user("hi")]);
assert!(out.contains("<|im_start|>user"));
assert!(out.contains("hi"));
assert!(out.contains("<|im_end|>"));
}
#[test]
fn llama3_renders_basic() {
let out = ChatTemplateKind::Llama3.render(&[ChatMessage::user("hi")]);
assert!(out.contains("<|begin_of_text|>"));
assert!(out.contains("<|start_header_id|>user<|end_header_id|>"));
assert!(out.contains("<|eot_id|>"));
}
#[test]
fn mistral_renders_basic() {
let out = ChatTemplateKind::Mistral
.render(&[ChatMessage::user("hi"), ChatMessage::assistant("there")]);
assert!(out.contains("[INST] hi [/INST]"));
assert!(out.contains("there"));
}
#[test]
fn gemma_renders_basic() {
let out = ChatTemplateKind::Gemma.render(&[ChatMessage::user("hi")]);
assert!(out.contains("<start_of_turn>user"));
assert!(out.contains("<end_of_turn>"));
}
#[test]
fn qwen_renders_basic() {
let out = ChatTemplateKind::Qwen.render(&[ChatMessage::user("hi")]);
assert!(out.contains("<|im_start|>user"));
assert!(out.contains("<|im_end|>"));
}
#[test]
fn generation_prompt_chatml() {
let p = ChatTemplateKind::ChatML.generation_prompt();
assert!(p.contains("assistant"));
}
#[test]
fn render_with_generation_prompt() {
let out =
ChatTemplateKind::ChatML.render_with_generation_prompt(&[ChatMessage::user("hi")]);
assert!(out.ends_with("<|im_start|>assistant\n"));
}
#[test]
fn infer_from_name_known() {
assert_eq!(
ChatTemplateKind::infer_from_name("Qwen3-1.7B"),
Some(ChatTemplateKind::Qwen)
);
assert_eq!(
ChatTemplateKind::infer_from_name("Meta-Llama-3-8B-Instruct"),
Some(ChatTemplateKind::Llama3)
);
assert_eq!(
ChatTemplateKind::infer_from_name("mistral-7b"),
Some(ChatTemplateKind::Mistral)
);
assert_eq!(
ChatTemplateKind::infer_from_name("gemma-2b"),
Some(ChatTemplateKind::Gemma)
);
}
#[test]
fn infer_from_name_unknown() {
assert_eq!(ChatTemplateKind::infer_from_name("bert-base"), None);
}
#[test]
fn encode_works_with_stub() {
let tok = crate::OxiTokenizer::char_level_stub(256);
let ids = ChatTemplateKind::ChatML
.encode(&tok, &[ChatMessage::user("hi")])
.expect("encode ok");
assert!(!ids.is_empty());
}
#[test]
fn chat_message_constructors() {
let u = ChatMessage::user("x");
assert_eq!(u.role, "user");
let a = ChatMessage::assistant("y");
assert_eq!(a.role, "assistant");
let s = ChatMessage::system("z");
assert_eq!(s.role, "system");
}
}