llm_chain/prompt/string_template/mod.rs
1mod tera;
2
3mod error;
4pub use error::StringTemplateError;
5use error::StringTemplateErrorImpl;
6use std::fmt;
7mod io;
8
9use serde::{Deserialize, Serialize};
10
11use crate::Parameters;
12
13/// A template for a prompt. This is a string that can be formatted with a set of parameters.
14///
15/// # Examples
16/// **Using the default key**
17/// ```
18/// use llm_chain::prompt::StringTemplate;
19/// use llm_chain::Parameters;
20/// let template: StringTemplate = "Hello {{ text }}!".into();
21/// let parameters: Parameters = "World".into();
22/// assert_eq!(template.format(¶meters).unwrap(), "Hello World!");
23/// ```
24/// **Using a custom key**
25/// ```
26/// use llm_chain::prompt::StringTemplate;
27/// use llm_chain::Parameters;
28/// let template: StringTemplate = "Hello {{ name }}!".into();
29/// let parameters: Parameters = vec![("name", "World")].into();
30/// assert_eq!(template.format(¶meters).unwrap(), "Hello World!");
31/// ```
32/// ## Tera
33/// ```rust
34/// use llm_chain::prompt::StringTemplate;
35/// use llm_chain::Parameters;
36/// let template: StringTemplate = StringTemplate::tera("Hello {{name}}!");
37/// let parameters: Parameters = vec![("name", "World")].into();
38/// assert_eq!(template.format(¶meters).unwrap(), "Hello World!");
39/// ```
40#[derive(Clone, Debug, Serialize, Deserialize)]
41#[serde(transparent)]
42pub struct StringTemplate(StringTemplateImpl);
43
44impl From<StringTemplateImpl> for StringTemplate {
45 fn from(template: StringTemplateImpl) -> Self {
46 Self(template)
47 }
48}
49
50impl StringTemplate {
51 /// Format the template with the given parameters.
52 pub fn format(&self, parameters: &Parameters) -> Result<String, error::StringTemplateError> {
53 self.0.format(parameters).map_err(|e| e.into())
54 }
55 /// Creates a non-dynmamic prompt template, useful for untrusted inputs.
56 pub fn static_string<K: Into<String>>(template: K) -> StringTemplate {
57 StringTemplateImpl::static_string(template.into()).into()
58 }
59
60 /// Creates a prompt template that uses the Tera templating engine.
61 /// This is only available if the `tera` feature is enabled, which it is by default.
62 /// # Examples
63 ///
64 /// ```rust
65 /// use llm_chain::prompt::StringTemplate;
66 /// use llm_chain::Parameters;
67 /// let template = StringTemplate::tera("Hello {{name}}!");
68 /// let parameters: Parameters = vec![("name", "World")].into();
69 /// assert_eq!(template.format(¶meters).unwrap(), "Hello World!");
70 /// ```
71 pub fn tera<K: Into<String>>(template: K) -> StringTemplate {
72 StringTemplateImpl::tera(template.into()).into()
73 }
74
75 /// Creates a prompt template from a file. The file should be a text file containing the template as a tera template.
76 /// # Examples
77 /// ```no_run
78 /// use llm_chain::prompt::StringTemplate;
79 /// let template = StringTemplate::from_file("template.txt").unwrap();
80 /// ```
81 pub fn from_file<K: AsRef<std::path::Path>>(path: K) -> Result<StringTemplate, std::io::Error> {
82 io::read_prompt_template_file(path)
83 }
84
85 /// Combines two prompt templates into one.
86 /// This is useful for creating a prompt template from multiple sources.
87 /// # Examples
88 /// ```
89 /// use llm_chain::prompt::StringTemplate;
90 /// use llm_chain::Parameters;
91 /// let template1 = StringTemplate::tera("Hello {{name}}");
92 /// let template2 = StringTemplate::tera("!");
93 /// let template3 = StringTemplate::combine(vec![template1, template2]);
94 /// let parameters: Parameters = vec![("name", "World")].into();
95 /// assert_eq!(template3.format(¶meters).unwrap(), "Hello World!");
96 /// ```
97 pub fn combine(parts: Vec<StringTemplate>) -> StringTemplate {
98 let res: Vec<StringTemplateImpl> = parts.into_iter().map(|p| p.0).collect();
99 StringTemplateImpl::combine(res).into()
100 }
101}
102
103impl fmt::Display for StringTemplate {
104 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105 write!(f, "{}", self.0)
106 }
107}
108
109/// The actual implementation of the prompt template. This hides the implementation details from the user.
110#[derive(Clone, Debug, Serialize, Deserialize)]
111enum StringTemplateImpl {
112 Static(String),
113 Tera(String),
114 Combined(Vec<StringTemplateImpl>),
115}
116
117impl StringTemplateImpl {
118 pub fn format(&self, parameters: &Parameters) -> Result<String, StringTemplateErrorImpl> {
119 match self {
120 Self::Static(template) => Ok(template.clone()),
121 Self::Tera(template) => tera::render(template, parameters).map_err(|e| e.into()),
122 Self::Combined(templates) => {
123 let mut result = String::new();
124 for template in templates {
125 let formatted = template.format(parameters)?;
126 result.push_str(&formatted);
127 }
128 Ok(result)
129 }
130 }
131 }
132
133 pub fn static_string(template: String) -> Self {
134 Self::Static(template)
135 }
136
137 pub fn tera(template: String) -> Self {
138 Self::Tera(template)
139 }
140
141 pub fn combine(templates: Vec<Self>) -> Self {
142 Self::Combined(templates)
143 }
144}
145
146impl fmt::Display for StringTemplateImpl {
147 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148 match self {
149 Self::Static(s) => write!(f, "{}", s),
150 Self::Tera(template) => write!(f, "{}", template),
151 Self::Combined(templates) => {
152 for template in templates {
153 write!(f, "{}", template)?;
154 }
155 Ok(())
156 }
157 }
158 }
159}
160
161impl From<&str> for StringTemplate {
162 fn from(template: &str) -> Self {
163 Self::tera(template.to_string())
164 }
165}