1use crate::config::{GenerationConfig, GenerationStrategy, GenerationTask, ProviderConfig, SourceConfig, SynthConfig};
2use crate::datasets::{DataSource, HuggingFaceSource, LocalSource, Record};
3use crate::providers::{create_provider, GenerationRequest, GenerationResponse, LLMProvider};
4use crate::{Error, Result};
5
6use super::prompt::{default_template_for_augment, default_template_for_generate, PromptBuilder};
7
8use futures::stream::{self, StreamExt};
9use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
10use std::sync::Arc;
11use tokio::sync::Mutex;
12
13#[derive(Debug, Default)]
14pub struct GenerationStats {
15 pub completed: AtomicUsize,
16 pub failed: AtomicUsize,
17 pub total_input_tokens: AtomicU64,
18 pub total_output_tokens: AtomicU64,
19}
20
21impl GenerationStats {
22 pub fn record_success(&self, response: &GenerationResponse) {
23 self.completed.fetch_add(1, Ordering::Relaxed);
24 self.total_input_tokens.fetch_add(response.input_tokens as u64, Ordering::Relaxed);
25 self.total_output_tokens.fetch_add(response.output_tokens as u64, Ordering::Relaxed);
26 }
27
28 pub fn record_failure(&self) {
29 self.failed.fetch_add(1, Ordering::Relaxed);
30 }
31
32 pub fn snapshot(&self) -> StatsSnapshot {
33 StatsSnapshot {
34 completed: self.completed.load(Ordering::Relaxed),
35 failed: self.failed.load(Ordering::Relaxed),
36 total_input_tokens: self.total_input_tokens.load(Ordering::Relaxed),
37 total_output_tokens: self.total_output_tokens.load(Ordering::Relaxed),
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
43pub struct StatsSnapshot {
44 pub completed: usize,
45 pub failed: usize,
46 pub total_input_tokens: u64,
47 pub total_output_tokens: u64,
48}
49
50#[derive(Debug, Clone)]
51pub struct GenerationResult {
52 pub content: String,
53 pub source_index: Option<usize>,
54 pub category: Option<String>,
55 pub input_tokens: u32,
56 pub output_tokens: u32,
57}
58
59struct GenerationTask_ {
60 prompt: String,
61 system_prompt: Option<String>,
62 source_index: Option<usize>,
63 category: Option<String>,
64}
65
66pub struct GenerationEngine {
67 provider: Box<dyn LLMProvider>,
68 config: GenerationConfig,
69 stats: Arc<GenerationStats>,
70}
71
72impl GenerationEngine {
73 pub fn new(provider_config: &ProviderConfig, generation_config: GenerationConfig) -> Result<Self> {
74 let provider = create_provider(provider_config)?;
75 Ok(Self {
76 provider,
77 config: generation_config,
78 stats: Arc::new(GenerationStats::default()),
79 })
80 }
81
82 pub fn stats(&self) -> Arc<GenerationStats> {
83 Arc::clone(&self.stats)
84 }
85
86 pub fn provider(&self) -> &dyn LLMProvider {
87 self.provider.as_ref()
88 }
89
90 pub async fn run(&self, config: &SynthConfig) -> Result<Vec<GenerationResult>> {
92 let tasks = self.build_tasks(config).await?;
93 let results = Arc::new(Mutex::new(Vec::with_capacity(tasks.len())));
94
95 stream::iter(tasks)
96 .map(|task| {
97 let provider = &self.provider;
98 let stats = Arc::clone(&self.stats);
99 let results = Arc::clone(&results);
100 async move {
101 match self.execute_task(provider.as_ref(), task).await {
102 Ok(result) => {
103 stats.record_success(&GenerationResponse {
104 content: result.content.clone(),
105 input_tokens: result.input_tokens,
106 output_tokens: result.output_tokens,
107 });
108 results.lock().await.push(result);
109 }
110 Err(e) => {
111 stats.record_failure();
112 tracing::warn!("Generation failed: {}", e);
113 }
114 }
115 }
116 })
117 .buffer_unordered(self.config.concurrency)
118 .collect::<Vec<_>>()
119 .await;
120
121 let results = Arc::try_unwrap(results)
122 .map_err(|_| Error::Provider("Failed to unwrap results".to_string()))?
123 .into_inner();
124
125 Ok(results)
126 }
127
128 pub async fn run_with_callback<F>(&self, config: &SynthConfig, on_result: F) -> Result<()>
130 where
131 F: FnMut(GenerationResult) + Send,
132 {
133 let tasks = self.build_tasks(config).await?;
134 let callback = Arc::new(Mutex::new(on_result));
135
136 stream::iter(tasks)
137 .map(|task| {
138 let provider = &self.provider;
139 let stats = Arc::clone(&self.stats);
140 let callback = Arc::clone(&callback);
141 async move {
142 match self.execute_task(provider.as_ref(), task).await {
143 Ok(result) => {
144 stats.record_success(&GenerationResponse {
145 content: result.content.clone(),
146 input_tokens: result.input_tokens,
147 output_tokens: result.output_tokens,
148 });
149 callback.lock().await(result);
150 }
151 Err(e) => {
152 stats.record_failure();
153 tracing::warn!("Generation failed: {}", e);
154 }
155 }
156 }
157 })
158 .buffer_unordered(self.config.concurrency)
159 .collect::<Vec<_>>()
160 .await;
161
162 Ok(())
163 }
164
165 async fn build_tasks(&self, config: &SynthConfig) -> Result<Vec<GenerationTask_>> {
166 let prompt_builder = self.create_prompt_builder();
167
168 match &config.generation.task {
169 GenerationTask::Generate => self.build_generate_tasks(&prompt_builder),
170 GenerationTask::Augment => self.build_augment_tasks(config, &prompt_builder).await,
171 }
172 }
173
174 fn build_generate_tasks(&self, prompt_builder: &PromptBuilder) -> Result<Vec<GenerationTask_>> {
175 let categories = self.config.categories.as_ref();
176 let count = self.config.count;
177 let system_prompt = Some(prompt_builder.system_prompt().to_string());
178
179 let mut tasks = Vec::with_capacity(count);
180
181 if let Some(cats) = categories {
182 let per_category = count / cats.len();
183 let remainder = count % cats.len();
184
185 for (cat_idx, category) in cats.iter().enumerate() {
186 let cat_count = per_category + if cat_idx < remainder { 1 } else { 0 };
187 for i in 0..cat_count {
188 tasks.push(GenerationTask_ {
189 prompt: prompt_builder.build_for_category(category, i),
190 system_prompt: system_prompt.clone(),
191 source_index: None,
192 category: Some(category.clone()),
193 });
194 }
195 }
196 } else {
197 for i in 0..count {
198 tasks.push(GenerationTask_ {
199 prompt: prompt_builder.build_for_category("default", i),
200 system_prompt: system_prompt.clone(),
201 source_index: None,
202 category: None,
203 });
204 }
205 }
206
207 Ok(tasks)
208 }
209
210 async fn build_augment_tasks(&self, config: &SynthConfig, prompt_builder: &PromptBuilder) -> Result<Vec<GenerationTask_>> {
211 let source_config = config.source.as_ref()
212 .ok_or_else(|| Error::Config("Augment task requires a source configuration".to_string()))?;
213
214 let records = self.load_source_data(source_config.clone()).await?;
215 let count_per = self.config.count_per_example.unwrap_or(1);
216 let system_prompt = Some(prompt_builder.system_prompt().to_string());
217
218 let mut tasks = Vec::with_capacity(records.len() * count_per);
219
220 for record in &records {
221 for _ in 0..count_per {
222 tasks.push(GenerationTask_ {
223 prompt: prompt_builder.build_for_record(record),
224 system_prompt: system_prompt.clone(),
225 source_index: Some(record.index),
226 category: None,
227 });
228 }
229 }
230
231 Ok(tasks)
232 }
233
234 async fn load_source_data(&self, source_config: SourceConfig) -> Result<Vec<Record>> {
235 tokio::task::spawn_blocking(move || {
237 match source_config {
238 SourceConfig::HuggingFace { dataset, subset, split, sample, columns } => {
239 let mut source = HuggingFaceSource::new(
240 dataset,
241 subset,
242 split,
243 columns,
244 )?;
245 source.load(sample)
246 }
247 SourceConfig::Local { path, format, sample } => {
248 let mut source = LocalSource::new(path, format)?;
249 source.load(sample)
250 }
251 }
252 })
253 .await
254 .map_err(|e| Error::Dataset(format!("Task join error: {}", e)))?
255 }
256
257 fn create_prompt_builder(&self) -> PromptBuilder {
258 let is_augment = matches!(&self.config.task, GenerationTask::Augment);
259
260 let template = self.config.template.clone().unwrap_or_else(|| {
261 match &self.config.task {
262 GenerationTask::Generate => default_template_for_generate(),
263 GenerationTask::Augment => {
264 let strategy = self.config.strategy.as_ref()
265 .map(|s| match s {
266 GenerationStrategy::Paraphrase => "paraphrase",
267 GenerationStrategy::StyleTransfer => "style_transfer",
268 GenerationStrategy::BackTranslation => "back_translation",
269 GenerationStrategy::Custom => "custom",
270 })
271 .unwrap_or("paraphrase");
272 default_template_for_augment(strategy)
273 }
274 }
275 });
276
277 PromptBuilder::new(template, self.config.system_prompt.clone(), is_augment)
278 }
279
280 async fn execute_task(&self, provider: &dyn LLMProvider, task: GenerationTask_) -> Result<GenerationResult> {
281 let request = GenerationRequest {
282 prompt: task.prompt,
283 system_prompt: task.system_prompt,
284 temperature: None,
285 max_tokens: None,
286 };
287
288 let response = provider.generate(request).await?;
289
290 Ok(GenerationResult {
291 content: response.content,
292 source_index: task.source_index,
293 category: task.category,
294 input_tokens: response.input_tokens,
295 output_tokens: response.output_tokens,
296 })
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303 use crate::config::*;
304
305 fn test_config() -> SynthConfig {
306 SynthConfig {
307 name: "test".to_string(),
308 source: None,
309 provider: ProviderConfig::OpenAI {
310 model: "gpt-4o-mini".to_string(),
311 api_key: Some("test-key".to_string()),
312 base_url: None,
313 temperature: None,
314 max_tokens: None,
315 },
316 generation: GenerationConfig {
317 task: GenerationTask::Generate,
318 count: 10,
319 count_per_example: None,
320 concurrency: 2,
321 strategy: None,
322 strategy_config: Default::default(),
323 template: Some("Generate a {category} example".to_string()),
324 system_prompt: None,
325 categories: Some(vec!["A".to_string(), "B".to_string()]),
326 },
327 output: OutputConfig {
328 format: OutputFormat::Jsonl,
329 path: "./output.jsonl".into(),
330 batch_size: 100,
331 },
332 }
333 }
334
335 #[test]
336 fn test_build_generate_tasks() {
337 let config = test_config();
338 let engine = GenerationEngine::new(&config.provider, config.generation.clone()).unwrap();
339 let prompt_builder = engine.create_prompt_builder();
340
341 let tasks = engine.build_generate_tasks(&prompt_builder).unwrap();
342
343 assert_eq!(tasks.len(), 10);
344 let a_count = tasks.iter().filter(|t| t.category.as_deref() == Some("A")).count();
346 let b_count = tasks.iter().filter(|t| t.category.as_deref() == Some("B")).count();
347 assert_eq!(a_count, 5);
348 assert_eq!(b_count, 5);
349 }
350
351 #[test]
352 fn test_stats_tracking() {
353 let stats = GenerationStats::default();
354
355 stats.record_success(&GenerationResponse {
356 content: "test".to_string(),
357 input_tokens: 100,
358 output_tokens: 50,
359 });
360 stats.record_success(&GenerationResponse {
361 content: "test".to_string(),
362 input_tokens: 200,
363 output_tokens: 100,
364 });
365 stats.record_failure();
366
367 let snapshot = stats.snapshot();
368 assert_eq!(snapshot.completed, 2);
369 assert_eq!(snapshot.failed, 1);
370 assert_eq!(snapshot.total_input_tokens, 300);
371 assert_eq!(snapshot.total_output_tokens, 150);
372 }
373}