dynamo_llm/preprocessor/prompt/
template.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::{collections::HashSet, sync::Arc};
17
18use anyhow::{Ok, Result};
19use minijinja::Environment;
20
21use crate::model_card::model::{ModelDeploymentCard, PromptContextMixin, PromptFormatterArtifact};
22
23mod context;
24mod formatters;
25mod oai;
26mod tokcfg;
27
28use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter};
29use tokcfg::ChatTemplate;
30
31impl PromptFormatter {
32    pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<PromptFormatter> {
33        match mdc
34            .prompt_formatter
35            .ok_or(anyhow::anyhow!("MDC does not contain a prompt formatter"))?
36        {
37            PromptFormatterArtifact::HfTokenizerConfigJson(file) => {
38                let content = std::fs::read_to_string(file)?;
39                let config: ChatTemplate = serde_json::from_str(&content)?;
40                Self::from_parts(
41                    config,
42                    mdc.prompt_context
43                        .map_or(ContextMixins::default(), |x| ContextMixins::new(&x)),
44                )
45            }
46            PromptFormatterArtifact::GGUF(gguf_path) => {
47                let config = ChatTemplate::from_gguf(&gguf_path)?;
48                Self::from_parts(config, ContextMixins::default())
49            }
50        }
51    }
52
53    pub fn from_parts(config: ChatTemplate, context: ContextMixins) -> Result<PromptFormatter> {
54        let formatter = HfTokenizerConfigJsonFormatter::new(config, context)?;
55        Ok(Self::OAI(Arc::new(formatter)))
56    }
57}
58
59/// Chat Template Jinja Renderer
60///
61/// Manages a Jinja environment with registered templates for chat formatting.
62/// Handles two types of ChatTemplateValue templates:
63///
64/// 1. String template: Registered as the 'default' template
65/// 2. Map template: Contains 'tool_use' and/or 'default' templates
66///    - tool_use: Template for tool-based interactions
67///    - default: Template for standard chat interactions
68///
69///   If the map contains both keys, the `tool_use` template is registered as the `tool_use` template
70///   and the `default` template is registered as the `default` template.
71struct JinjaEnvironment {
72    env: Environment<'static>,
73}
74
75/// Formatter for HuggingFace tokenizer config JSON templates
76///
77/// Implements chat template rendering based on HuggingFace's tokenizer_config.json format.
78/// Supports:
79/// - Tool usage templates
80/// - Generation prompts
81/// - Context mixins for template customization
82#[derive(Debug)]
83struct HfTokenizerConfigJsonFormatter {
84    env: Environment<'static>,
85    config: ChatTemplate,
86    mixins: Arc<ContextMixins>,
87    supports_add_generation_prompt: bool,
88}
89
90// /// OpenAI Standard Prompt Formatter
91// pub trait StandardPromptFormatter {
92//     fn render(&self, context: &impl StandardPromptContext) -> Result<String>;
93// }
94
95// pub trait StandardPromptContext {
96//     fn messages(&self) -> Value;
97//     fn tools(&self) -> Option<Value>;
98// }
99
100#[derive(Debug, Clone, Default)]
101pub struct ContextMixins {
102    context_mixins: HashSet<PromptContextMixin>,
103}