Skip to main content

synth_claw/config/
schema.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::path::PathBuf;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct SynthConfig {
7    pub name: String,
8    #[serde(default)]
9    pub source: Option<SourceConfig>,
10    pub provider: ProviderConfig,
11    pub generation: GenerationConfig,
12    pub output: OutputConfig,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16#[serde(tag = "type", rename_all = "lowercase")]
17pub enum SourceConfig {
18    HuggingFace {
19        dataset: String,
20        #[serde(default)]
21        subset: Option<String>,
22        #[serde(default = "default_split")]
23        split: String,
24        #[serde(default)]
25        sample: Option<usize>,
26        #[serde(default)]
27        columns: Option<Vec<String>>,
28    },
29    Local {
30        path: PathBuf,
31        format: FileFormat,
32        #[serde(default)]
33        sample: Option<usize>,
34    },
35}
36
37fn default_split() -> String {
38    "train".to_string()
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42#[serde(rename_all = "lowercase")]
43pub enum FileFormat {
44    Json,
45    Jsonl,
46    Csv,
47    Parquet,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51#[serde(tag = "type", rename_all = "lowercase")]
52pub enum ProviderConfig {
53    OpenAI {
54        model: String,
55        #[serde(default)]
56        api_key: Option<String>,
57        #[serde(default)]
58        base_url: Option<String>,
59        #[serde(default)]
60        temperature: Option<f32>,
61        #[serde(default)]
62        max_tokens: Option<u32>,
63    },
64    Anthropic {
65        model: String,
66        #[serde(default)]
67        api_key: Option<String>,
68        #[serde(default)]
69        temperature: Option<f32>,
70        #[serde(default)]
71        max_tokens: Option<u32>,
72    },
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct GenerationConfig {
77    pub task: GenerationTask,
78    #[serde(default = "default_count")]
79    pub count: usize,
80    #[serde(default)]
81    pub count_per_example: Option<usize>,
82    #[serde(default = "default_concurrency")]
83    pub concurrency: usize,
84    #[serde(default)]
85    pub strategy: Option<GenerationStrategy>,
86    #[serde(default)]
87    pub strategy_config: HashMap<String, serde_yaml::Value>,
88    #[serde(default)]
89    pub template: Option<String>,
90    #[serde(default)]
91    pub system_prompt: Option<String>,
92    #[serde(default)]
93    pub categories: Option<Vec<String>>,
94}
95
96fn default_count() -> usize {
97    100
98}
99
100fn default_concurrency() -> usize {
101    5
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
105#[serde(rename_all = "snake_case")]
106pub enum GenerationTask {
107    Generate,
108    Augment,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
112#[serde(rename_all = "snake_case")]
113pub enum GenerationStrategy {
114    Paraphrase,
115    StyleTransfer,
116    BackTranslation,
117    Custom,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct OutputConfig {
122    pub format: OutputFormat,
123    pub path: PathBuf,
124    #[serde(default = "default_batch_size")]
125    pub batch_size: usize,
126}
127
128fn default_batch_size() -> usize {
129    100
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
133#[serde(rename_all = "lowercase")]
134pub enum OutputFormat {
135    Json,
136    Jsonl,
137    Csv,
138    Parquet,
139}
140
141impl SynthConfig {
142    pub fn from_yaml(content: &str) -> crate::Result<Self> {
143        serde_yaml::from_str(content).map_err(Into::into)
144    }
145
146    pub fn from_file(path: &PathBuf) -> crate::Result<Self> {
147        let content = std::fs::read_to_string(path)?;
148        Self::from_yaml(&content)
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_parse_augment_config() {
158        let yaml = r#"
159name: "sentiment_augmentation"
160
161source:
162  type: huggingface
163  dataset: "cornell-movie-review-data/rotten_tomatoes"
164  split: "train"
165  sample: 1000
166
167provider:
168  type: openai
169  model: "gpt-4o-mini"
170
171generation:
172  task: augment
173  count_per_example: 3
174  concurrency: 10
175  strategy: paraphrase
176
177output:
178  format: jsonl
179  path: "./output/augmented.jsonl"
180"#;
181
182        let config = SynthConfig::from_yaml(yaml).unwrap();
183        assert_eq!(config.name, "sentiment_augmentation");
184        assert!(matches!(
185            config.source,
186            Some(SourceConfig::HuggingFace { .. })
187        ));
188        assert!(matches!(config.provider, ProviderConfig::OpenAI { .. }));
189        assert!(matches!(config.generation.task, GenerationTask::Augment));
190    }
191
192    #[test]
193    fn test_parse_generate_config() {
194        let yaml = r#"
195name: "product_reviews"
196
197provider:
198  type: anthropic
199  model: "claude-haiku-4-5-20251001"
200
201generation:
202  task: generate
203  count: 500
204  concurrency: 5
205  categories:
206    - electronics
207    - books
208    - clothing
209  template: |
210    Generate a realistic {category} product review.
211    Output only the review text.
212
213output:
214  format: parquet
215  path: "./output/reviews.parquet"
216"#;
217
218        let config = SynthConfig::from_yaml(yaml).unwrap();
219        assert_eq!(config.name, "product_reviews");
220        assert!(config.source.is_none());
221        assert!(matches!(config.provider, ProviderConfig::Anthropic { .. }));
222        assert!(matches!(config.generation.task, GenerationTask::Generate));
223        assert_eq!(config.generation.categories.as_ref().unwrap().len(), 3);
224    }
225}