1use crate::config::{GenerationConfig, GenerationStrategy, GenerationTask, ProviderConfig, SourceConfig, SynthConfig};
2use crate::datasets::{DataSource, HuggingFaceSource, LocalSource, Record};
3use crate::providers::{create_provider, GenerationRequest, GenerationResponse, LLMProvider};
4use crate::{Error, Result};
5
6use super::prompt::{default_template_for_augment, default_template_for_generate, PromptBuilder};
7
8use futures::stream::{self, StreamExt};
9use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
10use std::sync::Arc;
11use tokio::sync::Mutex;
12
13#[derive(Debug, Default)]
14pub struct GenerationStats {
15 pub completed: AtomicUsize,
16 pub failed: AtomicUsize,
17 pub total_input_tokens: AtomicU64,
18 pub total_output_tokens: AtomicU64,
19}
20
21impl GenerationStats {
22 pub fn record_success(&self, response: &GenerationResponse) {
23 self.completed.fetch_add(1, Ordering::Relaxed);
24 self.total_input_tokens.fetch_add(response.input_tokens as u64, Ordering::Relaxed);
25 self.total_output_tokens.fetch_add(response.output_tokens as u64, Ordering::Relaxed);
26 }
27
28 pub fn record_failure(&self) {
29 self.failed.fetch_add(1, Ordering::Relaxed);
30 }
31
32 pub fn snapshot(&self) -> StatsSnapshot {
33 StatsSnapshot {
34 completed: self.completed.load(Ordering::Relaxed),
35 failed: self.failed.load(Ordering::Relaxed),
36 total_input_tokens: self.total_input_tokens.load(Ordering::Relaxed),
37 total_output_tokens: self.total_output_tokens.load(Ordering::Relaxed),
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
43pub struct StatsSnapshot {
44 pub completed: usize,
45 pub failed: usize,
46 pub total_input_tokens: u64,
47 pub total_output_tokens: u64,
48}
49
50#[derive(Debug, Clone)]
51pub struct GenerationResult {
52 pub content: String,
53 pub source_index: Option<usize>,
54 pub category: Option<String>,
55 pub input_tokens: u32,
56 pub output_tokens: u32,
57}
58
59impl GenerationResult {
60 pub fn parse_json(&self) -> Result<serde_json::Value> {
62 let content = self.extract_json_content();
63 serde_json::from_str(&content).map_err(|e| Error::Json(e))
64 }
65
66 pub fn parse_json_as<T: serde::de::DeserializeOwned>(&self) -> Result<T> {
68 let content = self.extract_json_content();
69 serde_json::from_str(&content).map_err(|e| Error::Json(e))
70 }
71
72 fn extract_json_content(&self) -> String {
73 let content = self.content.trim();
74
75 if let Some(start) = content.find("```json") {
77 if let Some(end) = content[start + 7..].find("```") {
78 return content[start + 7..start + 7 + end].trim().to_string();
79 }
80 }
81
82 if let Some(start) = content.find("```") {
84 if let Some(end) = content[start + 3..].find("```") {
85 let inner = content[start + 3..start + 3 + end].trim();
86 if let Some(newline) = inner.find('\n') {
87 return inner[newline + 1..].trim().to_string();
88 }
89 return inner.to_string();
90 }
91 }
92
93 content.to_string()
94 }
95}
96
97struct GenerationTask_ {
98 prompt: String,
99 system_prompt: Option<String>,
100 source_index: Option<usize>,
101 category: Option<String>,
102}
103
104pub struct GenerationEngine {
105 provider: Box<dyn LLMProvider>,
106 config: GenerationConfig,
107 stats: Arc<GenerationStats>,
108}
109
110impl GenerationEngine {
111 pub fn new(provider_config: &ProviderConfig, generation_config: GenerationConfig) -> Result<Self> {
112 let provider = create_provider(provider_config)?;
113 Ok(Self {
114 provider,
115 config: generation_config,
116 stats: Arc::new(GenerationStats::default()),
117 })
118 }
119
120 pub fn stats(&self) -> Arc<GenerationStats> {
121 Arc::clone(&self.stats)
122 }
123
124 pub fn provider(&self) -> &dyn LLMProvider {
125 self.provider.as_ref()
126 }
127
128 pub async fn run(&self, config: &SynthConfig) -> Result<Vec<GenerationResult>> {
130 let tasks = self.build_tasks(config).await?;
131 let results = Arc::new(Mutex::new(Vec::with_capacity(tasks.len())));
132
133 stream::iter(tasks)
134 .map(|task| {
135 let provider = &self.provider;
136 let stats = Arc::clone(&self.stats);
137 let results = Arc::clone(&results);
138 async move {
139 match self.execute_task(provider.as_ref(), task).await {
140 Ok(result) => {
141 stats.record_success(&GenerationResponse {
142 content: result.content.clone(),
143 input_tokens: result.input_tokens,
144 output_tokens: result.output_tokens,
145 });
146 results.lock().await.push(result);
147 }
148 Err(e) => {
149 stats.record_failure();
150 tracing::warn!("Generation failed: {}", e);
151 }
152 }
153 }
154 })
155 .buffer_unordered(self.config.concurrency)
156 .collect::<Vec<_>>()
157 .await;
158
159 let results = Arc::try_unwrap(results)
160 .map_err(|_| Error::Provider("Failed to unwrap results".to_string()))?
161 .into_inner();
162
163 Ok(results)
164 }
165
166 pub async fn run_with_callback<F>(&self, config: &SynthConfig, on_result: F) -> Result<()>
168 where
169 F: FnMut(GenerationResult) + Send,
170 {
171 let tasks = self.build_tasks(config).await?;
172 let callback = Arc::new(Mutex::new(on_result));
173
174 stream::iter(tasks)
175 .map(|task| {
176 let provider = &self.provider;
177 let stats = Arc::clone(&self.stats);
178 let callback = Arc::clone(&callback);
179 async move {
180 match self.execute_task(provider.as_ref(), task).await {
181 Ok(result) => {
182 stats.record_success(&GenerationResponse {
183 content: result.content.clone(),
184 input_tokens: result.input_tokens,
185 output_tokens: result.output_tokens,
186 });
187 callback.lock().await(result);
188 }
189 Err(e) => {
190 stats.record_failure();
191 tracing::warn!("Generation failed: {}", e);
192 }
193 }
194 }
195 })
196 .buffer_unordered(self.config.concurrency)
197 .collect::<Vec<_>>()
198 .await;
199
200 Ok(())
201 }
202
203 async fn build_tasks(&self, config: &SynthConfig) -> Result<Vec<GenerationTask_>> {
204 let prompt_builder = self.create_prompt_builder();
205
206 match &config.generation.task {
207 GenerationTask::Generate => self.build_generate_tasks(&prompt_builder),
208 GenerationTask::Augment => self.build_augment_tasks(config, &prompt_builder).await,
209 }
210 }
211
212 fn build_generate_tasks(&self, prompt_builder: &PromptBuilder) -> Result<Vec<GenerationTask_>> {
213 let categories = self.config.categories.as_ref();
214 let count = self.config.count;
215 let system_prompt = Some(prompt_builder.system_prompt().to_string());
216
217 let mut tasks = Vec::with_capacity(count);
218
219 if let Some(cats) = categories {
220 let per_category = count / cats.len();
221 let remainder = count % cats.len();
222
223 for (cat_idx, category) in cats.iter().enumerate() {
224 let cat_count = per_category + if cat_idx < remainder { 1 } else { 0 };
225 for i in 0..cat_count {
226 tasks.push(GenerationTask_ {
227 prompt: prompt_builder.build_for_category(category, i),
228 system_prompt: system_prompt.clone(),
229 source_index: None,
230 category: Some(category.clone()),
231 });
232 }
233 }
234 } else {
235 for i in 0..count {
236 tasks.push(GenerationTask_ {
237 prompt: prompt_builder.build_for_category("default", i),
238 system_prompt: system_prompt.clone(),
239 source_index: None,
240 category: None,
241 });
242 }
243 }
244
245 Ok(tasks)
246 }
247
248 async fn build_augment_tasks(&self, config: &SynthConfig, prompt_builder: &PromptBuilder) -> Result<Vec<GenerationTask_>> {
249 let source_config = config.source.as_ref()
250 .ok_or_else(|| Error::Config("Augment task requires a source configuration".to_string()))?;
251
252 let records = self.load_source_data(source_config.clone()).await?;
253 let count_per = self.config.count_per_example.unwrap_or(1);
254 let system_prompt = Some(prompt_builder.system_prompt().to_string());
255
256 let mut tasks = Vec::with_capacity(records.len() * count_per);
257
258 for record in &records {
259 for _ in 0..count_per {
260 tasks.push(GenerationTask_ {
261 prompt: prompt_builder.build_for_record(record),
262 system_prompt: system_prompt.clone(),
263 source_index: Some(record.index),
264 category: None,
265 });
266 }
267 }
268
269 Ok(tasks)
270 }
271
272 async fn load_source_data(&self, source_config: SourceConfig) -> Result<Vec<Record>> {
273 tokio::task::spawn_blocking(move || {
275 match source_config {
276 SourceConfig::HuggingFace { dataset, subset, split, sample, columns } => {
277 let mut source = HuggingFaceSource::new(
278 dataset,
279 subset,
280 split,
281 columns,
282 )?;
283 source.load(sample)
284 }
285 SourceConfig::Local { path, format, sample } => {
286 let mut source = LocalSource::new(path, format)?;
287 source.load(sample)
288 }
289 }
290 })
291 .await
292 .map_err(|e| Error::Dataset(format!("Task join error: {}", e)))?
293 }
294
295 fn create_prompt_builder(&self) -> PromptBuilder {
296 let is_augment = matches!(&self.config.task, GenerationTask::Augment);
297
298 let template = self.config.template.clone().unwrap_or_else(|| {
299 match &self.config.task {
300 GenerationTask::Generate => default_template_for_generate(),
301 GenerationTask::Augment => {
302 let strategy = self.config.strategy.as_ref()
303 .map(|s| match s {
304 GenerationStrategy::Paraphrase => "paraphrase",
305 GenerationStrategy::StyleTransfer => "style_transfer",
306 GenerationStrategy::BackTranslation => "back_translation",
307 GenerationStrategy::Custom => "custom",
308 })
309 .unwrap_or("paraphrase");
310 default_template_for_augment(strategy)
311 }
312 }
313 });
314
315 PromptBuilder::new(template, self.config.system_prompt.clone(), is_augment)
316 }
317
318 async fn execute_task(&self, provider: &dyn LLMProvider, task: GenerationTask_) -> Result<GenerationResult> {
319 let request = GenerationRequest {
320 prompt: task.prompt,
321 system_prompt: task.system_prompt,
322 temperature: None,
323 max_tokens: None,
324 };
325
326 let response = provider.generate(request).await?;
327
328 Ok(GenerationResult {
329 content: response.content,
330 source_index: task.source_index,
331 category: task.category,
332 input_tokens: response.input_tokens,
333 output_tokens: response.output_tokens,
334 })
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341 use crate::config::*;
342
343 fn test_config() -> SynthConfig {
344 SynthConfig {
345 name: "test".to_string(),
346 source: None,
347 provider: ProviderConfig::OpenAI {
348 model: "gpt-4o-mini".to_string(),
349 api_key: Some("test-key".to_string()),
350 base_url: None,
351 temperature: None,
352 max_tokens: None,
353 },
354 generation: GenerationConfig {
355 task: GenerationTask::Generate,
356 count: 10,
357 count_per_example: None,
358 concurrency: 2,
359 strategy: None,
360 strategy_config: Default::default(),
361 template: Some("Generate a {category} example".to_string()),
362 system_prompt: None,
363 categories: Some(vec!["A".to_string(), "B".to_string()]),
364 },
365 output: OutputConfig {
366 format: OutputFormat::Jsonl,
367 path: "./output.jsonl".into(),
368 batch_size: 100,
369 },
370 validation: None,
371 hub: None,
372 }
373 }
374
375 #[test]
376 fn test_build_generate_tasks() {
377 let config = test_config();
378 let engine = GenerationEngine::new(&config.provider, config.generation.clone()).unwrap();
379 let prompt_builder = engine.create_prompt_builder();
380
381 let tasks = engine.build_generate_tasks(&prompt_builder).unwrap();
382
383 assert_eq!(tasks.len(), 10);
384 let a_count = tasks.iter().filter(|t| t.category.as_deref() == Some("A")).count();
386 let b_count = tasks.iter().filter(|t| t.category.as_deref() == Some("B")).count();
387 assert_eq!(a_count, 5);
388 assert_eq!(b_count, 5);
389 }
390
391 #[test]
392 fn test_stats_tracking() {
393 let stats = GenerationStats::default();
394
395 stats.record_success(&GenerationResponse {
396 content: "test".to_string(),
397 input_tokens: 100,
398 output_tokens: 50,
399 });
400 stats.record_success(&GenerationResponse {
401 content: "test".to_string(),
402 input_tokens: 200,
403 output_tokens: 100,
404 });
405 stats.record_failure();
406
407 let snapshot = stats.snapshot();
408 assert_eq!(snapshot.completed, 2);
409 assert_eq!(snapshot.failed, 1);
410 assert_eq!(snapshot.total_input_tokens, 300);
411 assert_eq!(snapshot.total_output_tokens, 150);
412 }
413}