mlx_lm_utils/
tokenizer.rs

1// Args:
2//     conversation (Union[list[dict[str, str]], list[list[dict[str, str]]]]): A list of dicts
3//         with "role" and "content" keys, representing the chat history so far.
4//     tools (`list[Union[Dict, Callable]]`, *optional*):
5//         A list of tools (callable functions) that will be accessible to the model. If the template does not
6//         support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
7//         giving the name, description and argument types for the tool. See our
8//         [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
9//         for more information.
10//     documents (`list[dict[str, str]]`, *optional*):
11//         A list of dicts representing documents that will be accessible to the model if it is performing RAG
12//         (retrieval-augmented generation). If the template does not support RAG, this argument will have no
13//         effect. We recommend that each document should be a dict containing "title" and "text" keys. Please
14//         see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG)
15//         for examples of passing documents with chat templates.
16//     chat_template (`str`, *optional*):
17//         A Jinja template to use for this conversion. It is usually not necessary to pass anything to this
18//         argument, as the model's template will be used by default.
19//     add_generation_prompt (bool, *optional*):
20//         If this is set, a prompt with the token(s) that indicate
21//         the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model.
22//         Note that this argument will be passed to the chat template, and so it must be supported in the
23//         template for this argument to have any effect.
24//     continue_final_message (bool, *optional*):
25//         If this is set, the chat will be formatted so that the final
26//         message in the chat is open-ended, without any EOS tokens. The model will continue this message
27//         rather than starting a new one. This allows you to "prefill" part of
28//         the model's response for it. Cannot be used at the same time as `add_generation_prompt`.
29//     tokenize (`bool`, defaults to `True`):
30//         Whether to tokenize the output. If `False`, the output will be a string.
31//     padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
32//             Select a strategy to pad the returned sequences (according to the model's padding side and padding
33//             index) among:
34
35//         - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
36//             sequence if provided).
37//         - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
38//             acceptable input length for the model if that argument is not provided.
39//         - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
40//             lengths).
41//     truncation (`bool`, defaults to `False`):
42//         Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`.
43//     max_length (`int`, *optional*):
44//         Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is `False`. If
45//         not specified, the tokenizer's `max_length` attribute will be used as a default.
46//     return_tensors (`str` or [`~utils.TensorType`], *optional*):
47//         If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable
48//         values are:
49//         - `'tf'`: Return TensorFlow `tf.Tensor` objects.
50//         - `'pt'`: Return PyTorch `torch.Tensor` objects.
51//         - `'np'`: Return NumPy `np.ndarray` objects.
52//         - `'jax'`: Return JAX `jnp.ndarray` objects.
53//     return_dict (`bool`, defaults to `False`):
54//         Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
55//     tokenizer_kwargs (`dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer.
56//     return_assistant_tokens_mask (`bool`, defaults to `False`):
57//         Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant,
58//         the mask will contain 1. For user and system tokens, the mask will contain 0.
59//         This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
60//     **kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
61
62// Returns:
63//     `Union[list[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This
64//     output is ready to pass to the model, either directly or via methods like `generate()`. If `return_dict` is
65//     set, will return a dict of tokenizer outputs instead.
66// """
67
68use std::{
69    collections::HashMap,
70    fs::read_to_string,
71    ops::{Deref, DerefMut},
72    path::Path,
73    str::FromStr,
74};
75
76use minijinja::{context, Environment, Template};
77use serde::Serialize;
78use tokenizers::Encoding;
79
80use crate::error::Error;
81
82/// Wrapper around [`tokenizers::Tokenizer`] and [`minijinja::Environment`]
83/// providing more utilities.
84pub struct Tokenizer {
85    inner: tokenizers::Tokenizer,
86    env: Environment<'static>,
87}
88
89impl FromStr for Tokenizer {
90    type Err = tokenizers::Error;
91
92    fn from_str(s: &str) -> Result<Self, Self::Err> {
93        tokenizers::Tokenizer::from_str(s).map(Self::from_tokenizer)
94    }
95}
96
97impl Tokenizer {
98    pub fn from_tokenizer(tokenizer: tokenizers::Tokenizer) -> Self {
99        let mut env = Environment::new();
100        env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
101        Self {
102            inner: tokenizer,
103            env,
104        }
105    }
106
107    pub fn from_file(file: impl AsRef<Path>) -> tokenizers::Result<Self> {
108        tokenizers::Tokenizer::from_file(file).map(Self::from_tokenizer)
109    }
110
111    pub fn from_bytes(bytes: impl AsRef<[u8]>) -> tokenizers::Result<Self> {
112        tokenizers::Tokenizer::from_bytes(bytes).map(Self::from_tokenizer)
113    }
114
115    pub fn apply_chat_template<'a, I, R, T>(
116        &'a mut self,
117        model_template: String,
118        args: ApplyChatTemplateArgs<'a, I, R, T>,
119    ) -> Result<Vec<String>, Error>
120    where
121        I: IntoIterator<Item = Chat<'a, R, T>>,
122        R: Serialize + 'a,
123        T: Serialize + ToString + 'a,
124    {
125        apply_chat_template(&mut self.env, model_template, args)
126    }
127
128    pub fn apply_chat_template_and_encode<'a, I, R, T>(
129        &mut self,
130        model_template: String,
131        args: ApplyChatTemplateArgs<'a, I, R, T>,
132    ) -> Result<Vec<Encoding>, Error>
133    where
134        I: IntoIterator<Item = Chat<'a, R, T>>,
135        R: Serialize + 'a,
136        T: Serialize + ToString + 'a,
137    {
138        let Self { inner, env } = self;
139
140        let rendered_chats = apply_chat_template(env, model_template, args)?;
141        inner
142            .encode_batch(rendered_chats, false)
143            .map_err(Into::into)
144    }
145}
146
147impl Deref for Tokenizer {
148    type Target = tokenizers::Tokenizer;
149
150    fn deref(&self) -> &Self::Target {
151        &self.inner
152    }
153}
154
155impl DerefMut for Tokenizer {
156    fn deref_mut(&mut self) -> &mut Self::Target {
157        &mut self.inner
158    }
159}
160
161#[derive(Debug, Clone, Copy, Serialize)]
162#[serde(rename_all = "lowercase")]
163pub enum Role {
164    User,
165    Assistant,
166}
167
168#[derive(Debug, Clone, Serialize)]
169pub enum Content {
170    String(String),
171    Map(HashMap<String, String>),
172}
173
174#[derive(Debug, Clone, Serialize)]
175pub struct Conversation<R, T> {
176    pub role: R,
177    pub content: T,
178}
179
180#[derive(Debug, Clone, Serialize)]
181#[serde(untagged)]
182pub enum Chat<'a, R, T> {
183    Borrowed(&'a [Conversation<R, T>]),
184    Owned(Vec<Conversation<R, T>>),
185}
186
187impl<R, T> Deref for Chat<'_, R, T> {
188    type Target = [Conversation<R, T>];
189
190    fn deref(&self) -> &Self::Target {
191        match self {
192            Chat::Borrowed(conversations) => conversations,
193            Chat::Owned(conversations) => conversations,
194        }
195    }
196}
197
198impl<R, T> From<Vec<Conversation<R, T>>> for Chat<'_, R, T> {
199    fn from(value: Vec<Conversation<R, T>>) -> Self {
200        Chat::Owned(value)
201    }
202}
203
204impl<'a, R, T> From<&'a [Conversation<R, T>]> for Chat<'a, R, T> {
205    fn from(value: &'a [Conversation<R, T>]) -> Self {
206        Chat::Borrowed(value)
207    }
208}
209
210#[derive(Debug, Clone, Serialize)]
211pub struct Document {
212    pub title: String,
213    pub text: String,
214}
215
216pub enum Padding {
217    Longest,
218    MaxLength,
219}
220
221pub enum Truncation {
222    MaxLength(usize),
223}
224
225#[derive(Default)]
226pub struct ApplyChatTemplateArgs<'a, I, R = Role, T = String>
227where
228    I: IntoIterator<Item = Chat<'a, R, T>>,
229    R: Serialize + 'a,
230    T: Serialize + ToString + 'a,
231{
232    // pub conversations: &'a [Conversation<R, T>],
233    pub conversations: I,
234    // pub tools: Option<Box<dyn FnOnce()>>, // TODO
235    pub documents: Option<&'a [Document]>,
236    pub model_id: &'a str,
237    pub chat_template_id: Option<&'a str>,
238    pub add_generation_prompt: Option<bool>,
239    pub continue_final_message: Option<bool>,
240}
241
242pub fn load_model_chat_template_from_str(content: &str) -> std::io::Result<Option<String>> {
243    serde_json::from_str::<serde_json::Value>(content)
244        .map(|value| {
245            value
246                .get("chat_template")
247                .and_then(|value| value.as_str())
248                .map(ToString::to_string)
249        })
250        .map_err(Into::into)
251}
252
253pub fn load_model_chat_template_from_file(
254    file: impl AsRef<Path>,
255) -> std::io::Result<Option<String>> {
256    let content = read_to_string(file)?;
257    load_model_chat_template_from_str(&content)
258}
259
260// chat_template = self.get_chat_template(chat_template, tools)
261
262// if isinstance(conversation, (list, tuple)) and (
263//     isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
264// ):
265//     conversations = conversation
266//     is_batched = True
267// else:
268//     conversations = [conversation]
269//     is_batched = False
270
271// if continue_final_message:
272//     if add_generation_prompt:
273//         raise ValueError(
274//             "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead."
275//         )
276//     if return_assistant_tokens_mask:
277//         raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.")
278
279// template_kwargs = {**self.special_tokens_map, **kwargs}  # kwargs overwrite special tokens if both are present
280// rendered_chat, generation_indices = render_jinja_template(
281//     conversations=conversations,
282//     tools=tools,
283//     documents=documents,
284//     chat_template=chat_template,
285//     return_assistant_tokens_mask=return_assistant_tokens_mask,
286//     continue_final_message=continue_final_message,
287//     add_generation_prompt=add_generation_prompt,
288//     **template_kwargs,
289// )
290
291// if not is_batched:
292//     rendered_chat = rendered_chat[0]
293
294// if tokenize:
295//     out = self(
296//         rendered_chat,
297//         padding=padding,
298//         truncation=truncation,
299//         max_length=max_length,
300//         add_special_tokens=False,
301//         return_tensors=return_tensors,
302//         **tokenizer_kwargs,
303//     )
304//     if return_dict:
305//         if return_assistant_tokens_mask:
306//             assistant_masks = []
307//             if is_batched or return_tensors:
308//                 input_ids = out["input_ids"]
309//             else:
310//                 input_ids = [out["input_ids"]]
311//             for i in range(len(input_ids)):
312//                 current_mask = [0] * len(input_ids[i])
313//                 for assistant_start_char, assistant_end_char in generation_indices[i]:
314//                     start_token = out.char_to_token(i, assistant_start_char)
315//                     end_token = out.char_to_token(i, assistant_end_char - 1)
316//                     if start_token is None:
317//                         # start_token is out of bounds maybe due to truncation.
318//                         break
319//                     for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])):
320//                         current_mask[token_id] = 1
321//                 assistant_masks.append(current_mask)
322
323//             if not is_batched and not return_tensors:
324//                 assistant_masks = assistant_masks[0]
325
326//             out["assistant_masks"] = assistant_masks
327
328//             if return_tensors:
329//                 out.convert_to_tensors(tensor_type=return_tensors)
330
331//         return out
332//     else:
333//         return out["input_ids"]
334// else:
335//     return rendered_chat
336
337// def render_jinja_template(
338//     conversations: list[list[dict[str, str]]],
339//     tools: Optional[list[Union[dict, Callable]]] = None,
340//     documents: Optional[list[dict[str, str]]] = None,
341//     chat_template: Optional[str] = None,
342//     return_assistant_tokens_mask: Optional[bool] = False,
343//     continue_final_message: Optional[bool] = False,
344//     add_generation_prompt: Optional[bool] = False,
345//     **kwargs,
346// ) -> str:
347//     if return_assistant_tokens_mask and not re.search(r"\{\%-?\s*generation\s*-?\%\}", chat_template):
348//         logger.warning_once(
349//             "return_assistant_tokens_mask==True but chat template does not contain `{% generation %}` keyword."
350//         )
351
352//     # Compilation function uses a cache to avoid recompiling the same template
353//     compiled_template = _compile_jinja_template(chat_template)
354
355//     # We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas
356//     if tools is not None:
357//         tool_schemas = []
358//         for tool in tools:
359//             if isinstance(tool, dict):
360//                 tool_schemas.append(tool)
361//             elif isfunction(tool):
362//                 tool_schemas.append(get_json_schema(tool))
363//             else:
364//                 raise ValueError(
365//                     "Tools should either be a JSON schema, or a callable function with type hints "
366//                     "and a docstring suitable for auto-conversion to a schema."
367//                 )
368//     else:
369//         tool_schemas = None
370
371//     if documents is not None:
372//         for document in documents:
373//             if not isinstance(document, dict):
374//                 raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!")
375
376//     rendered = []
377//     all_generation_indices = []
378//     for chat in conversations:
379//         if hasattr(chat, "messages"):
380//             # Indicates it's a Conversation object
381//             chat = chat.messages
382//         if return_assistant_tokens_mask:
383//             rendered_chat, generation_indices = _render_with_assistant_indices(
384//                 compiled_template=compiled_template,
385//                 messages=chat,
386//                 tools=tool_schemas,
387//                 documents=documents,
388//                 add_generation_prompt=add_generation_prompt,
389//                 **kwargs,
390//             )
391//             all_generation_indices.append(generation_indices)
392//         else:
393//             rendered_chat = compiled_template.render(
394//                 messages=chat,
395//                 tools=tool_schemas,
396//                 documents=documents,
397//                 add_generation_prompt=add_generation_prompt,
398//                 **kwargs,
399//             )
400//         if continue_final_message:
401//             final_message = chat[-1]["content"]
402//             if isinstance(final_message, (list, tuple)):
403//                 for content_block in reversed(final_message):
404//                     if "text" in content_block:
405//                         # Pick the last text block in the message (the first one we hit while iterating in reverse)
406//                         final_message = content_block["text"]
407//                         break
408//                 else:
409//                     raise ValueError(
410//                         "continue_final_message is set but we could not find any text to continuein the final message!"
411//                     )
412//             if final_message.strip() not in rendered_chat:
413//                 raise ValueError(
414//                     "continue_final_message is set but the final message does not appear in the chat after "
415//                     "applying the chat template! This can happen if the chat template deletes portions of "
416//                     "the final message. Please verify the chat template and final message in your chat to "
417//                     "ensure they are compatible."
418//                 )
419//             final_msg_loc = rendered_chat.rindex(final_message.strip())
420//             if rendered_chat[final_msg_loc : final_msg_loc + len(final_message.lstrip())] == final_message:
421//                 # The template preserves spacing or the message doesn't have trailing spacing, so things are simple
422//                 rendered_chat = rendered_chat[: final_msg_loc + len(final_message.lstrip())]
423//             else:
424//                 # The message has trailing spacing that was trimmed, so we must be more cautious
425//                 rendered_chat = rendered_chat[: final_msg_loc + len(final_message.strip())]
426//         rendered.append(rendered_chat)
427
428//     return rendered, all_generation_indices
429
430pub fn apply_chat_template<'a, I, R, T>(
431    env: &mut Environment<'static>,
432    model_template: String,
433    args: ApplyChatTemplateArgs<'a, I, R, T>,
434) -> Result<Vec<String>, Error>
435where
436    I: IntoIterator<Item = Chat<'a, R, T>>,
437    R: Serialize + 'a,
438    T: Serialize + ToString + 'a,
439{
440    let ApplyChatTemplateArgs {
441        conversations,
442        // tools,
443        documents,
444        model_id,
445        chat_template_id,
446        add_generation_prompt,
447        continue_final_message,
448    } = args;
449
450    let add_generation_prompt = add_generation_prompt.unwrap_or(false);
451    let continue_final_message = continue_final_message.unwrap_or(false);
452
453    let template = match chat_template_id {
454        Some(chat_template_id) => env.get_template(chat_template_id)?,
455        None => match env.get_template(model_id) {
456            Ok(template) => template,
457            Err(_) => {
458                env.add_template_owned(model_id.to_owned(), model_template)?;
459                env.get_template(model_id)
460                    .expect("Newly added template must be present")
461            }
462        },
463    };
464
465    // TODO: handle tool
466
467    // TODO: allow return_generation_indices
468
469    render_jinja_tempalte(
470        template,
471        conversations,
472        documents,
473        Some(add_generation_prompt),
474        Some(continue_final_message),
475    )
476}
477
478// TODO: render with assistant indices
479fn render_jinja_tempalte<'a, R, T>(
480    template: Template,
481    conversations: impl IntoIterator<Item = Chat<'a, R, T>>,
482    documents: Option<&'a [Document]>,
483    add_generation_prompt: Option<bool>,
484    continue_final_message: Option<bool>,
485) -> Result<Vec<String>, Error>
486where
487    R: Serialize + 'a,
488    T: Serialize + ToString + 'a,
489{
490    let add_generation_prompt = add_generation_prompt.unwrap_or(false);
491    let continue_final_message = continue_final_message.unwrap_or(false);
492
493    // TODO: what does checking for "messages" key do in the python code?
494    let mut rendered = Vec::new();
495    for chat in conversations {
496        let mut rendered_chat = template.render(context! {
497            messages => chat,
498            documents => documents,
499            add_generation_prompt => add_generation_prompt,
500        })?;
501
502        if continue_final_message {
503            let Some(final_message) = chat.last().map(|chat| &chat.content) else {
504                continue;
505            };
506
507            let final_message_str = final_message.to_string();
508
509            if !rendered_chat.contains(final_message_str.trim()) {
510                return Err(Error::FinalMsgNotInChat);
511            }
512
513            let final_msg_loc = rendered_chat.rfind(&final_message_str.trim()).unwrap();
514            let final_msg_len = final_message_str.trim_start().len();
515            rendered_chat = if rendered_chat[final_msg_loc..final_msg_loc + final_msg_len]
516                == final_message_str
517            {
518                // The template preserves spacing or the message doesn't have trailing spacing, so things are simple
519                rendered_chat[..final_msg_loc + final_msg_len].to_string()
520            } else {
521                // The message has trailing spacing that was trimmed, so we must be more cautious
522                rendered_chat[..final_msg_loc + final_message_str.trim().len()].to_string()
523            };
524        }
525        rendered.push(rendered_chat);
526    }
527
528    Ok(rendered)
529}
530
531#[cfg(test)]
532mod tests {
533    use minijinja::Environment;
534    use std::path::Path;
535
536    use crate::tokenizer::{
537        apply_chat_template, load_model_chat_template_from_file, ApplyChatTemplateArgs,
538        Conversation, Role,
539    };
540
541    // let model_id = "mlx-community/Qwen3-4B-bf16".to_string();
542    const CACHED_TEST_MODEL_DIR: &str = "../cache/Qwen3-4B-bf16";
543
544    // TODO: how to test this in CI? the model files might be too large
545    #[test]
546    fn test_load_chat_template_from_file() {
547        let file = Path::new(CACHED_TEST_MODEL_DIR).join("tokenizer_config.json");
548        let chat_template = load_model_chat_template_from_file(file).unwrap().unwrap();
549        assert!(!chat_template.is_empty());
550    }
551
552    #[test]
553    fn test_apply_chat_template() {
554        let file = Path::new(CACHED_TEST_MODEL_DIR).join("tokenizer_config.json");
555        let model_chat_template = load_model_chat_template_from_file(file).unwrap().unwrap();
556        assert!(!model_chat_template.is_empty());
557
558        let model_id = "mlx-community/Qwen3-4B-bf16".to_string();
559        let conversations = vec![Conversation {
560            role: Role::User,
561            content: "hello",
562        }];
563        let args = ApplyChatTemplateArgs {
564            conversations: [conversations.into()],
565            documents: None,
566            model_id: &model_id,
567            chat_template_id: None,
568            add_generation_prompt: None,
569            continue_final_message: None,
570        };
571
572        let mut env = Environment::new();
573        env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
574
575        let rendered_chat = apply_chat_template(&mut env, model_chat_template, args).unwrap();
576        println!("{:?}", rendered_chat);
577    }
578
579    #[test]
580    fn test_tokenizer_apply_chat_template() {
581        let tokenizer_file = Path::new(CACHED_TEST_MODEL_DIR).join("tokenizer.json");
582        let tokenizer_config_file = Path::new(CACHED_TEST_MODEL_DIR).join("tokenizer_config.json");
583
584        let model_id = "mlx-community/Qwen3-4B-bf16".to_string();
585
586        let conversations = vec![Conversation {
587            role: Role::User,
588            content: "hello",
589        }];
590
591        let mut tokenizer = super::Tokenizer::from_file(tokenizer_file).unwrap();
592
593        let model_chat_template = load_model_chat_template_from_file(tokenizer_config_file)
594            .unwrap()
595            .unwrap();
596        assert!(!model_chat_template.is_empty());
597
598        let args = ApplyChatTemplateArgs {
599            conversations: [conversations.into()],
600            documents: None,
601            model_id: &model_id,
602            chat_template_id: None,
603            add_generation_prompt: None,
604            continue_final_message: None,
605        };
606
607        let rendered_chat = tokenizer
608            .apply_chat_template(model_chat_template, args)
609            .unwrap();
610        println!("{:?}", rendered_chat);
611    }
612
613    #[test]
614    fn test_tokenizer_apply_chat_template_and_encode() {
615        let tokenizer_file = Path::new(CACHED_TEST_MODEL_DIR).join("tokenizer.json");
616        let tokenizer_config_file = Path::new(CACHED_TEST_MODEL_DIR).join("tokenizer_config.json");
617
618        let model_id = "mlx-community/Qwen3-4B-bf16".to_string();
619
620        let conversations = vec![Conversation {
621            role: Role::User,
622            content: "hello",
623        }];
624        let mut tokenizer = super::Tokenizer::from_file(tokenizer_file).unwrap();
625
626        let model_chat_template = load_model_chat_template_from_file(tokenizer_config_file)
627            .unwrap()
628            .unwrap();
629        assert!(!model_chat_template.is_empty());
630
631        let args = ApplyChatTemplateArgs {
632            conversations: [conversations.into()],
633            documents: None,
634            model_id: &model_id,
635            chat_template_id: None,
636            add_generation_prompt: None,
637            continue_final_message: None,
638        };
639
640        let encodings = tokenizer
641            .apply_chat_template_and_encode(model_chat_template, args)
642            .unwrap();
643        println!("{:?}", encodings.iter().map(|e| e.get_ids()).flatten());
644    }
645}