use std::collections::HashMap;
use std::path::Path;
use minijinja::Environment;
use serde::Serialize;
use crate::loader::load_templates_from_dir;
use crate::{PromptError, PromptTemplate};
#[derive(Clone)]
pub struct PromptManager {
env: Environment<'static>,
templates: HashMap<String, PromptTemplate>,
}
impl PromptManager {
pub fn new() -> Self {
Self {
env: Environment::new(),
templates: HashMap::new(),
}
}
pub fn with_builtin_templates() -> Result<Self, PromptError> {
let mut pm = Self::new();
for template in crate::builtin::builtin_templates() {
pm.add_template(template)?;
}
Ok(pm)
}
pub fn add_template(&mut self, template: PromptTemplate) -> Result<(), PromptError> {
self.env
.add_template_owned(template.name.clone(), template.content.clone())
.map_err(|e: minijinja::Error| PromptError::InvalidTemplate(e.to_string()))?;
self.templates.insert(template.name.clone(), template);
Ok(())
}
pub fn load_from_dir(&mut self, dir: &Path) -> Result<(), PromptError> {
let templates = load_templates_from_dir(dir)?;
for template in templates {
self.add_template(template)?;
}
Ok(())
}
pub fn render<T: Serialize>(&self, name: &str, ctx: T) -> Result<String, PromptError> {
let tmpl = self
.env
.get_template(name)
.map_err(|_| PromptError::TemplateNotFound(name.to_string()))?;
tmpl.render(ctx)
.map_err(|e| PromptError::RenderError(e.to_string()))
}
pub fn get_template(&self, name: &str) -> Option<&PromptTemplate> {
self.templates.get(name)
}
pub fn template_count(&self) -> usize {
self.templates.len()
}
}
impl Default for PromptManager {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for PromptManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PromptManager")
.field("template_count", &self.templates.len())
.field("template_names", &self.templates.keys().collect::<Vec<_>>())
.finish()
}
}