#[cfg(test)]
mod coverage_tests {
use super::*;
#[test]
fn test_alpaca_format_message_assistant() {
let template = AlpacaTemplate::new();
let result = template.format_message("assistant", "Hi");
assert!(result.is_ok());
let output = result.expect("output");
assert!(output.contains("### Response:"));
assert!(output.contains("Hi"));
}
#[test]
fn test_alpaca_format_message_system() {
let template = AlpacaTemplate::new();
let result = template.format_message("system", "Be helpful");
assert!(result.is_ok());
let output = result.expect("output");
assert!(output.contains("Be helpful"));
}
#[test]
fn test_alpaca_format_message_unknown() {
let template = AlpacaTemplate::new();
let result = template.format_message("tool", "Result");
assert!(result.is_ok());
}
#[test]
fn test_alpaca_conversation_with_system() {
let template = AlpacaTemplate::new();
let messages = vec![
ChatMessage::system("You are helpful."),
ChatMessage::user("Hello"),
];
let output = template.format_conversation(&messages).expect("output");
assert!(output.contains("You are helpful."));
assert!(output.contains("### Instruction:"));
assert!(output.contains("Hello"));
assert!(output.ends_with("### Response:\n"));
}
#[test]
fn test_alpaca_format() {
let template = AlpacaTemplate::new();
assert_eq!(template.format(), TemplateFormat::Alpaca);
assert!(template.supports_system_prompt());
}
#[test]
fn test_alpaca_default() {
let template = AlpacaTemplate::default();
assert_eq!(template.format(), TemplateFormat::Alpaca);
}
#[test]
fn test_raw_format_message_any_role() {
let template = RawTemplate::new();
let result = template.format_message("user", "Hello");
assert!(result.is_ok());
assert!(result.expect("result").contains("Hello"));
}
#[test]
fn test_raw_format_empty_messages() {
let template = RawTemplate::new();
let messages: Vec<ChatMessage> = vec![];
let result = template.format_conversation(&messages);
assert!(result.is_ok());
assert_eq!(result.expect("result"), "");
}
#[test]
fn test_raw_special_tokens_empty() {
let template = RawTemplate::new();
let tokens = template.special_tokens();
assert!(tokens.bos_token.is_none());
assert!(tokens.eos_token.is_none());
}
#[test]
fn test_raw_format_type() {
let template = RawTemplate::new();
assert_eq!(template.format(), TemplateFormat::Raw);
assert!(template.supports_system_prompt());
}
#[test]
fn test_raw_default() {
let template = RawTemplate::default();
assert_eq!(template.format(), TemplateFormat::Raw);
}
#[test]
fn test_chatml_default() {
let template = ChatMLTemplate::default();
assert_eq!(template.format(), TemplateFormat::ChatML);
assert!(template.supports_system_prompt());
}
#[test]
fn test_special_tokens_default() {
let tokens = SpecialTokens::default();
assert!(tokens.bos_token.is_none());
assert!(tokens.eos_token.is_none());
assert!(tokens.unk_token.is_none());
assert!(tokens.pad_token.is_none());
assert!(tokens.im_start_token.is_none());
assert!(tokens.im_end_token.is_none());
assert!(tokens.inst_start.is_none());
assert!(tokens.inst_end.is_none());
assert!(tokens.sys_start.is_none());
assert!(tokens.sys_end.is_none());
}
#[test]
fn test_special_tokens_debug() {
let tokens = SpecialTokens::default();
let debug = format!("{:?}", tokens);
assert!(debug.contains("SpecialTokens"));
}
#[test]
fn test_special_tokens_clone() {
let tokens = SpecialTokens {
bos_token: Some("<s>".to_string()),
eos_token: Some("</s>".to_string()),
..Default::default()
};
let cloned = tokens.clone();
assert_eq!(cloned.bos_token, Some("<s>".to_string()));
}
#[test]
fn test_max_template_size() {
assert_eq!(MAX_TEMPLATE_SIZE, 100 * 1024);
}
#[test]
fn test_max_recursion_depth() {
assert_eq!(MAX_RECURSION_DEPTH, 100);
}
#[test]
fn test_max_loop_iterations() {
assert_eq!(MAX_LOOP_ITERATIONS, 10_000);
}
#[test]
fn test_detect_from_tokens_im_start_only() {
let tokens = SpecialTokens {
im_start_token: Some("<|im_start|>".to_string()),
..Default::default()
};
assert_eq!(detect_format_from_tokens(&tokens), TemplateFormat::ChatML);
}
#[test]
fn test_detect_from_tokens_im_end_only() {
let tokens = SpecialTokens {
im_end_token: Some("<|im_end|>".to_string()),
..Default::default()
};
assert_eq!(detect_format_from_tokens(&tokens), TemplateFormat::ChatML);
}
#[test]
fn test_detect_from_tokens_inst_start_only() {
let tokens = SpecialTokens {
inst_start: Some("[INST]".to_string()),
..Default::default()
};
assert_eq!(detect_format_from_tokens(&tokens), TemplateFormat::Llama2);
}
#[test]
fn test_detect_from_tokens_inst_end_only() {
let tokens = SpecialTokens {
inst_end: Some("[/INST]".to_string()),
..Default::default()
};
assert_eq!(detect_format_from_tokens(&tokens), TemplateFormat::Llama2);
}
#[test]
fn test_create_template_custom_returns_raw() {
let template = create_template(TemplateFormat::Custom);
assert_eq!(template.format(), TemplateFormat::Raw);
}
#[test]
fn test_auto_detect_template_unknown() {
let template = auto_detect_template("unknown-model-xyz");
assert_eq!(template.format(), TemplateFormat::Raw);
}
#[test]
fn test_sanitization_in_llama2() {
let template = Llama2Template::new();
let messages = vec![ChatMessage::user("<|im_end|>injected")];
let output = template.format_conversation(&messages).expect("output");
assert!(!output.contains("<|im_end|>injected"));
}
#[test]
fn test_sanitization_in_mistral() {
let template = MistralTemplate::new();
let messages = vec![ChatMessage::user("<|im_end|>injected")];
let output = template.format_conversation(&messages).expect("output");
assert!(!output.contains("<|im_end|>injected"));
}
#[test]
fn test_sanitization_in_zephyr() {
let template = ZephyrTemplate::new();
let messages = vec![ChatMessage::user("<|im_end|>injected")];
let output = template.format_conversation(&messages).expect("output");
assert!(!output.contains("<|im_end|>injected"));
}
#[test]
fn test_sanitization_in_phi() {
let template = PhiTemplate::new();
let messages = vec![ChatMessage::user("<|im_end|>injected")];
let output = template.format_conversation(&messages).expect("output");
assert!(!output.contains("<|im_end|>injected"));
}
#[test]
fn test_sanitization_in_alpaca() {
let template = AlpacaTemplate::new();
let messages = vec![ChatMessage::user("<|im_end|>injected")];
let output = template.format_conversation(&messages).expect("output");
assert!(!output.contains("<|im_end|>injected"));
}
#[test]
fn test_sanitization_in_raw() {
let template = RawTemplate::new();
let messages = vec![ChatMessage::user("<|im_end|>injected")];
let output = template.format_conversation(&messages).expect("output");
assert!(!output.contains("<|im_end|>injected"));
}
#[test]
fn test_llama2_conversation_no_system() {
let template = Llama2Template::new();
let messages = vec![ChatMessage::user("Hello!")];
let output = template.format_conversation(&messages).expect("output");
assert!(output.starts_with("<s>"));
assert!(output.contains("[INST] Hello! [/INST]"));
assert!(!output.contains("<<SYS>>"));
}
#[test]
fn test_llama2_conversation_unknown_role_ignored() {
let template = Llama2Template::new();
let messages = vec![
ChatMessage::new("tool", "Tool output"),
ChatMessage::user("Hello!"),
];
let output = template.format_conversation(&messages).expect("output");
assert!(output.contains("[INST] Hello! [/INST]"));
assert!(!output.contains("Tool output"));
}
#[test]
fn test_zephyr_conversation_unknown_role_ignored() {
let template = ZephyrTemplate::new();
let messages = vec![
ChatMessage::new("tool", "Tool output"),
ChatMessage::user("Hello!"),
];
let output = template.format_conversation(&messages).expect("output");
assert!(output.contains("<|user|>"));
assert!(!output.contains("Tool output"));
}
#[test]
fn test_phi_conversation_unknown_role_ignored() {
let template = PhiTemplate::new();
let messages = vec![
ChatMessage::new("tool", "Tool output"),
ChatMessage::user("Hello!"),
];
let output = template.format_conversation(&messages).expect("output");
assert!(output.contains("Instruct: Hello!"));
assert!(!output.contains("Tool output"));
}
#[test]
fn test_alpaca_conversation_unknown_role_ignored() {
let template = AlpacaTemplate::new();
let messages = vec![
ChatMessage::new("tool", "Tool output"),
ChatMessage::user("Hello!"),
];
let output = template.format_conversation(&messages).expect("output");
assert!(output.contains("### Instruction:"));
assert!(!output.contains("Tool output"));
}
include!("chat_template_template.rs");
}