Skip to main content

axonml_text/datasets/
mod.rs

1//! Text Datasets
2//!
3//! Provides dataset implementations for common NLP tasks.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use crate::tokenizer::Tokenizer;
9use crate::vocab::Vocab;
10use axonml_data::Dataset;
11use axonml_tensor::Tensor;
12
13// =============================================================================
14// TextDataset
15// =============================================================================
16
17/// A dataset of text samples with labels.
18pub struct TextDataset {
19    texts: Vec<String>,
20    labels: Vec<usize>,
21    vocab: Vocab,
22    max_length: usize,
23    num_classes: usize,
24}
25
26impl TextDataset {
27    /// Creates a new `TextDataset`.
28    #[must_use]
29    pub fn new(texts: Vec<String>, labels: Vec<usize>, vocab: Vocab, max_length: usize) -> Self {
30        let num_classes = labels.iter().max().map_or(0, |&m| m + 1);
31        Self {
32            texts,
33            labels,
34            vocab,
35            max_length,
36            num_classes,
37        }
38    }
39
40    /// Creates a `TextDataset` from raw text samples with a tokenizer.
41    pub fn from_samples<T: Tokenizer>(
42        samples: &[(String, usize)],
43        tokenizer: &T,
44        min_freq: usize,
45        max_length: usize,
46    ) -> Self {
47        use std::collections::HashMap;
48
49        // Build vocabulary from tokenized text
50        let mut freq: HashMap<String, usize> = HashMap::new();
51        for (text, _) in samples {
52            for token in tokenizer.tokenize(text) {
53                *freq.entry(token).or_insert(0) += 1;
54            }
55        }
56
57        // Create vocabulary with tokens meeting min_freq threshold
58        let mut vocab = Vocab::with_special_tokens();
59        let mut tokens: Vec<_> = freq
60            .into_iter()
61            .filter(|(_, count)| *count >= min_freq)
62            .collect();
63        tokens.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
64        for (token, _) in tokens {
65            vocab.add_token(&token);
66        }
67
68        let texts: Vec<String> = samples.iter().map(|(t, _)| t.clone()).collect();
69        let labels: Vec<usize> = samples.iter().map(|(_, l)| *l).collect();
70
71        Self::new(texts, labels, vocab, max_length)
72    }
73
74    /// Returns the vocabulary.
75    #[must_use]
76    pub fn vocab(&self) -> &Vocab {
77        &self.vocab
78    }
79
80    /// Returns the number of classes.
81    #[must_use]
82    pub fn num_classes(&self) -> usize {
83        self.num_classes
84    }
85
86    /// Returns the maximum sequence length.
87    #[must_use]
88    pub fn max_length(&self) -> usize {
89        self.max_length
90    }
91
92    /// Encodes text to padded tensor.
93    fn encode_text(&self, text: &str) -> Tensor<f32> {
94        let tokens: Vec<&str> = text.split_whitespace().collect();
95        let mut indices: Vec<f32> = tokens
96            .iter()
97            .take(self.max_length)
98            .map(|t| self.vocab.token_to_index(t) as f32)
99            .collect();
100
101        // Pad to max_length
102        let pad_idx = self.vocab.pad_index().unwrap_or(0) as f32;
103        while indices.len() < self.max_length {
104            indices.push(pad_idx);
105        }
106
107        Tensor::from_vec(indices, &[self.max_length]).unwrap()
108    }
109}
110
111impl Dataset for TextDataset {
112    type Item = (Tensor<f32>, Tensor<f32>);
113
114    fn len(&self) -> usize {
115        self.texts.len()
116    }
117
118    fn get(&self, index: usize) -> Option<Self::Item> {
119        if index >= self.texts.len() {
120            return None;
121        }
122
123        let text = self.encode_text(&self.texts[index]);
124
125        // One-hot encode label
126        let mut label_vec = vec![0.0f32; self.num_classes];
127        label_vec[self.labels[index]] = 1.0;
128        let label = Tensor::from_vec(label_vec, &[self.num_classes]).unwrap();
129
130        Some((text, label))
131    }
132}
133
134// =============================================================================
135// LanguageModelDataset
136// =============================================================================
137
138/// A dataset for language modeling (next token prediction).
139pub struct LanguageModelDataset {
140    tokens: Vec<usize>,
141    sequence_length: usize,
142    vocab: Vocab,
143}
144
145impl LanguageModelDataset {
146    /// Creates a new `LanguageModelDataset`.
147    #[must_use]
148    pub fn new(text: &str, vocab: Vocab, sequence_length: usize) -> Self {
149        let tokens: Vec<usize> = text
150            .split_whitespace()
151            .map(|t| vocab.token_to_index(t))
152            .collect();
153
154        Self {
155            tokens,
156            sequence_length,
157            vocab,
158        }
159    }
160
161    /// Creates a dataset from text, building vocabulary automatically.
162    #[must_use]
163    pub fn from_text(text: &str, sequence_length: usize, min_freq: usize) -> Self {
164        let vocab = Vocab::from_text(text, min_freq);
165        Self::new(text, vocab, sequence_length)
166    }
167
168    /// Returns the vocabulary.
169    #[must_use]
170    pub fn vocab(&self) -> &Vocab {
171        &self.vocab
172    }
173}
174
175impl Dataset for LanguageModelDataset {
176    type Item = (Tensor<f32>, Tensor<f32>);
177
178    fn len(&self) -> usize {
179        if self.tokens.len() <= self.sequence_length {
180            0
181        } else {
182            self.tokens.len() - self.sequence_length
183        }
184    }
185
186    fn get(&self, index: usize) -> Option<Self::Item> {
187        if index >= self.len() {
188            return None;
189        }
190
191        // Input: tokens[index..index+sequence_length]
192        let input: Vec<f32> = self.tokens[index..index + self.sequence_length]
193            .iter()
194            .map(|&t| t as f32)
195            .collect();
196
197        // Target: tokens[index+1..index+sequence_length+1]
198        let target: Vec<f32> = self.tokens[(index + 1)..=(index + self.sequence_length)]
199            .iter()
200            .map(|&t| t as f32)
201            .collect();
202
203        Some((
204            Tensor::from_vec(input, &[self.sequence_length]).unwrap(),
205            Tensor::from_vec(target, &[self.sequence_length]).unwrap(),
206        ))
207    }
208}
209
210// =============================================================================
211// SyntheticSentimentDataset
212// =============================================================================
213
214/// A synthetic sentiment analysis dataset for testing.
215pub struct SyntheticSentimentDataset {
216    size: usize,
217    max_length: usize,
218    vocab_size: usize,
219}
220
221impl SyntheticSentimentDataset {
222    /// Creates a new synthetic sentiment dataset.
223    #[must_use]
224    pub fn new(size: usize, max_length: usize, vocab_size: usize) -> Self {
225        Self {
226            size,
227            max_length,
228            vocab_size,
229        }
230    }
231
232    /// Creates a small test dataset.
233    #[must_use]
234    pub fn small() -> Self {
235        Self::new(100, 32, 1000)
236    }
237
238    /// Creates a standard training dataset.
239    #[must_use]
240    pub fn train() -> Self {
241        Self::new(10000, 64, 10000)
242    }
243
244    /// Creates a standard test dataset.
245    #[must_use]
246    pub fn test() -> Self {
247        Self::new(2000, 64, 10000)
248    }
249}
250
251impl Dataset for SyntheticSentimentDataset {
252    type Item = (Tensor<f32>, Tensor<f32>);
253
254    fn len(&self) -> usize {
255        self.size
256    }
257
258    fn get(&self, index: usize) -> Option<Self::Item> {
259        if index >= self.size {
260            return None;
261        }
262
263        // Generate deterministic "random" sequence
264        let seed = index as u32;
265        let label = index % 2; // Binary sentiment
266
267        let mut text = Vec::with_capacity(self.max_length);
268        for i in 0..self.max_length {
269            let token_seed = seed.wrapping_mul(1103515245).wrapping_add(12345 + i as u32);
270            let token = (token_seed as usize) % self.vocab_size;
271            // Bias tokens based on sentiment
272            let biased_token = if label == 1 {
273                (token + self.vocab_size / 2) % self.vocab_size
274            } else {
275                token
276            };
277            text.push(biased_token as f32);
278        }
279
280        let text_tensor = Tensor::from_vec(text, &[self.max_length]).unwrap();
281
282        // One-hot label
283        let mut label_vec = vec![0.0f32; 2];
284        label_vec[label] = 1.0;
285        let label_tensor = Tensor::from_vec(label_vec, &[2]).unwrap();
286
287        Some((text_tensor, label_tensor))
288    }
289}
290
291// =============================================================================
292// SyntheticSequenceDataset
293// =============================================================================
294
295/// A synthetic sequence-to-sequence dataset for testing.
296pub struct SyntheticSeq2SeqDataset {
297    size: usize,
298    src_length: usize,
299    tgt_length: usize,
300    vocab_size: usize,
301}
302
303impl SyntheticSeq2SeqDataset {
304    /// Creates a new synthetic seq2seq dataset.
305    #[must_use]
306    pub fn new(size: usize, src_length: usize, tgt_length: usize, vocab_size: usize) -> Self {
307        Self {
308            size,
309            src_length,
310            tgt_length,
311            vocab_size,
312        }
313    }
314
315    /// Creates a copy task dataset (target = reversed source).
316    #[must_use]
317    pub fn copy_task(size: usize, length: usize, vocab_size: usize) -> Self {
318        Self::new(size, length, length, vocab_size)
319    }
320}
321
322impl Dataset for SyntheticSeq2SeqDataset {
323    type Item = (Tensor<f32>, Tensor<f32>);
324
325    fn len(&self) -> usize {
326        self.size
327    }
328
329    fn get(&self, index: usize) -> Option<Self::Item> {
330        if index >= self.size {
331            return None;
332        }
333
334        let seed = index as u32;
335
336        // Generate source sequence
337        let mut src = Vec::with_capacity(self.src_length);
338        for i in 0..self.src_length {
339            let token_seed = seed.wrapping_mul(1103515245).wrapping_add(12345 + i as u32);
340            let token = (token_seed as usize) % self.vocab_size;
341            src.push(token as f32);
342        }
343
344        // Target is reversed source (simple copy task)
345        let tgt: Vec<f32> = src.iter().rev().copied().collect();
346
347        Some((
348            Tensor::from_vec(src, &[self.src_length]).unwrap(),
349            Tensor::from_vec(tgt, &[self.tgt_length]).unwrap(),
350        ))
351    }
352}
353
354// =============================================================================
355// Tests
356// =============================================================================
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    #[test]
363    fn test_text_dataset() {
364        let vocab = Vocab::from_tokens(&["hello", "world", "good", "bad", "<pad>", "<unk>"]);
365        let texts = vec!["hello world".to_string(), "good bad".to_string()];
366        let labels = vec![0, 1];
367
368        let dataset = TextDataset::new(texts, labels, vocab, 10);
369
370        assert_eq!(dataset.len(), 2);
371        assert_eq!(dataset.num_classes(), 2);
372
373        let (text, label) = dataset.get(0).unwrap();
374        assert_eq!(text.shape(), &[10]);
375        assert_eq!(label.shape(), &[2]);
376    }
377
378    #[test]
379    fn test_language_model_dataset() {
380        let text = "the quick brown fox jumps over the lazy dog";
381        let dataset = LanguageModelDataset::from_text(text, 3, 1);
382
383        assert!(dataset.len() > 0);
384
385        let (input, target) = dataset.get(0).unwrap();
386        assert_eq!(input.shape(), &[3]);
387        assert_eq!(target.shape(), &[3]);
388    }
389
390    #[test]
391    fn test_synthetic_sentiment_dataset() {
392        let dataset = SyntheticSentimentDataset::small();
393
394        assert_eq!(dataset.len(), 100);
395
396        let (text, label) = dataset.get(0).unwrap();
397        assert_eq!(text.shape(), &[32]);
398        assert_eq!(label.shape(), &[2]);
399
400        // Check label is one-hot
401        let label_vec = label.to_vec();
402        let sum: f32 = label_vec.iter().sum();
403        assert!((sum - 1.0).abs() < 0.001);
404    }
405
406    #[test]
407    fn test_synthetic_sentiment_deterministic() {
408        let dataset = SyntheticSentimentDataset::small();
409
410        let (text1, label1) = dataset.get(5).unwrap();
411        let (text2, label2) = dataset.get(5).unwrap();
412
413        assert_eq!(text1.to_vec(), text2.to_vec());
414        assert_eq!(label1.to_vec(), label2.to_vec());
415    }
416
417    #[test]
418    fn test_synthetic_seq2seq_dataset() {
419        let dataset = SyntheticSeq2SeqDataset::copy_task(100, 10, 50);
420
421        assert_eq!(dataset.len(), 100);
422
423        let (src, tgt) = dataset.get(0).unwrap();
424        assert_eq!(src.shape(), &[10]);
425        assert_eq!(tgt.shape(), &[10]);
426
427        // Target should be reversed source
428        let src_vec = src.to_vec();
429        let tgt_vec = tgt.to_vec();
430        let reversed: Vec<f32> = src_vec.iter().rev().copied().collect();
431        assert_eq!(tgt_vec, reversed);
432    }
433
434    #[test]
435    fn test_text_dataset_padding() {
436        let vocab = Vocab::with_special_tokens();
437        let texts = vec!["a b".to_string()];
438        let labels = vec![0];
439
440        let dataset = TextDataset::new(texts, labels, vocab, 10);
441        let (text, _) = dataset.get(0).unwrap();
442
443        // Should be padded to length 10
444        assert_eq!(text.shape(), &[10]);
445    }
446}