1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::path::PathBuf;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct SynthConfig {
7 pub name: String,
8 #[serde(default)]
9 pub source: Option<SourceConfig>,
10 pub provider: ProviderConfig,
11 pub generation: GenerationConfig,
12 pub output: OutputConfig,
13 #[serde(default)]
14 pub validation: Option<ValidationConfig>,
15 #[serde(default)]
16 pub hub: Option<HubConfig>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20#[serde(tag = "type", rename_all = "lowercase")]
21pub enum SourceConfig {
22 HuggingFace {
23 dataset: String,
24 #[serde(default)]
25 subset: Option<String>,
26 #[serde(default = "default_split")]
27 split: String,
28 #[serde(default)]
29 sample: Option<usize>,
30 #[serde(default)]
31 columns: Option<Vec<String>>,
32 },
33 Local {
34 path: PathBuf,
35 format: FileFormat,
36 #[serde(default)]
37 sample: Option<usize>,
38 },
39}
40
41fn default_split() -> String {
42 "train".to_string()
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46#[serde(rename_all = "lowercase")]
47pub enum FileFormat {
48 Json,
49 Jsonl,
50 Csv,
51 Parquet,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
55#[serde(tag = "type", rename_all = "lowercase")]
56pub enum ProviderConfig {
57 OpenAI {
58 model: String,
59 #[serde(default)]
60 api_key: Option<String>,
61 #[serde(default)]
62 base_url: Option<String>,
63 #[serde(default)]
64 temperature: Option<f32>,
65 #[serde(default)]
66 max_tokens: Option<u32>,
67 },
68 Anthropic {
69 model: String,
70 #[serde(default)]
71 api_key: Option<String>,
72 #[serde(default)]
73 temperature: Option<f32>,
74 #[serde(default)]
75 max_tokens: Option<u32>,
76 },
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct GenerationConfig {
81 pub task: GenerationTask,
82 #[serde(default = "default_count")]
83 pub count: usize,
84 #[serde(default)]
85 pub count_per_example: Option<usize>,
86 #[serde(default = "default_concurrency")]
87 pub concurrency: usize,
88 #[serde(default)]
89 pub strategy: Option<GenerationStrategy>,
90 #[serde(default)]
91 pub strategy_config: HashMap<String, serde_yaml::Value>,
92 #[serde(default)]
93 pub template: Option<String>,
94 #[serde(default)]
95 pub system_prompt: Option<String>,
96 #[serde(default)]
97 pub categories: Option<Vec<String>>,
98}
99
100fn default_count() -> usize {
101 100
102}
103
104fn default_concurrency() -> usize {
105 5
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
109#[serde(rename_all = "snake_case")]
110pub enum GenerationTask {
111 Generate,
112 Augment,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
116#[serde(rename_all = "snake_case")]
117pub enum GenerationStrategy {
118 Paraphrase,
119 StyleTransfer,
120 BackTranslation,
121 Custom,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct OutputConfig {
126 pub format: OutputFormat,
127 pub path: PathBuf,
128 #[serde(default = "default_batch_size")]
129 pub batch_size: usize,
130}
131
132fn default_batch_size() -> usize {
133 100
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
137#[serde(rename_all = "lowercase")]
138pub enum OutputFormat {
139 Json,
140 Jsonl,
141 Csv,
142 Parquet,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize, Default)]
146pub struct ValidationConfig {
147 #[serde(default)]
148 pub min_length: Option<usize>,
149 #[serde(default)]
150 pub max_length: Option<usize>,
151 #[serde(default)]
152 pub json: bool,
153 #[serde(default)]
154 pub json_schema: Option<Vec<String>>,
155 #[serde(default)]
156 pub blocklist: bool,
157 #[serde(default)]
158 pub repetition: bool,
159 #[serde(default)]
160 pub dedupe: Option<DedupeStrategy>,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
164#[serde(rename_all = "lowercase")]
165pub enum DedupeStrategy {
166 Exact,
167 Normalized,
168 Jaccard,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize, Default)]
172pub struct HubConfig {
173 #[serde(default)]
174 pub token: Option<String>,
175 #[serde(default)]
176 pub repo: Option<String>,
177 #[serde(default)]
178 pub private: bool,
179}
180
181impl SynthConfig {
182 pub fn from_yaml(content: &str) -> crate::Result<Self> {
183 serde_yaml::from_str(content).map_err(Into::into)
184 }
185
186 pub fn from_file(path: &PathBuf) -> crate::Result<Self> {
187 let content = std::fs::read_to_string(path)?;
188 Self::from_yaml(&content)
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn test_parse_augment_config() {
198 let yaml = r#"
199name: "sentiment_augmentation"
200
201source:
202 type: huggingface
203 dataset: "cornell-movie-review-data/rotten_tomatoes"
204 split: "train"
205 sample: 1000
206
207provider:
208 type: openai
209 model: "gpt-4o-mini"
210
211generation:
212 task: augment
213 count_per_example: 3
214 concurrency: 10
215 strategy: paraphrase
216
217output:
218 format: jsonl
219 path: "./output/augmented.jsonl"
220"#;
221
222 let config = SynthConfig::from_yaml(yaml).unwrap();
223 assert_eq!(config.name, "sentiment_augmentation");
224 assert!(matches!(
225 config.source,
226 Some(SourceConfig::HuggingFace { .. })
227 ));
228 assert!(matches!(config.provider, ProviderConfig::OpenAI { .. }));
229 assert!(matches!(config.generation.task, GenerationTask::Augment));
230 }
231
232 #[test]
233 fn test_parse_generate_config() {
234 let yaml = r#"
235name: "product_reviews"
236
237provider:
238 type: anthropic
239 model: "claude-haiku-4-5-20251001"
240
241generation:
242 task: generate
243 count: 500
244 concurrency: 5
245 categories:
246 - electronics
247 - books
248 - clothing
249 template: |
250 Generate a realistic {category} product review.
251 Output only the review text.
252
253output:
254 format: parquet
255 path: "./output/reviews.parquet"
256"#;
257
258 let config = SynthConfig::from_yaml(yaml).unwrap();
259 assert_eq!(config.name, "product_reviews");
260 assert!(config.source.is_none());
261 assert!(matches!(config.provider, ProviderConfig::Anthropic { .. }));
262 assert!(matches!(config.generation.task, GenerationTask::Generate));
263 assert_eq!(config.generation.categories.as_ref().unwrap().len(), 3);
264 }
265
266 #[test]
267 fn test_parse_validation_config() {
268 let yaml = r#"
269name: "with_validation"
270
271provider:
272 type: openai
273 model: "gpt-4o-mini"
274
275generation:
276 task: generate
277 count: 10
278 template: "Generate JSON: {\"q\": \"...\", \"a\": \"...\"}"
279
280output:
281 format: jsonl
282 path: "./output.jsonl"
283
284validation:
285 min_length: 20
286 max_length: 1000
287 json: true
288 json_schema:
289 - question
290 - answer
291 blocklist: true
292 repetition: true
293 dedupe: normalized
294"#;
295
296 let config = SynthConfig::from_yaml(yaml).unwrap();
297 let v = config.validation.unwrap();
298 assert_eq!(v.min_length, Some(20));
299 assert_eq!(v.max_length, Some(1000));
300 assert!(v.json);
301 assert_eq!(v.json_schema.unwrap(), vec!["question", "answer"]);
302 assert!(v.blocklist);
303 assert!(v.repetition);
304 assert!(matches!(v.dedupe, Some(DedupeStrategy::Normalized)));
305 }
306}