ds_r1_rs/training/
data.rs

1//! # Training Data
2//!
3//! Data structures and utilities for training data management.
4
5use rand::Rng;
6use serde::{Deserialize, Serialize};
7
8/// Type of problem for training
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
10pub enum ProblemType {
11    Math,
12    Code,
13    Logic,
14    General,
15}
16
17/// Individual training example
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct TrainingExample {
20    pub input: String,
21    pub target: String,
22    pub reasoning_chain: Option<Vec<String>>,
23    pub problem_type: ProblemType,
24}
25
26impl TrainingExample {
27    /// Create a new training example
28    pub fn new(input: String, target: String, problem_type: ProblemType) -> Self {
29        Self {
30            input,
31            target,
32            reasoning_chain: None,
33            problem_type,
34        }
35    }
36
37    /// Create a training example with reasoning chain
38    pub fn with_reasoning(
39        input: String,
40        target: String,
41        reasoning_chain: Vec<String>,
42        problem_type: ProblemType,
43    ) -> Self {
44        Self {
45            input,
46            target,
47            reasoning_chain: Some(reasoning_chain),
48            problem_type,
49        }
50    }
51}
52
53/// Batch of training examples
54#[derive(Debug, Clone)]
55pub struct TrainingBatch {
56    pub examples: Vec<TrainingExample>,
57    pub batch_size: usize,
58}
59
60impl TrainingBatch {
61    /// Create a new training batch
62    pub fn new(examples: Vec<TrainingExample>) -> Self {
63        let batch_size = examples.len();
64        Self {
65            examples,
66            batch_size,
67        }
68    }
69
70    /// Split examples by problem type
71    pub fn split_by_type(&self) -> std::collections::HashMap<ProblemType, Vec<&TrainingExample>> {
72        let mut map = std::collections::HashMap::new();
73
74        for example in &self.examples {
75            map.entry(example.problem_type.clone())
76                .or_insert_with(Vec::new)
77                .push(example);
78        }
79
80        map
81    }
82}
83
84/// Synthetic dataset generator for training
85pub struct SyntheticDataGenerator {
86    rng: rand::rngs::ThreadRng,
87}
88
89impl SyntheticDataGenerator {
90    /// Create a new synthetic data generator
91    pub fn new() -> Self {
92        Self { rng: rand::rng() }
93    }
94
95    /// Generate mathematical reasoning problems
96    pub fn generate_math_problems(&mut self, count: usize) -> Vec<TrainingExample> {
97        let mut examples = Vec::new();
98
99        for _ in 0..count {
100            let problem_type = self.rng.random_range(0..4);
101
102            let example = match problem_type {
103                0 => self.generate_addition_problem(),
104                1 => self.generate_subtraction_problem(),
105                2 => self.generate_multiplication_problem(),
106                3 => self.generate_simple_equation(),
107                _ => unreachable!(),
108            };
109
110            examples.push(example);
111        }
112
113        examples
114    }
115
116    /// Generate addition problems
117    fn generate_addition_problem(&mut self) -> TrainingExample {
118        let a = self.rng.random_range(1..100);
119        let b = self.rng.random_range(1..100);
120        let result = a + b;
121
122        let input = format!("What is {} + {}?", a, b);
123        let reasoning = vec![
124            format!("I need to add {} and {}", a, b),
125            format!("{} + {} = {}", a, b, result),
126        ];
127        let target = format!("{}", result);
128
129        TrainingExample::with_reasoning(input, target, reasoning, ProblemType::Math)
130    }
131
132    /// Generate subtraction problems
133    fn generate_subtraction_problem(&mut self) -> TrainingExample {
134        let a = self.rng.random_range(50..200);
135        let b = self.rng.random_range(1..a);
136        let result = a - b;
137
138        let input = format!("What is {} - {}?", a, b);
139        let reasoning = vec![
140            format!("I need to subtract {} from {}", b, a),
141            format!("{} - {} = {}", a, b, result),
142        ];
143        let target = format!("{}", result);
144
145        TrainingExample::with_reasoning(input, target, reasoning, ProblemType::Math)
146    }
147
148    /// Generate multiplication problems
149    fn generate_multiplication_problem(&mut self) -> TrainingExample {
150        let a = self.rng.random_range(2..20);
151        let b = self.rng.random_range(2..20);
152        let result = a * b;
153
154        let input = format!("What is {} × {}?", a, b);
155        let reasoning = vec![
156            format!("I need to multiply {} by {}", a, b),
157            format!("{} × {} = {}", a, b, result),
158        ];
159        let target = format!("{}", result);
160
161        TrainingExample::with_reasoning(input, target, reasoning, ProblemType::Math)
162    }
163
164    /// Generate simple linear equations
165    fn generate_simple_equation(&mut self) -> TrainingExample {
166        let x = self.rng.random_range(1..20);
167        let b = self.rng.random_range(1..50);
168        let result = x + b;
169
170        let input = format!("Solve for x: x + {} = {}", b, result);
171        let reasoning = vec![
172            format!("I need to solve x + {} = {}", b, result),
173            format!("Subtracting {} from both sides: x = {} - {}", b, result, b),
174            format!("Therefore: x = {}", x),
175        ];
176        let target = format!("x = {}", x);
177
178        TrainingExample::with_reasoning(input, target, reasoning, ProblemType::Math)
179    }
180
181    /// Generate code explanation examples
182    pub fn generate_code_examples(&mut self, count: usize) -> Vec<TrainingExample> {
183        let mut examples = Vec::new();
184
185        for _ in 0..count {
186            let example_type = self.rng.random_range(0..3);
187
188            let example = match example_type {
189                0 => self.generate_loop_explanation(),
190                1 => self.generate_function_explanation(),
191                2 => self.generate_conditional_explanation(),
192                _ => unreachable!(),
193            };
194
195            examples.push(example);
196        }
197
198        examples
199    }
200
201    /// Generate loop explanation
202    fn generate_loop_explanation(&mut self) -> TrainingExample {
203        let n = self.rng.random_range(3..10);
204        let code = format!("for i in range({}):\n    print(i)", n);
205
206        let input = format!("Explain what this code does:\n{}", code);
207        let reasoning = vec![
208            "This is a for loop in Python".to_string(),
209            format!("It iterates from 0 to {} (exclusive)", n),
210            "In each iteration, it prints the current value of i".to_string(),
211            format!("So it will print numbers 0, 1, 2, ..., {}", n - 1),
212        ];
213        let target = format!(
214            "This code prints numbers from 0 to {} using a for loop",
215            n - 1
216        );
217
218        TrainingExample::with_reasoning(input, target, reasoning, ProblemType::Code)
219    }
220
221    /// Generate function explanation
222    fn generate_function_explanation(&mut self) -> TrainingExample {
223        let code = "def add_numbers(a, b):\n    return a + b";
224
225        let input = format!("Explain what this function does:\n{}", code);
226        let reasoning = vec![
227            "This defines a function called 'add_numbers'".to_string(),
228            "It takes two parameters: 'a' and 'b'".to_string(),
229            "The function returns the sum of a and b".to_string(),
230            "This is a simple addition function".to_string(),
231        ];
232        let target = "This function takes two numbers and returns their sum".to_string();
233
234        TrainingExample::with_reasoning(input, target, reasoning, ProblemType::Code)
235    }
236
237    /// Generate conditional explanation
238    fn generate_conditional_explanation(&mut self) -> TrainingExample {
239        let threshold = self.rng.random_range(10..100);
240        let code = format!(
241            "if x > {}:\n    print('Large')\nelse:\n    print('Small')",
242            threshold
243        );
244
245        let input = format!("Explain what this code does:\n{}", code);
246        let reasoning = vec![
247            "This is an if-else conditional statement".to_string(),
248            format!("It checks if variable x is greater than {}", threshold),
249            "If x is greater, it prints 'Large'".to_string(),
250            "Otherwise, it prints 'Small'".to_string(),
251        ];
252        let target = format!(
253            "This code prints 'Large' if x > {}, otherwise prints 'Small'",
254            threshold
255        );
256
257        TrainingExample::with_reasoning(input, target, reasoning, ProblemType::Code)
258    }
259
260    /// Generate logical reasoning tasks
261    pub fn generate_logic_problems(&mut self, count: usize) -> Vec<TrainingExample> {
262        let mut examples = Vec::new();
263
264        for _ in 0..count {
265            let problem_type = self.rng.random_range(0..3);
266
267            let example = match problem_type {
268                0 => self.generate_syllogism(),
269                1 => self.generate_pattern_recognition(),
270                2 => self.generate_simple_deduction(),
271                _ => unreachable!(),
272            };
273
274            examples.push(example);
275        }
276
277        examples
278    }
279
280    /// Generate syllogism problems
281    fn generate_syllogism(&mut self) -> TrainingExample {
282        let animals = ["cats", "dogs", "birds", "fish"];
283        let properties = ["mammals", "vertebrates", "animals", "living things"];
284
285        let animal = animals[self.rng.random_range(0..animals.len())];
286        let property = properties[self.rng.random_range(0..properties.len())];
287
288        let input = format!(
289            "All {} are {}. Fluffy is a cat. Is Fluffy a {}?",
290            animal, property, property
291        );
292        let reasoning = vec![
293            format!("Given: All {} are {}", animal, property),
294            "Given: Fluffy is a cat".to_string(),
295            format!("Since cats are {}, and Fluffy is a cat", property),
296            format!("Therefore: Fluffy is a {}", property),
297        ];
298        let target = format!("Yes, Fluffy is a {}", property);
299
300        TrainingExample::with_reasoning(input, target, reasoning, ProblemType::Logic)
301    }
302
303    /// Generate pattern recognition
304    fn generate_pattern_recognition(&mut self) -> TrainingExample {
305        let start = self.rng.random_range(1..10);
306        let step = self.rng.random_range(2..5);
307        let sequence: Vec<i32> = (0..4).map(|i| start + i * step).collect();
308        let next = start + 4 * step;
309
310        let input = format!(
311            "What comes next in this sequence: {}, {}, {}, {}?",
312            sequence[0], sequence[1], sequence[2], sequence[3]
313        );
314        let reasoning = vec![
315            format!(
316                "Looking at the sequence: {}, {}, {}, {}",
317                sequence[0], sequence[1], sequence[2], sequence[3]
318            ),
319            format!("The difference between consecutive terms is {}", step),
320            format!("This is an arithmetic sequence with step {}", step),
321            format!(
322                "The next term would be {} + {} = {}",
323                sequence[3], step, next
324            ),
325        ];
326        let target = format!("{}", next);
327
328        TrainingExample::with_reasoning(input, target, reasoning, ProblemType::Logic)
329    }
330
331    /// Generate simple deduction
332    fn generate_simple_deduction(&mut self) -> TrainingExample {
333        let colors = ["red", "blue", "green", "yellow"];
334        let objects = ["car", "house", "ball", "book"];
335
336        let color = colors[self.rng.random_range(0..colors.len())];
337        let object = objects[self.rng.random_range(0..objects.len())];
338
339        let input = format!(
340            "If all {} things are expensive, and this {} is {}, is it expensive?",
341            color, object, color
342        );
343        let reasoning = vec![
344            format!("Given: All {} things are expensive", color),
345            format!("Given: This {} is {}", object, color),
346            format!(
347                "Since the {} is {}, and all {} things are expensive",
348                object, color, color
349            ),
350            format!("Therefore: This {} is expensive", object),
351        ];
352        let target = "Yes, it is expensive".to_string();
353
354        TrainingExample::with_reasoning(input, target, reasoning, ProblemType::Logic)
355    }
356
357    /// Generate a mixed dataset with all problem types
358    pub fn generate_mixed_dataset(&mut self, total_count: usize) -> Vec<TrainingExample> {
359        let math_count = total_count / 3;
360        let code_count = total_count / 3;
361        let logic_count = total_count - math_count - code_count;
362
363        let mut examples = Vec::new();
364        examples.extend(self.generate_math_problems(math_count));
365        examples.extend(self.generate_code_examples(code_count));
366        examples.extend(self.generate_logic_problems(logic_count));
367
368        // Shuffle the examples
369        use rand::seq::SliceRandom;
370        examples.shuffle(&mut self.rng);
371
372        examples
373    }
374}
375
376impl Default for SyntheticDataGenerator {
377    fn default() -> Self {
378        Self::new()
379    }
380}
381
382/// Data loader for batching training examples
383pub struct DataLoader {
384    examples: Vec<TrainingExample>,
385    batch_size: usize,
386    current_index: usize,
387    shuffle: bool,
388    rng: rand::rngs::ThreadRng,
389}
390
391impl DataLoader {
392    /// Create a new data loader
393    pub fn new(examples: Vec<TrainingExample>, batch_size: usize, shuffle: bool) -> Self {
394        Self {
395            examples,
396            batch_size,
397            current_index: 0,
398            shuffle,
399            rng: rand::rng(),
400        }
401    }
402
403    /// Reset the data loader to the beginning
404    pub fn reset(&mut self) {
405        self.current_index = 0;
406        if self.shuffle {
407            use rand::seq::SliceRandom;
408            self.examples.shuffle(&mut self.rng);
409        }
410    }
411
412    /// Get the next batch
413    pub fn next_batch(&mut self) -> Option<TrainingBatch> {
414        if self.current_index >= self.examples.len() {
415            return None;
416        }
417
418        let end_index = (self.current_index + self.batch_size).min(self.examples.len());
419        let batch_examples = self.examples[self.current_index..end_index].to_vec();
420        self.current_index = end_index;
421
422        Some(TrainingBatch::new(batch_examples))
423    }
424
425    /// Check if there are more batches
426    pub fn has_next(&self) -> bool {
427        self.current_index < self.examples.len()
428    }
429
430    /// Get total number of examples
431    pub fn len(&self) -> usize {
432        self.examples.len()
433    }
434
435    /// Check if the data loader is empty
436    pub fn is_empty(&self) -> bool {
437        self.examples.is_empty()
438    }
439
440    /// Get number of batches per epoch
441    pub fn batches_per_epoch(&self) -> usize {
442        (self.examples.len() + self.batch_size - 1) / self.batch_size
443    }
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449
450    #[test]
451    fn test_training_example_creation() {
452        let example =
453            TrainingExample::new("2 + 2 = ?".to_string(), "4".to_string(), ProblemType::Math);
454
455        assert_eq!(example.input, "2 + 2 = ?");
456        assert_eq!(example.target, "4");
457        assert!(matches!(example.problem_type, ProblemType::Math));
458        assert!(example.reasoning_chain.is_none());
459    }
460
461    #[test]
462    fn test_training_example_with_reasoning() {
463        let reasoning = vec!["I need to add 2 and 2".to_string(), "2 + 2 = 4".to_string()];
464
465        let example = TrainingExample::with_reasoning(
466            "2 + 2 = ?".to_string(),
467            "4".to_string(),
468            reasoning.clone(),
469            ProblemType::Math,
470        );
471
472        assert_eq!(example.reasoning_chain, Some(reasoning));
473    }
474
475    #[test]
476    fn test_training_batch() {
477        let examples = vec![
478            TrainingExample::new("2 + 2".to_string(), "4".to_string(), ProblemType::Math),
479            TrainingExample::new("3 * 3".to_string(), "9".to_string(), ProblemType::Math),
480        ];
481
482        let batch = TrainingBatch::new(examples);
483        assert_eq!(batch.batch_size, 2);
484
485        let by_type = batch.split_by_type();
486        assert_eq!(by_type.len(), 1);
487        assert!(by_type.contains_key(&ProblemType::Math));
488    }
489
490    #[test]
491    fn test_synthetic_data_generator() {
492        let mut generator = SyntheticDataGenerator::new();
493
494        // Test math problems generation
495        let math_problems = generator.generate_math_problems(5);
496        assert_eq!(math_problems.len(), 5);
497        for problem in &math_problems {
498            assert!(matches!(problem.problem_type, ProblemType::Math));
499            assert!(problem.reasoning_chain.is_some());
500        }
501
502        // Test code examples generation
503        let code_examples = generator.generate_code_examples(3);
504        assert_eq!(code_examples.len(), 3);
505        for example in &code_examples {
506            assert!(matches!(example.problem_type, ProblemType::Code));
507            assert!(example.reasoning_chain.is_some());
508        }
509
510        // Test logic problems generation
511        let logic_problems = generator.generate_logic_problems(4);
512        assert_eq!(logic_problems.len(), 4);
513        for problem in &logic_problems {
514            assert!(matches!(problem.problem_type, ProblemType::Logic));
515            assert!(problem.reasoning_chain.is_some());
516        }
517    }
518
519    #[test]
520    fn test_mixed_dataset_generation() {
521        let mut generator = SyntheticDataGenerator::new();
522        let mixed_dataset = generator.generate_mixed_dataset(12);
523
524        assert_eq!(mixed_dataset.len(), 12);
525
526        // Check that we have different problem types
527        let mut has_math = false;
528        let mut has_code = false;
529        let mut has_logic = false;
530
531        for example in &mixed_dataset {
532            match example.problem_type {
533                ProblemType::Math => has_math = true,
534                ProblemType::Code => has_code = true,
535                ProblemType::Logic => has_logic = true,
536                _ => {}
537            }
538        }
539
540        assert!(has_math);
541        assert!(has_code);
542        assert!(has_logic);
543    }
544
545    #[test]
546    fn test_data_loader() {
547        let examples = vec![
548            TrainingExample::new("1".to_string(), "a".to_string(), ProblemType::Math),
549            TrainingExample::new("2".to_string(), "b".to_string(), ProblemType::Math),
550            TrainingExample::new("3".to_string(), "c".to_string(), ProblemType::Math),
551            TrainingExample::new("4".to_string(), "d".to_string(), ProblemType::Math),
552            TrainingExample::new("5".to_string(), "e".to_string(), ProblemType::Math),
553        ];
554
555        let mut loader = DataLoader::new(examples, 2, false);
556
557        assert_eq!(loader.len(), 5);
558        assert_eq!(loader.batches_per_epoch(), 3);
559        assert!(loader.has_next());
560
561        // Get first batch
562        let batch1 = loader.next_batch().unwrap();
563        assert_eq!(batch1.batch_size, 2);
564
565        // Get second batch
566        let batch2 = loader.next_batch().unwrap();
567        assert_eq!(batch2.batch_size, 2);
568
569        // Get third batch (partial)
570        let batch3 = loader.next_batch().unwrap();
571        assert_eq!(batch3.batch_size, 1);
572
573        // No more batches
574        assert!(!loader.has_next());
575        assert!(loader.next_batch().is_none());
576    }
577
578    #[test]
579    fn test_data_loader_reset() {
580        let examples = vec![
581            TrainingExample::new("1".to_string(), "a".to_string(), ProblemType::Math),
582            TrainingExample::new("2".to_string(), "b".to_string(), ProblemType::Math),
583        ];
584
585        let mut loader = DataLoader::new(examples, 1, false);
586
587        // Consume all batches
588        let _batch1 = loader.next_batch().unwrap();
589        let _batch2 = loader.next_batch().unwrap();
590        assert!(!loader.has_next());
591
592        // Reset and check we can iterate again
593        loader.reset();
594        assert!(loader.has_next());
595        let _batch1_again = loader.next_batch().unwrap();
596        assert!(loader.has_next());
597    }
598
599    #[test]
600    fn test_empty_data_loader() {
601        let loader = DataLoader::new(vec![], 5, false);
602        assert!(loader.is_empty());
603        assert_eq!(loader.len(), 0);
604        assert_eq!(loader.batches_per_epoch(), 0);
605    }
606}