dynamo_llm/preprocessor/prompt/
template.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::{collections::HashSet, sync::Arc};
17
18use anyhow::{Context, Ok, Result};
19use minijinja::Environment;
20
21use crate::model_card::model::{ModelDeploymentCard, PromptContextMixin, PromptFormatterArtifact};
22
23mod context;
24mod formatters;
25mod oai;
26mod tokcfg;
27
28use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter};
29use tokcfg::ChatTemplate;
30
31impl PromptFormatter {
32    pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<PromptFormatter> {
33        match mdc
34            .prompt_formatter
35            .ok_or(anyhow::anyhow!("MDC does not contain a prompt formatter"))?
36        {
37            PromptFormatterArtifact::HfTokenizerConfigJson(file) => {
38                let content = std::fs::read_to_string(&file)
39                    .with_context(|| format!("fs:read_to_string '{file}'"))?;
40                let config: ChatTemplate = serde_json::from_str(&content)?;
41                Self::from_parts(
42                    config,
43                    mdc.prompt_context
44                        .map_or(ContextMixins::default(), |x| ContextMixins::new(&x)),
45                )
46            }
47            PromptFormatterArtifact::GGUF(gguf_path) => {
48                let config = ChatTemplate::from_gguf(&gguf_path)?;
49                Self::from_parts(config, ContextMixins::default())
50            }
51        }
52    }
53
54    pub fn from_parts(config: ChatTemplate, context: ContextMixins) -> Result<PromptFormatter> {
55        let formatter = HfTokenizerConfigJsonFormatter::new(config, context)?;
56        Ok(Self::OAI(Arc::new(formatter)))
57    }
58}
59
60/// Chat Template Jinja Renderer
61///
62/// Manages a Jinja environment with registered templates for chat formatting.
63/// Handles two types of ChatTemplateValue templates:
64///
65/// 1. String template: Registered as the 'default' template
66/// 2. Map template: Contains 'tool_use' and/or 'default' templates
67///    - tool_use: Template for tool-based interactions
68///    - default: Template for standard chat interactions
69///
70///   If the map contains both keys, the `tool_use` template is registered as the `tool_use` template
71///   and the `default` template is registered as the `default` template.
72struct JinjaEnvironment {
73    env: Environment<'static>,
74}
75
76/// Formatter for HuggingFace tokenizer config JSON templates
77///
78/// Implements chat template rendering based on HuggingFace's tokenizer_config.json format.
79/// Supports:
80/// - Tool usage templates
81/// - Generation prompts
82/// - Context mixins for template customization
83#[derive(Debug)]
84struct HfTokenizerConfigJsonFormatter {
85    env: Environment<'static>,
86    config: ChatTemplate,
87    mixins: Arc<ContextMixins>,
88    supports_add_generation_prompt: bool,
89}
90
91// /// OpenAI Standard Prompt Formatter
92// pub trait StandardPromptFormatter {
93//     fn render(&self, context: &impl StandardPromptContext) -> Result<String>;
94// }
95
96// pub trait StandardPromptContext {
97//     fn messages(&self) -> Value;
98//     fn tools(&self) -> Option<Value>;
99// }
100
101#[derive(Debug, Clone, Default)]
102pub struct ContextMixins {
103    context_mixins: HashSet<PromptContextMixin>,
104}