Skip to main content

batuta/serve/
templates.rs

1//! Chat Template Engine
2//!
3//! Unified prompt templating for different model formats.
4//! Implements Toyota Way "Standardized Work" principle.
5//!
6//! ## Supported Formats
7//!
8//! - Llama 2: `[INST] {prompt} [/INST]`
9//! - Mistral: `[INST] {prompt} [/INST]`
10//! - ChatML: `<|im_start|>user\n{prompt}<|im_end|>`
11//! - Alpaca: `### Instruction:\n{prompt}\n### Response:`
12//! - Vicuna: `USER: {prompt}\nASSISTANT:`
13//! - Raw: No formatting (pass-through)
14
15use serde::{Deserialize, Serialize};
16use std::fmt;
17
18// ============================================================================
19// SERVE-TPL-001: Chat Template Types
20// ============================================================================
21
22/// Chat message role
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
24#[serde(rename_all = "lowercase")]
25pub enum Role {
26    System,
27    User,
28    Assistant,
29}
30
31impl fmt::Display for Role {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        match self {
34            Self::System => write!(f, "system"),
35            Self::User => write!(f, "user"),
36            Self::Assistant => write!(f, "assistant"),
37        }
38    }
39}
40
41/// A chat message with role and content
42#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
43pub struct ChatMessage {
44    pub role: Role,
45    pub content: String,
46}
47
48impl ChatMessage {
49    fn with_role(role: Role, content: impl Into<String>) -> Self {
50        Self { role, content: content.into() }
51    }
52
53    pub fn system(content: impl Into<String>) -> Self {
54        Self::with_role(Role::System, content)
55    }
56
57    pub fn user(content: impl Into<String>) -> Self {
58        Self::with_role(Role::User, content)
59    }
60
61    pub fn assistant(content: impl Into<String>) -> Self {
62        Self::with_role(Role::Assistant, content)
63    }
64}
65
66/// Chat template format
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
68pub enum TemplateFormat {
69    /// Llama 2 format: `[INST] {prompt} [/INST]`
70    Llama2,
71    /// Mistral format: `[INST] {prompt} [/INST]`
72    Mistral,
73    /// ChatML format: <|im_start|>role\n{content}<|im_end|>
74    ChatML,
75    /// Alpaca format: ### Instruction:\n{prompt}\n### Response:
76    Alpaca,
77    /// Vicuna format: USER: {prompt}\nASSISTANT:
78    Vicuna,
79    /// Raw pass-through (no formatting)
80    #[default]
81    Raw,
82}
83
84impl TemplateFormat {
85    /// Detect template format from model name
86    #[must_use]
87    pub fn from_model_name(name: &str) -> Self {
88        let lower = name.to_lowercase();
89        if lower.contains("llama-2") || lower.contains("llama2") {
90            Self::Llama2
91        } else if lower.contains("mistral") || lower.contains("mixtral") {
92            Self::Mistral
93        } else if lower.contains("chatml") || lower.contains("openhermes") {
94            Self::ChatML
95        } else if lower.contains("alpaca") {
96            Self::Alpaca
97        } else if lower.contains("vicuna") {
98            Self::Vicuna
99        } else {
100            Self::Raw
101        }
102    }
103}
104
105// ============================================================================
106// SERVE-TPL-002: Chat Template Engine
107// ============================================================================
108
109/// Chat template engine for formatting prompts
110#[derive(Debug, Clone)]
111pub struct ChatTemplateEngine {
112    format: TemplateFormat,
113    bos_token: Option<String>,
114    eos_token: Option<String>,
115}
116
117impl ChatTemplateEngine {
118    /// Create a new template engine with the specified format
119    #[must_use]
120    pub fn new(format: TemplateFormat) -> Self {
121        let (bos_token, eos_token) = match format {
122            TemplateFormat::Llama2 | TemplateFormat::Mistral => {
123                (Some("<s>".to_string()), Some("</s>".to_string()))
124            }
125            _ => (None, None),
126        };
127        Self { format, bos_token, eos_token }
128    }
129
130    /// Create engine from model name (auto-detect format)
131    #[must_use]
132    pub fn from_model(model_name: &str) -> Self {
133        Self::new(TemplateFormat::from_model_name(model_name))
134    }
135
136    /// Get the template format
137    #[must_use]
138    pub fn format(&self) -> TemplateFormat {
139        self.format
140    }
141
142    /// Apply template to a list of chat messages
143    #[must_use]
144    pub fn apply(&self, messages: &[ChatMessage]) -> String {
145        match self.format {
146            TemplateFormat::Llama2 => self.apply_llama2(messages),
147            TemplateFormat::Mistral => self.apply_mistral(messages),
148            TemplateFormat::ChatML => self.apply_chatml(messages),
149            TemplateFormat::Alpaca => self.apply_alpaca(messages),
150            TemplateFormat::Vicuna => self.apply_vicuna(messages),
151            TemplateFormat::Raw => self.apply_raw(messages),
152        }
153    }
154
155    /// Apply template to a simple prompt string
156    #[must_use]
157    pub fn apply_prompt(&self, prompt: &str) -> String {
158        self.apply(&[ChatMessage::user(prompt)])
159    }
160
161    /// Push BOS token to result if configured
162    fn push_bos(&self, result: &mut String) {
163        if let Some(ref bos) = self.bos_token {
164            result.push_str(bos);
165        }
166    }
167
168    /// Push EOS token to result if configured
169    fn push_eos(&self, result: &mut String) {
170        if let Some(ref eos) = self.eos_token {
171            result.push_str(eos);
172        }
173    }
174
175    // Llama 2 format
176    fn apply_llama2(&self, messages: &[ChatMessage]) -> String {
177        let mut result = String::new();
178        self.push_bos(&mut result);
179
180        let mut system_prompt = None;
181        for msg in messages {
182            match msg.role {
183                Role::System => {
184                    system_prompt = Some(&msg.content);
185                }
186                Role::User => {
187                    result.push_str("[INST] ");
188                    if let Some(sys) = system_prompt.take() {
189                        result.push_str("<<SYS>>\n");
190                        result.push_str(sys);
191                        result.push_str("\n<</SYS>>\n\n");
192                    }
193                    result.push_str(&msg.content);
194                    result.push_str(" [/INST]");
195                }
196                Role::Assistant => {
197                    result.push(' ');
198                    result.push_str(&msg.content);
199                    self.push_eos(&mut result);
200                }
201            }
202        }
203
204        result
205    }
206
207    // Mistral format (similar to Llama 2 but without <<SYS>> tags)
208    fn apply_mistral(&self, messages: &[ChatMessage]) -> String {
209        let mut result = String::new();
210        self.push_bos(&mut result);
211
212        for msg in messages {
213            match msg.role {
214                Role::System => {
215                    // Mistral prepends system to first user message
216                    result.push_str("[INST] ");
217                    result.push_str(&msg.content);
218                    result.push_str("\n\n");
219                }
220                Role::User => {
221                    if !result.contains("[INST]") {
222                        result.push_str("[INST] ");
223                    }
224                    result.push_str(&msg.content);
225                    result.push_str(" [/INST]");
226                }
227                Role::Assistant => {
228                    result.push_str(&msg.content);
229                    self.push_eos(&mut result);
230                }
231            }
232        }
233
234        result
235    }
236
237    // ChatML format
238    fn apply_chatml(&self, messages: &[ChatMessage]) -> String {
239        let mut result = String::new();
240
241        for msg in messages {
242            result.push_str("<|im_start|>");
243            result.push_str(&msg.role.to_string());
244            result.push('\n');
245            result.push_str(&msg.content);
246            result.push_str("<|im_end|>\n");
247        }
248
249        // Add assistant prompt
250        result.push_str("<|im_start|>assistant\n");
251
252        result
253    }
254
255    // Alpaca format
256    fn apply_alpaca(&self, messages: &[ChatMessage]) -> String {
257        let mut result = String::new();
258
259        for msg in messages {
260            match msg.role {
261                Role::System => {
262                    result.push_str(&msg.content);
263                    result.push_str("\n\n");
264                }
265                Role::User => {
266                    result.push_str("### Instruction:\n");
267                    result.push_str(&msg.content);
268                    result.push_str("\n\n### Response:\n");
269                }
270                Role::Assistant => {
271                    result.push_str(&msg.content);
272                    result.push('\n');
273                }
274            }
275        }
276
277        result
278    }
279
280    // Vicuna format
281    fn apply_vicuna(&self, messages: &[ChatMessage]) -> String {
282        let mut result = String::new();
283
284        for msg in messages {
285            match msg.role {
286                Role::System => {
287                    result.push_str(&msg.content);
288                    result.push_str("\n\n");
289                }
290                Role::User => {
291                    result.push_str("USER: ");
292                    result.push_str(&msg.content);
293                    result.push_str("\nASSISTANT:");
294                }
295                Role::Assistant => {
296                    result.push(' ');
297                    result.push_str(&msg.content);
298                    result.push('\n');
299                }
300            }
301        }
302
303        result
304    }
305
306    // Raw format (no transformation)
307    fn apply_raw(&self, messages: &[ChatMessage]) -> String {
308        messages.iter().map(|m| m.content.as_str()).collect::<Vec<_>>().join("\n")
309    }
310}
311
312impl Default for ChatTemplateEngine {
313    fn default() -> Self {
314        Self::new(TemplateFormat::Raw)
315    }
316}
317
318// ============================================================================
319// Tests
320// ============================================================================
321
322#[cfg(test)]
323#[allow(non_snake_case)]
324mod tests {
325    use super::*;
326
327    // ========================================================================
328    // Test Helpers
329    // ========================================================================
330
331    /// Assert that a model name maps to the expected template format.
332    fn assert_format_detected(model_name: &str, expected: TemplateFormat) {
333        assert_eq!(
334            TemplateFormat::from_model_name(model_name),
335            expected,
336            "model name {model_name:?} should map to {expected:?}"
337        );
338    }
339
340    /// Assert that a `ChatMessage` constructor produced the expected role and content.
341    fn assert_message(msg: &ChatMessage, expected_role: Role, expected_content: &str) {
342        assert_eq!(msg.role, expected_role);
343        assert_eq!(msg.content, expected_content);
344    }
345
346    /// Convenience: render a single prompt through the given format.
347    fn render_prompt(format: TemplateFormat, prompt: &str) -> String {
348        ChatTemplateEngine::new(format).apply_prompt(prompt)
349    }
350
351    /// Build a standard multi-turn conversation for reuse across format tests.
352    fn multiturn_messages() -> Vec<ChatMessage> {
353        vec![
354            ChatMessage::user("Hi!"),
355            ChatMessage::assistant("Hello!"),
356            ChatMessage::user("How are you?"),
357        ]
358    }
359
360    // ========================================================================
361    // SERVE-TPL-001: Role and ChatMessage Tests
362    // ========================================================================
363
364    #[test]
365    fn test_SERVE_TPL_001_role_display() {
366        assert_eq!(format!("{}", Role::System), "system");
367        assert_eq!(format!("{}", Role::User), "user");
368        assert_eq!(format!("{}", Role::Assistant), "assistant");
369    }
370
371    #[test]
372    fn test_SERVE_TPL_001_chat_message_system() {
373        let msg = ChatMessage::system("You are a helpful assistant.");
374        assert_message(&msg, Role::System, "You are a helpful assistant.");
375    }
376
377    #[test]
378    fn test_SERVE_TPL_001_chat_message_user() {
379        let msg = ChatMessage::user("Hello!");
380        assert_message(&msg, Role::User, "Hello!");
381    }
382
383    #[test]
384    fn test_SERVE_TPL_001_chat_message_assistant() {
385        let msg = ChatMessage::assistant("Hi there!");
386        assert_message(&msg, Role::Assistant, "Hi there!");
387    }
388
389    // ========================================================================
390    // SERVE-TPL-002: Template Format Detection Tests
391    // ========================================================================
392
393    #[test]
394    fn test_SERVE_TPL_002_detect_llama2() {
395        assert_format_detected("meta-llama/Llama-2-7b", TemplateFormat::Llama2);
396        assert_format_detected("llama2-13b", TemplateFormat::Llama2);
397    }
398
399    #[test]
400    fn test_SERVE_TPL_002_detect_mistral() {
401        assert_format_detected("mistralai/Mistral-7B", TemplateFormat::Mistral);
402        assert_format_detected("mixtral-8x7b", TemplateFormat::Mistral);
403    }
404
405    #[test]
406    fn test_SERVE_TPL_002_detect_chatml() {
407        assert_format_detected("OpenHermes-2.5", TemplateFormat::ChatML);
408        assert_format_detected("chatml-model", TemplateFormat::ChatML);
409    }
410
411    #[test]
412    fn test_SERVE_TPL_002_detect_alpaca() {
413        assert_format_detected("alpaca-7b", TemplateFormat::Alpaca);
414    }
415
416    #[test]
417    fn test_SERVE_TPL_002_detect_vicuna() {
418        assert_format_detected("vicuna-13b", TemplateFormat::Vicuna);
419    }
420
421    #[test]
422    fn test_SERVE_TPL_002_detect_raw_fallback() {
423        assert_format_detected("unknown-model", TemplateFormat::Raw);
424    }
425
426    // ========================================================================
427    // SERVE-TPL-003: Llama 2 Template Tests
428    // ========================================================================
429
430    #[test]
431    fn test_SERVE_TPL_003_llama2_simple() {
432        let result = render_prompt(TemplateFormat::Llama2, "Hello!");
433        assert!(result.contains("[INST]"));
434        assert!(result.contains("[/INST]"));
435        assert!(result.contains("Hello!"));
436    }
437
438    #[test]
439    fn test_SERVE_TPL_003_llama2_with_system() {
440        let engine = ChatTemplateEngine::new(TemplateFormat::Llama2);
441        let messages = vec![ChatMessage::system("You are helpful."), ChatMessage::user("Hi!")];
442        let result = engine.apply(&messages);
443        assert!(result.contains("<<SYS>>"));
444        assert!(result.contains("You are helpful."));
445        assert!(result.contains("<</SYS>>"));
446        assert!(result.contains("Hi!"));
447    }
448
449    #[test]
450    fn test_SERVE_TPL_003_llama2_bos_token() {
451        let result = render_prompt(TemplateFormat::Llama2, "Test");
452        assert!(result.starts_with("<s>"));
453    }
454
455    // ========================================================================
456    // SERVE-TPL-004: Mistral Template Tests
457    // ========================================================================
458
459    #[test]
460    fn test_SERVE_TPL_004_mistral_simple() {
461        let result = render_prompt(TemplateFormat::Mistral, "Hello!");
462        assert!(result.contains("[INST]"));
463        assert!(result.contains("[/INST]"));
464    }
465
466    #[test]
467    fn test_SERVE_TPL_004_mistral_no_sys_tags() {
468        let engine = ChatTemplateEngine::new(TemplateFormat::Mistral);
469        let messages = vec![ChatMessage::system("Be helpful."), ChatMessage::user("Hi!")];
470        let result = engine.apply(&messages);
471        // Mistral doesn't use <<SYS>> tags
472        assert!(!result.contains("<<SYS>>"));
473    }
474
475    // ========================================================================
476    // SERVE-TPL-005: ChatML Template Tests
477    // ========================================================================
478
479    #[test]
480    fn test_SERVE_TPL_005_chatml_simple() {
481        let result = render_prompt(TemplateFormat::ChatML, "Hello!");
482        assert!(result.contains("<|im_start|>user"));
483        assert!(result.contains("<|im_end|>"));
484        assert!(result.contains("<|im_start|>assistant"));
485    }
486
487    #[test]
488    fn test_SERVE_TPL_005_chatml_with_system() {
489        let engine = ChatTemplateEngine::new(TemplateFormat::ChatML);
490        let messages = vec![ChatMessage::system("You are an AI."), ChatMessage::user("Hi!")];
491        let result = engine.apply(&messages);
492        assert!(result.contains("<|im_start|>system"));
493        assert!(result.contains("You are an AI."));
494    }
495
496    // ========================================================================
497    // SERVE-TPL-006: Alpaca Template Tests
498    // ========================================================================
499
500    #[test]
501    fn test_SERVE_TPL_006_alpaca_simple() {
502        let result = render_prompt(TemplateFormat::Alpaca, "What is 2+2?");
503        assert!(result.contains("### Instruction:"));
504        assert!(result.contains("### Response:"));
505        assert!(result.contains("What is 2+2?"));
506    }
507
508    // ========================================================================
509    // SERVE-TPL-007: Vicuna Template Tests
510    // ========================================================================
511
512    #[test]
513    fn test_SERVE_TPL_007_vicuna_simple() {
514        let result = render_prompt(TemplateFormat::Vicuna, "Hello!");
515        assert!(result.contains("USER:"));
516        assert!(result.contains("ASSISTANT:"));
517    }
518
519    // ========================================================================
520    // SERVE-TPL-008: Raw Template Tests
521    // ========================================================================
522
523    #[test]
524    fn test_SERVE_TPL_008_raw_passthrough() {
525        let result = render_prompt(TemplateFormat::Raw, "Hello!");
526        assert_eq!(result, "Hello!");
527    }
528
529    #[test]
530    fn test_SERVE_TPL_008_raw_multiple_messages() {
531        let engine = ChatTemplateEngine::new(TemplateFormat::Raw);
532        let messages = vec![ChatMessage::user("A"), ChatMessage::user("B")];
533        let result = engine.apply(&messages);
534        assert_eq!(result, "A\nB");
535    }
536
537    // ========================================================================
538    // SERVE-TPL-009: Engine Factory Tests
539    // ========================================================================
540
541    #[test]
542    fn test_SERVE_TPL_009_from_model() {
543        let engine = ChatTemplateEngine::from_model("meta-llama/Llama-2-7b-chat");
544        assert_eq!(engine.format(), TemplateFormat::Llama2);
545    }
546
547    #[test]
548    fn test_SERVE_TPL_009_default() {
549        let engine = ChatTemplateEngine::default();
550        assert_eq!(engine.format(), TemplateFormat::Raw);
551    }
552
553    // ========================================================================
554    // SERVE-TPL-010: Multi-turn Conversation Tests
555    // ========================================================================
556
557    #[test]
558    fn test_SERVE_TPL_010_llama2_multiturn() {
559        let engine = ChatTemplateEngine::new(TemplateFormat::Llama2);
560        let result = engine.apply(&multiturn_messages());
561        // Should have multiple [INST] blocks
562        assert!(result.matches("[INST]").count() >= 2);
563    }
564
565    #[test]
566    fn test_SERVE_TPL_010_chatml_multiturn() {
567        let engine = ChatTemplateEngine::new(TemplateFormat::ChatML);
568        let result = engine.apply(&multiturn_messages());
569        // Should have multiple im_start tags
570        assert!(result.matches("<|im_start|>").count() >= 3);
571    }
572
573    // ========================================================================
574    // SERVE-TPL-011: Multi-turn Vicuna Template Tests
575    // ========================================================================
576
577    #[test]
578    fn test_SERVE_TPL_011_vicuna_with_system() {
579        let engine = ChatTemplateEngine::new(TemplateFormat::Vicuna);
580        let messages = vec![ChatMessage::system("You are helpful."), ChatMessage::user("Hi!")];
581        let result = engine.apply(&messages);
582        assert!(result.contains("You are helpful."));
583        assert!(result.contains("USER: Hi!"));
584        assert!(result.contains("ASSISTANT:"));
585    }
586
587    #[test]
588    fn test_SERVE_TPL_011_vicuna_with_assistant_response() {
589        let engine = ChatTemplateEngine::new(TemplateFormat::Vicuna);
590        let messages = vec![ChatMessage::user("Hi!"), ChatMessage::assistant("Hello there!")];
591        let result = engine.apply(&messages);
592        assert!(result.contains("USER: Hi!"));
593        assert!(result.contains(" Hello there!"));
594    }
595
596    #[test]
597    fn test_SERVE_TPL_011_vicuna_multiturn() {
598        let engine = ChatTemplateEngine::new(TemplateFormat::Vicuna);
599        let result = engine.apply(&multiturn_messages());
600        // Two user messages
601        assert_eq!(result.matches("USER:").count(), 2);
602        // One assistant response
603        assert!(result.contains(" Hello!"));
604    }
605
606    #[test]
607    fn test_SERVE_TPL_011_vicuna_system_and_assistant() {
608        let engine = ChatTemplateEngine::new(TemplateFormat::Vicuna);
609        let messages = vec![
610            ChatMessage::system("Be concise."),
611            ChatMessage::user("What is 2+2?"),
612            ChatMessage::assistant("4"),
613            ChatMessage::user("And 3+3?"),
614        ];
615        let result = engine.apply(&messages);
616        assert!(result.contains("Be concise."));
617        assert!(result.contains("USER: What is 2+2?"));
618        assert!(result.contains(" 4\n"));
619        assert!(result.contains("USER: And 3+3?"));
620    }
621
622    // ========================================================================
623    // SERVE-TPL-012: Multi-turn Alpaca Template Tests
624    // ========================================================================
625
626    #[test]
627    fn test_SERVE_TPL_012_alpaca_with_system() {
628        let engine = ChatTemplateEngine::new(TemplateFormat::Alpaca);
629        let messages =
630            vec![ChatMessage::system("You are a tutor."), ChatMessage::user("Explain gravity.")];
631        let result = engine.apply(&messages);
632        assert!(result.contains("You are a tutor."));
633        assert!(result.contains("### Instruction:"));
634        assert!(result.contains("Explain gravity."));
635        assert!(result.contains("### Response:"));
636    }
637
638    #[test]
639    fn test_SERVE_TPL_012_alpaca_with_assistant_response() {
640        let engine = ChatTemplateEngine::new(TemplateFormat::Alpaca);
641        let messages = vec![
642            ChatMessage::user("What is AI?"),
643            ChatMessage::assistant("Artificial Intelligence."),
644        ];
645        let result = engine.apply(&messages);
646        assert!(result.contains("### Instruction:"));
647        assert!(result.contains("What is AI?"));
648        assert!(result.contains("Artificial Intelligence.\n"));
649    }
650
651    #[test]
652    fn test_SERVE_TPL_012_alpaca_multiturn() {
653        let engine = ChatTemplateEngine::new(TemplateFormat::Alpaca);
654        let result = engine.apply(&multiturn_messages());
655        // Two instructions
656        assert_eq!(result.matches("### Instruction:").count(), 2);
657        // One assistant response
658        assert!(result.contains("Hello!\n"));
659    }
660
661    #[test]
662    fn test_SERVE_TPL_012_alpaca_system_and_multiturn() {
663        let engine = ChatTemplateEngine::new(TemplateFormat::Alpaca);
664        let messages = vec![
665            ChatMessage::system("Be brief."),
666            ChatMessage::user("Define ML."),
667            ChatMessage::assistant("Machine Learning."),
668            ChatMessage::user("Define AI."),
669        ];
670        let result = engine.apply(&messages);
671        assert!(result.contains("Be brief.\n\n"));
672        assert!(result.contains("### Instruction:\nDefine ML."));
673        assert!(result.contains("Machine Learning.\n"));
674        assert!(result.contains("### Instruction:\nDefine AI."));
675    }
676
677    // ========================================================================
678    // SERVE-TPL-013: Multi-turn Mistral Template Tests
679    // ========================================================================
680
681    #[test]
682    fn test_SERVE_TPL_013_mistral_multiturn() {
683        let engine = ChatTemplateEngine::new(TemplateFormat::Mistral);
684        let result = engine.apply(&multiturn_messages());
685        // Verify BOS token
686        assert!(result.starts_with("<s>"));
687        // First user message
688        assert!(result.contains("[INST] Hi! [/INST]"));
689        // Assistant response followed by EOS
690        assert!(result.contains("Hello!</s>"));
691    }
692
693    #[test]
694    fn test_SERVE_TPL_013_mistral_with_system_and_assistant() {
695        let engine = ChatTemplateEngine::new(TemplateFormat::Mistral);
696        let messages = vec![
697            ChatMessage::system("You are an expert."),
698            ChatMessage::user("Explain ML."),
699            ChatMessage::assistant("Machine Learning is..."),
700            ChatMessage::user("More detail."),
701        ];
702        let result = engine.apply(&messages);
703        assert!(result.contains("[INST] You are an expert."));
704        assert!(result.contains("Explain ML. [/INST]"));
705        assert!(result.contains("Machine Learning is...</s>"));
706        assert!(result.contains("More detail. [/INST]"));
707    }
708
709    #[test]
710    fn test_SERVE_TPL_013_mistral_system_prepends_to_first_inst() {
711        let engine = ChatTemplateEngine::new(TemplateFormat::Mistral);
712        let messages = vec![ChatMessage::system("Be helpful."), ChatMessage::user("Hi!")];
713        let result = engine.apply(&messages);
714        // System message should come first inside [INST]
715        assert!(result.contains("[INST] Be helpful."));
716        assert!(result.contains("Hi! [/INST]"));
717    }
718
719    // ========================================================================
720    // SERVE-TPL-014: Llama2 Multi-turn with Assistant
721    // ========================================================================
722
723    #[test]
724    fn test_SERVE_TPL_014_llama2_multiturn_with_assistant() {
725        let engine = ChatTemplateEngine::new(TemplateFormat::Llama2);
726        let messages = vec![
727            ChatMessage::system("You are an AI."),
728            ChatMessage::user("Hello!"),
729            ChatMessage::assistant("Hi!"),
730            ChatMessage::user("How are you?"),
731        ];
732        let result = engine.apply(&messages);
733        assert!(result.starts_with("<s>"));
734        assert!(result.contains("<<SYS>>"));
735        assert!(result.contains("You are an AI."));
736        assert!(result.contains("<</SYS>>"));
737        assert!(result.contains(" Hi!</s>"));
738        assert!(result.contains("[INST] How are you? [/INST]"));
739    }
740
741    // ========================================================================
742    // SERVE-TPL-015: ChatML Multi-turn with System and Assistant
743    // ========================================================================
744
745    #[test]
746    fn test_SERVE_TPL_015_chatml_system_and_multiturn() {
747        let engine = ChatTemplateEngine::new(TemplateFormat::ChatML);
748        let messages = vec![
749            ChatMessage::system("Be concise."),
750            ChatMessage::user("Hi!"),
751            ChatMessage::assistant("Hello!"),
752            ChatMessage::user("Bye!"),
753        ];
754        let result = engine.apply(&messages);
755        assert!(result.contains("<|im_start|>system\nBe concise.<|im_end|>"));
756        assert!(result.contains("<|im_start|>user\nHi!<|im_end|>"));
757        assert!(result.contains("<|im_start|>assistant\nHello!<|im_end|>"));
758        assert!(result.contains("<|im_start|>user\nBye!<|im_end|>"));
759        // Trailing assistant prompt
760        assert!(result.ends_with("<|im_start|>assistant\n"));
761    }
762}