use std::collections::HashMap;
use async_trait::async_trait;
use tera::{Context, Tera};
#[cfg(feature = "tracing")]
use tracing::instrument;
use crate::error::AnchorChainError;
use crate::node::Node;
#[derive(Debug)]
pub struct Prompt {
tera: Tera,
}
impl Prompt {
pub fn new(template: &str) -> Self {
let mut tera = Tera::default();
tera.add_raw_template("prompt", template)
.expect("Error creating template");
Prompt { tera }
}
}
#[async_trait]
impl Node for Prompt {
type Input = HashMap<String, String>;
type Output = String;
#[cfg_attr(feature = "tracing", instrument)]
async fn process(&self, input: Self::Input) -> Result<Self::Output, AnchorChainError> {
let context = Context::from_serialize(input)?;
Ok(self.tera.render("prompt", &context)?.to_string())
}
}