agent_chain_core/prompts/
string.rs

1//! Base string prompt template.
2//!
3//! This module provides the base string prompt template and formatting utilities,
4//! mirroring `langchain_core.prompts.string` in Python.
5
6use std::collections::{HashMap, HashSet};
7
8use crate::error::{Error, Result};
9use crate::utils::formatting::{FORMATTER, FormattingError};
10use crate::utils::mustache::{MustacheValue, render as mustache_render};
11
12/// Template format types.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
14pub enum PromptTemplateFormat {
15    /// F-string format using `{variable}` syntax.
16    #[default]
17    FString,
18    /// Mustache format using `{{variable}}` syntax.
19    Mustache,
20    /// Jinja2 format (requires jinja2 feature).
21    Jinja2,
22}
23
24impl std::str::FromStr for PromptTemplateFormat {
25    type Err = Error;
26
27    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
28        match s {
29            "f-string" | "fstring" | "f_string" => Ok(Self::FString),
30            "mustache" => Ok(Self::Mustache),
31            "jinja2" => Ok(Self::Jinja2),
32            _ => Err(Error::InvalidConfig(format!(
33                "Invalid template format: {}. Expected one of: f-string, mustache, jinja2",
34                s
35            ))),
36        }
37    }
38}
39
40impl PromptTemplateFormat {
41    /// Convert to a string representation.
42    pub fn as_str(&self) -> &'static str {
43        match self {
44            Self::FString => "f-string",
45            Self::Mustache => "mustache",
46            Self::Jinja2 => "jinja2",
47        }
48    }
49}
50
51impl std::fmt::Display for PromptTemplateFormat {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        write!(f, "{}", self.as_str())
54    }
55}
56
57impl serde::Serialize for PromptTemplateFormat {
58    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
59    where
60        S: serde::Serializer,
61    {
62        serializer.serialize_str(self.as_str())
63    }
64}
65
66impl<'de> serde::Deserialize<'de> for PromptTemplateFormat {
67    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
68    where
69        D: serde::Deserializer<'de>,
70    {
71        use std::str::FromStr;
72        let s = String::deserialize(deserializer)?;
73        Self::from_str(&s).map_err(serde::de::Error::custom)
74    }
75}
76
77/// Format a template using jinja2.
78///
79/// **Security warning**: Jinja2 templates can execute arbitrary code.
80/// Never use jinja2 templates from untrusted sources.
81///
82/// # Arguments
83///
84/// * `template` - The template string.
85/// * `kwargs` - The keyword arguments to substitute.
86///
87/// # Returns
88///
89/// The formatted string, or an error if formatting fails.
90pub fn jinja2_formatter(template: &str, kwargs: &HashMap<String, String>) -> Result<String> {
91    // Note: In Rust, we don't have a direct jinja2 equivalent.
92    // We'll implement a basic version using string replacement for now.
93    // For full jinja2 support, consider using the `minijinja` crate.
94    let mut result = template.to_string();
95
96    for (key, value) in kwargs {
97        // Replace {{ key }} patterns
98        let pattern = format!("{{{{ {} }}}}", key);
99        result = result.replace(&pattern, value);
100
101        // Also replace {{key}} without spaces
102        let pattern_no_space = format!("{{{{{}}}}}", key);
103        result = result.replace(&pattern_no_space, value);
104    }
105
106    Ok(result)
107}
108
109/// Format a template using mustache.
110///
111/// # Arguments
112///
113/// * `template` - The template string.
114/// * `kwargs` - The keyword arguments to substitute.
115///
116/// # Returns
117///
118/// The formatted string, or an error if formatting fails.
119pub fn mustache_formatter(template: &str, kwargs: &HashMap<String, String>) -> Result<String> {
120    let mut data = HashMap::new();
121    for (key, value) in kwargs {
122        data.insert(key.clone(), MustacheValue::String(value.clone()));
123    }
124
125    mustache_render(template, &MustacheValue::Map(data), None)
126        .map_err(|e| Error::Other(format!("Mustache error: {}", e)))
127}
128
129/// Validate that input variables match the template for jinja2.
130///
131/// Issues a warning if missing or extra variables are found.
132///
133/// # Arguments
134///
135/// * `template` - The template string.
136/// * `input_variables` - The input variables to validate.
137pub fn validate_jinja2(template: &str, input_variables: &[String]) -> Result<()> {
138    let template_vars = get_jinja2_variables(template);
139    let input_set: HashSet<_> = input_variables.iter().cloned().collect();
140
141    let missing: Vec<_> = template_vars.difference(&input_set).collect();
142    let extra: Vec<_> = input_set.difference(&template_vars).collect();
143
144    if !missing.is_empty() || !extra.is_empty() {
145        let mut warning = String::new();
146        if !missing.is_empty() {
147            warning.push_str(&format!("Missing variables: {:?} ", missing));
148        }
149        if !extra.is_empty() {
150            warning.push_str(&format!("Extra variables: {:?}", extra));
151        }
152        eprintln!("Warning: {}", warning.trim());
153    }
154
155    Ok(())
156}
157
158/// Get variables from a jinja2 template.
159fn get_jinja2_variables(template: &str) -> HashSet<String> {
160    let mut variables = HashSet::new();
161    let mut chars = template.chars().peekable();
162
163    while let Some(c) = chars.next() {
164        if c == '{' && chars.peek() == Some(&'{') {
165            chars.next(); // consume second '{'
166
167            // Skip whitespace
168            while chars.peek() == Some(&' ') {
169                chars.next();
170            }
171
172            // Collect variable name
173            let mut var_name = String::new();
174            while let Some(&c) = chars.peek() {
175                if c == '}' || c == ' ' || c == '|' || c == '.' {
176                    break;
177                }
178                var_name.push(c);
179                chars.next();
180            }
181
182            if !var_name.is_empty() && !var_name.starts_with('%') && !var_name.starts_with('#') {
183                variables.insert(var_name);
184            }
185        }
186    }
187
188    variables
189}
190
191/// Get the top-level variables from a mustache template.
192///
193/// For nested variables like `{{person.name}}`, only the top-level
194/// key (`person`) is returned.
195pub fn mustache_template_vars(template: &str) -> HashSet<String> {
196    let mut variables = HashSet::new();
197    let mut chars = template.chars().peekable();
198    let mut section_depth = 0;
199
200    while let Some(c) = chars.next() {
201        if c == '{' && chars.peek() == Some(&'{') {
202            chars.next(); // consume second '{'
203
204            // Check for special tags
205            let first_char = chars.peek().cloned();
206
207            match first_char {
208                Some('#') | Some('^') => {
209                    section_depth += 1;
210                    // Skip to end of tag
211                    while let Some(&c) = chars.peek() {
212                        if c == '}' {
213                            break;
214                        }
215                        chars.next();
216                    }
217                }
218                Some('/') => {
219                    section_depth -= 1;
220                    // Skip to end of tag
221                    while let Some(&c) = chars.peek() {
222                        if c == '}' {
223                            break;
224                        }
225                        chars.next();
226                    }
227                }
228                Some('!') | Some('>') => {
229                    // Comment or partial - skip
230                    while let Some(&c) = chars.peek() {
231                        if c == '}' {
232                            break;
233                        }
234                        chars.next();
235                    }
236                }
237                Some('{') => {
238                    // Triple mustache (no escape)
239                    chars.next();
240                    let mut var_name = String::new();
241                    while let Some(&c) = chars.peek() {
242                        if c == '}' {
243                            break;
244                        }
245                        var_name.push(c);
246                        chars.next();
247                    }
248                    let var_name = var_name.trim();
249                    if !var_name.is_empty() && var_name != "." && section_depth == 0 {
250                        let top_level = var_name.split('.').next().unwrap_or(var_name);
251                        variables.insert(top_level.to_string());
252                    }
253                }
254                Some('&') => {
255                    // Unescaped variable
256                    chars.next();
257                    let mut var_name = String::new();
258                    while let Some(&c) = chars.peek() {
259                        if c == '}' {
260                            break;
261                        }
262                        var_name.push(c);
263                        chars.next();
264                    }
265                    let var_name = var_name.trim();
266                    if !var_name.is_empty() && var_name != "." && section_depth == 0 {
267                        let top_level = var_name.split('.').next().unwrap_or(var_name);
268                        variables.insert(top_level.to_string());
269                    }
270                }
271                _ => {
272                    // Regular variable
273                    let mut var_name = String::new();
274                    while let Some(&c) = chars.peek() {
275                        if c == '}' {
276                            break;
277                        }
278                        var_name.push(c);
279                        chars.next();
280                    }
281                    let var_name = var_name.trim();
282                    if !var_name.is_empty() && var_name != "." && section_depth == 0 {
283                        let top_level = var_name.split('.').next().unwrap_or(var_name);
284                        variables.insert(top_level.to_string());
285                    }
286                }
287            }
288        }
289    }
290
291    variables
292}
293
294/// Check that template string is valid.
295///
296/// # Arguments
297///
298/// * `template` - The template string.
299/// * `template_format` - The template format.
300/// * `input_variables` - The input variables.
301///
302/// # Returns
303///
304/// Ok(()) if valid, or an error if invalid.
305pub fn check_valid_template(
306    template: &str,
307    template_format: PromptTemplateFormat,
308    input_variables: &[String],
309) -> Result<()> {
310    match template_format {
311        PromptTemplateFormat::FString => FORMATTER
312            .validate_input_variables(template, input_variables)
313            .map_err(|e| match e {
314                FormattingError::MissingKey(key) => Error::InvalidConfig(format!(
315                    "Invalid prompt schema; missing input parameter: {}",
316                    key
317                )),
318                FormattingError::InvalidFormat(msg) => {
319                    Error::InvalidConfig(format!("Invalid format string: {}", msg))
320                }
321            }),
322        PromptTemplateFormat::Jinja2 => validate_jinja2(template, input_variables),
323        PromptTemplateFormat::Mustache => {
324            // Mustache templates cannot be validated in the same way
325            Ok(())
326        }
327    }
328}
329
330/// Get the variables from the template.
331///
332/// # Arguments
333///
334/// * `template` - The template string.
335/// * `template_format` - The template format.
336///
337/// # Returns
338///
339/// A sorted list of variable names from the template.
340pub fn get_template_variables(
341    template: &str,
342    template_format: PromptTemplateFormat,
343) -> Result<Vec<String>> {
344    let variables: HashSet<String> = match template_format {
345        PromptTemplateFormat::FString => {
346            let placeholders = FORMATTER.extract_placeholders(template);
347            // Validate that variables don't contain dots, brackets, or are all digits
348            for var in &placeholders {
349                if var.contains('.') || var.contains('[') || var.contains(']') {
350                    return Err(Error::InvalidConfig(format!(
351                        "Invalid variable name '{}' in f-string template. \
352                         Variable names cannot contain attribute access (.) or indexing ([]).",
353                        var
354                    )));
355                }
356                if var.chars().all(|c| c.is_ascii_digit()) {
357                    return Err(Error::InvalidConfig(format!(
358                        "Invalid variable name '{}' in f-string template. \
359                         Variable names cannot be all digits as they are interpreted as positional arguments.",
360                        var
361                    )));
362                }
363            }
364            placeholders
365        }
366        PromptTemplateFormat::Jinja2 => get_jinja2_variables(template),
367        PromptTemplateFormat::Mustache => mustache_template_vars(template),
368    };
369
370    let mut vars: Vec<_> = variables.into_iter().collect();
371    vars.sort();
372    Ok(vars)
373}
374
375/// Format a template string with the given format and kwargs.
376pub fn format_template(
377    template: &str,
378    template_format: PromptTemplateFormat,
379    kwargs: &HashMap<String, String>,
380) -> Result<String> {
381    match template_format {
382        PromptTemplateFormat::FString => FORMATTER.format(template, kwargs).map_err(|e| match e {
383            FormattingError::MissingKey(key) => {
384                Error::InvalidConfig(format!("Missing key in format string: {}", key))
385            }
386            FormattingError::InvalidFormat(msg) => {
387                Error::InvalidConfig(format!("Invalid format string: {}", msg))
388            }
389        }),
390        PromptTemplateFormat::Mustache => mustache_formatter(template, kwargs),
391        PromptTemplateFormat::Jinja2 => jinja2_formatter(template, kwargs),
392    }
393}
394
395/// Trait for string prompt templates.
396///
397/// String prompt templates format to a string (as opposed to a list of messages).
398pub trait StringPromptTemplate: Send + Sync {
399    /// Get the input variables for this template.
400    fn input_variables(&self) -> &[String];
401
402    /// Get the optional variables for this template.
403    fn optional_variables(&self) -> &[String] {
404        &[]
405    }
406
407    /// Get partial variables for this template.
408    fn partial_variables(&self) -> &HashMap<String, String> {
409        static EMPTY: std::sync::LazyLock<HashMap<String, String>> =
410            std::sync::LazyLock::new(HashMap::new);
411        &EMPTY
412    }
413
414    /// Get the template format.
415    fn template_format(&self) -> PromptTemplateFormat {
416        PromptTemplateFormat::FString
417    }
418
419    /// Format the prompt with the inputs.
420    ///
421    /// # Arguments
422    ///
423    /// * `kwargs` - The keyword arguments to format the template with.
424    ///
425    /// # Returns
426    ///
427    /// A formatted string, or an error if formatting fails.
428    fn format(&self, kwargs: &HashMap<String, String>) -> Result<String>;
429
430    /// Async format the prompt with the inputs.
431    ///
432    /// Default implementation calls the sync version.
433    fn aformat(
434        &self,
435        kwargs: &HashMap<String, String>,
436    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<String>> + Send + '_>> {
437        let result = self.format(kwargs);
438        Box::pin(async move { result })
439    }
440
441    /// Get a pretty representation of the prompt.
442    fn pretty_repr(&self, html: bool) -> String;
443
444    /// Print a pretty representation of the prompt.
445    fn pretty_print(&self) {
446        println!("{}", self.pretty_repr(false));
447    }
448}
449
450/// Check if a value is a subsequence of another sequence.
451///
452/// This function checks if `child` is a prefix of `parent`.
453/// Part of the Python langchain_core API.
454#[allow(dead_code)]
455pub fn is_subsequence<T: PartialEq>(child: &[T], parent: &[T]) -> bool {
456    if child.is_empty() || parent.is_empty() {
457        return false;
458    }
459    if parent.len() < child.len() {
460        return false;
461    }
462    child.iter().zip(parent.iter()).all(|(c, p)| c == p)
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468
469    #[test]
470    fn test_template_format_from_str() {
471        use std::str::FromStr;
472        assert_eq!(
473            PromptTemplateFormat::from_str("f-string").unwrap(),
474            PromptTemplateFormat::FString
475        );
476        assert_eq!(
477            PromptTemplateFormat::from_str("mustache").unwrap(),
478            PromptTemplateFormat::Mustache
479        );
480        assert_eq!(
481            PromptTemplateFormat::from_str("jinja2").unwrap(),
482            PromptTemplateFormat::Jinja2
483        );
484    }
485
486    #[test]
487    fn test_get_template_variables_fstring() {
488        let vars = get_template_variables(
489            "Hello, {name}! You are {age} years old.",
490            PromptTemplateFormat::FString,
491        )
492        .unwrap();
493        assert!(vars.contains(&"name".to_string()));
494        assert!(vars.contains(&"age".to_string()));
495        assert_eq!(vars.len(), 2);
496    }
497
498    #[test]
499    fn test_get_template_variables_mustache() {
500        let vars = get_template_variables(
501            "Hello, {{name}}! You are {{age}} years old.",
502            PromptTemplateFormat::Mustache,
503        )
504        .unwrap();
505        assert!(vars.contains(&"name".to_string()));
506        assert!(vars.contains(&"age".to_string()));
507        assert_eq!(vars.len(), 2);
508    }
509
510    #[test]
511    fn test_format_template_fstring() {
512        let mut kwargs = HashMap::new();
513        kwargs.insert("name".to_string(), "World".to_string());
514
515        let result =
516            format_template("Hello, {name}!", PromptTemplateFormat::FString, &kwargs).unwrap();
517        assert_eq!(result, "Hello, World!");
518    }
519
520    #[test]
521    fn test_format_template_mustache() {
522        let mut kwargs = HashMap::new();
523        kwargs.insert("name".to_string(), "World".to_string());
524
525        let result =
526            format_template("Hello, {{name}}!", PromptTemplateFormat::Mustache, &kwargs).unwrap();
527        assert_eq!(result, "Hello, World!");
528    }
529
530    #[test]
531    fn test_invalid_fstring_variable() {
532        let result = get_template_variables("Hello {obj.attr}", PromptTemplateFormat::FString);
533        assert!(result.is_err());
534    }
535
536    #[test]
537    fn test_is_subsequence() {
538        assert!(is_subsequence(&[1, 2], &[1, 2, 3]));
539        assert!(!is_subsequence(&[1, 3], &[1, 2, 3]));
540        assert!(!is_subsequence(&[1, 2, 3, 4], &[1, 2, 3]));
541    }
542}