1use rand::Rng;
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
10pub enum ProblemType {
11 Math,
12 Code,
13 Logic,
14 General,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct TrainingExample {
20 pub input: String,
21 pub target: String,
22 pub reasoning_chain: Option<Vec<String>>,
23 pub problem_type: ProblemType,
24}
25
26impl TrainingExample {
27 pub fn new(input: String, target: String, problem_type: ProblemType) -> Self {
29 Self {
30 input,
31 target,
32 reasoning_chain: None,
33 problem_type,
34 }
35 }
36
37 pub fn with_reasoning(
39 input: String,
40 target: String,
41 reasoning_chain: Vec<String>,
42 problem_type: ProblemType,
43 ) -> Self {
44 Self {
45 input,
46 target,
47 reasoning_chain: Some(reasoning_chain),
48 problem_type,
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
55pub struct TrainingBatch {
56 pub examples: Vec<TrainingExample>,
57 pub batch_size: usize,
58}
59
60impl TrainingBatch {
61 pub fn new(examples: Vec<TrainingExample>) -> Self {
63 let batch_size = examples.len();
64 Self {
65 examples,
66 batch_size,
67 }
68 }
69
70 pub fn split_by_type(&self) -> std::collections::HashMap<ProblemType, Vec<&TrainingExample>> {
72 let mut map = std::collections::HashMap::new();
73
74 for example in &self.examples {
75 map.entry(example.problem_type.clone())
76 .or_insert_with(Vec::new)
77 .push(example);
78 }
79
80 map
81 }
82}
83
84pub struct SyntheticDataGenerator {
86 rng: rand::rngs::ThreadRng,
87}
88
89impl SyntheticDataGenerator {
90 pub fn new() -> Self {
92 Self { rng: rand::rng() }
93 }
94
95 pub fn generate_math_problems(&mut self, count: usize) -> Vec<TrainingExample> {
97 let mut examples = Vec::new();
98
99 for _ in 0..count {
100 let problem_type = self.rng.random_range(0..4);
101
102 let example = match problem_type {
103 0 => self.generate_addition_problem(),
104 1 => self.generate_subtraction_problem(),
105 2 => self.generate_multiplication_problem(),
106 3 => self.generate_simple_equation(),
107 _ => unreachable!(),
108 };
109
110 examples.push(example);
111 }
112
113 examples
114 }
115
116 fn generate_addition_problem(&mut self) -> TrainingExample {
118 let a = self.rng.random_range(1..100);
119 let b = self.rng.random_range(1..100);
120 let result = a + b;
121
122 let input = format!("What is {} + {}?", a, b);
123 let reasoning = vec![
124 format!("I need to add {} and {}", a, b),
125 format!("{} + {} = {}", a, b, result),
126 ];
127 let target = format!("{}", result);
128
129 TrainingExample::with_reasoning(input, target, reasoning, ProblemType::Math)
130 }
131
132 fn generate_subtraction_problem(&mut self) -> TrainingExample {
134 let a = self.rng.random_range(50..200);
135 let b = self.rng.random_range(1..a);
136 let result = a - b;
137
138 let input = format!("What is {} - {}?", a, b);
139 let reasoning = vec![
140 format!("I need to subtract {} from {}", b, a),
141 format!("{} - {} = {}", a, b, result),
142 ];
143 let target = format!("{}", result);
144
145 TrainingExample::with_reasoning(input, target, reasoning, ProblemType::Math)
146 }
147
148 fn generate_multiplication_problem(&mut self) -> TrainingExample {
150 let a = self.rng.random_range(2..20);
151 let b = self.rng.random_range(2..20);
152 let result = a * b;
153
154 let input = format!("What is {} × {}?", a, b);
155 let reasoning = vec![
156 format!("I need to multiply {} by {}", a, b),
157 format!("{} × {} = {}", a, b, result),
158 ];
159 let target = format!("{}", result);
160
161 TrainingExample::with_reasoning(input, target, reasoning, ProblemType::Math)
162 }
163
164 fn generate_simple_equation(&mut self) -> TrainingExample {
166 let x = self.rng.random_range(1..20);
167 let b = self.rng.random_range(1..50);
168 let result = x + b;
169
170 let input = format!("Solve for x: x + {} = {}", b, result);
171 let reasoning = vec![
172 format!("I need to solve x + {} = {}", b, result),
173 format!("Subtracting {} from both sides: x = {} - {}", b, result, b),
174 format!("Therefore: x = {}", x),
175 ];
176 let target = format!("x = {}", x);
177
178 TrainingExample::with_reasoning(input, target, reasoning, ProblemType::Math)
179 }
180
181 pub fn generate_code_examples(&mut self, count: usize) -> Vec<TrainingExample> {
183 let mut examples = Vec::new();
184
185 for _ in 0..count {
186 let example_type = self.rng.random_range(0..3);
187
188 let example = match example_type {
189 0 => self.generate_loop_explanation(),
190 1 => self.generate_function_explanation(),
191 2 => self.generate_conditional_explanation(),
192 _ => unreachable!(),
193 };
194
195 examples.push(example);
196 }
197
198 examples
199 }
200
201 fn generate_loop_explanation(&mut self) -> TrainingExample {
203 let n = self.rng.random_range(3..10);
204 let code = format!("for i in range({}):\n print(i)", n);
205
206 let input = format!("Explain what this code does:\n{}", code);
207 let reasoning = vec![
208 "This is a for loop in Python".to_string(),
209 format!("It iterates from 0 to {} (exclusive)", n),
210 "In each iteration, it prints the current value of i".to_string(),
211 format!("So it will print numbers 0, 1, 2, ..., {}", n - 1),
212 ];
213 let target = format!(
214 "This code prints numbers from 0 to {} using a for loop",
215 n - 1
216 );
217
218 TrainingExample::with_reasoning(input, target, reasoning, ProblemType::Code)
219 }
220
221 fn generate_function_explanation(&mut self) -> TrainingExample {
223 let code = "def add_numbers(a, b):\n return a + b";
224
225 let input = format!("Explain what this function does:\n{}", code);
226 let reasoning = vec![
227 "This defines a function called 'add_numbers'".to_string(),
228 "It takes two parameters: 'a' and 'b'".to_string(),
229 "The function returns the sum of a and b".to_string(),
230 "This is a simple addition function".to_string(),
231 ];
232 let target = "This function takes two numbers and returns their sum".to_string();
233
234 TrainingExample::with_reasoning(input, target, reasoning, ProblemType::Code)
235 }
236
237 fn generate_conditional_explanation(&mut self) -> TrainingExample {
239 let threshold = self.rng.random_range(10..100);
240 let code = format!(
241 "if x > {}:\n print('Large')\nelse:\n print('Small')",
242 threshold
243 );
244
245 let input = format!("Explain what this code does:\n{}", code);
246 let reasoning = vec![
247 "This is an if-else conditional statement".to_string(),
248 format!("It checks if variable x is greater than {}", threshold),
249 "If x is greater, it prints 'Large'".to_string(),
250 "Otherwise, it prints 'Small'".to_string(),
251 ];
252 let target = format!(
253 "This code prints 'Large' if x > {}, otherwise prints 'Small'",
254 threshold
255 );
256
257 TrainingExample::with_reasoning(input, target, reasoning, ProblemType::Code)
258 }
259
260 pub fn generate_logic_problems(&mut self, count: usize) -> Vec<TrainingExample> {
262 let mut examples = Vec::new();
263
264 for _ in 0..count {
265 let problem_type = self.rng.random_range(0..3);
266
267 let example = match problem_type {
268 0 => self.generate_syllogism(),
269 1 => self.generate_pattern_recognition(),
270 2 => self.generate_simple_deduction(),
271 _ => unreachable!(),
272 };
273
274 examples.push(example);
275 }
276
277 examples
278 }
279
280 fn generate_syllogism(&mut self) -> TrainingExample {
282 let animals = ["cats", "dogs", "birds", "fish"];
283 let properties = ["mammals", "vertebrates", "animals", "living things"];
284
285 let animal = animals[self.rng.random_range(0..animals.len())];
286 let property = properties[self.rng.random_range(0..properties.len())];
287
288 let input = format!(
289 "All {} are {}. Fluffy is a cat. Is Fluffy a {}?",
290 animal, property, property
291 );
292 let reasoning = vec![
293 format!("Given: All {} are {}", animal, property),
294 "Given: Fluffy is a cat".to_string(),
295 format!("Since cats are {}, and Fluffy is a cat", property),
296 format!("Therefore: Fluffy is a {}", property),
297 ];
298 let target = format!("Yes, Fluffy is a {}", property);
299
300 TrainingExample::with_reasoning(input, target, reasoning, ProblemType::Logic)
301 }
302
303 fn generate_pattern_recognition(&mut self) -> TrainingExample {
305 let start = self.rng.random_range(1..10);
306 let step = self.rng.random_range(2..5);
307 let sequence: Vec<i32> = (0..4).map(|i| start + i * step).collect();
308 let next = start + 4 * step;
309
310 let input = format!(
311 "What comes next in this sequence: {}, {}, {}, {}?",
312 sequence[0], sequence[1], sequence[2], sequence[3]
313 );
314 let reasoning = vec![
315 format!(
316 "Looking at the sequence: {}, {}, {}, {}",
317 sequence[0], sequence[1], sequence[2], sequence[3]
318 ),
319 format!("The difference between consecutive terms is {}", step),
320 format!("This is an arithmetic sequence with step {}", step),
321 format!(
322 "The next term would be {} + {} = {}",
323 sequence[3], step, next
324 ),
325 ];
326 let target = format!("{}", next);
327
328 TrainingExample::with_reasoning(input, target, reasoning, ProblemType::Logic)
329 }
330
331 fn generate_simple_deduction(&mut self) -> TrainingExample {
333 let colors = ["red", "blue", "green", "yellow"];
334 let objects = ["car", "house", "ball", "book"];
335
336 let color = colors[self.rng.random_range(0..colors.len())];
337 let object = objects[self.rng.random_range(0..objects.len())];
338
339 let input = format!(
340 "If all {} things are expensive, and this {} is {}, is it expensive?",
341 color, object, color
342 );
343 let reasoning = vec![
344 format!("Given: All {} things are expensive", color),
345 format!("Given: This {} is {}", object, color),
346 format!(
347 "Since the {} is {}, and all {} things are expensive",
348 object, color, color
349 ),
350 format!("Therefore: This {} is expensive", object),
351 ];
352 let target = "Yes, it is expensive".to_string();
353
354 TrainingExample::with_reasoning(input, target, reasoning, ProblemType::Logic)
355 }
356
357 pub fn generate_mixed_dataset(&mut self, total_count: usize) -> Vec<TrainingExample> {
359 let math_count = total_count / 3;
360 let code_count = total_count / 3;
361 let logic_count = total_count - math_count - code_count;
362
363 let mut examples = Vec::new();
364 examples.extend(self.generate_math_problems(math_count));
365 examples.extend(self.generate_code_examples(code_count));
366 examples.extend(self.generate_logic_problems(logic_count));
367
368 use rand::seq::SliceRandom;
370 examples.shuffle(&mut self.rng);
371
372 examples
373 }
374}
375
376impl Default for SyntheticDataGenerator {
377 fn default() -> Self {
378 Self::new()
379 }
380}
381
382pub struct DataLoader {
384 examples: Vec<TrainingExample>,
385 batch_size: usize,
386 current_index: usize,
387 shuffle: bool,
388 rng: rand::rngs::ThreadRng,
389}
390
391impl DataLoader {
392 pub fn new(examples: Vec<TrainingExample>, batch_size: usize, shuffle: bool) -> Self {
394 Self {
395 examples,
396 batch_size,
397 current_index: 0,
398 shuffle,
399 rng: rand::rng(),
400 }
401 }
402
403 pub fn reset(&mut self) {
405 self.current_index = 0;
406 if self.shuffle {
407 use rand::seq::SliceRandom;
408 self.examples.shuffle(&mut self.rng);
409 }
410 }
411
412 pub fn next_batch(&mut self) -> Option<TrainingBatch> {
414 if self.current_index >= self.examples.len() {
415 return None;
416 }
417
418 let end_index = (self.current_index + self.batch_size).min(self.examples.len());
419 let batch_examples = self.examples[self.current_index..end_index].to_vec();
420 self.current_index = end_index;
421
422 Some(TrainingBatch::new(batch_examples))
423 }
424
425 pub fn has_next(&self) -> bool {
427 self.current_index < self.examples.len()
428 }
429
430 pub fn len(&self) -> usize {
432 self.examples.len()
433 }
434
435 pub fn is_empty(&self) -> bool {
437 self.examples.is_empty()
438 }
439
440 pub fn batches_per_epoch(&self) -> usize {
442 (self.examples.len() + self.batch_size - 1) / self.batch_size
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449
450 #[test]
451 fn test_training_example_creation() {
452 let example =
453 TrainingExample::new("2 + 2 = ?".to_string(), "4".to_string(), ProblemType::Math);
454
455 assert_eq!(example.input, "2 + 2 = ?");
456 assert_eq!(example.target, "4");
457 assert!(matches!(example.problem_type, ProblemType::Math));
458 assert!(example.reasoning_chain.is_none());
459 }
460
461 #[test]
462 fn test_training_example_with_reasoning() {
463 let reasoning = vec!["I need to add 2 and 2".to_string(), "2 + 2 = 4".to_string()];
464
465 let example = TrainingExample::with_reasoning(
466 "2 + 2 = ?".to_string(),
467 "4".to_string(),
468 reasoning.clone(),
469 ProblemType::Math,
470 );
471
472 assert_eq!(example.reasoning_chain, Some(reasoning));
473 }
474
475 #[test]
476 fn test_training_batch() {
477 let examples = vec![
478 TrainingExample::new("2 + 2".to_string(), "4".to_string(), ProblemType::Math),
479 TrainingExample::new("3 * 3".to_string(), "9".to_string(), ProblemType::Math),
480 ];
481
482 let batch = TrainingBatch::new(examples);
483 assert_eq!(batch.batch_size, 2);
484
485 let by_type = batch.split_by_type();
486 assert_eq!(by_type.len(), 1);
487 assert!(by_type.contains_key(&ProblemType::Math));
488 }
489
490 #[test]
491 fn test_synthetic_data_generator() {
492 let mut generator = SyntheticDataGenerator::new();
493
494 let math_problems = generator.generate_math_problems(5);
496 assert_eq!(math_problems.len(), 5);
497 for problem in &math_problems {
498 assert!(matches!(problem.problem_type, ProblemType::Math));
499 assert!(problem.reasoning_chain.is_some());
500 }
501
502 let code_examples = generator.generate_code_examples(3);
504 assert_eq!(code_examples.len(), 3);
505 for example in &code_examples {
506 assert!(matches!(example.problem_type, ProblemType::Code));
507 assert!(example.reasoning_chain.is_some());
508 }
509
510 let logic_problems = generator.generate_logic_problems(4);
512 assert_eq!(logic_problems.len(), 4);
513 for problem in &logic_problems {
514 assert!(matches!(problem.problem_type, ProblemType::Logic));
515 assert!(problem.reasoning_chain.is_some());
516 }
517 }
518
519 #[test]
520 fn test_mixed_dataset_generation() {
521 let mut generator = SyntheticDataGenerator::new();
522 let mixed_dataset = generator.generate_mixed_dataset(12);
523
524 assert_eq!(mixed_dataset.len(), 12);
525
526 let mut has_math = false;
528 let mut has_code = false;
529 let mut has_logic = false;
530
531 for example in &mixed_dataset {
532 match example.problem_type {
533 ProblemType::Math => has_math = true,
534 ProblemType::Code => has_code = true,
535 ProblemType::Logic => has_logic = true,
536 _ => {}
537 }
538 }
539
540 assert!(has_math);
541 assert!(has_code);
542 assert!(has_logic);
543 }
544
545 #[test]
546 fn test_data_loader() {
547 let examples = vec![
548 TrainingExample::new("1".to_string(), "a".to_string(), ProblemType::Math),
549 TrainingExample::new("2".to_string(), "b".to_string(), ProblemType::Math),
550 TrainingExample::new("3".to_string(), "c".to_string(), ProblemType::Math),
551 TrainingExample::new("4".to_string(), "d".to_string(), ProblemType::Math),
552 TrainingExample::new("5".to_string(), "e".to_string(), ProblemType::Math),
553 ];
554
555 let mut loader = DataLoader::new(examples, 2, false);
556
557 assert_eq!(loader.len(), 5);
558 assert_eq!(loader.batches_per_epoch(), 3);
559 assert!(loader.has_next());
560
561 let batch1 = loader.next_batch().unwrap();
563 assert_eq!(batch1.batch_size, 2);
564
565 let batch2 = loader.next_batch().unwrap();
567 assert_eq!(batch2.batch_size, 2);
568
569 let batch3 = loader.next_batch().unwrap();
571 assert_eq!(batch3.batch_size, 1);
572
573 assert!(!loader.has_next());
575 assert!(loader.next_batch().is_none());
576 }
577
578 #[test]
579 fn test_data_loader_reset() {
580 let examples = vec![
581 TrainingExample::new("1".to_string(), "a".to_string(), ProblemType::Math),
582 TrainingExample::new("2".to_string(), "b".to_string(), ProblemType::Math),
583 ];
584
585 let mut loader = DataLoader::new(examples, 1, false);
586
587 let _batch1 = loader.next_batch().unwrap();
589 let _batch2 = loader.next_batch().unwrap();
590 assert!(!loader.has_next());
591
592 loader.reset();
594 assert!(loader.has_next());
595 let _batch1_again = loader.next_batch().unwrap();
596 assert!(loader.has_next());
597 }
598
599 #[test]
600 fn test_empty_data_loader() {
601 let loader = DataLoader::new(vec![], 5, false);
602 assert!(loader.is_empty());
603 assert_eq!(loader.len(), 0);
604 assert_eq!(loader.batches_per_epoch(), 0);
605 }
606}