Skip to main content

llama_cpp_bindings/model/
llama_chat_template.rs

1use std::ffi::{CStr, CString};
2use std::str::Utf8Error;
3
4#[derive(Eq, PartialEq, Clone, PartialOrd, Ord, Hash)]
5pub struct LlamaChatTemplate(pub CString);
6
7impl LlamaChatTemplate {
8    /// # Errors
9    /// Returns an error if the template string contains null bytes.
10    pub fn new(template: &str) -> Result<Self, std::ffi::NulError> {
11        Ok(Self(CString::new(template)?))
12    }
13
14    #[must_use]
15    pub fn as_c_str(&self) -> &CStr {
16        &self.0
17    }
18
19    /// # Errors
20    /// Returns an error if the template is not valid UTF-8.
21    pub fn to_str(&self) -> Result<&str, Utf8Error> {
22        self.0.to_str()
23    }
24
25    /// # Errors
26    /// Returns an error if the template is not valid UTF-8.
27    pub fn to_string(&self) -> Result<String, Utf8Error> {
28        self.to_str().map(str::to_string)
29    }
30}
31
32impl std::fmt::Debug for LlamaChatTemplate {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        self.0.fmt(f)
35    }
36}
37
38#[cfg(test)]
39mod tests {
40    use super::LlamaChatTemplate;
41
42    #[test]
43    fn valid_template_creation() {
44        let template = LlamaChatTemplate::new("chatml").unwrap();
45        let template_str = template.to_str().unwrap();
46
47        assert_eq!(template_str, "chatml");
48    }
49
50    #[test]
51    fn null_byte_returns_error() {
52        let template = LlamaChatTemplate::new("null\0byte");
53
54        assert!(template.is_err());
55    }
56
57    #[test]
58    fn debug_formatting() {
59        let template = LlamaChatTemplate::new("chatml").unwrap();
60        let debug_output = format!("{template:?}");
61
62        assert!(debug_output.contains("chatml"));
63    }
64
65    #[test]
66    fn to_string_returns_owned_string() {
67        let template = LlamaChatTemplate::new("llama3").unwrap();
68        let owned = template.to_string().unwrap();
69
70        assert_eq!(owned, "llama3");
71    }
72
73    #[test]
74    fn as_c_str_returns_valid_cstr() {
75        let template = LlamaChatTemplate::new("test").unwrap();
76        let cstr = template.as_c_str();
77
78        assert_eq!(cstr.to_str().unwrap(), "test");
79    }
80}