llm_models 0.0.3

llm_models: Load and download LLM models, metadata, and tokenizers
Documentation
use anyhow::Context;

use super::metadata::tokenizer::TokenizerMetadata;

#[derive(Clone, serde::Serialize, serde::Deserialize, PartialEq)]
pub struct LlmChatTemplate {
    pub chat_template: String,
    pub bos_token: Option<String>,
    pub eos_token: String,
    pub unk_token: Option<String>,
    pub base_generation_prefix: Option<String>,
}

impl LlmChatTemplate {
    pub fn from_gguf_tokenizer(tokenizer: &TokenizerMetadata) -> crate::Result<Self> {
        let chat_template = if let Some(chat_template) = &tokenizer.chat_template {
            chat_template
        } else {
            anyhow::bail!("chat_template not found.");
        };
        let ggml = if let Some(ggml) = &tokenizer.ggml {
            ggml
        } else {
            anyhow::bail!("GGML tokenizer model not found.");
        };

        let bos_token = ggml
            .tokens
            .get(ggml.bos_token_id as usize)
            .map(ToString::to_string)
            .with_context(|| format!("Token not found for ID: {}", ggml.bos_token_id))?;

        let eos_token = ggml
            .tokens
            .get(ggml.eos_token_id as usize)
            .map(ToString::to_string)
            .with_context(|| format!("Token not found for ID: {}", ggml.eos_token_id))?;

        let unk_token = if let Some(unk_token_id) = ggml.unknown_token_id {
            Some(
                ggml.tokens
                    .get(unk_token_id as usize)
                    .map(ToString::to_string)
                    .with_context(|| format!("Token not found for ID: {}", unk_token_id))?,
            )
        } else {
            None
        };

        let mut chat_template = LlmChatTemplate {
            chat_template: chat_template.to_owned(),
            bos_token: Some(bos_token),
            eos_token,
            unk_token,
            base_generation_prefix: None,
        };
        chat_template.set_base_generation_prefix()?;
        Ok(chat_template)
    }

    fn set_base_generation_prefix(&mut self) -> crate::Result<()> {
        let user_message_1 = std::collections::HashMap::from([
            ("role".to_string(), "user".to_string()),
            ("content".to_string(), "test_user_message_1".to_string()),
        ]);
        let assistant_message_1 = std::collections::HashMap::from([
            ("role".to_string(), "assistant".to_string()),
            (
                "content".to_string(),
                "test_assistant_message_1".to_string(),
            ),
        ]);

        let message_1 = llm_prompt::apply_chat_template(
            &vec![user_message_1.clone()],
            &self.chat_template,
            self.bos_token.as_deref(),
            &self.eos_token,
            self.unk_token.as_deref(),
        );
        let message_1 = message_1
            .trim_end_matches(self.eos_token.as_str())
            .to_owned();
        let message_2 = llm_prompt::apply_chat_template(
            &vec![user_message_1, assistant_message_1],
            &self.chat_template,
            self.bos_token.as_deref(),
            &self.eos_token,
            self.unk_token.as_deref(),
        );

        // Find the point where the outputs start to differ
        let diff_index = message_1
            .chars()
            .zip(message_2.chars())
            .position(|(a, b)| a != b)
            .unwrap_or(message_1.len());

        // Extract the differing part
        let diff_part = &message_2[diff_index..];

        // Find the start of the assistant content
        if let Some(content_index) = diff_part.find("test_assistant_message_1") {
            // The prefix is everything before the content
            self.base_generation_prefix = Some(
                diff_part[..content_index]
                    .trim_start_matches(self.eos_token.as_str())
                    .to_string(),
            );
        } else {
            crate::bail!("Error finding base_generation_prefix");
        }
        Ok(())
    }
}

impl std::fmt::Debug for LlmChatTemplate {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        writeln!(f, "LlmChatTemplate:")?;
        writeln!(f, "chat_template: too long to print nicely")?;
        writeln!(f, "bos_token: {:?}", self.bos_token)?;
        writeln!(f, "eos_token: {}", self.eos_token)?;
        writeln!(f, "unk_token: {:?}", self.unk_token)?;
        writeln!(
            f,
            "base_generation_prefix: {:?}",
            self.base_generation_prefix
        )?;
        Ok(())
    }
}