use minijinja::Environment;
use serde::Serialize;
#[derive(Debug, Clone)]
pub enum PromptPart {
Text(String),
Image {
media_type: String,
data: Vec<u8>,
},
}
pub trait ToPrompt {
fn to_prompt_parts_with_mode(&self, mode: &str) -> Vec<PromptPart> {
let _ = mode; self.to_prompt_parts()
}
fn to_prompt_with_mode(&self, mode: &str) -> String {
self.to_prompt_parts_with_mode(mode)
.iter()
.filter_map(|part| match part {
PromptPart::Text(text) => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("")
}
fn to_prompt_parts(&self) -> Vec<PromptPart> {
self.to_prompt_parts_with_mode("full")
}
fn to_prompt(&self) -> String {
self.to_prompt_with_mode("full")
}
fn prompt_schema() -> String {
String::new() }
}
impl ToPrompt for String {
fn to_prompt_parts(&self) -> Vec<PromptPart> {
vec![PromptPart::Text(self.clone())]
}
fn to_prompt(&self) -> String {
self.clone()
}
}
impl ToPrompt for &str {
fn to_prompt_parts(&self) -> Vec<PromptPart> {
vec![PromptPart::Text(self.to_string())]
}
fn to_prompt(&self) -> String {
self.to_string()
}
}
impl ToPrompt for bool {
fn to_prompt_parts(&self) -> Vec<PromptPart> {
vec![PromptPart::Text(self.to_string())]
}
fn to_prompt(&self) -> String {
self.to_string()
}
}
impl ToPrompt for char {
fn to_prompt_parts(&self) -> Vec<PromptPart> {
vec![PromptPart::Text(self.to_string())]
}
fn to_prompt(&self) -> String {
self.to_string()
}
}
macro_rules! impl_to_prompt_for_numbers {
($($t:ty),*) => {
$(
impl ToPrompt for $t {
fn to_prompt_parts(&self) -> Vec<PromptPart> {
vec![PromptPart::Text(self.to_string())]
}
fn to_prompt(&self) -> String {
self.to_string()
}
}
)*
};
}
impl_to_prompt_for_numbers!(
i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64
);
impl<T: ToPrompt> ToPrompt for Vec<T> {
fn to_prompt_parts(&self) -> Vec<PromptPart> {
vec![PromptPart::Text(self.to_prompt())]
}
fn to_prompt(&self) -> String {
format!(
"[{}]",
self.iter()
.map(|item| item.to_prompt())
.collect::<Vec<_>>()
.join(", ")
)
}
}
impl<T: ToPrompt> ToPrompt for Option<T> {
fn to_prompt_parts(&self) -> Vec<PromptPart> {
vec![PromptPart::Text(self.to_prompt())]
}
fn to_prompt(&self) -> String {
match self {
Some(value) => value.to_prompt(),
None => String::new(),
}
}
}
pub fn render_prompt<T: Serialize>(template: &str, context: T) -> Result<String, minijinja::Error> {
let mut env = Environment::new();
env.add_template("prompt", template)?;
let tmpl = env.get_template("prompt")?;
tmpl.render(context)
}
#[macro_export]
macro_rules! prompt {
($template:expr, $($key:ident = $value:expr),* $(,)?) => {
$crate::prompt::render_prompt($template, minijinja::context!($($key => $value),*))
};
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Serialize;
use std::fmt::Display;
enum TestEnum {
VariantA,
VariantB,
}
impl Display for TestEnum {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TestEnum::VariantA => write!(f, "Variant A"),
TestEnum::VariantB => write!(f, "Variant B"),
}
}
}
impl ToPrompt for TestEnum {
fn to_prompt_parts(&self) -> Vec<PromptPart> {
vec![PromptPart::Text(self.to_string())]
}
fn to_prompt(&self) -> String {
self.to_string()
}
}
#[test]
fn test_to_prompt_for_enum() {
let variant = TestEnum::VariantA;
assert_eq!(variant.to_prompt(), "Variant A");
}
#[test]
fn test_to_prompt_for_enum_variant_b() {
let variant = TestEnum::VariantB;
assert_eq!(variant.to_prompt(), "Variant B");
}
#[test]
fn test_to_prompt_for_string() {
let s = "hello world";
assert_eq!(s.to_prompt(), "hello world");
}
#[test]
fn test_to_prompt_for_number() {
let n = 42;
assert_eq!(n.to_prompt(), "42");
}
#[test]
fn test_to_prompt_for_option_some() {
let opt: Option<String> = Some("hello".to_string());
assert_eq!(opt.to_prompt(), "hello");
}
#[test]
fn test_to_prompt_for_option_none() {
let opt: Option<String> = None;
assert_eq!(opt.to_prompt(), "");
}
#[test]
fn test_to_prompt_for_option_number() {
let opt_some: Option<i32> = Some(42);
assert_eq!(opt_some.to_prompt(), "42");
let opt_none: Option<i32> = None;
assert_eq!(opt_none.to_prompt(), "");
}
#[test]
fn test_to_prompt_parts_for_option() {
let opt: Option<String> = Some("test".to_string());
let parts = opt.to_prompt_parts();
assert_eq!(parts.len(), 1);
match &parts[0] {
PromptPart::Text(text) => assert_eq!(text, "test"),
_ => panic!("Expected PromptPart::Text"),
}
}
#[derive(Serialize)]
struct SystemInfo {
version: &'static str,
os: &'static str,
}
#[test]
fn test_prompt_macro_simple() {
let user = "Yui";
let task = "implementation";
let prompt = prompt!(
"User {{user}} is working on the {{task}}.",
user = user,
task = task
)
.unwrap();
assert_eq!(prompt, "User Yui is working on the implementation.");
}
#[test]
fn test_prompt_macro_with_struct() {
let sys = SystemInfo {
version: "0.1.0",
os: "Rust",
};
let prompt = prompt!("System: {{sys.version}} on {{sys.os}}", sys = sys).unwrap();
assert_eq!(prompt, "System: 0.1.0 on Rust");
}
#[test]
fn test_prompt_macro_mixed() {
let user = "Mai";
let sys = SystemInfo {
version: "0.1.0",
os: "Rust",
};
let prompt = prompt!(
"User {{user}} is using {{sys.os}} v{{sys.version}}.",
user = user,
sys = sys
)
.unwrap();
assert_eq!(prompt, "User Mai is using Rust v0.1.0.");
}
#[test]
fn test_to_prompt_for_vec_of_strings() {
let items = vec!["apple", "banana", "cherry"];
assert_eq!(items.to_prompt(), "[apple, banana, cherry]");
}
#[test]
fn test_to_prompt_for_vec_of_numbers() {
let numbers = vec![1, 2, 3, 42];
assert_eq!(numbers.to_prompt(), "[1, 2, 3, 42]");
}
#[test]
fn test_to_prompt_for_empty_vec() {
let empty: Vec<String> = vec![];
assert_eq!(empty.to_prompt(), "[]");
}
#[test]
fn test_to_prompt_for_nested_vec() {
let nested = vec![vec![1, 2], vec![3, 4]];
assert_eq!(nested.to_prompt(), "[[1, 2], [3, 4]]");
}
#[test]
fn test_to_prompt_parts_for_vec() {
let items = vec!["a", "b", "c"];
let parts = items.to_prompt_parts();
assert_eq!(parts.len(), 1);
match &parts[0] {
PromptPart::Text(text) => assert_eq!(text, "[a, b, c]"),
_ => panic!("Expected Text variant"),
}
}
#[test]
fn test_to_prompt_for_option_vec() {
let opt_vec_some: Option<Vec<String>> = Some(vec!["a".to_string(), "b".to_string()]);
assert_eq!(opt_vec_some.to_prompt(), "[a, b]");
let opt_vec_none: Option<Vec<String>> = None;
assert_eq!(opt_vec_none.to_prompt(), "");
}
#[test]
fn test_to_prompt_for_vec_option() {
let vec_opts = vec![Some("hello".to_string()), None, Some("world".to_string())];
assert_eq!(vec_opts.to_prompt(), "[hello, , world]");
}
#[test]
fn test_to_prompt_for_option_none_with_parts() {
let opt: Option<String> = None;
let parts = opt.to_prompt_parts();
assert_eq!(parts.len(), 1);
match &parts[0] {
PromptPart::Text(text) => assert_eq!(text, ""),
_ => panic!("Expected PromptPart::Text"),
}
}
#[test]
fn test_prompt_macro_no_args() {
let prompt = prompt!("This is a static prompt.",).unwrap();
assert_eq!(prompt, "This is a static prompt.");
}
#[test]
fn test_render_prompt_with_json_value_dot_notation() {
use serde_json::json;
let context = json!({
"user": {
"name": "Alice",
"age": 30,
"profile": {
"role": "Developer"
}
}
});
let template =
"{{ user.name }} is {{ user.age }} years old and works as {{ user.profile.role }}";
let result = render_prompt(template, &context).unwrap();
assert_eq!(result, "Alice is 30 years old and works as Developer");
}
#[test]
fn test_render_prompt_with_hashmap_json_value() {
use serde_json::json;
use std::collections::HashMap;
let mut context = HashMap::new();
context.insert(
"step_1_output".to_string(),
json!({
"result": "success",
"data": {
"count": 42
}
}),
);
context.insert("task".to_string(), json!("analysis"));
let template = "Task: {{ task }}, Result: {{ step_1_output.result }}, Count: {{ step_1_output.data.count }}";
let result = render_prompt(template, &context).unwrap();
assert_eq!(result, "Task: analysis, Result: success, Count: 42");
}
#[test]
fn test_render_prompt_with_array_in_json_template() {
use serde_json::json;
use std::collections::HashMap;
let mut context = HashMap::new();
context.insert(
"user_request".to_string(),
json!({
"narrative_keywords": ["betrayal", "redemption", "sacrifice"]
}),
);
let template = r#"{"keywords": {{ user_request.narrative_keywords }}}"#;
let result = render_prompt(template, &context).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
assert_eq!(parsed["keywords"][0], "betrayal");
assert_eq!(parsed["keywords"][1], "redemption");
assert_eq!(parsed["keywords"][2], "sacrifice");
}
#[test]
fn test_render_prompt_with_object_in_json_template() {
use serde_json::json;
use std::collections::HashMap;
let mut context = HashMap::new();
context.insert(
"user_request".to_string(),
json!({
"config": {
"theme": "dark_fantasy",
"complexity": 5
}
}),
);
let template = r#"{"settings": {{ user_request.config }}}"#;
let result = render_prompt(template, &context).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
assert_eq!(parsed["settings"]["theme"], "dark_fantasy");
assert_eq!(parsed["settings"]["complexity"], 5);
}
#[test]
fn test_render_prompt_mixed_json_template() {
use serde_json::json;
use std::collections::HashMap;
let mut context = HashMap::new();
context.insert(
"world_concept".to_string(),
json!({
"concept": "A world where identity is volatile"
}),
);
context.insert(
"user_request".to_string(),
json!({
"narrative_keywords": ["betrayal", "redemption"],
"theme": "dark fantasy"
}),
);
let template = r#"{"concept": "{{ world_concept.concept }}", "keywords": {{ user_request.narrative_keywords }}, "theme": "{{ user_request.theme }}"}"#;
let result = render_prompt(template, &context).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&result).unwrap();
assert_eq!(parsed["concept"], "A world where identity is volatile");
assert_eq!(parsed["keywords"][0], "betrayal");
assert_eq!(parsed["theme"], "dark fantasy");
}
}
#[derive(Debug, thiserror::Error)]
pub enum PromptSetError {
#[error("Target '{target}' not found. Available targets: {available:?}")]
TargetNotFound {
target: String,
available: Vec<String>,
},
#[error("Failed to render prompt for target '{target}': {source}")]
RenderFailed {
target: String,
source: minijinja::Error,
},
}
pub trait ToPromptSet {
fn to_prompt_parts_for(&self, target: &str) -> Result<Vec<PromptPart>, PromptSetError>;
fn to_prompt_for(&self, target: &str) -> Result<String, PromptSetError> {
let parts = self.to_prompt_parts_for(target)?;
let text = parts
.iter()
.filter_map(|part| match part {
PromptPart::Text(text) => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n");
Ok(text)
}
}
pub trait ToPromptFor<T> {
fn to_prompt_for_with_mode(&self, target: &T, mode: &str) -> String;
fn to_prompt_for(&self, target: &T) -> String {
self.to_prompt_for_with_mode(target, "full")
}
}