dynamo_llm/preprocessor/prompt/template/
oai.rs1use super::*;
5
6use minijinja::{context, value::Value};
7
8use crate::protocols::openai::{
9 chat_completions::NvCreateChatCompletionRequest, completions::NvCreateCompletionRequest,
10};
11use tracing;
12
13use crate::preprocessor::prompt::{PromptInput, TextInput, TokenInput};
14
15fn may_be_fix_tool_schema(tools: serde_json::Value) -> Option<Value> {
16 let mut updated_tools = Vec::new();
20 if let Some(arr) = tools.as_array() {
21 for tool in arr {
22 let mut tool = tool.clone();
23 if let Some(function) = tool.get_mut("function")
24 && let Some(parameters) = function.get_mut("parameters")
25 {
26 if parameters.is_object() {
28 let mut needs_type = false;
29 let mut needs_properties = false;
30 let is_empty = parameters
31 .as_object()
32 .map(|o| o.is_empty())
33 .unwrap_or(false);
34
35 if is_empty {
37 needs_type = true;
38 needs_properties = true;
39 } else {
40 if let Some(obj) = parameters.as_object() {
42 if !obj.contains_key("type") {
43 needs_type = true;
44 }
45 if !obj.contains_key("properties") {
46 needs_properties = true;
47 }
48 }
49 }
50
51 if (needs_type || needs_properties)
52 && let Some(obj) = parameters.as_object_mut()
53 {
54 if needs_type {
55 obj.insert(
56 "type".to_string(),
57 serde_json::Value::String("object".to_string()),
58 );
59 }
60 if needs_properties {
61 obj.insert(
62 "properties".to_string(),
63 serde_json::Value::Object(Default::default()),
64 );
65 }
66 }
67 }
68 }
69 updated_tools.push(tool);
70 }
71 }
72 Some(Value::from_serialize(&updated_tools))
73}
74
75fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
76 let Some(arr) = messages.as_array() else {
81 return Value::from_serialize(&messages);
82 };
83
84 let updated_messages: Vec<_> = arr
85 .iter()
86 .map(|msg| {
87 match msg.get("content") {
88 Some(serde_json::Value::Array(content_array)) => {
89 let is_text_only_array = !content_array.is_empty()
90 && content_array.iter().all(|part| {
91 part.get("type")
92 .and_then(|type_field| type_field.as_str())
93 .map(|type_str| type_str == "text")
94 .unwrap_or(false)
95 });
96
97 if is_text_only_array {
98 let mut modified_msg = msg.clone();
99 if let Some(msg_object) = modified_msg.as_object_mut() {
100 let text_parts: Vec<&str> = content_array
101 .iter()
102 .filter_map(|part| part.get("text")?.as_str())
103 .collect();
104 let concatenated_text = text_parts.join("\n");
105
106 msg_object.insert(
107 "content".to_string(),
108 serde_json::Value::String(concatenated_text),
109 );
110 }
111 modified_msg } else {
113 msg.clone() }
115 }
116 _ => msg.clone(), }
118 })
119 .collect();
120
121 Value::from_serialize(&updated_messages)
122}
123
124impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
125 fn model(&self) -> String {
126 self.inner.model.clone()
127 }
128
129 fn messages(&self) -> Value {
130 let messages_json = serde_json::to_value(&self.inner.messages).unwrap();
131
132 let needs_fixing = if let Some(arr) = messages_json.as_array() {
133 arr.iter()
134 .any(|msg| msg.get("content").and_then(|c| c.as_array()).is_some())
135 } else {
136 false
137 };
138
139 if needs_fixing {
140 may_be_fix_msg_content(messages_json)
141 } else {
142 Value::from_serialize(&messages_json)
143 }
144 }
145
146 fn tools(&self) -> Option<Value> {
147 if self.inner.tools.is_none() {
148 Some(Value::from_serialize(Vec::<serde_json::Value>::new()))
153 } else {
154 Some(may_be_fix_tool_schema(
156 serde_json::to_value(&self.inner.tools).unwrap(),
157 )?)
158 }
159 }
160
161 fn tool_choice(&self) -> Option<Value> {
162 if self.inner.tool_choice.is_none() {
163 None
164 } else {
165 Some(Value::from_serialize(&self.inner.tool_choice))
166 }
167 }
168
169 fn should_add_generation_prompt(&self) -> bool {
170 if let Some(last) = self.inner.messages.last() {
171 matches!(
172 last,
173 dynamo_async_openai::types::ChatCompletionRequestMessage::User(_)
174 )
175 } else {
176 true
177 }
178 }
179
180 fn extract_text(&self) -> Option<TextInput> {
181 Some(TextInput::Single(String::new()))
182 }
183
184 fn chat_template_args(&self) -> Option<&std::collections::HashMap<String, serde_json::Value>> {
185 self.chat_template_args.as_ref()
186 }
187}
188
189impl OAIChatLikeRequest for NvCreateCompletionRequest {
190 fn model(&self) -> String {
191 self.inner.model.clone()
192 }
193 fn messages(&self) -> minijinja::value::Value {
194 let message = dynamo_async_openai::types::ChatCompletionRequestMessage::User(
195 dynamo_async_openai::types::ChatCompletionRequestUserMessage {
196 content: dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(
197 crate::protocols::openai::completions::prompt_to_string(&self.inner.prompt),
198 ),
199 name: None,
200 },
201 );
202
203 minijinja::value::Value::from_serialize(vec![message])
204 }
205
206 fn should_add_generation_prompt(&self) -> bool {
207 true
208 }
209
210 fn prompt_input_type(&self) -> PromptInput {
211 match &self.inner.prompt {
212 dynamo_async_openai::types::Prompt::IntegerArray(_) => {
213 PromptInput::Tokens(TokenInput::Single(vec![]))
214 }
215 dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(_) => {
216 PromptInput::Tokens(TokenInput::Batch(vec![]))
217 }
218 dynamo_async_openai::types::Prompt::String(_) => {
219 PromptInput::Text(TextInput::Single(String::new()))
220 }
221 dynamo_async_openai::types::Prompt::StringArray(_) => {
222 PromptInput::Text(TextInput::Batch(vec![]))
223 }
224 }
225 }
226
227 fn extract_tokens(&self) -> Option<TokenInput> {
228 match &self.inner.prompt {
229 dynamo_async_openai::types::Prompt::IntegerArray(tokens) => {
230 Some(TokenInput::Single(tokens.clone()))
231 }
232 dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(arrays) => {
233 Some(TokenInput::Batch(arrays.clone()))
234 }
235 _ => None,
236 }
237 }
238
239 fn extract_text(&self) -> Option<TextInput> {
240 match &self.inner.prompt {
241 dynamo_async_openai::types::Prompt::String(text) => {
242 Some(TextInput::Single(text.to_string()))
243 }
244 dynamo_async_openai::types::Prompt::StringArray(texts) => {
245 Some(TextInput::Batch(texts.to_vec()))
246 }
247 _ => None,
248 }
249 }
250}
251
252impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
253 fn supports_add_generation_prompt(&self) -> bool {
254 self.supports_add_generation_prompt
255 }
256
257 fn render(&self, req: &dyn OAIChatLikeRequest) -> Result<String> {
258 let mixins = Value::from_dyn_object(self.mixins.clone());
259
260 let tools = req.tools();
261 let has_tools = tools.as_ref().and_then(|v| v.len()).is_some_and(|l| l > 0);
263 let add_generation_prompt = req.should_add_generation_prompt();
264
265 tracing::trace!(
266 "Rendering prompt with tools: {:?}, add_generation_prompt: {}",
267 has_tools,
268 add_generation_prompt
269 );
270
271 let ctx = context! {
272 messages => req.messages(),
273 tools => tools,
274 bos_token => self.config.bos_tok(),
275 eos_token => self.config.eos_tok(),
276 unk_token => self.config.unk_tok(),
277 add_generation_prompt => add_generation_prompt,
278 ..mixins
279 };
280
281 let ctx = if let Some(args) = req.chat_template_args() {
283 let extra = Value::from_serialize(args);
284 context! { ..ctx, ..extra }
285 } else {
286 ctx
287 };
288
289 let tmpl: minijinja::Template<'_, '_> = if has_tools {
290 self.env.get_template("tool_use")?
291 } else {
292 self.env.get_template("default")?
293 };
294 Ok(tmpl.render(&ctx)?)
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn test_may_be_fix_tool_schema_missing_type_and_properties() {
304 let json_str = r#"{
305 "model": "gpt-4o",
306 "messages": [],
307 "tools": [
308 {
309 "type": "function",
310 "function": {
311 "name": "get_weather",
312 "description": "Get the current weather in a given location",
313 "parameters": {},
314 "strict": null
315 }
316 }
317 ]
318 }"#;
319
320 let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
321 let tools = serde_json::to_value(request.tools()).unwrap();
322
323 assert!(tools[0]["function"]["parameters"]["type"] == "object");
324 assert!(
325 tools[0]["function"]["parameters"]["properties"]
326 == serde_json::Value::Object(Default::default())
327 );
328 }
329
330 #[test]
331 fn test_may_be_fix_tool_schema_missing_type() {
332 let json_str = r#"{
333 "model": "gpt-4o",
334 "messages": [],
335 "tools": [
336 {
337 "type": "function",
338 "function": {
339 "name": "get_weather",
340 "description": "Get the current weather in a given location",
341 "parameters": {
342 "properties": {
343 "location": {
344 "type": "string",
345 "description": "City and state, e.g., 'San Francisco, CA'"
346 }
347 }
348 },
349 "strict": null
350 }
351 }
352 ]
353 }"#;
354 let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
355
356 let tools = serde_json::to_value(request.tools()).unwrap();
357
358 assert_eq!(tools[0]["function"]["parameters"]["type"], "object");
359
360 let mut expected_properties = serde_json::Map::new();
361 let mut location = serde_json::Map::new();
362 location.insert(
363 "type".to_string(),
364 serde_json::Value::String("string".to_string()),
365 );
366 location.insert(
367 "description".to_string(),
368 serde_json::Value::String("City and state, e.g., 'San Francisco, CA'".to_string()),
369 );
370 expected_properties.insert("location".to_string(), serde_json::Value::Object(location));
371
372 assert_eq!(
373 tools[0]["function"]["parameters"]["properties"],
374 serde_json::Value::Object(expected_properties)
375 );
376 }
377
378 #[test]
379 fn test_may_be_fix_tool_schema_missing_properties() {
380 let json_str = r#"{
381 "model": "gpt-4o",
382 "messages": [],
383 "tools": [
384 {
385 "type": "function",
386 "function": {
387 "name": "get_weather",
388 "description": "Get the current weather in a given location",
389 "parameters": {"type": "object"},
390 "strict": null
391 }
392 }
393 ]
394 }"#;
395
396 let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
397 let tools = serde_json::to_value(request.tools()).unwrap();
398
399 assert_eq!(
400 tools[0]["function"]["parameters"]["properties"],
401 serde_json::Value::Object(Default::default())
402 );
403 assert_eq!(tools[0]["function"]["parameters"]["type"], "object");
404 }
405
406 #[test]
408 fn test_may_be_fix_msg_content_user_multipart() {
409 let json_str = r#"{
410 "model": "gpt-4o",
411 "messages": [
412 {
413 "role": "user",
414 "content": [
415 {"type": "text", "text": "part 1"},
416 {"type": "text", "text": "part 2"}
417 ]
418 }
419 ]
420 }"#;
421
422 let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
423 let messages = serde_json::to_value(request.messages()).unwrap();
424
425 assert_eq!(
427 messages[0]["content"],
428 serde_json::Value::String("part 1\npart 2".to_string())
429 );
430 }
431
432 #[test]
435 fn test_may_be_fix_msg_content_mixed_messages() {
436 let json_str = r#"{
437 "model": "gpt-4o",
438 "messages": [
439 {
440 "role": "system",
441 "content": "You are a helpful assistant"
442 },
443 {
444 "role": "user",
445 "content": [
446 {"type": "text", "text": "Hello"},
447 {"type": "text", "text": "World"}
448 ]
449 },
450 {
451 "role": "assistant",
452 "content": "Hi there!"
453 },
454 {
455 "role": "user",
456 "content": [
457 {"type": "text", "text": "Another"},
458 {"type": "text", "text": "multi-part"},
459 {"type": "text", "text": "message"}
460 ]
461 }
462 ]
463 }"#;
464
465 let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
466 let messages = serde_json::to_value(request.messages()).unwrap();
467
468 assert_eq!(
470 messages[0]["content"],
471 serde_json::Value::String("You are a helpful assistant".to_string())
472 );
473
474 assert_eq!(
476 messages[1]["content"],
477 serde_json::Value::String("Hello\nWorld".to_string())
478 );
479
480 assert_eq!(
482 messages[2]["content"],
483 serde_json::Value::String("Hi there!".to_string())
484 );
485
486 assert_eq!(
488 messages[3]["content"],
489 serde_json::Value::String("Another\nmulti-part\nmessage".to_string())
490 );
491 }
492
493 #[test]
495 fn test_may_be_fix_msg_content_empty_array() {
496 let json_str = r#"{
497 "model": "gpt-4o",
498 "messages": [
499 {
500 "role": "user",
501 "content": []
502 }
503 ]
504 }"#;
505
506 let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
507 let messages = serde_json::to_value(request.messages()).unwrap();
508
509 assert!(messages[0]["content"].is_array());
511 assert_eq!(messages[0]["content"].as_array().unwrap().len(), 0);
512 }
513
514 #[test]
516 fn test_may_be_fix_msg_content_single_text() {
517 let json_str = r#"{
518 "model": "gpt-4o",
519 "messages": [
520 {
521 "role": "user",
522 "content": "Simple text message"
523 }
524 ]
525 }"#;
526
527 let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
528 let messages = serde_json::to_value(request.messages()).unwrap();
529
530 assert_eq!(
532 messages[0]["content"],
533 serde_json::Value::String("Simple text message".to_string())
534 );
535 }
536
537 #[test]
539 fn test_may_be_fix_msg_content_mixed_types() {
540 let json_str = r#"{
541 "model": "gpt-4o",
542 "messages": [
543 {
544 "role": "user",
545 "content": [
546 {"type": "text", "text": "Check this image:"},
547 {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}},
548 {"type": "text", "text": "What do you see?"}
549 ]
550 }
551 ]
552 }"#;
553
554 let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
555 let messages = serde_json::to_value(request.messages()).unwrap();
556
557 assert!(messages[0]["content"].is_array());
559 let content_array = messages[0]["content"].as_array().unwrap();
560 assert_eq!(content_array.len(), 3);
561 assert_eq!(content_array[0]["type"], "text");
562 assert_eq!(content_array[1]["type"], "image_url");
563 assert_eq!(content_array[2]["type"], "text");
564 }
565
566 #[test]
568 fn test_may_be_fix_msg_content_non_text_only() {
569 let json_str = r#"{
570 "model": "gpt-4o",
571 "messages": [
572 {
573 "role": "user",
574 "content": [
575 {"type": "image_url", "image_url": {"url": "https://example.com/image1.jpg"}},
576 {"type": "image_url", "image_url": {"url": "https://example.com/image2.jpg"}}
577 ]
578 }
579 ]
580 }"#;
581
582 let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
583 let messages = serde_json::to_value(request.messages()).unwrap();
584
585 assert!(messages[0]["content"].is_array());
587 let content_array = messages[0]["content"].as_array().unwrap();
588 assert_eq!(content_array.len(), 2);
589 assert_eq!(content_array[0]["type"], "image_url");
590 assert_eq!(content_array[1]["type"], "image_url");
591 }
592
593 #[test]
595 fn test_may_be_fix_msg_content_multiple_content_types() {
596 let json_str = r#"{
598 "model": "gpt-4o",
599 "messages": [
600 {
601 "role": "user",
602 "content": [
603 {"type": "text", "text": "Listen to this:"},
604 {"type": "audio_url", "audio_url": {"url": "https://example.com/audio.mp3"}},
605 {"type": "text", "text": "And look at:"},
606 {"type": "image_url", "image_url": {"url": "https://example.com/img.jpg"}},
607 {"type": "text", "text": "What do you think?"}
608 ]
609 }
610 ]
611 }"#;
612
613 let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
614 let messages = serde_json::to_value(request.messages()).unwrap();
615
616 assert!(messages[0]["content"].is_array());
618 assert_eq!(messages[0]["content"].as_array().unwrap().len(), 5);
619
620 let json_str = r#"{
622 "model": "gpt-4o",
623 "messages": [
624 {
625 "role": "user",
626 "content": [
627 {"type": "text", "text": "Check this:"},
628 {"type": "video_url", "video_url": {"url": "https://example.com/vid.mp4"}},
629 {"type": "text", "text": "Interesting?"}
630 ]
631 }
632 ]
633 }"#;
634
635 let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
636 let messages = serde_json::to_value(request.messages()).unwrap();
637
638 assert!(messages[0]["content"].is_array());
640 assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3);
641 }
642}