use std::collections::HashMap;
use async_trait::async_trait;
use serde_json::Value;
use cognis_core::error::{CognisError, Result};
use cognis_core::runnables::base::Runnable;
use cognis_core::runnables::config::RunnableConfig;
pub struct PromptTemplate {
pub template: String,
pub variables: Vec<String>,
pub partial_variables: HashMap<String, String>,
}
impl PromptTemplate {
pub fn new(template: impl Into<String>) -> Self {
let template = template.into();
let variables = extract_variables(&template);
Self {
template,
variables,
partial_variables: HashMap::new(),
}
}
pub fn with_partial(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.partial_variables.insert(key.into(), value.into());
self
}
pub fn input_variables(&self) -> Vec<&str> {
self.variables
.iter()
.filter(|v| !self.partial_variables.contains_key(v.as_str()))
.map(String::as_str)
.collect()
}
pub fn format(&self, variables: &HashMap<String, String>) -> Result<String> {
let mut merged = self.partial_variables.clone();
merged.extend(variables.iter().map(|(k, v)| (k.clone(), v.clone())));
format_template_str(&self.template, &merged)
}
}
#[async_trait]
impl Runnable for PromptTemplate {
fn name(&self) -> &str {
"PromptTemplate"
}
async fn invoke(&self, input: Value, _config: Option<&RunnableConfig>) -> Result<Value> {
let kwargs: HashMap<String, String> = match input {
Value::Object(map) => map
.into_iter()
.map(|(k, v)| {
let s = match v {
Value::String(s) => s,
other => other.to_string(),
};
(k, s)
})
.collect(),
_ => {
return Err(CognisError::TypeMismatch {
expected: "Object".into(),
got: "non-Object".into(),
});
}
};
let text = self.format(&kwargs)?;
Ok(Value::String(text))
}
}
fn extract_variables(template: &str) -> Vec<String> {
let mut vars = Vec::new();
let mut chars = template.chars().peekable();
while let Some(ch) = chars.next() {
if ch == '{' {
if chars.peek() == Some(&'{') {
chars.next();
continue;
}
let mut name = String::new();
for inner in chars.by_ref() {
if inner == '}' {
break;
}
name.push(inner);
}
if !name.is_empty() && !vars.contains(&name) {
vars.push(name);
}
} else if ch == '}' && chars.peek() == Some(&'}') {
chars.next();
}
}
vars
}
fn format_template_str(template: &str, variables: &HashMap<String, String>) -> Result<String> {
let mut result = String::with_capacity(template.len());
let mut chars = template.chars().peekable();
while let Some(ch) = chars.next() {
if ch == '{' {
if chars.peek() == Some(&'{') {
chars.next();
result.push('{');
continue;
}
let mut name = String::new();
for inner in chars.by_ref() {
if inner == '}' {
break;
}
name.push(inner);
}
let value = variables.get(&name).ok_or_else(|| {
CognisError::Other(format!(
"Missing variable '{}'. Available: {:?}",
name,
variables.keys().collect::<Vec<_>>()
))
})?;
result.push_str(value);
} else if ch == '}' {
if chars.peek() == Some(&'}') {
chars.next();
result.push('}');
} else {
result.push('}');
}
} else {
result.push(ch);
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_auto_extract_variables() {
let t = PromptTemplate::new("Hello {name}, you are {age} years old.");
assert_eq!(t.variables, vec!["name", "age"]);
}
#[test]
fn test_format_with_all_variables() {
let t = PromptTemplate::new("Hello {name}!");
let mut vars = HashMap::new();
vars.insert("name".into(), "World".into());
assert_eq!(t.format(&vars).unwrap(), "Hello World!");
}
#[test]
fn test_partial_variables() {
let t =
PromptTemplate::new("Hello {name}, welcome to {place}!").with_partial("place", "Rust");
assert_eq!(t.input_variables(), vec!["name"]);
let mut vars = HashMap::new();
vars.insert("name".into(), "Alice".into());
assert_eq!(t.format(&vars).unwrap(), "Hello Alice, welcome to Rust!");
}
#[test]
fn test_missing_variable_error() {
let t = PromptTemplate::new("Hello {name}!");
let vars = HashMap::new();
let err = t.format(&vars).unwrap_err();
let msg = format!("{}", err);
assert!(msg.contains("Missing variable 'name'"));
}
#[tokio::test]
async fn test_runnable_invoke() {
let t = PromptTemplate::new("Hello {name}!");
let result = t
.invoke(serde_json::json!({"name": "World"}), None)
.await
.unwrap();
assert_eq!(result, Value::String("Hello World!".into()));
}
#[tokio::test]
async fn test_runnable_invoke_non_string_values() {
let t = PromptTemplate::new("Count: {n}");
let result = t.invoke(serde_json::json!({"n": 42}), None).await.unwrap();
assert_eq!(result, Value::String("Count: 42".into()));
}
}