Skip to main content

synth_claw/config/
schema.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::path::PathBuf;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct SynthConfig {
7    pub name: String,
8    #[serde(default)]
9    pub source: Option<SourceConfig>,
10    pub provider: ProviderConfig,
11    pub generation: GenerationConfig,
12    pub output: OutputConfig,
13    #[serde(default)]
14    pub validation: Option<ValidationConfig>,
15    #[serde(default)]
16    pub hub: Option<HubConfig>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20#[serde(tag = "type", rename_all = "lowercase")]
21pub enum SourceConfig {
22    HuggingFace {
23        dataset: String,
24        #[serde(default)]
25        subset: Option<String>,
26        #[serde(default = "default_split")]
27        split: String,
28        #[serde(default)]
29        sample: Option<usize>,
30        #[serde(default)]
31        columns: Option<Vec<String>>,
32    },
33    Local {
34        path: PathBuf,
35        format: FileFormat,
36        #[serde(default)]
37        sample: Option<usize>,
38    },
39}
40
41fn default_split() -> String {
42    "train".to_string()
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46#[serde(rename_all = "lowercase")]
47pub enum FileFormat {
48    Json,
49    Jsonl,
50    Csv,
51    Parquet,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
55#[serde(tag = "type", rename_all = "lowercase")]
56pub enum ProviderConfig {
57    OpenAI {
58        model: String,
59        #[serde(default)]
60        api_key: Option<String>,
61        #[serde(default)]
62        base_url: Option<String>,
63        #[serde(default)]
64        temperature: Option<f32>,
65        #[serde(default)]
66        max_tokens: Option<u32>,
67    },
68    Anthropic {
69        model: String,
70        #[serde(default)]
71        api_key: Option<String>,
72        #[serde(default)]
73        temperature: Option<f32>,
74        #[serde(default)]
75        max_tokens: Option<u32>,
76    },
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct GenerationConfig {
81    pub task: GenerationTask,
82    #[serde(default = "default_count")]
83    pub count: usize,
84    #[serde(default)]
85    pub count_per_example: Option<usize>,
86    #[serde(default = "default_concurrency")]
87    pub concurrency: usize,
88    #[serde(default)]
89    pub strategy: Option<GenerationStrategy>,
90    #[serde(default)]
91    pub strategy_config: HashMap<String, serde_yaml::Value>,
92    #[serde(default)]
93    pub template: Option<String>,
94    #[serde(default)]
95    pub system_prompt: Option<String>,
96    #[serde(default)]
97    pub categories: Option<Vec<String>>,
98}
99
100fn default_count() -> usize {
101    100
102}
103
104fn default_concurrency() -> usize {
105    5
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
109#[serde(rename_all = "snake_case")]
110pub enum GenerationTask {
111    Generate,
112    Augment,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
116#[serde(rename_all = "snake_case")]
117pub enum GenerationStrategy {
118    Paraphrase,
119    StyleTransfer,
120    BackTranslation,
121    Custom,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct OutputConfig {
126    pub format: OutputFormat,
127    pub path: PathBuf,
128    #[serde(default = "default_batch_size")]
129    pub batch_size: usize,
130}
131
132fn default_batch_size() -> usize {
133    100
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
137#[serde(rename_all = "lowercase")]
138pub enum OutputFormat {
139    Json,
140    Jsonl,
141    Csv,
142    Parquet,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize, Default)]
146pub struct ValidationConfig {
147    #[serde(default)]
148    pub min_length: Option<usize>,
149    #[serde(default)]
150    pub max_length: Option<usize>,
151    #[serde(default)]
152    pub json: bool,
153    #[serde(default)]
154    pub json_schema: Option<Vec<String>>,
155    #[serde(default)]
156    pub blocklist: bool,
157    #[serde(default)]
158    pub repetition: bool,
159    #[serde(default)]
160    pub dedupe: Option<DedupeStrategy>,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
164#[serde(rename_all = "lowercase")]
165pub enum DedupeStrategy {
166    Exact,
167    Normalized,
168    Jaccard,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize, Default)]
172pub struct HubConfig {
173    #[serde(default)]
174    pub token: Option<String>,
175    #[serde(default)]
176    pub repo: Option<String>,
177    #[serde(default)]
178    pub private: bool,
179}
180
181impl SynthConfig {
182    pub fn from_yaml(content: &str) -> crate::Result<Self> {
183        serde_yaml::from_str(content).map_err(Into::into)
184    }
185
186    pub fn from_file(path: &PathBuf) -> crate::Result<Self> {
187        let content = std::fs::read_to_string(path)?;
188        Self::from_yaml(&content)
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn test_parse_augment_config() {
198        let yaml = r#"
199name: "sentiment_augmentation"
200
201source:
202  type: huggingface
203  dataset: "cornell-movie-review-data/rotten_tomatoes"
204  split: "train"
205  sample: 1000
206
207provider:
208  type: openai
209  model: "gpt-4o-mini"
210
211generation:
212  task: augment
213  count_per_example: 3
214  concurrency: 10
215  strategy: paraphrase
216
217output:
218  format: jsonl
219  path: "./output/augmented.jsonl"
220"#;
221
222        let config = SynthConfig::from_yaml(yaml).unwrap();
223        assert_eq!(config.name, "sentiment_augmentation");
224        assert!(matches!(
225            config.source,
226            Some(SourceConfig::HuggingFace { .. })
227        ));
228        assert!(matches!(config.provider, ProviderConfig::OpenAI { .. }));
229        assert!(matches!(config.generation.task, GenerationTask::Augment));
230    }
231
232    #[test]
233    fn test_parse_generate_config() {
234        let yaml = r#"
235name: "product_reviews"
236
237provider:
238  type: anthropic
239  model: "claude-haiku-4-5-20251001"
240
241generation:
242  task: generate
243  count: 500
244  concurrency: 5
245  categories:
246    - electronics
247    - books
248    - clothing
249  template: |
250    Generate a realistic {category} product review.
251    Output only the review text.
252
253output:
254  format: parquet
255  path: "./output/reviews.parquet"
256"#;
257
258        let config = SynthConfig::from_yaml(yaml).unwrap();
259        assert_eq!(config.name, "product_reviews");
260        assert!(config.source.is_none());
261        assert!(matches!(config.provider, ProviderConfig::Anthropic { .. }));
262        assert!(matches!(config.generation.task, GenerationTask::Generate));
263        assert_eq!(config.generation.categories.as_ref().unwrap().len(), 3);
264    }
265
266    #[test]
267    fn test_parse_validation_config() {
268        let yaml = r#"
269name: "with_validation"
270
271provider:
272  type: openai
273  model: "gpt-4o-mini"
274
275generation:
276  task: generate
277  count: 10
278  template: "Generate JSON: {\"q\": \"...\", \"a\": \"...\"}"
279
280output:
281  format: jsonl
282  path: "./output.jsonl"
283
284validation:
285  min_length: 20
286  max_length: 1000
287  json: true
288  json_schema:
289    - question
290    - answer
291  blocklist: true
292  repetition: true
293  dedupe: normalized
294"#;
295
296        let config = SynthConfig::from_yaml(yaml).unwrap();
297        let v = config.validation.unwrap();
298        assert_eq!(v.min_length, Some(20));
299        assert_eq!(v.max_length, Some(1000));
300        assert!(v.json);
301        assert_eq!(v.json_schema.unwrap(), vec!["question", "answer"]);
302        assert!(v.blocklist);
303        assert!(v.repetition);
304        assert!(matches!(v.dedupe, Some(DedupeStrategy::Normalized)));
305    }
306}