mlx_lm_utils/tokenizer.rs
1// Args:
2// conversation (Union[list[dict[str, str]], list[list[dict[str, str]]]]): A list of dicts
3// with "role" and "content" keys, representing the chat history so far.
4// tools (`list[Union[Dict, Callable]]`, *optional*):
5// A list of tools (callable functions) that will be accessible to the model. If the template does not
6// support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
7// giving the name, description and argument types for the tool. See our
8// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
9// for more information.
10// documents (`list[dict[str, str]]`, *optional*):
11// A list of dicts representing documents that will be accessible to the model if it is performing RAG
12// (retrieval-augmented generation). If the template does not support RAG, this argument will have no
13// effect. We recommend that each document should be a dict containing "title" and "text" keys. Please
14// see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG)
15// for examples of passing documents with chat templates.
16// chat_template (`str`, *optional*):
17// A Jinja template to use for this conversion. It is usually not necessary to pass anything to this
18// argument, as the model's template will be used by default.
19// add_generation_prompt (bool, *optional*):
20// If this is set, a prompt with the token(s) that indicate
21// the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model.
22// Note that this argument will be passed to the chat template, and so it must be supported in the
23// template for this argument to have any effect.
24// continue_final_message (bool, *optional*):
25// If this is set, the chat will be formatted so that the final
26// message in the chat is open-ended, without any EOS tokens. The model will continue this message
27// rather than starting a new one. This allows you to "prefill" part of
28// the model's response for it. Cannot be used at the same time as `add_generation_prompt`.
29// tokenize (`bool`, defaults to `True`):
30// Whether to tokenize the output. If `False`, the output will be a string.
31// padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
32// Select a strategy to pad the returned sequences (according to the model's padding side and padding
33// index) among:
34
35// - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
36// sequence if provided).
37// - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
38// acceptable input length for the model if that argument is not provided.
39// - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
40// lengths).
41// truncation (`bool`, defaults to `False`):
42// Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`.
43// max_length (`int`, *optional*):
44// Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is `False`. If
45// not specified, the tokenizer's `max_length` attribute will be used as a default.
46// return_tensors (`str` or [`~utils.TensorType`], *optional*):
47// If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable
48// values are:
49// - `'tf'`: Return TensorFlow `tf.Tensor` objects.
50// - `'pt'`: Return PyTorch `torch.Tensor` objects.
51// - `'np'`: Return NumPy `np.ndarray` objects.
52// - `'jax'`: Return JAX `jnp.ndarray` objects.
53// return_dict (`bool`, defaults to `False`):
54// Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
55// tokenizer_kwargs (`dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer.
56// return_assistant_tokens_mask (`bool`, defaults to `False`):
57// Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant,
58// the mask will contain 1. For user and system tokens, the mask will contain 0.
59// This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
60// **kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
61
62// Returns:
63// `Union[list[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This
64// output is ready to pass to the model, either directly or via methods like `generate()`. If `return_dict` is
65// set, will return a dict of tokenizer outputs instead.
66// """
67
68use std::{
69 collections::HashMap,
70 fs::read_to_string,
71 ops::{Deref, DerefMut},
72 path::Path,
73 str::FromStr,
74};
75
76use minijinja::{context, Environment, Template};
77use serde::Serialize;
78use tokenizers::Encoding;
79
80use crate::error::Error;
81
82/// Wrapper around [`tokenizers::Tokenizer`] and [`minijinja::Environment`]
83/// providing more utilities.
84pub struct Tokenizer {
85 inner: tokenizers::Tokenizer,
86 env: Environment<'static>,
87}
88
89impl FromStr for Tokenizer {
90 type Err = tokenizers::Error;
91
92 fn from_str(s: &str) -> Result<Self, Self::Err> {
93 tokenizers::Tokenizer::from_str(s).map(Self::from_tokenizer)
94 }
95}
96
97impl Tokenizer {
98 pub fn from_tokenizer(tokenizer: tokenizers::Tokenizer) -> Self {
99 let mut env = Environment::new();
100 env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
101 Self {
102 inner: tokenizer,
103 env,
104 }
105 }
106
107 pub fn from_file(file: impl AsRef<Path>) -> tokenizers::Result<Self> {
108 tokenizers::Tokenizer::from_file(file).map(Self::from_tokenizer)
109 }
110
111 pub fn from_bytes(bytes: impl AsRef<[u8]>) -> tokenizers::Result<Self> {
112 tokenizers::Tokenizer::from_bytes(bytes).map(Self::from_tokenizer)
113 }
114
115 pub fn apply_chat_template<'a, I, R, T>(
116 &'a mut self,
117 model_template: String,
118 args: ApplyChatTemplateArgs<'a, I, R, T>,
119 ) -> Result<Vec<String>, Error>
120 where
121 I: IntoIterator<Item = Chat<'a, R, T>>,
122 R: Serialize + 'a,
123 T: Serialize + ToString + 'a,
124 {
125 apply_chat_template(&mut self.env, model_template, args)
126 }
127
128 pub fn apply_chat_template_and_encode<'a, I, R, T>(
129 &mut self,
130 model_template: String,
131 args: ApplyChatTemplateArgs<'a, I, R, T>,
132 ) -> Result<Vec<Encoding>, Error>
133 where
134 I: IntoIterator<Item = Chat<'a, R, T>>,
135 R: Serialize + 'a,
136 T: Serialize + ToString + 'a,
137 {
138 let Self { inner, env } = self;
139
140 let rendered_chats = apply_chat_template(env, model_template, args)?;
141 inner
142 .encode_batch(rendered_chats, false)
143 .map_err(Into::into)
144 }
145}
146
147impl Deref for Tokenizer {
148 type Target = tokenizers::Tokenizer;
149
150 fn deref(&self) -> &Self::Target {
151 &self.inner
152 }
153}
154
155impl DerefMut for Tokenizer {
156 fn deref_mut(&mut self) -> &mut Self::Target {
157 &mut self.inner
158 }
159}
160
161#[derive(Debug, Clone, Copy, Serialize)]
162#[serde(rename_all = "lowercase")]
163pub enum Role {
164 User,
165 Assistant,
166}
167
168#[derive(Debug, Clone, Serialize)]
169pub enum Content {
170 String(String),
171 Map(HashMap<String, String>),
172}
173
174#[derive(Debug, Clone, Serialize)]
175pub struct Conversation<R, T> {
176 pub role: R,
177 pub content: T,
178}
179
180#[derive(Debug, Clone, Serialize)]
181#[serde(untagged)]
182pub enum Chat<'a, R, T> {
183 Borrowed(&'a [Conversation<R, T>]),
184 Owned(Vec<Conversation<R, T>>),
185}
186
187impl<R, T> Deref for Chat<'_, R, T> {
188 type Target = [Conversation<R, T>];
189
190 fn deref(&self) -> &Self::Target {
191 match self {
192 Chat::Borrowed(conversations) => conversations,
193 Chat::Owned(conversations) => conversations,
194 }
195 }
196}
197
198impl<R, T> From<Vec<Conversation<R, T>>> for Chat<'_, R, T> {
199 fn from(value: Vec<Conversation<R, T>>) -> Self {
200 Chat::Owned(value)
201 }
202}
203
204impl<'a, R, T> From<&'a [Conversation<R, T>]> for Chat<'a, R, T> {
205 fn from(value: &'a [Conversation<R, T>]) -> Self {
206 Chat::Borrowed(value)
207 }
208}
209
210#[derive(Debug, Clone, Serialize)]
211pub struct Document {
212 pub title: String,
213 pub text: String,
214}
215
216pub enum Padding {
217 Longest,
218 MaxLength,
219}
220
221pub enum Truncation {
222 MaxLength(usize),
223}
224
225#[derive(Default)]
226pub struct ApplyChatTemplateArgs<'a, I, R = Role, T = String>
227where
228 I: IntoIterator<Item = Chat<'a, R, T>>,
229 R: Serialize + 'a,
230 T: Serialize + ToString + 'a,
231{
232 // pub conversations: &'a [Conversation<R, T>],
233 pub conversations: I,
234 // pub tools: Option<Box<dyn FnOnce()>>, // TODO
235 pub documents: Option<&'a [Document]>,
236 pub model_id: &'a str,
237 pub chat_template_id: Option<&'a str>,
238 pub add_generation_prompt: Option<bool>,
239 pub continue_final_message: Option<bool>,
240}
241
242pub fn load_model_chat_template_from_str(content: &str) -> std::io::Result<Option<String>> {
243 serde_json::from_str::<serde_json::Value>(content)
244 .map(|value| {
245 value
246 .get("chat_template")
247 .and_then(|value| value.as_str())
248 .map(ToString::to_string)
249 })
250 .map_err(Into::into)
251}
252
253pub fn load_model_chat_template_from_file(
254 file: impl AsRef<Path>,
255) -> std::io::Result<Option<String>> {
256 let content = read_to_string(file)?;
257 load_model_chat_template_from_str(&content)
258}
259
260// chat_template = self.get_chat_template(chat_template, tools)
261
262// if isinstance(conversation, (list, tuple)) and (
263// isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
264// ):
265// conversations = conversation
266// is_batched = True
267// else:
268// conversations = [conversation]
269// is_batched = False
270
271// if continue_final_message:
272// if add_generation_prompt:
273// raise ValueError(
274// "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead."
275// )
276// if return_assistant_tokens_mask:
277// raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.")
278
279// template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
280// rendered_chat, generation_indices = render_jinja_template(
281// conversations=conversations,
282// tools=tools,
283// documents=documents,
284// chat_template=chat_template,
285// return_assistant_tokens_mask=return_assistant_tokens_mask,
286// continue_final_message=continue_final_message,
287// add_generation_prompt=add_generation_prompt,
288// **template_kwargs,
289// )
290
291// if not is_batched:
292// rendered_chat = rendered_chat[0]
293
294// if tokenize:
295// out = self(
296// rendered_chat,
297// padding=padding,
298// truncation=truncation,
299// max_length=max_length,
300// add_special_tokens=False,
301// return_tensors=return_tensors,
302// **tokenizer_kwargs,
303// )
304// if return_dict:
305// if return_assistant_tokens_mask:
306// assistant_masks = []
307// if is_batched or return_tensors:
308// input_ids = out["input_ids"]
309// else:
310// input_ids = [out["input_ids"]]
311// for i in range(len(input_ids)):
312// current_mask = [0] * len(input_ids[i])
313// for assistant_start_char, assistant_end_char in generation_indices[i]:
314// start_token = out.char_to_token(i, assistant_start_char)
315// end_token = out.char_to_token(i, assistant_end_char - 1)
316// if start_token is None:
317// # start_token is out of bounds maybe due to truncation.
318// break
319// for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])):
320// current_mask[token_id] = 1
321// assistant_masks.append(current_mask)
322
323// if not is_batched and not return_tensors:
324// assistant_masks = assistant_masks[0]
325
326// out["assistant_masks"] = assistant_masks
327
328// if return_tensors:
329// out.convert_to_tensors(tensor_type=return_tensors)
330
331// return out
332// else:
333// return out["input_ids"]
334// else:
335// return rendered_chat
336
337// def render_jinja_template(
338// conversations: list[list[dict[str, str]]],
339// tools: Optional[list[Union[dict, Callable]]] = None,
340// documents: Optional[list[dict[str, str]]] = None,
341// chat_template: Optional[str] = None,
342// return_assistant_tokens_mask: Optional[bool] = False,
343// continue_final_message: Optional[bool] = False,
344// add_generation_prompt: Optional[bool] = False,
345// **kwargs,
346// ) -> str:
347// if return_assistant_tokens_mask and not re.search(r"\{\%-?\s*generation\s*-?\%\}", chat_template):
348// logger.warning_once(
349// "return_assistant_tokens_mask==True but chat template does not contain `{% generation %}` keyword."
350// )
351
352// # Compilation function uses a cache to avoid recompiling the same template
353// compiled_template = _compile_jinja_template(chat_template)
354
355// # We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas
356// if tools is not None:
357// tool_schemas = []
358// for tool in tools:
359// if isinstance(tool, dict):
360// tool_schemas.append(tool)
361// elif isfunction(tool):
362// tool_schemas.append(get_json_schema(tool))
363// else:
364// raise ValueError(
365// "Tools should either be a JSON schema, or a callable function with type hints "
366// "and a docstring suitable for auto-conversion to a schema."
367// )
368// else:
369// tool_schemas = None
370
371// if documents is not None:
372// for document in documents:
373// if not isinstance(document, dict):
374// raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!")
375
376// rendered = []
377// all_generation_indices = []
378// for chat in conversations:
379// if hasattr(chat, "messages"):
380// # Indicates it's a Conversation object
381// chat = chat.messages
382// if return_assistant_tokens_mask:
383// rendered_chat, generation_indices = _render_with_assistant_indices(
384// compiled_template=compiled_template,
385// messages=chat,
386// tools=tool_schemas,
387// documents=documents,
388// add_generation_prompt=add_generation_prompt,
389// **kwargs,
390// )
391// all_generation_indices.append(generation_indices)
392// else:
393// rendered_chat = compiled_template.render(
394// messages=chat,
395// tools=tool_schemas,
396// documents=documents,
397// add_generation_prompt=add_generation_prompt,
398// **kwargs,
399// )
400// if continue_final_message:
401// final_message = chat[-1]["content"]
402// if isinstance(final_message, (list, tuple)):
403// for content_block in reversed(final_message):
404// if "text" in content_block:
405// # Pick the last text block in the message (the first one we hit while iterating in reverse)
406// final_message = content_block["text"]
407// break
408// else:
409// raise ValueError(
410// "continue_final_message is set but we could not find any text to continuein the final message!"
411// )
412// if final_message.strip() not in rendered_chat:
413// raise ValueError(
414// "continue_final_message is set but the final message does not appear in the chat after "
415// "applying the chat template! This can happen if the chat template deletes portions of "
416// "the final message. Please verify the chat template and final message in your chat to "
417// "ensure they are compatible."
418// )
419// final_msg_loc = rendered_chat.rindex(final_message.strip())
420// if rendered_chat[final_msg_loc : final_msg_loc + len(final_message.lstrip())] == final_message:
421// # The template preserves spacing or the message doesn't have trailing spacing, so things are simple
422// rendered_chat = rendered_chat[: final_msg_loc + len(final_message.lstrip())]
423// else:
424// # The message has trailing spacing that was trimmed, so we must be more cautious
425// rendered_chat = rendered_chat[: final_msg_loc + len(final_message.strip())]
426// rendered.append(rendered_chat)
427
428// return rendered, all_generation_indices
429
430pub fn apply_chat_template<'a, I, R, T>(
431 env: &mut Environment<'static>,
432 model_template: String,
433 args: ApplyChatTemplateArgs<'a, I, R, T>,
434) -> Result<Vec<String>, Error>
435where
436 I: IntoIterator<Item = Chat<'a, R, T>>,
437 R: Serialize + 'a,
438 T: Serialize + ToString + 'a,
439{
440 let ApplyChatTemplateArgs {
441 conversations,
442 // tools,
443 documents,
444 model_id,
445 chat_template_id,
446 add_generation_prompt,
447 continue_final_message,
448 } = args;
449
450 let add_generation_prompt = add_generation_prompt.unwrap_or(false);
451 let continue_final_message = continue_final_message.unwrap_or(false);
452
453 let template = match chat_template_id {
454 Some(chat_template_id) => env.get_template(chat_template_id)?,
455 None => match env.get_template(model_id) {
456 Ok(template) => template,
457 Err(_) => {
458 env.add_template_owned(model_id.to_owned(), model_template)?;
459 env.get_template(model_id)
460 .expect("Newly added template must be present")
461 }
462 },
463 };
464
465 // TODO: handle tool
466
467 // TODO: allow return_generation_indices
468
469 render_jinja_tempalte(
470 template,
471 conversations,
472 documents,
473 Some(add_generation_prompt),
474 Some(continue_final_message),
475 )
476}
477
478// TODO: render with assistant indices
479fn render_jinja_tempalte<'a, R, T>(
480 template: Template,
481 conversations: impl IntoIterator<Item = Chat<'a, R, T>>,
482 documents: Option<&'a [Document]>,
483 add_generation_prompt: Option<bool>,
484 continue_final_message: Option<bool>,
485) -> Result<Vec<String>, Error>
486where
487 R: Serialize + 'a,
488 T: Serialize + ToString + 'a,
489{
490 let add_generation_prompt = add_generation_prompt.unwrap_or(false);
491 let continue_final_message = continue_final_message.unwrap_or(false);
492
493 // TODO: what does checking for "messages" key do in the python code?
494 let mut rendered = Vec::new();
495 for chat in conversations {
496 let mut rendered_chat = template.render(context! {
497 messages => chat,
498 documents => documents,
499 add_generation_prompt => add_generation_prompt,
500 })?;
501
502 if continue_final_message {
503 let Some(final_message) = chat.last().map(|chat| &chat.content) else {
504 continue;
505 };
506
507 let final_message_str = final_message.to_string();
508
509 if !rendered_chat.contains(final_message_str.trim()) {
510 return Err(Error::FinalMsgNotInChat);
511 }
512
513 let final_msg_loc = rendered_chat.rfind(&final_message_str.trim()).unwrap();
514 let final_msg_len = final_message_str.trim_start().len();
515 rendered_chat = if rendered_chat[final_msg_loc..final_msg_loc + final_msg_len]
516 == final_message_str
517 {
518 // The template preserves spacing or the message doesn't have trailing spacing, so things are simple
519 rendered_chat[..final_msg_loc + final_msg_len].to_string()
520 } else {
521 // The message has trailing spacing that was trimmed, so we must be more cautious
522 rendered_chat[..final_msg_loc + final_message_str.trim().len()].to_string()
523 };
524 }
525 rendered.push(rendered_chat);
526 }
527
528 Ok(rendered)
529}
530
531#[cfg(test)]
532mod tests {
533 use minijinja::Environment;
534 use std::path::Path;
535
536 use crate::tokenizer::{
537 apply_chat_template, load_model_chat_template_from_file, ApplyChatTemplateArgs,
538 Conversation, Role,
539 };
540
541 // let model_id = "mlx-community/Qwen3-4B-bf16".to_string();
542 const CACHED_TEST_MODEL_DIR: &str = "../cache/Qwen3-4B-bf16";
543
544 // TODO: how to test this in CI? the model files might be too large
545 #[test]
546 fn test_load_chat_template_from_file() {
547 let file = Path::new(CACHED_TEST_MODEL_DIR).join("tokenizer_config.json");
548 let chat_template = load_model_chat_template_from_file(file).unwrap().unwrap();
549 assert!(!chat_template.is_empty());
550 }
551
552 #[test]
553 fn test_apply_chat_template() {
554 let file = Path::new(CACHED_TEST_MODEL_DIR).join("tokenizer_config.json");
555 let model_chat_template = load_model_chat_template_from_file(file).unwrap().unwrap();
556 assert!(!model_chat_template.is_empty());
557
558 let model_id = "mlx-community/Qwen3-4B-bf16".to_string();
559 let conversations = vec![Conversation {
560 role: Role::User,
561 content: "hello",
562 }];
563 let args = ApplyChatTemplateArgs {
564 conversations: [conversations.into()],
565 documents: None,
566 model_id: &model_id,
567 chat_template_id: None,
568 add_generation_prompt: None,
569 continue_final_message: None,
570 };
571
572 let mut env = Environment::new();
573 env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
574
575 let rendered_chat = apply_chat_template(&mut env, model_chat_template, args).unwrap();
576 println!("{:?}", rendered_chat);
577 }
578
579 #[test]
580 fn test_tokenizer_apply_chat_template() {
581 let tokenizer_file = Path::new(CACHED_TEST_MODEL_DIR).join("tokenizer.json");
582 let tokenizer_config_file = Path::new(CACHED_TEST_MODEL_DIR).join("tokenizer_config.json");
583
584 let model_id = "mlx-community/Qwen3-4B-bf16".to_string();
585
586 let conversations = vec![Conversation {
587 role: Role::User,
588 content: "hello",
589 }];
590
591 let mut tokenizer = super::Tokenizer::from_file(tokenizer_file).unwrap();
592
593 let model_chat_template = load_model_chat_template_from_file(tokenizer_config_file)
594 .unwrap()
595 .unwrap();
596 assert!(!model_chat_template.is_empty());
597
598 let args = ApplyChatTemplateArgs {
599 conversations: [conversations.into()],
600 documents: None,
601 model_id: &model_id,
602 chat_template_id: None,
603 add_generation_prompt: None,
604 continue_final_message: None,
605 };
606
607 let rendered_chat = tokenizer
608 .apply_chat_template(model_chat_template, args)
609 .unwrap();
610 println!("{:?}", rendered_chat);
611 }
612
613 #[test]
614 fn test_tokenizer_apply_chat_template_and_encode() {
615 let tokenizer_file = Path::new(CACHED_TEST_MODEL_DIR).join("tokenizer.json");
616 let tokenizer_config_file = Path::new(CACHED_TEST_MODEL_DIR).join("tokenizer_config.json");
617
618 let model_id = "mlx-community/Qwen3-4B-bf16".to_string();
619
620 let conversations = vec![Conversation {
621 role: Role::User,
622 content: "hello",
623 }];
624 let mut tokenizer = super::Tokenizer::from_file(tokenizer_file).unwrap();
625
626 let model_chat_template = load_model_chat_template_from_file(tokenizer_config_file)
627 .unwrap()
628 .unwrap();
629 assert!(!model_chat_template.is_empty());
630
631 let args = ApplyChatTemplateArgs {
632 conversations: [conversations.into()],
633 documents: None,
634 model_id: &model_id,
635 chat_template_id: None,
636 add_generation_prompt: None,
637 continue_final_message: None,
638 };
639
640 let encodings = tokenizer
641 .apply_chat_template_and_encode(model_chat_template, args)
642 .unwrap();
643 println!("{:?}", encodings.iter().map(|e| e.get_ids()).flatten());
644 }
645}