synth_claw/generation/
prompt.rs1use crate::datasets::Record;
2use crate::{Error, Result};
3use regex::Regex;
4use serde_json::Value;
5use std::collections::HashMap;
6
7pub struct PromptBuilder {
8 template: String,
9 system_prompt: String,
10}
11
12impl PromptBuilder {
13 pub fn new(template: String, system_prompt: Option<String>, is_augment: bool) -> Self {
14 let system_prompt = system_prompt.unwrap_or_else(|| {
15 if is_augment {
16 default_system_prompt_augment().to_string()
17 } else {
18 default_system_prompt_generate().to_string()
19 }
20 });
21 Self {
22 template,
23 system_prompt,
24 }
25 }
26
27 pub fn system_prompt(&self) -> &str {
28 &self.system_prompt
29 }
30
31 pub fn build_for_category(&self, category: &str, index: usize) -> String {
33 let mut vars = HashMap::new();
34 vars.insert("category".to_string(), Value::String(category.to_string()));
35 vars.insert("index".to_string(), Value::Number(index.into()));
36 self.substitute(&vars)
37 }
38
39 pub fn build_for_record(&self, record: &Record) -> String {
41 let vars = self.extract_vars(&record.data);
42 self.substitute(&vars)
43 }
44
45 pub fn required_variables(&self) -> Vec<String> {
47 let re = Regex::new(r"\{(\w+)\}").unwrap();
48 re.captures_iter(&self.template)
49 .map(|cap| cap[1].to_string())
50 .collect()
51 }
52
53 pub fn validate_for_generate(&self, categories: &Option<Vec<String>>) -> Result<()> {
55 let required = self.required_variables();
56
57 for var in &required {
58 match var.as_str() {
59 "category" => {
60 if categories.is_none()
61 || categories.as_ref().map(|c| c.is_empty()).unwrap_or(true)
62 {
63 return Err(Error::Config(
64 "Template uses {category} but no categories provided".to_string(),
65 ));
66 }
67 }
68 "index" => {} other => {
70 return Err(Error::Config(format!(
71 "Template uses {{{}}} which is not available in generate mode. Available: {{category}}, {{index}}",
72 other
73 )));
74 }
75 }
76 }
77 Ok(())
78 }
79
80 pub fn validate_for_augment(&self, available_columns: &[String]) -> Result<()> {
82 let required = self.required_variables();
83
84 for var in &required {
85 if var != "index" && !available_columns.contains(var) {
86 return Err(Error::Config(format!(
87 "Template uses {{{}}} but source data only has columns: {:?}",
88 var, available_columns
89 )));
90 }
91 }
92 Ok(())
93 }
94
95 fn substitute(&self, vars: &HashMap<String, Value>) -> String {
96 let mut result = self.template.clone();
97 for (key, value) in vars {
98 let placeholder = format!("{{{}}}", key);
99 let replacement = match value {
100 Value::String(s) => s.clone(),
101 Value::Number(n) => n.to_string(),
102 Value::Bool(b) => b.to_string(),
103 Value::Null => "null".to_string(),
104 Value::Array(arr) => serde_json::to_string(arr).unwrap_or_default(),
105 Value::Object(obj) => serde_json::to_string(obj).unwrap_or_default(),
106 };
107 result = result.replace(&placeholder, &replacement);
108 }
109 result
110 }
111
112 fn extract_vars(&self, data: &Value) -> HashMap<String, Value> {
113 let mut vars = HashMap::new();
114 if let Value::Object(map) = data {
115 for (key, value) in map {
116 vars.insert(key.clone(), value.clone());
117 }
118 }
119 vars
120 }
121}
122
123pub fn default_system_prompt_generate() -> &'static str {
124 r#"You are a synthetic data generation assistant. Your task is to generate realistic, high-quality training data.
125
126Rules:
127- Output ONLY the requested content, nothing else
128- No explanations, meta-commentary, or surrounding text
129- No markdown formatting unless explicitly requested
130- Generate diverse, realistic examples that could plausibly exist in the real world
131- Vary your outputs - avoid repetitive patterns or templates
132- Match the tone and style appropriate for the content type
133- If generating text that would have a label (sentiment, category, etc.), make the content clearly match that label"#
134}
135
136pub fn default_system_prompt_augment() -> &'static str {
137 r#"You are a data augmentation assistant. Your task is to transform input data while preserving its essential properties.
138
139Rules:
140- Output ONLY the transformed content, nothing else
141- No explanations, meta-commentary, or surrounding text
142- No markdown formatting unless explicitly requested
143- Preserve the original meaning, sentiment, and intent
144- If the data has a label (positive/negative, category, etc.), the augmented version must retain the same label
145- Make meaningful changes - simple word swaps are not sufficient
146- The output should be natural and fluent"#
147}
148
149pub fn default_template_for_generate() -> String {
150 r#"Generate a realistic example of: {category}
151
152Requirements:
153- Authentic, natural language
154- Specific details that make it believable
155- 2-5 sentences unless otherwise specified
156- Diverse - vary structure and content"#
157 .to_string()
158}
159
160pub fn default_template_for_augment(strategy: &str) -> String {
161 match strategy {
162 "paraphrase" => {
163 r#"Paraphrase the following text. Preserve the original meaning and sentiment exactly.
164
165Input: {text}
166
167Paraphrased version:"#
168 .to_string()
169 }
170
171 "style_transfer" => {
172 r#"Rewrite the following text in a different style while preserving the core meaning.
173
174Input: {text}
175Target style: {style}
176
177Rewritten version:"#
178 .to_string()
179 }
180
181 "back_translation" => {
182 r#"Rephrase this text as if it was translated to another language and back. Keep the same meaning but use different word choices and sentence structures.
183
184Input: {text}
185
186Rephrased version:"#
187 .to_string()
188 }
189
190 _ => {
191 r#"Transform the following text while preserving its meaning:
192
193Input: {text}
194
195Transformed version:"#
196 .to_string()
197 }
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 #[test]
206 fn test_build_for_category() {
207 let builder = PromptBuilder::new(
208 "Generate a {category} review (item #{index})".to_string(),
209 None,
210 false,
211 );
212 let result = builder.build_for_category("electronics", 5);
213 assert_eq!(result, "Generate a electronics review (item #5)");
214 }
215
216 #[test]
217 fn test_build_for_record() {
218 let builder =
219 PromptBuilder::new("Paraphrase: {text}\nLabel: {label}".to_string(), None, true);
220 let record = Record {
221 data: serde_json::json!({
222 "text": "This movie is great!",
223 "label": 1
224 }),
225 index: 0,
226 };
227 let result = builder.build_for_record(&record);
228 assert_eq!(result, "Paraphrase: This movie is great!\nLabel: 1");
229 }
230
231 #[test]
232 fn test_required_variables() {
233 let builder = PromptBuilder::new(
234 "Hello {name}, your {item} for {category} is ready".to_string(),
235 None,
236 false,
237 );
238 let vars = builder.required_variables();
239 assert!(vars.contains(&"name".to_string()));
240 assert!(vars.contains(&"item".to_string()));
241 assert!(vars.contains(&"category".to_string()));
242 }
243
244 #[test]
245 fn test_validate_generate_missing_categories() {
246 let builder = PromptBuilder::new("Generate a {category} example".to_string(), None, false);
247 let result = builder.validate_for_generate(&None);
248 assert!(result.is_err());
249 }
250
251 #[test]
252 fn test_validate_generate_with_categories() {
253 let builder = PromptBuilder::new("Generate a {category} example".to_string(), None, false);
254 let result = builder.validate_for_generate(&Some(vec!["test".to_string()]));
255 assert!(result.is_ok());
256 }
257
258 #[test]
259 fn test_validate_augment_missing_column() {
260 let builder =
261 PromptBuilder::new("Paraphrase: {text} with {missing}".to_string(), None, true);
262 let result = builder.validate_for_augment(&["text".to_string()]);
263 assert!(result.is_err());
264 }
265
266 #[test]
267 fn test_validate_augment_valid() {
268 let builder = PromptBuilder::new("Paraphrase: {text}".to_string(), None, true);
269 let result = builder.validate_for_augment(&["text".to_string(), "label".to_string()]);
270 assert!(result.is_ok());
271 }
272
273 #[test]
274 fn test_default_system_prompts_exist() {
275 assert!(!default_system_prompt_generate().is_empty());
276 assert!(!default_system_prompt_augment().is_empty());
277 }
278}