dynamo_llm/preprocessor/prompt/template/
oai.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::*;
5
6use minijinja::{context, value::Value};
7
8use crate::protocols::openai::{
9    chat_completions::NvCreateChatCompletionRequest, completions::NvCreateCompletionRequest,
10};
11use tracing;
12
13use crate::preprocessor::prompt::{PromptInput, TextInput, TokenInput};
14
15fn may_be_fix_tool_schema(tools: serde_json::Value) -> Option<Value> {
16    // No need to validate or enforce other schema checks as the basic Named function schema is already validated while creating the request.
17    // Empty parameters is allowed by OpenAI at request level. Need to enforce it at template level.
18    // Whenever parameters is empty, insert "type": "object" and "properties": {}
19    let mut updated_tools = Vec::new();
20    if let Some(arr) = tools.as_array() {
21        for tool in arr {
22            let mut tool = tool.clone();
23            if let Some(function) = tool.get_mut("function")
24                && let Some(parameters) = function.get_mut("parameters")
25            {
26                // Only operate if parameters is an object
27                if parameters.is_object() {
28                    let mut needs_type = false;
29                    let mut needs_properties = false;
30                    let is_empty = parameters
31                        .as_object()
32                        .map(|o| o.is_empty())
33                        .unwrap_or(false);
34
35                    // If empty, we need to insert both
36                    if is_empty {
37                        needs_type = true;
38                        needs_properties = true;
39                    } else {
40                        // If not empty, check if type/properties are missing
41                        if let Some(obj) = parameters.as_object() {
42                            if !obj.contains_key("type") {
43                                needs_type = true;
44                            }
45                            if !obj.contains_key("properties") {
46                                needs_properties = true;
47                            }
48                        }
49                    }
50
51                    if (needs_type || needs_properties)
52                        && let Some(obj) = parameters.as_object_mut()
53                    {
54                        if needs_type {
55                            obj.insert(
56                                "type".to_string(),
57                                serde_json::Value::String("object".to_string()),
58                            );
59                        }
60                        if needs_properties {
61                            obj.insert(
62                                "properties".to_string(),
63                                serde_json::Value::Object(Default::default()),
64                            );
65                        }
66                    }
67                }
68            }
69            updated_tools.push(tool);
70        }
71    }
72    Some(Value::from_serialize(&updated_tools))
73}
74
75fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
76    // If messages[content] is provided as a list containing ONLY text parts,
77    // concatenate them into a string to match chat template expectations.
78    // Mixed content types are left for chat templates to handle.
79
80    let Some(arr) = messages.as_array() else {
81        return Value::from_serialize(&messages);
82    };
83
84    let updated_messages: Vec<_> = arr
85        .iter()
86        .map(|msg| {
87            match msg.get("content") {
88                Some(serde_json::Value::Array(content_array)) => {
89                    let is_text_only_array = !content_array.is_empty()
90                        && content_array.iter().all(|part| {
91                            part.get("type")
92                                .and_then(|type_field| type_field.as_str())
93                                .map(|type_str| type_str == "text")
94                                .unwrap_or(false)
95                        });
96
97                    if is_text_only_array {
98                        let mut modified_msg = msg.clone();
99                        if let Some(msg_object) = modified_msg.as_object_mut() {
100                            let text_parts: Vec<&str> = content_array
101                                .iter()
102                                .filter_map(|part| part.get("text")?.as_str())
103                                .collect();
104                            let concatenated_text = text_parts.join("\n");
105
106                            msg_object.insert(
107                                "content".to_string(),
108                                serde_json::Value::String(concatenated_text),
109                            );
110                        }
111                        modified_msg // Concatenated string content
112                    } else {
113                        msg.clone() // Mixed content or non-text only
114                    }
115                }
116                _ => msg.clone(), // String content or missing content - return unchanged
117            }
118        })
119        .collect();
120
121    Value::from_serialize(&updated_messages)
122}
123
124impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
125    fn model(&self) -> String {
126        self.inner.model.clone()
127    }
128
129    fn messages(&self) -> Value {
130        let messages_json = serde_json::to_value(&self.inner.messages).unwrap();
131
132        let needs_fixing = if let Some(arr) = messages_json.as_array() {
133            arr.iter()
134                .any(|msg| msg.get("content").and_then(|c| c.as_array()).is_some())
135        } else {
136            false
137        };
138
139        if needs_fixing {
140            may_be_fix_msg_content(messages_json)
141        } else {
142            Value::from_serialize(&messages_json)
143        }
144    }
145
146    fn tools(&self) -> Option<Value> {
147        if self.inner.tools.is_none() {
148            // ISSUE: {%- if tools is iterable and tools | length > 0 %}
149            // For cases like above, minijinja will not error out in calculating the length of tools
150            // as it evaluates both the sides an don't do short circuiting.
151            // Safe to return an empty array here. This will work even if tools are not present as length = 0
152            Some(Value::from_serialize(Vec::<serde_json::Value>::new()))
153        } else {
154            // Try to fix the tool schema if it is missing type and properties
155            Some(may_be_fix_tool_schema(
156                serde_json::to_value(&self.inner.tools).unwrap(),
157            )?)
158        }
159    }
160
161    fn tool_choice(&self) -> Option<Value> {
162        if self.inner.tool_choice.is_none() {
163            None
164        } else {
165            Some(Value::from_serialize(&self.inner.tool_choice))
166        }
167    }
168
169    fn should_add_generation_prompt(&self) -> bool {
170        if let Some(last) = self.inner.messages.last() {
171            matches!(
172                last,
173                dynamo_async_openai::types::ChatCompletionRequestMessage::User(_)
174            )
175        } else {
176            true
177        }
178    }
179
180    fn extract_text(&self) -> Option<TextInput> {
181        Some(TextInput::Single(String::new()))
182    }
183
184    fn chat_template_args(&self) -> Option<&std::collections::HashMap<String, serde_json::Value>> {
185        self.chat_template_args.as_ref()
186    }
187}
188
189impl OAIChatLikeRequest for NvCreateCompletionRequest {
190    fn model(&self) -> String {
191        self.inner.model.clone()
192    }
193    fn messages(&self) -> minijinja::value::Value {
194        let message = dynamo_async_openai::types::ChatCompletionRequestMessage::User(
195            dynamo_async_openai::types::ChatCompletionRequestUserMessage {
196                content: dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(
197                    crate::protocols::openai::completions::prompt_to_string(&self.inner.prompt),
198                ),
199                name: None,
200            },
201        );
202
203        minijinja::value::Value::from_serialize(vec![message])
204    }
205
206    fn should_add_generation_prompt(&self) -> bool {
207        true
208    }
209
210    fn prompt_input_type(&self) -> PromptInput {
211        match &self.inner.prompt {
212            dynamo_async_openai::types::Prompt::IntegerArray(_) => {
213                PromptInput::Tokens(TokenInput::Single(vec![]))
214            }
215            dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(_) => {
216                PromptInput::Tokens(TokenInput::Batch(vec![]))
217            }
218            dynamo_async_openai::types::Prompt::String(_) => {
219                PromptInput::Text(TextInput::Single(String::new()))
220            }
221            dynamo_async_openai::types::Prompt::StringArray(_) => {
222                PromptInput::Text(TextInput::Batch(vec![]))
223            }
224        }
225    }
226
227    fn extract_tokens(&self) -> Option<TokenInput> {
228        match &self.inner.prompt {
229            dynamo_async_openai::types::Prompt::IntegerArray(tokens) => {
230                Some(TokenInput::Single(tokens.clone()))
231            }
232            dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(arrays) => {
233                Some(TokenInput::Batch(arrays.clone()))
234            }
235            _ => None,
236        }
237    }
238
239    fn extract_text(&self) -> Option<TextInput> {
240        match &self.inner.prompt {
241            dynamo_async_openai::types::Prompt::String(text) => {
242                Some(TextInput::Single(text.to_string()))
243            }
244            dynamo_async_openai::types::Prompt::StringArray(texts) => {
245                Some(TextInput::Batch(texts.to_vec()))
246            }
247            _ => None,
248        }
249    }
250}
251
252impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
253    fn supports_add_generation_prompt(&self) -> bool {
254        self.supports_add_generation_prompt
255    }
256
257    fn render(&self, req: &dyn OAIChatLikeRequest) -> Result<String> {
258        let mixins = Value::from_dyn_object(self.mixins.clone());
259
260        let tools = req.tools();
261        // has_tools should be true if tools is a non-empty array
262        let has_tools = tools.as_ref().and_then(|v| v.len()).is_some_and(|l| l > 0);
263        let add_generation_prompt = req.should_add_generation_prompt();
264
265        tracing::trace!(
266            "Rendering prompt with tools: {:?}, add_generation_prompt: {}",
267            has_tools,
268            add_generation_prompt
269        );
270
271        let ctx = context! {
272            messages => req.messages(),
273            tools => tools,
274            bos_token => self.config.bos_tok(),
275            eos_token => self.config.eos_tok(),
276            unk_token => self.config.unk_tok(),
277            add_generation_prompt => add_generation_prompt,
278            ..mixins
279        };
280
281        // Merge any additional args into the context last so they take precedence
282        let ctx = if let Some(args) = req.chat_template_args() {
283            let extra = Value::from_serialize(args);
284            context! { ..ctx, ..extra }
285        } else {
286            ctx
287        };
288
289        let tmpl: minijinja::Template<'_, '_> = if has_tools {
290            self.env.get_template("tool_use")?
291        } else {
292            self.env.get_template("default")?
293        };
294        Ok(tmpl.render(&ctx)?)
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn test_may_be_fix_tool_schema_missing_type_and_properties() {
304        let json_str = r#"{
305            "model": "gpt-4o",
306            "messages": [],
307            "tools": [
308                {
309                    "type": "function",
310                    "function": {
311                        "name": "get_weather",
312                        "description": "Get the current weather in a given location",
313                        "parameters": {},
314                        "strict": null
315                    }
316                }
317            ]
318        }"#;
319
320        let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
321        let tools = serde_json::to_value(request.tools()).unwrap();
322
323        assert!(tools[0]["function"]["parameters"]["type"] == "object");
324        assert!(
325            tools[0]["function"]["parameters"]["properties"]
326                == serde_json::Value::Object(Default::default())
327        );
328    }
329
330    #[test]
331    fn test_may_be_fix_tool_schema_missing_type() {
332        let json_str = r#"{
333            "model": "gpt-4o",
334            "messages": [],
335            "tools": [
336                {
337                    "type": "function",
338                    "function": {
339                        "name": "get_weather",
340                        "description": "Get the current weather in a given location",
341                        "parameters": {
342                            "properties": {
343                                "location": {
344                                    "type": "string",
345                                    "description": "City and state, e.g., 'San Francisco, CA'"
346                                }
347                            }
348                        },
349                        "strict": null
350                    }
351                }
352            ]
353        }"#;
354        let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
355
356        let tools = serde_json::to_value(request.tools()).unwrap();
357
358        assert_eq!(tools[0]["function"]["parameters"]["type"], "object");
359
360        let mut expected_properties = serde_json::Map::new();
361        let mut location = serde_json::Map::new();
362        location.insert(
363            "type".to_string(),
364            serde_json::Value::String("string".to_string()),
365        );
366        location.insert(
367            "description".to_string(),
368            serde_json::Value::String("City and state, e.g., 'San Francisco, CA'".to_string()),
369        );
370        expected_properties.insert("location".to_string(), serde_json::Value::Object(location));
371
372        assert_eq!(
373            tools[0]["function"]["parameters"]["properties"],
374            serde_json::Value::Object(expected_properties)
375        );
376    }
377
378    #[test]
379    fn test_may_be_fix_tool_schema_missing_properties() {
380        let json_str = r#"{
381            "model": "gpt-4o",
382            "messages": [],
383            "tools": [
384                {
385                    "type": "function",
386                    "function": {
387                        "name": "get_weather",
388                        "description": "Get the current weather in a given location",
389                        "parameters": {"type": "object"},
390                        "strict": null
391                    }
392                }
393            ]
394        }"#;
395
396        let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
397        let tools = serde_json::to_value(request.tools()).unwrap();
398
399        assert_eq!(
400            tools[0]["function"]["parameters"]["properties"],
401            serde_json::Value::Object(Default::default())
402        );
403        assert_eq!(tools[0]["function"]["parameters"]["type"], "object");
404    }
405
406    /// Tests that content arrays (containing only text parts) are correctly concatenated.
407    #[test]
408    fn test_may_be_fix_msg_content_user_multipart() {
409        let json_str = r#"{
410            "model": "gpt-4o",
411            "messages": [
412                {
413                    "role": "user",
414                    "content": [
415                        {"type": "text", "text": "part 1"},
416                        {"type": "text", "text": "part 2"}
417                    ]
418                }
419            ]
420        }"#;
421
422        let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
423        let messages = serde_json::to_value(request.messages()).unwrap();
424
425        // Verify: text-only array is concatenated into a single string
426        assert_eq!(
427            messages[0]["content"],
428            serde_json::Value::String("part 1\npart 2".to_string())
429        );
430    }
431
432    /// Tests that the function correctly handles a conversation
433    /// with multiple roles and mixed message types:
434    #[test]
435    fn test_may_be_fix_msg_content_mixed_messages() {
436        let json_str = r#"{
437            "model": "gpt-4o",
438            "messages": [
439                {
440                    "role": "system",
441                    "content": "You are a helpful assistant"
442                },
443                {
444                    "role": "user",
445                    "content": [
446                        {"type": "text", "text": "Hello"},
447                        {"type": "text", "text": "World"}
448                    ]
449                },
450                {
451                    "role": "assistant",
452                    "content": "Hi there!"
453                },
454                {
455                    "role": "user",
456                    "content": [
457                        {"type": "text", "text": "Another"},
458                        {"type": "text", "text": "multi-part"},
459                        {"type": "text", "text": "message"}
460                    ]
461                }
462            ]
463        }"#;
464
465        let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
466        let messages = serde_json::to_value(request.messages()).unwrap();
467
468        // Verify: System message with string content remains unchanged
469        assert_eq!(
470            messages[0]["content"],
471            serde_json::Value::String("You are a helpful assistant".to_string())
472        );
473
474        // Verify: User message with text-only array is concatenated
475        assert_eq!(
476            messages[1]["content"],
477            serde_json::Value::String("Hello\nWorld".to_string())
478        );
479
480        // Verify: Assistant message with string content remains unchanged
481        assert_eq!(
482            messages[2]["content"],
483            serde_json::Value::String("Hi there!".to_string())
484        );
485
486        // Verify: Second user message with text-only array is concatenated
487        assert_eq!(
488            messages[3]["content"],
489            serde_json::Value::String("Another\nmulti-part\nmessage".to_string())
490        );
491    }
492
493    /// Tests that empty content arrays remain unchanged.
494    #[test]
495    fn test_may_be_fix_msg_content_empty_array() {
496        let json_str = r#"{
497            "model": "gpt-4o",
498            "messages": [
499                {
500                    "role": "user",
501                    "content": []
502                }
503            ]
504        }"#;
505
506        let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
507        let messages = serde_json::to_value(request.messages()).unwrap();
508
509        // Verify: Empty arrays are preserved as-is
510        assert!(messages[0]["content"].is_array());
511        assert_eq!(messages[0]["content"].as_array().unwrap().len(), 0);
512    }
513
514    /// Tests that messages with simple string content remain unchanged.
515    #[test]
516    fn test_may_be_fix_msg_content_single_text() {
517        let json_str = r#"{
518            "model": "gpt-4o",
519            "messages": [
520                {
521                    "role": "user",
522                    "content": "Simple text message"
523                }
524            ]
525        }"#;
526
527        let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
528        let messages = serde_json::to_value(request.messages()).unwrap();
529
530        // Verify: String content is not modified
531        assert_eq!(
532            messages[0]["content"],
533            serde_json::Value::String("Simple text message".to_string())
534        );
535    }
536
537    /// Tests that content arrays with mixed types (text + non-text) remain as arrays.
538    #[test]
539    fn test_may_be_fix_msg_content_mixed_types() {
540        let json_str = r#"{
541            "model": "gpt-4o",
542            "messages": [
543                {
544                    "role": "user",
545                    "content": [
546                        {"type": "text", "text": "Check this image:"},
547                        {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}},
548                        {"type": "text", "text": "What do you see?"}
549                    ]
550                }
551            ]
552        }"#;
553
554        let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
555        let messages = serde_json::to_value(request.messages()).unwrap();
556
557        // Verify: Mixed content types are preserved as array for template handling
558        assert!(messages[0]["content"].is_array());
559        let content_array = messages[0]["content"].as_array().unwrap();
560        assert_eq!(content_array.len(), 3);
561        assert_eq!(content_array[0]["type"], "text");
562        assert_eq!(content_array[1]["type"], "image_url");
563        assert_eq!(content_array[2]["type"], "text");
564    }
565
566    /// Tests that content arrays containing only non-text types remain as arrays.
567    #[test]
568    fn test_may_be_fix_msg_content_non_text_only() {
569        let json_str = r#"{
570            "model": "gpt-4o",
571            "messages": [
572                {
573                    "role": "user",
574                    "content": [
575                        {"type": "image_url", "image_url": {"url": "https://example.com/image1.jpg"}},
576                        {"type": "image_url", "image_url": {"url": "https://example.com/image2.jpg"}}
577                    ]
578                }
579            ]
580        }"#;
581
582        let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
583        let messages = serde_json::to_value(request.messages()).unwrap();
584
585        // Verify: Non-text content arrays are preserved for template handling
586        assert!(messages[0]["content"].is_array());
587        let content_array = messages[0]["content"].as_array().unwrap();
588        assert_eq!(content_array.len(), 2);
589        assert_eq!(content_array[0]["type"], "image_url");
590        assert_eq!(content_array[1]["type"], "image_url");
591    }
592
593    /// Tests mixed content type scenarios.
594    #[test]
595    fn test_may_be_fix_msg_content_multiple_content_types() {
596        // Scenario 1: Multiple different content types (text + image + audio)
597        let json_str = r#"{
598            "model": "gpt-4o",
599            "messages": [
600                {
601                    "role": "user",
602                    "content": [
603                        {"type": "text", "text": "Listen to this:"},
604                        {"type": "audio_url", "audio_url": {"url": "https://example.com/audio.mp3"}},
605                        {"type": "text", "text": "And look at:"},
606                        {"type": "image_url", "image_url": {"url": "https://example.com/img.jpg"}},
607                        {"type": "text", "text": "What do you think?"}
608                    ]
609                }
610            ]
611        }"#;
612
613        let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
614        let messages = serde_json::to_value(request.messages()).unwrap();
615
616        // Mixed types should preserve array structure
617        assert!(messages[0]["content"].is_array());
618        assert_eq!(messages[0]["content"].as_array().unwrap().len(), 5);
619
620        // Scenario 2: Unknown/future content types mixed with text
621        let json_str = r#"{
622            "model": "gpt-4o",
623            "messages": [
624                {
625                    "role": "user",
626                    "content": [
627                        {"type": "text", "text": "Check this:"},
628                        {"type": "video_url", "video_url": {"url": "https://example.com/vid.mp4"}},
629                        {"type": "text", "text": "Interesting?"}
630                    ]
631                }
632            ]
633        }"#;
634
635        let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
636        let messages = serde_json::to_value(request.messages()).unwrap();
637
638        // Unknown types mixed with text should preserve array
639        assert!(messages[0]["content"].is_array());
640        assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3);
641    }
642}