use anyhow::{Context as _, Result};
use lazy_static::lazy_static;
use tera::Tera;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::node::Node;
lazy_static! {
static ref TEMPLATE_REPOSITORY: RwLock<Tera> = {
let prefix = env!("CARGO_MANIFEST_DIR");
let path = format!("{prefix}/src/transformers/prompts/**/*.prompt.md");
match Tera::new(&path)
{
Ok(t) => RwLock::new(t),
Err(e) => {
tracing::error!("Parsing error(s): {e}");
::std::process::exit(1);
}
}
};
}
#[derive(Clone, Debug)]
pub struct Prompt {
template: PromptTemplate,
context: Option<tera::Context>,
}
#[derive(Clone, Debug)]
pub enum PromptTemplate {
CompiledTemplate(String),
String(String),
Static(&'static str),
}
impl PromptTemplate {
pub fn from_compiled_template_name(name: impl Into<String>) -> PromptTemplate {
PromptTemplate::CompiledTemplate(name.into())
}
pub fn from_string(template: impl Into<String>) -> PromptTemplate {
PromptTemplate::String(template.into())
}
pub async fn extend(tera: &Tera) -> Result<()> {
TEMPLATE_REPOSITORY
.write()
.await
.extend(tera)
.context("Could not extend prompt repository with custom Tera instance")
}
pub async fn try_compiled_from_str(
template: impl AsRef<str> + Send + 'static,
) -> Result<PromptTemplate> {
let id = Uuid::new_v4().to_string();
let mut lock = TEMPLATE_REPOSITORY.write().await;
lock.add_raw_template(&id, template.as_ref())
.context("Failed to add raw template")?;
Ok(PromptTemplate::CompiledTemplate(id))
}
pub async fn render(&self, context: &Option<tera::Context>) -> Result<String> {
use PromptTemplate::{CompiledTemplate, Static, String};
let template = match self {
CompiledTemplate(id) => {
let context = match &context {
Some(context) => context,
None => &tera::Context::default(),
};
let lock = TEMPLATE_REPOSITORY.read().await;
let available = lock.get_template_names().collect::<Vec<_>>().join(", ");
tracing::debug!(id, available, "Rendering template ...");
let result = lock.render(id, context);
if result.is_err() {
tracing::error!(
error = result.as_ref().unwrap_err().to_string(),
available,
"Error rendering template {id}"
);
}
result.with_context(|| format!("Failed to render template '{id}'"))?
}
String(template) => {
if let Some(context) = context {
Tera::one_off(template, context, false)
.context("Failed to render one-off template")?
} else {
template.to_string()
}
}
Static(template) => {
if let Some(context) = context {
Tera::one_off(template, context, false)
.context("Failed to render one-off template")?
} else {
(*template).to_string()
}
}
};
Ok(template)
}
pub fn to_prompt(&self) -> Prompt {
Prompt {
template: self.clone(),
context: Some(tera::Context::default()),
}
}
}
impl From<&'static str> for PromptTemplate {
fn from(template: &'static str) -> Self {
PromptTemplate::Static(template)
}
}
impl From<String> for PromptTemplate {
fn from(template: String) -> Self {
PromptTemplate::String(template)
}
}
impl Prompt {
#[must_use]
pub fn with_node(mut self, node: &Node) -> Self {
let context = self.context.get_or_insert_with(tera::Context::default);
context.insert("node", &node);
self
}
#[must_use]
pub fn with_context(mut self, new_context: impl Into<tera::Context>) -> Self {
let context = self.context.get_or_insert_with(tera::Context::default);
context.extend(new_context.into());
self
}
#[must_use]
pub fn with_context_value(mut self, key: &str, value: impl Into<tera::Value>) -> Self {
let context = self.context.get_or_insert_with(tera::Context::default);
context.insert(key, &value.into());
self
}
pub async fn render(&self) -> Result<String> {
self.template.render(&self.context).await
}
}
impl From<&'static str> for Prompt {
fn from(prompt: &'static str) -> Self {
Prompt {
template: PromptTemplate::Static(prompt),
context: None,
}
}
}
impl From<String> for Prompt {
fn from(prompt: String) -> Self {
Prompt {
template: PromptTemplate::String(prompt),
context: None,
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[tokio::test]
async fn test_prompt() {
let template = PromptTemplate::try_compiled_from_str("hello {{world}}")
.await
.unwrap();
let prompt = template.to_prompt().with_context_value("world", "swiftide");
assert_eq!(prompt.render().await.unwrap(), "hello swiftide");
}
#[tokio::test]
async fn test_prompt_with_node() {
let template = PromptTemplate::try_compiled_from_str("hello {{node.chunk}}")
.await
.unwrap();
let node = Node::new("test");
let prompt = template.to_prompt().with_node(&node);
assert_eq!(prompt.render().await.unwrap(), "hello test");
}
#[tokio::test]
async fn test_one_off_from_string() {
let mut prompt: Prompt = "hello {{world}}".into();
prompt = prompt.with_context_value("world", "swiftide");
assert_eq!(prompt.render().await.unwrap(), "hello swiftide");
}
#[tokio::test]
async fn test_extending_with_custom_repository() {
let mut custom_tera = Tera::new("**/some/prompts.md").unwrap();
custom_tera
.add_raw_template("hello", "hello {{world}}")
.unwrap();
PromptTemplate::extend(&custom_tera).await.unwrap();
let prompt = PromptTemplate::from_compiled_template_name("hello")
.to_prompt()
.with_context_value("world", "swiftide");
assert_eq!(prompt.render().await.unwrap(), "hello swiftide");
}
#[tokio::test]
async fn test_coercion_to_prompt() {
let raw: &str = "hello {{world}}";
let prompt: Prompt = raw.into();
assert_eq!(
prompt
.with_context_value("world", "swiftide")
.render()
.await
.unwrap(),
"hello swiftide"
);
let prompt: Prompt = raw.to_string().into();
assert_eq!(
prompt
.with_context_value("world", "swiftide")
.render()
.await
.unwrap(),
"hello swiftide"
);
}
#[tokio::test]
async fn test_coercion_to_template() {
let raw: &str = "hello {{world}}";
let prompt: PromptTemplate = raw.into();
assert_eq!(
prompt
.to_prompt()
.with_context_value("world", "swiftide")
.render()
.await
.unwrap(),
"hello swiftide"
);
let prompt: PromptTemplate = raw.to_string().into();
assert_eq!(
prompt
.to_prompt()
.with_context_value("world", "swiftide")
.render()
.await
.unwrap(),
"hello swiftide"
);
}
}