use std::collections::HashMap;
use async_trait::async_trait;
use synaptic_core::{RunnableConfig, SynapticError};
use synaptic_runnables::Runnable;
use crate::{FewShotExample, PromptTemplate};
pub struct FewShotPromptTemplate {
examples: Vec<FewShotExample>,
example_prompt: PromptTemplate,
prefix: Option<String>,
suffix: PromptTemplate,
example_separator: String,
}
impl FewShotPromptTemplate {
pub fn new(
examples: Vec<FewShotExample>,
example_prompt: PromptTemplate,
suffix: PromptTemplate,
) -> Self {
Self {
examples,
example_prompt,
prefix: None,
suffix,
example_separator: "\n\n".to_string(),
}
}
pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
self.prefix = Some(prefix.into());
self
}
pub fn with_separator(mut self, sep: impl Into<String>) -> Self {
self.example_separator = sep.into();
self
}
pub fn render(&self, values: &HashMap<String, String>) -> Result<String, SynapticError> {
let mut parts: Vec<String> = Vec::new();
if let Some(prefix) = &self.prefix {
parts.push(prefix.clone());
}
let mut example_strings = Vec::with_capacity(self.examples.len());
for example in &self.examples {
let example_values = HashMap::from([
("input".to_string(), example.input.clone()),
("output".to_string(), example.output.clone()),
]);
let rendered = self
.example_prompt
.render(&example_values)
.map_err(|e| SynapticError::Prompt(e.to_string()))?;
example_strings.push(rendered);
}
if !example_strings.is_empty() {
parts.push(example_strings.join(&self.example_separator));
}
let suffix_rendered = self
.suffix
.render(values)
.map_err(|e| SynapticError::Prompt(e.to_string()))?;
parts.push(suffix_rendered);
Ok(parts.join(&self.example_separator))
}
}
#[async_trait]
impl Runnable<HashMap<String, String>, String> for FewShotPromptTemplate {
async fn invoke(
&self,
input: HashMap<String, String>,
_config: &RunnableConfig,
) -> Result<String, SynapticError> {
self.render(&input)
}
}