dynamo_llm/preprocessor/prompt/
template.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{collections::HashSet, sync::Arc};
5
6use anyhow::{Context, Ok, Result};
7use minijinja::Environment;
8
9use crate::model_card::{ModelDeploymentCard, PromptContextMixin, PromptFormatterArtifact};
10
11mod context;
12mod formatters;
13mod oai;
14mod tokcfg;
15
16use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter};
17use tokcfg::{ChatTemplate, ChatTemplateValue};
18
19impl PromptFormatter {
20    pub fn from_mdc(mdc: &ModelDeploymentCard) -> Result<PromptFormatter> {
21        match mdc
22            .prompt_formatter
23            .as_ref()
24            .ok_or(anyhow::anyhow!("MDC does not contain a prompt formatter"))?
25        {
26            PromptFormatterArtifact::HfTokenizerConfigJson(checked_file) => {
27                let Some(file) = checked_file.path() else {
28                    anyhow::bail!(
29                        "HfTokenizerConfigJson for {} is a URL, cannot load",
30                        mdc.display_name
31                    );
32                };
33                let contents = std::fs::read_to_string(file)
34                    .with_context(|| format!("fs:read_to_string '{}'", file.display()))?;
35                let mut config: ChatTemplate =
36                    serde_json::from_str(&contents).inspect_err(|err| {
37                        crate::log_json_err(&file.display().to_string(), &contents, err)
38                    })?;
39
40                // Some HF model (i.e. meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8)
41                // stores the chat template in a separate file, we check if the file exists and
42                // put the chat template into config as normalization.
43                // This may also be a custom template provided via CLI flag.
44                if let Some(PromptFormatterArtifact::HfChatTemplate(checked_file)) =
45                    mdc.chat_template_file.as_ref()
46                {
47                    let Some(chat_template_file) = checked_file.path() else {
48                        anyhow::bail!(
49                            "HfChatTemplate for {} is a URL, cannot load",
50                            mdc.display_name
51                        );
52                    };
53                    let chat_template =
54                        std::fs::read_to_string(chat_template_file).with_context(|| {
55                            format!("fs:read_to_string '{}'", chat_template_file.display())
56                        })?;
57                    // clean up the string to remove newlines
58                    let chat_template = chat_template.replace('\n', "");
59                    config.chat_template = Some(ChatTemplateValue(either::Left(chat_template)));
60                }
61                Self::from_parts(
62                    config,
63                    mdc.prompt_context
64                        .clone()
65                        .map_or(ContextMixins::default(), |x| ContextMixins::new(&x)),
66                )
67            }
68            PromptFormatterArtifact::HfChatTemplate(_) => Err(anyhow::anyhow!(
69                "prompt_formatter should not have type HfChatTemplate"
70            )),
71            PromptFormatterArtifact::GGUF(gguf_path) => {
72                let config = ChatTemplate::from_gguf(gguf_path)?;
73                Self::from_parts(config, ContextMixins::default())
74            }
75        }
76    }
77
78    pub fn from_parts(config: ChatTemplate, context: ContextMixins) -> Result<PromptFormatter> {
79        let formatter = HfTokenizerConfigJsonFormatter::new(config, context)?;
80        Ok(Self::OAI(Arc::new(formatter)))
81    }
82}
83
84/// Chat Template Jinja Renderer
85///
86/// Manages a Jinja environment with registered templates for chat formatting.
87/// Handles two types of ChatTemplateValue templates:
88///
89/// 1. String template: Registered as the 'default' template
90/// 2. Map template: Contains 'tool_use' and/or 'default' templates
91///    - tool_use: Template for tool-based interactions
92///    - default: Template for standard chat interactions
93///
94///   If the map contains both keys, the `tool_use` template is registered as the `tool_use` template
95///   and the `default` template is registered as the `default` template.
96struct JinjaEnvironment {
97    env: Environment<'static>,
98}
99
100/// Formatter for HuggingFace tokenizer config JSON templates
101///
102/// Implements chat template rendering based on HuggingFace's tokenizer_config.json format.
103/// Supports:
104/// - Tool usage templates
105/// - Generation prompts
106/// - Context mixins for template customization
107#[derive(Debug)]
108struct HfTokenizerConfigJsonFormatter {
109    env: Environment<'static>,
110    config: ChatTemplate,
111    mixins: Arc<ContextMixins>,
112    supports_add_generation_prompt: bool,
113}
114
115// /// OpenAI Standard Prompt Formatter
116// pub trait StandardPromptFormatter {
117//     fn render(&self, context: &impl StandardPromptContext) -> Result<String>;
118// }
119
120// pub trait StandardPromptContext {
121//     fn messages(&self) -> Value;
122//     fn tools(&self) -> Option<Value>;
123// }
124
125#[derive(Debug, Clone, Default)]
126pub struct ContextMixins {
127    context_mixins: HashSet<PromptContextMixin>,
128}