use crate::datasets::Record;
use crate::{Error, Result};
use regex::Regex;
use serde_json::Value;
use std::collections::HashMap;
pub struct PromptBuilder {
template: String,
system_prompt: String,
}
impl PromptBuilder {
pub fn new(template: String, system_prompt: Option<String>, is_augment: bool) -> Self {
let system_prompt = system_prompt.unwrap_or_else(|| {
if is_augment {
default_system_prompt_augment().to_string()
} else {
default_system_prompt_generate().to_string()
}
});
Self {
template,
system_prompt,
}
}
pub fn system_prompt(&self) -> &str {
&self.system_prompt
}
pub fn build_for_category(&self, category: &str, index: usize) -> String {
let mut vars = HashMap::new();
vars.insert("category".to_string(), Value::String(category.to_string()));
vars.insert("index".to_string(), Value::Number(index.into()));
self.substitute(&vars)
}
pub fn build_for_record(&self, record: &Record) -> String {
let vars = self.extract_vars(&record.data);
self.substitute(&vars)
}
pub fn required_variables(&self) -> Vec<String> {
let re = Regex::new(r"\{(\w+)\}").unwrap();
re.captures_iter(&self.template)
.map(|cap| cap[1].to_string())
.collect()
}
pub fn validate_for_generate(&self, categories: &Option<Vec<String>>) -> Result<()> {
let required = self.required_variables();
for var in &required {
match var.as_str() {
"category" => {
if categories.is_none()
|| categories.as_ref().map(|c| c.is_empty()).unwrap_or(true)
{
return Err(Error::Config(
"Template uses {category} but no categories provided".to_string(),
));
}
}
"index" => {} other => {
return Err(Error::Config(format!(
"Template uses {{{}}} which is not available in generate mode. Available: {{category}}, {{index}}",
other
)));
}
}
}
Ok(())
}
pub fn validate_for_augment(&self, available_columns: &[String]) -> Result<()> {
let required = self.required_variables();
for var in &required {
if var != "index" && !available_columns.contains(var) {
return Err(Error::Config(format!(
"Template uses {{{}}} but source data only has columns: {:?}",
var, available_columns
)));
}
}
Ok(())
}
fn substitute(&self, vars: &HashMap<String, Value>) -> String {
let mut result = self.template.clone();
for (key, value) in vars {
let placeholder = format!("{{{}}}", key);
let replacement = match value {
Value::String(s) => s.clone(),
Value::Number(n) => n.to_string(),
Value::Bool(b) => b.to_string(),
Value::Null => "null".to_string(),
Value::Array(arr) => serde_json::to_string(arr).unwrap_or_default(),
Value::Object(obj) => serde_json::to_string(obj).unwrap_or_default(),
};
result = result.replace(&placeholder, &replacement);
}
result
}
fn extract_vars(&self, data: &Value) -> HashMap<String, Value> {
let mut vars = HashMap::new();
if let Value::Object(map) = data {
for (key, value) in map {
vars.insert(key.clone(), value.clone());
}
}
vars
}
}
pub fn default_system_prompt_generate() -> &'static str {
r#"You are a synthetic data generation assistant. Your task is to generate realistic, high-quality training data.
Rules:
- Output ONLY the requested content, nothing else
- No explanations, meta-commentary, or surrounding text
- No markdown formatting unless explicitly requested
- Generate diverse, realistic examples that could plausibly exist in the real world
- Vary your outputs - avoid repetitive patterns or templates
- Match the tone and style appropriate for the content type
- If generating text that would have a label (sentiment, category, etc.), make the content clearly match that label"#
}
pub fn default_system_prompt_augment() -> &'static str {
r#"You are a data augmentation assistant. Your task is to transform input data while preserving its essential properties.
Rules:
- Output ONLY the transformed content, nothing else
- No explanations, meta-commentary, or surrounding text
- No markdown formatting unless explicitly requested
- Preserve the original meaning, sentiment, and intent
- If the data has a label (positive/negative, category, etc.), the augmented version must retain the same label
- Make meaningful changes - simple word swaps are not sufficient
- The output should be natural and fluent"#
}
pub fn default_template_for_generate() -> String {
r#"Generate a realistic example of: {category}
Requirements:
- Authentic, natural language
- Specific details that make it believable
- 2-5 sentences unless otherwise specified
- Diverse - vary structure and content"#
.to_string()
}
pub fn default_template_for_augment(strategy: &str) -> String {
match strategy {
"paraphrase" => {
r#"Paraphrase the following text. Preserve the original meaning and sentiment exactly.
Input: {text}
Paraphrased version:"#
.to_string()
}
"style_transfer" => {
r#"Rewrite the following text in a different style while preserving the core meaning.
Input: {text}
Target style: {style}
Rewritten version:"#
.to_string()
}
"back_translation" => {
r#"Rephrase this text as if it was translated to another language and back. Keep the same meaning but use different word choices and sentence structures.
Input: {text}
Rephrased version:"#
.to_string()
}
_ => {
r#"Transform the following text while preserving its meaning:
Input: {text}
Transformed version:"#
.to_string()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_for_category() {
let builder = PromptBuilder::new(
"Generate a {category} review (item #{index})".to_string(),
None,
false,
);
let result = builder.build_for_category("electronics", 5);
assert_eq!(result, "Generate a electronics review (item #5)");
}
#[test]
fn test_build_for_record() {
let builder =
PromptBuilder::new("Paraphrase: {text}\nLabel: {label}".to_string(), None, true);
let record = Record {
data: serde_json::json!({
"text": "This movie is great!",
"label": 1
}),
index: 0,
};
let result = builder.build_for_record(&record);
assert_eq!(result, "Paraphrase: This movie is great!\nLabel: 1");
}
#[test]
fn test_required_variables() {
let builder = PromptBuilder::new(
"Hello {name}, your {item} for {category} is ready".to_string(),
None,
false,
);
let vars = builder.required_variables();
assert!(vars.contains(&"name".to_string()));
assert!(vars.contains(&"item".to_string()));
assert!(vars.contains(&"category".to_string()));
}
#[test]
fn test_validate_generate_missing_categories() {
let builder = PromptBuilder::new("Generate a {category} example".to_string(), None, false);
let result = builder.validate_for_generate(&None);
assert!(result.is_err());
}
#[test]
fn test_validate_generate_with_categories() {
let builder = PromptBuilder::new("Generate a {category} example".to_string(), None, false);
let result = builder.validate_for_generate(&Some(vec!["test".to_string()]));
assert!(result.is_ok());
}
#[test]
fn test_validate_augment_missing_column() {
let builder =
PromptBuilder::new("Paraphrase: {text} with {missing}".to_string(), None, true);
let result = builder.validate_for_augment(&["text".to_string()]);
assert!(result.is_err());
}
#[test]
fn test_validate_augment_valid() {
let builder = PromptBuilder::new("Paraphrase: {text}".to_string(), None, true);
let result = builder.validate_for_augment(&["text".to_string(), "label".to_string()]);
assert!(result.is_ok());
}
#[test]
fn test_default_system_prompts_exist() {
assert!(!default_system_prompt_generate().is_empty());
assert!(!default_system_prompt_augment().is_empty());
}
}