agent_chain_core/prompts/
string.rs1use std::collections::{HashMap, HashSet};
7
8use crate::error::{Error, Result};
9use crate::utils::formatting::{FORMATTER, FormattingError};
10use crate::utils::mustache::{MustacheValue, render as mustache_render};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
14pub enum PromptTemplateFormat {
15 #[default]
17 FString,
18 Mustache,
20 Jinja2,
22}
23
24impl std::str::FromStr for PromptTemplateFormat {
25 type Err = Error;
26
27 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
28 match s {
29 "f-string" | "fstring" | "f_string" => Ok(Self::FString),
30 "mustache" => Ok(Self::Mustache),
31 "jinja2" => Ok(Self::Jinja2),
32 _ => Err(Error::InvalidConfig(format!(
33 "Invalid template format: {}. Expected one of: f-string, mustache, jinja2",
34 s
35 ))),
36 }
37 }
38}
39
40impl PromptTemplateFormat {
41 pub fn as_str(&self) -> &'static str {
43 match self {
44 Self::FString => "f-string",
45 Self::Mustache => "mustache",
46 Self::Jinja2 => "jinja2",
47 }
48 }
49}
50
51impl std::fmt::Display for PromptTemplateFormat {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 write!(f, "{}", self.as_str())
54 }
55}
56
57impl serde::Serialize for PromptTemplateFormat {
58 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
59 where
60 S: serde::Serializer,
61 {
62 serializer.serialize_str(self.as_str())
63 }
64}
65
66impl<'de> serde::Deserialize<'de> for PromptTemplateFormat {
67 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
68 where
69 D: serde::Deserializer<'de>,
70 {
71 use std::str::FromStr;
72 let s = String::deserialize(deserializer)?;
73 Self::from_str(&s).map_err(serde::de::Error::custom)
74 }
75}
76
77pub fn jinja2_formatter(template: &str, kwargs: &HashMap<String, String>) -> Result<String> {
91 let mut result = template.to_string();
95
96 for (key, value) in kwargs {
97 let pattern = format!("{{{{ {} }}}}", key);
99 result = result.replace(&pattern, value);
100
101 let pattern_no_space = format!("{{{{{}}}}}", key);
103 result = result.replace(&pattern_no_space, value);
104 }
105
106 Ok(result)
107}
108
109pub fn mustache_formatter(template: &str, kwargs: &HashMap<String, String>) -> Result<String> {
120 let mut data = HashMap::new();
121 for (key, value) in kwargs {
122 data.insert(key.clone(), MustacheValue::String(value.clone()));
123 }
124
125 mustache_render(template, &MustacheValue::Map(data), None)
126 .map_err(|e| Error::Other(format!("Mustache error: {}", e)))
127}
128
129pub fn validate_jinja2(template: &str, input_variables: &[String]) -> Result<()> {
138 let template_vars = get_jinja2_variables(template);
139 let input_set: HashSet<_> = input_variables.iter().cloned().collect();
140
141 let missing: Vec<_> = template_vars.difference(&input_set).collect();
142 let extra: Vec<_> = input_set.difference(&template_vars).collect();
143
144 if !missing.is_empty() || !extra.is_empty() {
145 let mut warning = String::new();
146 if !missing.is_empty() {
147 warning.push_str(&format!("Missing variables: {:?} ", missing));
148 }
149 if !extra.is_empty() {
150 warning.push_str(&format!("Extra variables: {:?}", extra));
151 }
152 eprintln!("Warning: {}", warning.trim());
153 }
154
155 Ok(())
156}
157
158fn get_jinja2_variables(template: &str) -> HashSet<String> {
160 let mut variables = HashSet::new();
161 let mut chars = template.chars().peekable();
162
163 while let Some(c) = chars.next() {
164 if c == '{' && chars.peek() == Some(&'{') {
165 chars.next(); while chars.peek() == Some(&' ') {
169 chars.next();
170 }
171
172 let mut var_name = String::new();
174 while let Some(&c) = chars.peek() {
175 if c == '}' || c == ' ' || c == '|' || c == '.' {
176 break;
177 }
178 var_name.push(c);
179 chars.next();
180 }
181
182 if !var_name.is_empty() && !var_name.starts_with('%') && !var_name.starts_with('#') {
183 variables.insert(var_name);
184 }
185 }
186 }
187
188 variables
189}
190
191pub fn mustache_template_vars(template: &str) -> HashSet<String> {
196 let mut variables = HashSet::new();
197 let mut chars = template.chars().peekable();
198 let mut section_depth = 0;
199
200 while let Some(c) = chars.next() {
201 if c == '{' && chars.peek() == Some(&'{') {
202 chars.next(); let first_char = chars.peek().cloned();
206
207 match first_char {
208 Some('#') | Some('^') => {
209 section_depth += 1;
210 while let Some(&c) = chars.peek() {
212 if c == '}' {
213 break;
214 }
215 chars.next();
216 }
217 }
218 Some('/') => {
219 section_depth -= 1;
220 while let Some(&c) = chars.peek() {
222 if c == '}' {
223 break;
224 }
225 chars.next();
226 }
227 }
228 Some('!') | Some('>') => {
229 while let Some(&c) = chars.peek() {
231 if c == '}' {
232 break;
233 }
234 chars.next();
235 }
236 }
237 Some('{') => {
238 chars.next();
240 let mut var_name = String::new();
241 while let Some(&c) = chars.peek() {
242 if c == '}' {
243 break;
244 }
245 var_name.push(c);
246 chars.next();
247 }
248 let var_name = var_name.trim();
249 if !var_name.is_empty() && var_name != "." && section_depth == 0 {
250 let top_level = var_name.split('.').next().unwrap_or(var_name);
251 variables.insert(top_level.to_string());
252 }
253 }
254 Some('&') => {
255 chars.next();
257 let mut var_name = String::new();
258 while let Some(&c) = chars.peek() {
259 if c == '}' {
260 break;
261 }
262 var_name.push(c);
263 chars.next();
264 }
265 let var_name = var_name.trim();
266 if !var_name.is_empty() && var_name != "." && section_depth == 0 {
267 let top_level = var_name.split('.').next().unwrap_or(var_name);
268 variables.insert(top_level.to_string());
269 }
270 }
271 _ => {
272 let mut var_name = String::new();
274 while let Some(&c) = chars.peek() {
275 if c == '}' {
276 break;
277 }
278 var_name.push(c);
279 chars.next();
280 }
281 let var_name = var_name.trim();
282 if !var_name.is_empty() && var_name != "." && section_depth == 0 {
283 let top_level = var_name.split('.').next().unwrap_or(var_name);
284 variables.insert(top_level.to_string());
285 }
286 }
287 }
288 }
289 }
290
291 variables
292}
293
294pub fn check_valid_template(
306 template: &str,
307 template_format: PromptTemplateFormat,
308 input_variables: &[String],
309) -> Result<()> {
310 match template_format {
311 PromptTemplateFormat::FString => FORMATTER
312 .validate_input_variables(template, input_variables)
313 .map_err(|e| match e {
314 FormattingError::MissingKey(key) => Error::InvalidConfig(format!(
315 "Invalid prompt schema; missing input parameter: {}",
316 key
317 )),
318 FormattingError::InvalidFormat(msg) => {
319 Error::InvalidConfig(format!("Invalid format string: {}", msg))
320 }
321 }),
322 PromptTemplateFormat::Jinja2 => validate_jinja2(template, input_variables),
323 PromptTemplateFormat::Mustache => {
324 Ok(())
326 }
327 }
328}
329
330pub fn get_template_variables(
341 template: &str,
342 template_format: PromptTemplateFormat,
343) -> Result<Vec<String>> {
344 let variables: HashSet<String> = match template_format {
345 PromptTemplateFormat::FString => {
346 let placeholders = FORMATTER.extract_placeholders(template);
347 for var in &placeholders {
349 if var.contains('.') || var.contains('[') || var.contains(']') {
350 return Err(Error::InvalidConfig(format!(
351 "Invalid variable name '{}' in f-string template. \
352 Variable names cannot contain attribute access (.) or indexing ([]).",
353 var
354 )));
355 }
356 if var.chars().all(|c| c.is_ascii_digit()) {
357 return Err(Error::InvalidConfig(format!(
358 "Invalid variable name '{}' in f-string template. \
359 Variable names cannot be all digits as they are interpreted as positional arguments.",
360 var
361 )));
362 }
363 }
364 placeholders
365 }
366 PromptTemplateFormat::Jinja2 => get_jinja2_variables(template),
367 PromptTemplateFormat::Mustache => mustache_template_vars(template),
368 };
369
370 let mut vars: Vec<_> = variables.into_iter().collect();
371 vars.sort();
372 Ok(vars)
373}
374
375pub fn format_template(
377 template: &str,
378 template_format: PromptTemplateFormat,
379 kwargs: &HashMap<String, String>,
380) -> Result<String> {
381 match template_format {
382 PromptTemplateFormat::FString => FORMATTER.format(template, kwargs).map_err(|e| match e {
383 FormattingError::MissingKey(key) => {
384 Error::InvalidConfig(format!("Missing key in format string: {}", key))
385 }
386 FormattingError::InvalidFormat(msg) => {
387 Error::InvalidConfig(format!("Invalid format string: {}", msg))
388 }
389 }),
390 PromptTemplateFormat::Mustache => mustache_formatter(template, kwargs),
391 PromptTemplateFormat::Jinja2 => jinja2_formatter(template, kwargs),
392 }
393}
394
395pub trait StringPromptTemplate: Send + Sync {
399 fn input_variables(&self) -> &[String];
401
402 fn optional_variables(&self) -> &[String] {
404 &[]
405 }
406
407 fn partial_variables(&self) -> &HashMap<String, String> {
409 static EMPTY: std::sync::LazyLock<HashMap<String, String>> =
410 std::sync::LazyLock::new(HashMap::new);
411 &EMPTY
412 }
413
414 fn template_format(&self) -> PromptTemplateFormat {
416 PromptTemplateFormat::FString
417 }
418
419 fn format(&self, kwargs: &HashMap<String, String>) -> Result<String>;
429
430 fn aformat(
434 &self,
435 kwargs: &HashMap<String, String>,
436 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String>> + Send + '_>> {
437 let result = self.format(kwargs);
438 Box::pin(async move { result })
439 }
440
441 fn pretty_repr(&self, html: bool) -> String;
443
444 fn pretty_print(&self) {
446 println!("{}", self.pretty_repr(false));
447 }
448}
449
450#[allow(dead_code)]
455pub fn is_subsequence<T: PartialEq>(child: &[T], parent: &[T]) -> bool {
456 if child.is_empty() || parent.is_empty() {
457 return false;
458 }
459 if parent.len() < child.len() {
460 return false;
461 }
462 child.iter().zip(parent.iter()).all(|(c, p)| c == p)
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468
469 #[test]
470 fn test_template_format_from_str() {
471 use std::str::FromStr;
472 assert_eq!(
473 PromptTemplateFormat::from_str("f-string").unwrap(),
474 PromptTemplateFormat::FString
475 );
476 assert_eq!(
477 PromptTemplateFormat::from_str("mustache").unwrap(),
478 PromptTemplateFormat::Mustache
479 );
480 assert_eq!(
481 PromptTemplateFormat::from_str("jinja2").unwrap(),
482 PromptTemplateFormat::Jinja2
483 );
484 }
485
486 #[test]
487 fn test_get_template_variables_fstring() {
488 let vars = get_template_variables(
489 "Hello, {name}! You are {age} years old.",
490 PromptTemplateFormat::FString,
491 )
492 .unwrap();
493 assert!(vars.contains(&"name".to_string()));
494 assert!(vars.contains(&"age".to_string()));
495 assert_eq!(vars.len(), 2);
496 }
497
498 #[test]
499 fn test_get_template_variables_mustache() {
500 let vars = get_template_variables(
501 "Hello, {{name}}! You are {{age}} years old.",
502 PromptTemplateFormat::Mustache,
503 )
504 .unwrap();
505 assert!(vars.contains(&"name".to_string()));
506 assert!(vars.contains(&"age".to_string()));
507 assert_eq!(vars.len(), 2);
508 }
509
510 #[test]
511 fn test_format_template_fstring() {
512 let mut kwargs = HashMap::new();
513 kwargs.insert("name".to_string(), "World".to_string());
514
515 let result =
516 format_template("Hello, {name}!", PromptTemplateFormat::FString, &kwargs).unwrap();
517 assert_eq!(result, "Hello, World!");
518 }
519
520 #[test]
521 fn test_format_template_mustache() {
522 let mut kwargs = HashMap::new();
523 kwargs.insert("name".to_string(), "World".to_string());
524
525 let result =
526 format_template("Hello, {{name}}!", PromptTemplateFormat::Mustache, &kwargs).unwrap();
527 assert_eq!(result, "Hello, World!");
528 }
529
530 #[test]
531 fn test_invalid_fstring_variable() {
532 let result = get_template_variables("Hello {obj.attr}", PromptTemplateFormat::FString);
533 assert!(result.is_err());
534 }
535
536 #[test]
537 fn test_is_subsequence() {
538 assert!(is_subsequence(&[1, 2], &[1, 2, 3]));
539 assert!(!is_subsequence(&[1, 3], &[1, 2, 3]));
540 assert!(!is_subsequence(&[1, 2, 3, 4], &[1, 2, 3]));
541 }
542}