Skip to main content

axonml_text/datasets/
mod.rs

1//! Text Datasets
2//!
3//! Provides dataset implementations for common NLP tasks.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use crate::tokenizer::Tokenizer;
9use crate::vocab::Vocab;
10use axonml_data::Dataset;
11use axonml_tensor::Tensor;
12
13// =============================================================================
14// TextDataset
15// =============================================================================
16
17/// A dataset of text samples with labels.
18pub struct TextDataset {
19    texts: Vec<String>,
20    labels: Vec<usize>,
21    vocab: Vocab,
22    max_length: usize,
23    num_classes: usize,
24}
25
26impl TextDataset {
27    /// Creates a new `TextDataset`.
28    #[must_use] pub fn new(texts: Vec<String>, labels: Vec<usize>, vocab: Vocab, max_length: usize) -> Self {
29        let num_classes = labels.iter().max().map_or(0, |&m| m + 1);
30        Self {
31            texts,
32            labels,
33            vocab,
34            max_length,
35            num_classes,
36        }
37    }
38
39    /// Creates a `TextDataset` from raw text samples with a tokenizer.
40    pub fn from_samples<T: Tokenizer>(
41        samples: &[(String, usize)],
42        tokenizer: &T,
43        min_freq: usize,
44        max_length: usize,
45    ) -> Self {
46        use std::collections::HashMap;
47
48        // Build vocabulary from tokenized text
49        let mut freq: HashMap<String, usize> = HashMap::new();
50        for (text, _) in samples {
51            for token in tokenizer.tokenize(text) {
52                *freq.entry(token).or_insert(0) += 1;
53            }
54        }
55
56        // Create vocabulary with tokens meeting min_freq threshold
57        let mut vocab = Vocab::with_special_tokens();
58        let mut tokens: Vec<_> = freq
59            .into_iter()
60            .filter(|(_, count)| *count >= min_freq)
61            .collect();
62        tokens.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
63        for (token, _) in tokens {
64            vocab.add_token(&token);
65        }
66
67        let texts: Vec<String> = samples.iter().map(|(t, _)| t.clone()).collect();
68        let labels: Vec<usize> = samples.iter().map(|(_, l)| *l).collect();
69
70        Self::new(texts, labels, vocab, max_length)
71    }
72
73    /// Returns the vocabulary.
74    #[must_use] pub fn vocab(&self) -> &Vocab {
75        &self.vocab
76    }
77
78    /// Returns the number of classes.
79    #[must_use] pub fn num_classes(&self) -> usize {
80        self.num_classes
81    }
82
83    /// Returns the maximum sequence length.
84    #[must_use] pub fn max_length(&self) -> usize {
85        self.max_length
86    }
87
88    /// Encodes text to padded tensor.
89    fn encode_text(&self, text: &str) -> Tensor<f32> {
90        let tokens: Vec<&str> = text.split_whitespace().collect();
91        let mut indices: Vec<f32> = tokens
92            .iter()
93            .take(self.max_length)
94            .map(|t| self.vocab.token_to_index(t) as f32)
95            .collect();
96
97        // Pad to max_length
98        let pad_idx = self.vocab.pad_index().unwrap_or(0) as f32;
99        while indices.len() < self.max_length {
100            indices.push(pad_idx);
101        }
102
103        Tensor::from_vec(indices, &[self.max_length]).unwrap()
104    }
105}
106
107impl Dataset for TextDataset {
108    type Item = (Tensor<f32>, Tensor<f32>);
109
110    fn len(&self) -> usize {
111        self.texts.len()
112    }
113
114    fn get(&self, index: usize) -> Option<Self::Item> {
115        if index >= self.texts.len() {
116            return None;
117        }
118
119        let text = self.encode_text(&self.texts[index]);
120
121        // One-hot encode label
122        let mut label_vec = vec![0.0f32; self.num_classes];
123        label_vec[self.labels[index]] = 1.0;
124        let label = Tensor::from_vec(label_vec, &[self.num_classes]).unwrap();
125
126        Some((text, label))
127    }
128}
129
130// =============================================================================
131// LanguageModelDataset
132// =============================================================================
133
134/// A dataset for language modeling (next token prediction).
135pub struct LanguageModelDataset {
136    tokens: Vec<usize>,
137    sequence_length: usize,
138    vocab: Vocab,
139}
140
141impl LanguageModelDataset {
142    /// Creates a new `LanguageModelDataset`.
143    #[must_use] pub fn new(text: &str, vocab: Vocab, sequence_length: usize) -> Self {
144        let tokens: Vec<usize> = text
145            .split_whitespace()
146            .map(|t| vocab.token_to_index(t))
147            .collect();
148
149        Self {
150            tokens,
151            sequence_length,
152            vocab,
153        }
154    }
155
156    /// Creates a dataset from text, building vocabulary automatically.
157    #[must_use] pub fn from_text(text: &str, sequence_length: usize, min_freq: usize) -> Self {
158        let vocab = Vocab::from_text(text, min_freq);
159        Self::new(text, vocab, sequence_length)
160    }
161
162    /// Returns the vocabulary.
163    #[must_use] pub fn vocab(&self) -> &Vocab {
164        &self.vocab
165    }
166}
167
168impl Dataset for LanguageModelDataset {
169    type Item = (Tensor<f32>, Tensor<f32>);
170
171    fn len(&self) -> usize {
172        if self.tokens.len() <= self.sequence_length {
173            0
174        } else {
175            self.tokens.len() - self.sequence_length
176        }
177    }
178
179    fn get(&self, index: usize) -> Option<Self::Item> {
180        if index >= self.len() {
181            return None;
182        }
183
184        // Input: tokens[index..index+sequence_length]
185        let input: Vec<f32> = self.tokens[index..index + self.sequence_length]
186            .iter()
187            .map(|&t| t as f32)
188            .collect();
189
190        // Target: tokens[index+1..index+sequence_length+1]
191        let target: Vec<f32> = self.tokens[(index + 1)..=(index + self.sequence_length)]
192            .iter()
193            .map(|&t| t as f32)
194            .collect();
195
196        Some((
197            Tensor::from_vec(input, &[self.sequence_length]).unwrap(),
198            Tensor::from_vec(target, &[self.sequence_length]).unwrap(),
199        ))
200    }
201}
202
203// =============================================================================
204// SyntheticSentimentDataset
205// =============================================================================
206
207/// A synthetic sentiment analysis dataset for testing.
208pub struct SyntheticSentimentDataset {
209    size: usize,
210    max_length: usize,
211    vocab_size: usize,
212}
213
214impl SyntheticSentimentDataset {
215    /// Creates a new synthetic sentiment dataset.
216    #[must_use] pub fn new(size: usize, max_length: usize, vocab_size: usize) -> Self {
217        Self {
218            size,
219            max_length,
220            vocab_size,
221        }
222    }
223
224    /// Creates a small test dataset.
225    #[must_use] pub fn small() -> Self {
226        Self::new(100, 32, 1000)
227    }
228
229    /// Creates a standard training dataset.
230    #[must_use] pub fn train() -> Self {
231        Self::new(10000, 64, 10000)
232    }
233
234    /// Creates a standard test dataset.
235    #[must_use] pub fn test() -> Self {
236        Self::new(2000, 64, 10000)
237    }
238}
239
240impl Dataset for SyntheticSentimentDataset {
241    type Item = (Tensor<f32>, Tensor<f32>);
242
243    fn len(&self) -> usize {
244        self.size
245    }
246
247    fn get(&self, index: usize) -> Option<Self::Item> {
248        if index >= self.size {
249            return None;
250        }
251
252        // Generate deterministic "random" sequence
253        let seed = index as u32;
254        let label = index % 2; // Binary sentiment
255
256        let mut text = Vec::with_capacity(self.max_length);
257        for i in 0..self.max_length {
258            let token_seed = seed.wrapping_mul(1103515245).wrapping_add(12345 + i as u32);
259            let token = (token_seed as usize) % self.vocab_size;
260            // Bias tokens based on sentiment
261            let biased_token = if label == 1 {
262                (token + self.vocab_size / 2) % self.vocab_size
263            } else {
264                token
265            };
266            text.push(biased_token as f32);
267        }
268
269        let text_tensor = Tensor::from_vec(text, &[self.max_length]).unwrap();
270
271        // One-hot label
272        let mut label_vec = vec![0.0f32; 2];
273        label_vec[label] = 1.0;
274        let label_tensor = Tensor::from_vec(label_vec, &[2]).unwrap();
275
276        Some((text_tensor, label_tensor))
277    }
278}
279
280// =============================================================================
281// SyntheticSequenceDataset
282// =============================================================================
283
284/// A synthetic sequence-to-sequence dataset for testing.
285pub struct SyntheticSeq2SeqDataset {
286    size: usize,
287    src_length: usize,
288    tgt_length: usize,
289    vocab_size: usize,
290}
291
292impl SyntheticSeq2SeqDataset {
293    /// Creates a new synthetic seq2seq dataset.
294    #[must_use] pub fn new(size: usize, src_length: usize, tgt_length: usize, vocab_size: usize) -> Self {
295        Self {
296            size,
297            src_length,
298            tgt_length,
299            vocab_size,
300        }
301    }
302
303    /// Creates a copy task dataset (target = reversed source).
304    #[must_use] pub fn copy_task(size: usize, length: usize, vocab_size: usize) -> Self {
305        Self::new(size, length, length, vocab_size)
306    }
307}
308
309impl Dataset for SyntheticSeq2SeqDataset {
310    type Item = (Tensor<f32>, Tensor<f32>);
311
312    fn len(&self) -> usize {
313        self.size
314    }
315
316    fn get(&self, index: usize) -> Option<Self::Item> {
317        if index >= self.size {
318            return None;
319        }
320
321        let seed = index as u32;
322
323        // Generate source sequence
324        let mut src = Vec::with_capacity(self.src_length);
325        for i in 0..self.src_length {
326            let token_seed = seed.wrapping_mul(1103515245).wrapping_add(12345 + i as u32);
327            let token = (token_seed as usize) % self.vocab_size;
328            src.push(token as f32);
329        }
330
331        // Target is reversed source (simple copy task)
332        let tgt: Vec<f32> = src.iter().rev().copied().collect();
333
334        Some((
335            Tensor::from_vec(src, &[self.src_length]).unwrap(),
336            Tensor::from_vec(tgt, &[self.tgt_length]).unwrap(),
337        ))
338    }
339}
340
341// =============================================================================
342// Tests
343// =============================================================================
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    #[test]
350    fn test_text_dataset() {
351        let vocab = Vocab::from_tokens(&["hello", "world", "good", "bad", "<pad>", "<unk>"]);
352        let texts = vec!["hello world".to_string(), "good bad".to_string()];
353        let labels = vec![0, 1];
354
355        let dataset = TextDataset::new(texts, labels, vocab, 10);
356
357        assert_eq!(dataset.len(), 2);
358        assert_eq!(dataset.num_classes(), 2);
359
360        let (text, label) = dataset.get(0).unwrap();
361        assert_eq!(text.shape(), &[10]);
362        assert_eq!(label.shape(), &[2]);
363    }
364
365    #[test]
366    fn test_language_model_dataset() {
367        let text = "the quick brown fox jumps over the lazy dog";
368        let dataset = LanguageModelDataset::from_text(text, 3, 1);
369
370        assert!(dataset.len() > 0);
371
372        let (input, target) = dataset.get(0).unwrap();
373        assert_eq!(input.shape(), &[3]);
374        assert_eq!(target.shape(), &[3]);
375    }
376
377    #[test]
378    fn test_synthetic_sentiment_dataset() {
379        let dataset = SyntheticSentimentDataset::small();
380
381        assert_eq!(dataset.len(), 100);
382
383        let (text, label) = dataset.get(0).unwrap();
384        assert_eq!(text.shape(), &[32]);
385        assert_eq!(label.shape(), &[2]);
386
387        // Check label is one-hot
388        let label_vec = label.to_vec();
389        let sum: f32 = label_vec.iter().sum();
390        assert!((sum - 1.0).abs() < 0.001);
391    }
392
393    #[test]
394    fn test_synthetic_sentiment_deterministic() {
395        let dataset = SyntheticSentimentDataset::small();
396
397        let (text1, label1) = dataset.get(5).unwrap();
398        let (text2, label2) = dataset.get(5).unwrap();
399
400        assert_eq!(text1.to_vec(), text2.to_vec());
401        assert_eq!(label1.to_vec(), label2.to_vec());
402    }
403
404    #[test]
405    fn test_synthetic_seq2seq_dataset() {
406        let dataset = SyntheticSeq2SeqDataset::copy_task(100, 10, 50);
407
408        assert_eq!(dataset.len(), 100);
409
410        let (src, tgt) = dataset.get(0).unwrap();
411        assert_eq!(src.shape(), &[10]);
412        assert_eq!(tgt.shape(), &[10]);
413
414        // Target should be reversed source
415        let src_vec = src.to_vec();
416        let tgt_vec = tgt.to_vec();
417        let reversed: Vec<f32> = src_vec.iter().rev().copied().collect();
418        assert_eq!(tgt_vec, reversed);
419    }
420
421    #[test]
422    fn test_text_dataset_padding() {
423        let vocab = Vocab::with_special_tokens();
424        let texts = vec!["a b".to_string()];
425        let labels = vec![0];
426
427        let dataset = TextDataset::new(texts, labels, vocab, 10);
428        let (text, _) = dataset.get(0).unwrap();
429
430        // Should be padded to length 10
431        assert_eq!(text.shape(), &[10]);
432    }
433}