rlx-text 0.2.5

RLX text — tokenizer wrappers, chat templates, sampling. Public surface for downstream LM apps.
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

//! Chat-template engine for RLX runners.
//!
//! Replaces `LlamaModel::apply_chat_template` (llama-cpp-4) end-to-end. Two
//! sources: an inline Jinja2 string, or `tokenizer.chat_template` (and
//! `tokenizer.ggml.chat_template`) read directly from a GGUF file's
//! metadata. Rendering uses `minijinja`.
//!
//! BOS/EOS strings are looked up via `tokenizer.ggml.bos_token_id` /
//! `eos_token_id` against the `tokenizer.ggml.tokens` array (the GGUF
//! convention).

use anyhow::{Context, Result, anyhow};
use minijinja::{Environment, Error as JinjaError, ErrorKind, Value};
use rlx_gguf::{GgufFile, MetaValue};
use std::path::Path;

/// Convenience for the M3 auto-dispatch family: load the chat template
/// + BOS/EOS strings directly from a GGUF path.
///
/// Alias for [`ChatTemplate::from_gguf`]. Use `rlx_models::run::auto_chat_template(path)`
/// next to `rlx_models::run::auto_runner(path)`.
pub fn auto_chat_template(path: &Path) -> Result<ChatTemplate> {
    ChatTemplate::from_gguf(path)
}

/// One chat turn. `role` is conventionally one of `system`, `user`,
/// `assistant`, `tool` — but templates can accept anything.
#[derive(Debug, Clone)]
pub struct ChatMessage {
    pub role: String,
    pub content: String,
}

impl ChatMessage {
    pub fn user(content: impl Into<String>) -> Self {
        Self {
            role: "user".into(),
            content: content.into(),
        }
    }
    pub fn system(content: impl Into<String>) -> Self {
        Self {
            role: "system".into(),
            content: content.into(),
        }
    }
    pub fn assistant(content: impl Into<String>) -> Self {
        Self {
            role: "assistant".into(),
            content: content.into(),
        }
    }
}

/// Where a [`ChatTemplate`] was loaded from. Useful for diagnostics and
/// for letting a caller round-trip the source string into config.
#[derive(Debug, Clone)]
pub enum ChatTemplateSource {
    Inline,
    GgufMetadata(String),
}

/// Compiled Jinja chat template + BOS/EOS strings.
pub struct ChatTemplate {
    env: Environment<'static>,
    source_text: String,
    source_kind: ChatTemplateSource,
    bos_token: Option<String>,
    eos_token: Option<String>,
}

const TEMPLATE_NAME: &str = "chat";

fn build_env(source: String) -> Result<Environment<'static>> {
    let mut env = Environment::new();
    // HF templates occasionally call `raise_exception(msg)` for invariant
    // checks (e.g. "system must come first"). Wire it to a Jinja error.
    env.add_function(
        "raise_exception",
        |msg: String| -> Result<Value, JinjaError> {
            Err(JinjaError::new(ErrorKind::InvalidOperation, msg))
        },
    );
    env.add_template_owned(TEMPLATE_NAME, source)
        .context("compiling chat template")?;
    Ok(env)
}

impl ChatTemplate {
    /// Compile a chat template from a raw Jinja string.
    pub fn from_source(src: impl Into<String>) -> Result<Self> {
        let source_text: String = src.into();
        let env = build_env(source_text.clone())?;
        Ok(Self {
            env,
            source_text,
            source_kind: ChatTemplateSource::Inline,
            bos_token: None,
            eos_token: None,
        })
    }

    /// Override BOS/EOS strings (passed to the template as `bos_token` /
    /// `eos_token` Jinja variables).
    pub fn with_tokens(mut self, bos: Option<String>, eos: Option<String>) -> Self {
        self.bos_token = bos;
        self.eos_token = eos;
        self
    }

    /// Load template + BOS/EOS from a GGUF file. Reads
    /// `tokenizer.chat_template` first, then `tokenizer.ggml.chat_template`.
    pub fn from_gguf(path: &Path) -> Result<Self> {
        let raw = GgufFile::from_path(path).with_context(|| format!("opening GGUF {path:?}"))?;
        Self::from_gguf_file(&raw)
    }

    /// Same as [`from_gguf`](Self::from_gguf), but reuses an already-parsed file.
    pub fn from_gguf_file(raw: &GgufFile) -> Result<Self> {
        let (key, src) = pick_chat_template_meta(raw).ok_or_else(|| {
            anyhow!("no tokenizer.chat_template or tokenizer.ggml.chat_template in GGUF metadata")
        })?;
        let env = build_env(src.clone())?;
        let bos = resolve_special_token(raw, "tokenizer.ggml.bos_token_id");
        let eos = resolve_special_token(raw, "tokenizer.ggml.eos_token_id");
        Ok(Self {
            env,
            source_text: src,
            source_kind: ChatTemplateSource::GgufMetadata(key.to_owned()),
            bos_token: bos,
            eos_token: eos,
        })
    }

    pub fn source_text(&self) -> &str {
        &self.source_text
    }

    pub fn source_kind(&self) -> &ChatTemplateSource {
        &self.source_kind
    }

    pub fn bos_token(&self) -> Option<&str> {
        self.bos_token.as_deref()
    }

    pub fn eos_token(&self) -> Option<&str> {
        self.eos_token.as_deref()
    }

    /// Render the template with the given messages.
    ///
    /// The template sees Jinja variables: `messages` (list of
    /// `{role, content}` maps), `add_generation_prompt` (bool), and
    /// `bos_token` / `eos_token` strings (empty if unknown).
    pub fn render(&self, messages: &[ChatMessage], add_generation_prompt: bool) -> Result<String> {
        let msgs: Vec<Value> = messages
            .iter()
            .map(|m| {
                Value::from_serialize(serde_json::json!({
                    "role": m.role,
                    "content": m.content,
                }))
            })
            .collect();
        let ctx = minijinja::context! {
            messages => Value::from(msgs),
            add_generation_prompt => add_generation_prompt,
            bos_token => self.bos_token.clone().unwrap_or_default(),
            eos_token => self.eos_token.clone().unwrap_or_default(),
        };
        let tmpl = self
            .env
            .get_template(TEMPLATE_NAME)
            .expect("template registered in build_env");
        tmpl.render(ctx).context("rendering chat template")
    }
}

fn pick_chat_template_meta(raw: &GgufFile) -> Option<(&'static str, String)> {
    for key in ["tokenizer.chat_template", "tokenizer.ggml.chat_template"] {
        if let Some(MetaValue::String(s)) = raw.metadata.get(key) {
            return Some((key, s.clone()));
        }
    }
    None
}

fn resolve_special_token(raw: &GgufFile, id_key: &str) -> Option<String> {
    let id = raw.metadata.get(id_key).and_then(MetaValue::as_u32)? as usize;
    let toks = raw.metadata.get("tokenizer.ggml.tokens")?;
    let MetaValue::Array(arr) = toks else {
        return None;
    };
    match arr.get(id)? {
        MetaValue::String(s) => Some(s.clone()),
        _ => None,
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    // Minimal Qwen / ChatML-style template — same shape as Qwen3's, simplified
    // enough that test failures point at our rendering plumbing not at
    // upstream Jinja quirks. Whitespace-trim markers are intentionally
    // avoided so the literal `\n` inside the template survives.
    const QWEN_TEMPLATE: &str = "{% for m in messages %}<|im_start|>{{ m.role }}\n{{ m.content }}<|im_end|>\n{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}";

    // Minimal Llama-3-style template using bos_token + headers.
    const LLAMA3_TEMPLATE: &str = "{% for m in messages %}{% if loop.first %}{{ bos_token }}{% endif %}<|start_header_id|>{{ m.role }}<|end_header_id|>\n\n{{ m.content }}<|eot_id|>{% endfor %}{% if add_generation_prompt %}<|start_header_id|>assistant<|end_header_id|>\n\n{% endif %}";

    // Minimal Gemma-style template.
    const GEMMA_TEMPLATE: &str = "{% for m in messages %}{% set role = 'user' if m.role == 'system' else m.role %}<start_of_turn>{{ role }}\n{{ m.content }}<end_of_turn>\n{% endfor %}{% if add_generation_prompt %}<start_of_turn>model\n{% endif %}";

    fn sample_conv() -> Vec<ChatMessage> {
        vec![ChatMessage::system("be concise"), ChatMessage::user("hi")]
    }

    #[test]
    fn qwen_template_renders_with_generation_prompt() {
        let t = ChatTemplate::from_source(QWEN_TEMPLATE).unwrap();
        let out = t.render(&sample_conv(), true).unwrap();
        let expected = "<|im_start|>system\nbe concise<|im_end|>\n\
                        <|im_start|>user\nhi<|im_end|>\n\
                        <|im_start|>assistant\n";
        assert_eq!(out, expected);
    }

    #[test]
    fn qwen_template_omits_generation_prompt_when_disabled() {
        let t = ChatTemplate::from_source(QWEN_TEMPLATE).unwrap();
        let out = t.render(&sample_conv(), false).unwrap();
        assert!(out.ends_with("<|im_end|>\n"));
        assert!(!out.contains("<|im_start|>assistant\n"));
    }

    #[test]
    fn llama3_template_uses_bos_token() {
        let t = ChatTemplate::from_source(LLAMA3_TEMPLATE)
            .unwrap()
            .with_tokens(Some("<|begin_of_text|>".into()), Some("<|eot_id|>".into()));
        let out = t.render(&sample_conv(), true).unwrap();
        let expected = "<|begin_of_text|>\
                        <|start_header_id|>system<|end_header_id|>\n\nbe concise<|eot_id|>\
                        <|start_header_id|>user<|end_header_id|>\n\nhi<|eot_id|>\
                        <|start_header_id|>assistant<|end_header_id|>\n\n";
        assert_eq!(out, expected);
        assert_eq!(t.bos_token(), Some("<|begin_of_text|>"));
        assert_eq!(t.eos_token(), Some("<|eot_id|>"));
    }

    #[test]
    fn gemma_template_rewrites_system_to_user() {
        let t = ChatTemplate::from_source(GEMMA_TEMPLATE).unwrap();
        let out = t.render(&sample_conv(), true).unwrap();
        let expected = "<start_of_turn>user\nbe concise<end_of_turn>\n\
                        <start_of_turn>user\nhi<end_of_turn>\n\
                        <start_of_turn>model\n";
        assert_eq!(out, expected);
    }

    #[test]
    fn raise_exception_propagates_as_error() {
        let t = ChatTemplate::from_source("{{ raise_exception('nope') }}").unwrap();
        let err = t.render(&[], false).unwrap_err();
        assert!(format!("{err:#}").contains("nope"));
    }

    /// Builds a minimal GGUF in a temp file with a chat_template + token
    /// table, then verifies BOS/EOS resolve and rendering works.
    #[test]
    fn from_gguf_reads_template_and_special_tokens() {
        // We build a v3 GGUF with three metadata keys:
        //   tokenizer.chat_template      (String)
        //   tokenizer.ggml.tokens        (Array of String)
        //   tokenizer.ggml.bos_token_id  (U32)
        //   tokenizer.ggml.eos_token_id  (U32)
        // and one tiny f32 tensor so the file passes the loader.
        let mut buf: Vec<u8> = Vec::new();
        buf.extend_from_slice(&rlx_gguf::GGUF_MAGIC.to_le_bytes());
        buf.extend_from_slice(&3u32.to_le_bytes());
        buf.extend_from_slice(&1u64.to_le_bytes()); // tensor count
        buf.extend_from_slice(&4u64.to_le_bytes()); // kv count

        let write_string_kv = |buf: &mut Vec<u8>, k: &str, v: &str| {
            buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
            buf.extend_from_slice(k.as_bytes());
            buf.extend_from_slice(&8u32.to_le_bytes());
            buf.extend_from_slice(&(v.len() as u64).to_le_bytes());
            buf.extend_from_slice(v.as_bytes());
        };
        let write_u32_kv = |buf: &mut Vec<u8>, k: &str, v: u32| {
            buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
            buf.extend_from_slice(k.as_bytes());
            buf.extend_from_slice(&4u32.to_le_bytes());
            buf.extend_from_slice(&v.to_le_bytes());
        };
        let write_string_array_kv = |buf: &mut Vec<u8>, k: &str, items: &[&str]| {
            buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
            buf.extend_from_slice(k.as_bytes());
            // type = Array(9)
            buf.extend_from_slice(&9u32.to_le_bytes());
            // element type = String(8)
            buf.extend_from_slice(&8u32.to_le_bytes());
            // length (u64)
            buf.extend_from_slice(&(items.len() as u64).to_le_bytes());
            for s in items {
                buf.extend_from_slice(&(s.len() as u64).to_le_bytes());
                buf.extend_from_slice(s.as_bytes());
            }
        };

        write_string_kv(&mut buf, "tokenizer.chat_template", QWEN_TEMPLATE);
        write_string_array_kv(
            &mut buf,
            "tokenizer.ggml.tokens",
            &["<pad>", "<bos>", "<eos>", "hi"],
        );
        write_u32_kv(&mut buf, "tokenizer.ggml.bos_token_id", 1);
        write_u32_kv(&mut buf, "tokenizer.ggml.eos_token_id", 2);

        // tiny f32 tensor
        let name = "w";
        buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
        buf.extend_from_slice(name.as_bytes());
        buf.extend_from_slice(&1u32.to_le_bytes());
        buf.extend_from_slice(&4u64.to_le_bytes());
        buf.extend_from_slice(&(rlx_gguf::GgmlType::F32 as u32).to_le_bytes());
        buf.extend_from_slice(&0u64.to_le_bytes());
        while !buf
            .len()
            .is_multiple_of(rlx_gguf::DEFAULT_ALIGNMENT as usize)
        {
            buf.push(0);
        }
        for _ in 0..4 {
            buf.extend_from_slice(&1.0f32.to_le_bytes());
        }
        let path = std::env::temp_dir().join("rlx_chat_template_from_gguf.gguf");
        std::fs::write(&path, &buf).unwrap();

        let t = ChatTemplate::from_gguf(&path).expect("from_gguf");
        assert_eq!(t.bos_token(), Some("<bos>"));
        assert_eq!(t.eos_token(), Some("<eos>"));
        let out = t.render(&sample_conv(), true).unwrap();
        assert!(out.contains("<|im_start|>assistant\n"));
        match t.source_kind() {
            ChatTemplateSource::GgufMetadata(k) => assert_eq!(k, "tokenizer.chat_template"),
            other => panic!("unexpected source: {other:?}"),
        }
        std::fs::remove_file(&path).ok();
    }
}