use crate::error::{Error, Result};
use derive_builder::Builder;
use std::{collections::HashMap, path::PathBuf};
use tera::{Context, Tera};
pub struct PromptEnv {
template: Tera,
}
impl Default for PromptEnv {
fn default() -> Self {
PromptEnv::new(PathBuf::from("prompts"))
}
}
impl PromptEnv {
pub fn new(path: PathBuf) -> Self {
PromptEnv {
template: Tera::new(&format!("{}/**/*", path.display()))
.expect("failed to load templates to tera"),
}
}
pub fn generate_prompt(
&self,
path: &str,
variables: HashMap<String, String>,
) -> Result<String> {
let prompt = PromptBuilder::default()
.path(path.to_string())
.variables(variables)
.build()
.map_err(|e| Error::PromptError(format!("error building prompt: {e:?}")))?;
let ctx = Context::from_serialize(prompt.variables.clone())
.map_err(|e| Error::PromptError(format!("error creating variable context: {e:?}")))?;
self.template
.render(&prompt.path, &ctx)
.map_err(|e| Error::PromptError(format!("error rendering prompt: {e:?}")))
}
}
#[derive(Builder)]
#[builder(pattern = "owned", setter(into), build_fn(error = "Error"))]
#[allow(missing_docs)]
pub struct Prompt {
path: String,
variables: HashMap<String, String>,
}
impl Prompt {
pub fn builder() -> PromptBuilder {
PromptBuilder::default()
}
}
impl PromptBuilder {
pub fn with_variable(mut self, key: String, value: String) -> Self {
if let Some(vars) = self.variables.as_mut() {
vars.insert(key, value);
} else {
self.variables = Some(HashMap::from([(key, value)]));
}
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prompt_env_creation_with_examples_path() {
let path = PathBuf::from("examples/prompts");
let env = PromptEnv::new(path);
assert!(
!env.template
.get_template_names()
.collect::<Vec<_>>()
.is_empty(),
"PromptEnv should have loaded templates from examples/prompts"
);
let path = PathBuf::from("examples/prompts/system");
let env = PromptEnv::new(path);
let template_names: Vec<&str> = env.template.get_template_names().collect();
assert!(
!template_names.is_empty(),
"PromptEnv should have loaded templates from examples/prompts/system"
);
let path = PathBuf::from("examples/prompts/user");
let env = PromptEnv::new(path);
let template_names: Vec<&str> = env.template.get_template_names().collect();
assert!(
!template_names.is_empty(),
"PromptEnv should have loaded templates from examples/prompts/user"
);
}
}