alith_models/local_model/
chat_template.rs

1use std::collections::HashMap;
2
3use super::metadata::tokenizer::TokenizerMetadata;
4use anyhow::Context;
5use serde::Deserialize;
6
7#[derive(Deserialize, Clone, PartialEq)]
8pub struct LLMChatTemplate {
9    pub chat_template: String,
10    pub bos_token: Option<String>,
11    pub eos_token: String,
12    pub unk_token: Option<String>,
13    pub base_generation_prefix: Option<String>,
14}
15
16impl LLMChatTemplate {
17    pub fn from_local_path(
18        tokenizer_config_local_path: &std::path::PathBuf,
19    ) -> crate::Result<Self> {
20        let file = std::fs::File::open(tokenizer_config_local_path)?;
21        let reader = std::io::BufReader::new(file);
22        let mut chat_template: LLMChatTemplate = serde_json::from_reader(reader)?;
23        chat_template.set_generation_prefix()?;
24        Ok(chat_template)
25    }
26
27    pub fn from_gguf_tokenizer(tokenizer: &TokenizerMetadata) -> crate::Result<Self> {
28        let chat_template = if let Some(chat_template) = &tokenizer.chat_template {
29            chat_template
30        } else {
31            anyhow::bail!("chat_template not found.");
32        };
33        let ggml = if let Some(ggml) = &tokenizer.ggml {
34            ggml
35        } else {
36            anyhow::bail!("GGML tokenizer model not found.");
37        };
38
39        let bos_token = ggml
40            .tokens
41            .get(ggml.bos_token_id as usize)
42            .map(ToString::to_string)
43            .with_context(|| format!("Token not found for ID: {}", ggml.bos_token_id))?;
44
45        let eos_token = ggml
46            .tokens
47            .get(ggml.eos_token_id as usize)
48            .map(ToString::to_string)
49            .with_context(|| format!("Token not found for ID: {}", ggml.eos_token_id))?;
50
51        let unk_token = if let Some(unk_token_id) = ggml.unknown_token_id {
52            Some(
53                ggml.tokens
54                    .get(unk_token_id as usize)
55                    .map(ToString::to_string)
56                    .with_context(|| format!("Token not found for ID: {}", unk_token_id))?,
57            )
58        } else {
59            None
60        };
61
62        let mut chat_template = LLMChatTemplate {
63            chat_template: chat_template.to_owned(),
64            bos_token: Some(bos_token),
65            eos_token,
66            unk_token,
67            base_generation_prefix: None,
68        };
69        chat_template.set_generation_prefix()?;
70        Ok(chat_template)
71    }
72
73    fn set_generation_prefix(&mut self) -> crate::Result<()> {
74        let user_message_1 = HashMap::from([
75            ("role".to_string(), "user".to_string()),
76            ("content".to_string(), "test_user_message_1".to_string()),
77        ]);
78        let assistant_message_1 = HashMap::from([
79            ("role".to_string(), "assistant".to_string()),
80            (
81                "content".to_string(),
82                "test_assistant_message_1".to_string(),
83            ),
84        ]);
85
86        let message_1 = alith_prompt::apply_chat_template(
87            &[user_message_1.clone()],
88            &self.chat_template,
89            self.bos_token.as_deref(),
90            &self.eos_token,
91            self.unk_token.as_deref(),
92        );
93        let message_1 = message_1
94            .trim_end_matches(self.eos_token.as_str())
95            .to_owned();
96        let message_2 = alith_prompt::apply_chat_template(
97            &[user_message_1, assistant_message_1],
98            &self.chat_template,
99            self.bos_token.as_deref(),
100            &self.eos_token,
101            self.unk_token.as_deref(),
102        );
103
104        // Find the point where the outputs start to differ
105        let diff_index = message_1
106            .chars()
107            .zip(message_2.chars())
108            .position(|(a, b)| a != b)
109            .unwrap_or(message_1.len());
110
111        // Extract the differing part
112        let diff_part = &message_2[diff_index..];
113
114        // Find the start of the assistant content
115        if let Some(content_index) = diff_part.find("test_assistant_message_1") {
116            // The prefix is everything before the content
117            self.base_generation_prefix = Some(
118                diff_part[..content_index]
119                    .trim_start_matches(self.eos_token.as_str())
120                    .to_string(),
121            );
122        } else {
123            crate::bail!("Error finding base_generation_prefix");
124        }
125        Ok(())
126    }
127}
128
129impl std::fmt::Debug for LLMChatTemplate {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        let mut debug_struct = f.debug_struct("LLMChatTemplate");
132        debug_struct.field("chat_template", &"string too long to print nicely");
133        debug_struct.field("bos_token", &self.bos_token);
134        debug_struct.field("eos_token", &self.eos_token);
135        debug_struct.field("unk_token", &self.unk_token);
136        debug_struct.finish()
137    }
138}