mod template;
pub use template::{PromptTemplate, TemplateError, TemplateResult};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Default)]
pub struct PromptLibrary {
templates: HashMap<String, PromptTemplate>,
categories: HashMap<String, Vec<String>>,
}
impl PromptLibrary {
pub fn new() -> Self {
Self::default()
}
pub fn add(&mut self, name: impl Into<String>, template: PromptTemplate) {
let name = name.into();
self.templates.insert(name, template);
}
pub fn add_with_category(
&mut self,
name: impl Into<String>,
template: PromptTemplate,
category: impl Into<String>,
) {
let name = name.into();
let category = category.into();
self.templates.insert(name.clone(), template);
self.categories.entry(category).or_default().push(name);
}
pub fn get(&self, name: &str) -> Option<&PromptTemplate> {
self.templates.get(name)
}
pub fn remove(&mut self, name: &str) -> Option<PromptTemplate> {
for templates in self.categories.values_mut() {
templates.retain(|n| n != name);
}
self.templates.remove(name)
}
pub fn list(&self) -> Vec<&str> {
self.templates.keys().map(|s| s.as_str()).collect()
}
pub fn list_by_category(&self, category: &str) -> Vec<&str> {
self.categories
.get(category)
.map(|names| names.iter().map(|s| s.as_str()).collect())
.unwrap_or_default()
}
pub fn categories(&self) -> Vec<&str> {
self.categories.keys().map(|s| s.as_str()).collect()
}
pub fn has(&self, name: &str) -> bool {
self.templates.contains_key(name)
}
pub fn len(&self) -> usize {
self.templates.len()
}
pub fn is_empty(&self) -> bool {
self.templates.is_empty()
}
pub fn render(
&self,
name: &str,
variables: &HashMap<String, String>,
) -> TemplateResult<String> {
let template = self
.get(name)
.ok_or_else(|| TemplateError::NotFound(name.to_string()))?;
template.render(variables)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptConfig {
pub name: String,
#[serde(default)]
pub description: String,
pub template: String,
#[serde(default)]
pub defaults: HashMap<String, String>,
#[serde(default)]
pub category: Option<String>,
#[serde(default)]
pub tags: Vec<String>,
}
impl PromptConfig {
pub fn new(name: impl Into<String>, template: impl Into<String>) -> Self {
Self {
name: name.into(),
description: String::new(),
template: template.into(),
defaults: HashMap::new(),
category: None,
tags: Vec::new(),
}
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = description.into();
self
}
pub fn with_default(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.defaults.insert(key.into(), value.into());
self
}
pub fn with_category(mut self, category: impl Into<String>) -> Self {
self.category = Some(category.into());
self
}
pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
self.tags.push(tag.into());
self
}
pub fn into_template(self) -> PromptTemplate {
let mut template = PromptTemplate::new(&self.template);
for (key, value) in self.defaults {
template = template.with_default(key, value);
}
template
}
}
#[derive(Debug, Default)]
pub struct PromptBuilder {
sections: Vec<(String, String)>,
variables: HashMap<String, String>,
}
impl PromptBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn section(mut self, name: impl Into<String>, content: impl Into<String>) -> Self {
self.sections.push((name.into(), content.into()));
self
}
pub fn var(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.variables.insert(key.into(), value.into());
self
}
pub fn vars(
mut self,
vars: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
) -> Self {
for (k, v) in vars {
self.variables.insert(k.into(), v.into());
}
self
}
pub fn build(self) -> String {
let mut result = String::new();
for (name, content) in self.sections {
if !result.is_empty() {
result.push_str("\n\n");
}
let mut processed = content;
for (key, value) in &self.variables {
let placeholder = format!("{{{{{}}}}}", key);
processed = processed.replace(&placeholder, value);
let simple_placeholder = format!("{{{}}}", key);
processed = processed.replace(&simple_placeholder, value);
}
if !name.is_empty() {
result.push_str(&format!("## {}\n\n", name));
}
result.push_str(&processed);
}
result
}
}
pub struct CommonPrompts;
impl CommonPrompts {
pub fn summarize() -> PromptTemplate {
PromptTemplate::new(
"Summarize the following text in {{style}} style:\n\n{{text}}\n\nSummary:",
)
.with_default("style", "concise")
}
pub fn translate() -> PromptTemplate {
PromptTemplate::new(
"Translate the following text to {{target_language}}:\n\n{{text}}\n\nTranslation:",
)
}
pub fn qa() -> PromptTemplate {
PromptTemplate::new(
"Context:\n{{context}}\n\nQuestion: {{question}}\n\nAnswer based only on the context provided:"
)
}
pub fn code_review() -> PromptTemplate {
PromptTemplate::new(
"Review the following {{language}} code for issues, improvements, and best practices:\n\n```{{language}}\n{{code}}\n```\n\nProvide a detailed review:"
)
.with_default("language", "")
}
pub fn classify() -> PromptTemplate {
PromptTemplate::new(
"Classify the following text into one of these categories: {{categories}}\n\nText: {{text}}\n\nCategory:"
)
}
pub fn extract_entities() -> PromptTemplate {
PromptTemplate::new(
"Extract all {{entity_type}} entities from the following text:\n\n{{text}}\n\nEntities (as JSON array):"
)
.with_default("entity_type", "named")
}
pub fn rewrite() -> PromptTemplate {
PromptTemplate::new(
"Rewrite the following text in {{tone}} tone:\n\n{{text}}\n\nRewritten:",
)
.with_default("tone", "professional")
}
pub fn chain_of_thought() -> PromptTemplate {
PromptTemplate::new("{{question}}\n\nLet's think step by step:")
}
pub fn library() -> PromptLibrary {
let mut lib = PromptLibrary::new();
lib.add_with_category("summarize", Self::summarize(), "text");
lib.add_with_category("translate", Self::translate(), "text");
lib.add_with_category("qa", Self::qa(), "qa");
lib.add_with_category("code_review", Self::code_review(), "code");
lib.add_with_category("classify", Self::classify(), "classification");
lib.add_with_category("extract_entities", Self::extract_entities(), "extraction");
lib.add_with_category("rewrite", Self::rewrite(), "text");
lib.add_with_category("chain_of_thought", Self::chain_of_thought(), "reasoning");
lib
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prompt_library() {
let mut lib = PromptLibrary::new();
lib.add("test", PromptTemplate::new("Hello {{name}}"));
assert!(lib.has("test"));
assert!(!lib.has("nonexistent"));
assert_eq!(lib.len(), 1);
}
#[test]
fn test_prompt_library_categories() {
let mut lib = PromptLibrary::new();
lib.add_with_category("t1", PromptTemplate::new("Template 1"), "cat1");
lib.add_with_category("t2", PromptTemplate::new("Template 2"), "cat1");
lib.add_with_category("t3", PromptTemplate::new("Template 3"), "cat2");
assert_eq!(lib.list_by_category("cat1").len(), 2);
assert_eq!(lib.list_by_category("cat2").len(), 1);
assert_eq!(lib.categories().len(), 2);
}
#[test]
fn test_prompt_library_render() {
let mut lib = PromptLibrary::new();
lib.add("greet", PromptTemplate::new("Hello, {{name}}!"));
let mut vars = HashMap::new();
vars.insert("name".to_string(), "World".to_string());
let result = lib.render("greet", &vars).unwrap();
assert_eq!(result, "Hello, World!");
}
#[test]
fn test_prompt_config() {
let config = PromptConfig::new("test", "Hello {{name}}")
.with_description("A test template")
.with_default("name", "World")
.with_category("greetings")
.with_tag("simple");
assert_eq!(config.name, "test");
assert_eq!(config.description, "A test template");
assert_eq!(config.defaults.get("name"), Some(&"World".to_string()));
assert_eq!(config.category, Some("greetings".to_string()));
assert_eq!(config.tags, vec!["simple"]);
}
#[test]
fn test_prompt_config_to_template() {
let config = PromptConfig::new("test", "Hello {{name}}").with_default("name", "World");
let template = config.into_template();
let result = template.render(&HashMap::new()).unwrap();
assert_eq!(result, "Hello World");
}
#[test]
fn test_prompt_builder() {
let prompt = PromptBuilder::new()
.section("Context", "You are a helpful assistant.")
.section("Task", "Answer the question: {{question}}")
.var("question", "What is 2+2?")
.build();
assert!(prompt.contains("## Context"));
assert!(prompt.contains("You are a helpful assistant."));
assert!(prompt.contains("## Task"));
assert!(prompt.contains("What is 2+2?"));
}
#[test]
fn test_prompt_builder_no_header() {
let prompt = PromptBuilder::new()
.section("", "Just some content")
.build();
assert!(!prompt.contains("##"));
assert!(prompt.contains("Just some content"));
}
#[test]
fn test_common_prompts() {
let summarize = CommonPrompts::summarize();
let mut vars = HashMap::new();
vars.insert("text".to_string(), "Long text here".to_string());
let result = summarize.render(&vars).unwrap();
assert!(result.contains("Long text here"));
assert!(result.contains("concise"));
let translate = CommonPrompts::translate();
vars.insert("target_language".to_string(), "Spanish".to_string());
let result = translate.render(&vars).unwrap();
assert!(result.contains("Spanish"));
}
#[test]
fn test_common_prompts_library() {
let lib = CommonPrompts::library();
assert!(lib.has("summarize"));
assert!(lib.has("translate"));
assert!(lib.has("qa"));
assert!(lib.has("code_review"));
assert!(lib.len() >= 8);
}
#[test]
fn test_library_remove() {
let mut lib = PromptLibrary::new();
lib.add_with_category("test", PromptTemplate::new("Hello"), "cat1");
assert!(lib.has("test"));
assert_eq!(lib.list_by_category("cat1").len(), 1);
lib.remove("test");
assert!(!lib.has("test"));
assert_eq!(lib.list_by_category("cat1").len(), 0);
}
}