use std::collections::HashMap;
use async_trait::async_trait;
use serde_json::Value;
use synaptic_core::{Message, RunnableConfig, SynapticError};
use synaptic_runnables::Runnable;
use crate::PromptTemplate;
pub enum MessageTemplate {
System(PromptTemplate),
Human(PromptTemplate),
AI(PromptTemplate),
Placeholder(String),
}
pub struct ChatPromptTemplate {
templates: Vec<MessageTemplate>,
}
impl ChatPromptTemplate {
pub fn new(templates: Vec<MessageTemplate>) -> Self {
Self { templates }
}
pub fn from_messages(templates: Vec<MessageTemplate>) -> Self {
Self::new(templates)
}
pub fn format(&self, values: &HashMap<String, Value>) -> Result<Vec<Message>, SynapticError> {
let mut messages = Vec::new();
let string_values: HashMap<String, String> = values
.iter()
.filter_map(|(k, v)| {
if let Value::String(s) = v {
Some((k.clone(), s.clone()))
} else {
None
}
})
.collect();
for template in &self.templates {
match template {
MessageTemplate::System(pt) => {
let content = pt
.render(&string_values)
.map_err(|e| SynapticError::Prompt(e.to_string()))?;
messages.push(Message::system(content));
}
MessageTemplate::Human(pt) => {
let content = pt
.render(&string_values)
.map_err(|e| SynapticError::Prompt(e.to_string()))?;
messages.push(Message::human(content));
}
MessageTemplate::AI(pt) => {
let content = pt
.render(&string_values)
.map_err(|e| SynapticError::Prompt(e.to_string()))?;
messages.push(Message::ai(content));
}
MessageTemplate::Placeholder(key) => {
let value = values.get(key).ok_or_else(|| {
SynapticError::Prompt(format!("missing placeholder: {key}"))
})?;
let msgs: Vec<Message> =
serde_json::from_value(value.clone()).map_err(|e| {
SynapticError::Prompt(format!(
"invalid messages for placeholder '{key}': {e}"
))
})?;
messages.extend(msgs);
}
}
}
Ok(messages)
}
}
#[async_trait]
impl Runnable<HashMap<String, Value>, Vec<Message>> for ChatPromptTemplate {
async fn invoke(
&self,
input: HashMap<String, Value>,
_config: &RunnableConfig,
) -> Result<Vec<Message>, SynapticError> {
self.format(&input)
}
}