alith_models/local_model/
chat_template.rs1use std::collections::HashMap;
2
3use super::metadata::tokenizer::TokenizerMetadata;
4use anyhow::Context;
5use serde::Deserialize;
6
7#[derive(Deserialize, Clone, PartialEq)]
8pub struct LLMChatTemplate {
9 pub chat_template: String,
10 pub bos_token: Option<String>,
11 pub eos_token: String,
12 pub unk_token: Option<String>,
13 pub base_generation_prefix: Option<String>,
14}
15
16impl LLMChatTemplate {
17 pub fn from_local_path(
18 tokenizer_config_local_path: &std::path::PathBuf,
19 ) -> crate::Result<Self> {
20 let file = std::fs::File::open(tokenizer_config_local_path)?;
21 let reader = std::io::BufReader::new(file);
22 let mut chat_template: LLMChatTemplate = serde_json::from_reader(reader)?;
23 chat_template.set_generation_prefix()?;
24 Ok(chat_template)
25 }
26
27 pub fn from_gguf_tokenizer(tokenizer: &TokenizerMetadata) -> crate::Result<Self> {
28 let chat_template = if let Some(chat_template) = &tokenizer.chat_template {
29 chat_template
30 } else {
31 anyhow::bail!("chat_template not found.");
32 };
33 let ggml = if let Some(ggml) = &tokenizer.ggml {
34 ggml
35 } else {
36 anyhow::bail!("GGML tokenizer model not found.");
37 };
38
39 let bos_token = ggml
40 .tokens
41 .get(ggml.bos_token_id as usize)
42 .map(ToString::to_string)
43 .with_context(|| format!("Token not found for ID: {}", ggml.bos_token_id))?;
44
45 let eos_token = ggml
46 .tokens
47 .get(ggml.eos_token_id as usize)
48 .map(ToString::to_string)
49 .with_context(|| format!("Token not found for ID: {}", ggml.eos_token_id))?;
50
51 let unk_token = if let Some(unk_token_id) = ggml.unknown_token_id {
52 Some(
53 ggml.tokens
54 .get(unk_token_id as usize)
55 .map(ToString::to_string)
56 .with_context(|| format!("Token not found for ID: {}", unk_token_id))?,
57 )
58 } else {
59 None
60 };
61
62 let mut chat_template = LLMChatTemplate {
63 chat_template: chat_template.to_owned(),
64 bos_token: Some(bos_token),
65 eos_token,
66 unk_token,
67 base_generation_prefix: None,
68 };
69 chat_template.set_generation_prefix()?;
70 Ok(chat_template)
71 }
72
73 fn set_generation_prefix(&mut self) -> crate::Result<()> {
74 let user_message_1 = HashMap::from([
75 ("role".to_string(), "user".to_string()),
76 ("content".to_string(), "test_user_message_1".to_string()),
77 ]);
78 let assistant_message_1 = HashMap::from([
79 ("role".to_string(), "assistant".to_string()),
80 (
81 "content".to_string(),
82 "test_assistant_message_1".to_string(),
83 ),
84 ]);
85
86 let message_1 = alith_prompt::apply_chat_template(
87 &[user_message_1.clone()],
88 &self.chat_template,
89 self.bos_token.as_deref(),
90 &self.eos_token,
91 self.unk_token.as_deref(),
92 );
93 let message_1 = message_1
94 .trim_end_matches(self.eos_token.as_str())
95 .to_owned();
96 let message_2 = alith_prompt::apply_chat_template(
97 &[user_message_1, assistant_message_1],
98 &self.chat_template,
99 self.bos_token.as_deref(),
100 &self.eos_token,
101 self.unk_token.as_deref(),
102 );
103
104 let diff_index = message_1
106 .chars()
107 .zip(message_2.chars())
108 .position(|(a, b)| a != b)
109 .unwrap_or(message_1.len());
110
111 let diff_part = &message_2[diff_index..];
113
114 if let Some(content_index) = diff_part.find("test_assistant_message_1") {
116 self.base_generation_prefix = Some(
118 diff_part[..content_index]
119 .trim_start_matches(self.eos_token.as_str())
120 .to_string(),
121 );
122 } else {
123 crate::bail!("Error finding base_generation_prefix");
124 }
125 Ok(())
126 }
127}
128
129impl std::fmt::Debug for LLMChatTemplate {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 let mut debug_struct = f.debug_struct("LLMChatTemplate");
132 debug_struct.field("chat_template", &"string too long to print nicely");
133 debug_struct.field("bos_token", &self.bos_token);
134 debug_struct.field("eos_token", &self.eos_token);
135 debug_struct.field("unk_token", &self.unk_token);
136 debug_struct.finish()
137 }
138}