Skip to main content

llama_cpp_bindings/model/
llama_chat_template.rs

1use std::ffi::{CStr, CString};
2use std::str::Utf8Error;
3
4/// A performance-friendly wrapper around [`super::LlamaModel::chat_template`].
5///
6/// This is fed into [`super::LlamaModel::apply_chat_template`] to convert a list of messages into
7/// an LLM prompt. Internally the template is stored as a `CString` to avoid round-trip conversions
8/// within the FFI.
9#[derive(Eq, PartialEq, Clone, PartialOrd, Ord, Hash)]
10pub struct LlamaChatTemplate(pub CString);
11
12impl LlamaChatTemplate {
13    /// Create a new template from a string. This can either be the name of a llama.cpp [chat template](https://github.com/ggerganov/llama.cpp/blob/8a8c4ceb6050bd9392609114ca56ae6d26f5b8f5/src/llama-chat.cpp#L27-L61)
14    /// like "chatml" or "llama3" or an actual Jinja template for llama.cpp to interpret.
15    ///
16    /// # Errors
17    /// Returns an error if the template string contains null bytes.
18    pub fn new(template: &str) -> Result<Self, std::ffi::NulError> {
19        Ok(Self(CString::new(template)?))
20    }
21
22    /// Accesses the template as a c string reference.
23    #[must_use]
24    pub fn as_c_str(&self) -> &CStr {
25        &self.0
26    }
27
28    /// Attempts to convert the `CString` into a Rust str reference.
29    ///
30    /// # Errors
31    /// Returns an error if the template is not valid UTF-8.
32    pub fn to_str(&self) -> Result<&str, Utf8Error> {
33        self.0.to_str()
34    }
35
36    /// Convenience method to create an owned String.
37    ///
38    /// # Errors
39    /// Returns an error if the template is not valid UTF-8.
40    pub fn to_string(&self) -> Result<String, Utf8Error> {
41        self.to_str().map(str::to_string)
42    }
43}
44
45impl std::fmt::Debug for LlamaChatTemplate {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        self.0.fmt(f)
48    }
49}
50
51#[cfg(test)]
52mod tests {
53    use super::LlamaChatTemplate;
54
55    #[test]
56    fn valid_template_creation() {
57        let template = LlamaChatTemplate::new("chatml").unwrap();
58        let template_str = template.to_str().unwrap();
59
60        assert_eq!(template_str, "chatml");
61    }
62
63    #[test]
64    fn null_byte_returns_error() {
65        let template = LlamaChatTemplate::new("null\0byte");
66
67        assert!(template.is_err());
68    }
69
70    #[test]
71    fn debug_formatting() {
72        let template = LlamaChatTemplate::new("chatml").unwrap();
73        let debug_output = format!("{template:?}");
74
75        assert!(debug_output.contains("chatml"));
76    }
77
78    #[test]
79    fn to_string_returns_owned_string() {
80        let template = LlamaChatTemplate::new("llama3").unwrap();
81        let owned = template.to_string().unwrap();
82
83        assert_eq!(owned, "llama3");
84    }
85
86    #[test]
87    fn as_c_str_returns_valid_cstr() {
88        let template = LlamaChatTemplate::new("test").unwrap();
89        let cstr = template.as_c_str();
90
91        assert_eq!(cstr.to_str().unwrap(), "test");
92    }
93}