dynamo_llm/preprocessor/prompt/template.rs
1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::{collections::HashSet, sync::Arc};
17
18use anyhow::{Ok, Result};
19use minijinja::Environment;
20
21use crate::model_card::model::{ModelDeploymentCard, PromptContextMixin, PromptFormatterArtifact};
22
23mod context;
24mod formatters;
25mod oai;
26mod tokcfg;
27
28use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter};
29use tokcfg::ChatTemplate;
30
31impl PromptFormatter {
32 pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<PromptFormatter> {
33 match mdc
34 .prompt_formatter
35 .ok_or(anyhow::anyhow!("MDC does not contain a prompt formatter"))?
36 {
37 PromptFormatterArtifact::HfTokenizerConfigJson(file) => {
38 let content = std::fs::read_to_string(file)?;
39 let config: ChatTemplate = serde_json::from_str(&content)?;
40 Self::from_parts(
41 config,
42 mdc.prompt_context
43 .map_or(ContextMixins::default(), |x| ContextMixins::new(&x)),
44 )
45 }
46 PromptFormatterArtifact::GGUF(gguf_path) => {
47 let config = ChatTemplate::from_gguf(&gguf_path)?;
48 Self::from_parts(config, ContextMixins::default())
49 }
50 }
51 }
52
53 pub fn from_parts(config: ChatTemplate, context: ContextMixins) -> Result<PromptFormatter> {
54 let formatter = HfTokenizerConfigJsonFormatter::new(config, context)?;
55 Ok(Self::OAI(Arc::new(formatter)))
56 }
57}
58
59/// Chat Template Jinja Renderer
60///
61/// Manages a Jinja environment with registered templates for chat formatting.
62/// Handles two types of ChatTemplateValue templates:
63///
64/// 1. String template: Registered as the 'default' template
65/// 2. Map template: Contains 'tool_use' and/or 'default' templates
66/// - tool_use: Template for tool-based interactions
67/// - default: Template for standard chat interactions
68///
69/// If the map contains both keys, the `tool_use` template is registered as the `tool_use` template
70/// and the `default` template is registered as the `default` template.
71struct JinjaEnvironment {
72 env: Environment<'static>,
73}
74
75/// Formatter for HuggingFace tokenizer config JSON templates
76///
77/// Implements chat template rendering based on HuggingFace's tokenizer_config.json format.
78/// Supports:
79/// - Tool usage templates
80/// - Generation prompts
81/// - Context mixins for template customization
82#[derive(Debug)]
83struct HfTokenizerConfigJsonFormatter {
84 env: Environment<'static>,
85 config: ChatTemplate,
86 mixins: Arc<ContextMixins>,
87 supports_add_generation_prompt: bool,
88}
89
90// /// OpenAI Standard Prompt Formatter
91// pub trait StandardPromptFormatter {
92// fn render(&self, context: &impl StandardPromptContext) -> Result<String>;
93// }
94
95// pub trait StandardPromptContext {
96// fn messages(&self) -> Value;
97// fn tools(&self) -> Option<Value>;
98// }
99
100#[derive(Debug, Clone, Default)]
101pub struct ContextMixins {
102 context_mixins: HashSet<PromptContextMixin>,
103}