Skip to main content

axonml_text/datasets/
mod.rs

1//! Text Datasets
2//!
3//! # File
4//! `crates/axonml-text/src/datasets/mod.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use crate::tokenizer::Tokenizer;
18use crate::vocab::Vocab;
19use axonml_data::Dataset;
20use axonml_tensor::Tensor;
21
22// =============================================================================
23// TextDataset
24// =============================================================================
25
26/// A dataset of text samples with labels.
27pub struct TextDataset {
28    texts: Vec<String>,
29    labels: Vec<usize>,
30    vocab: Vocab,
31    max_length: usize,
32    num_classes: usize,
33}
34
35impl TextDataset {
36    /// Creates a new `TextDataset`.
37    #[must_use]
38    pub fn new(texts: Vec<String>, labels: Vec<usize>, vocab: Vocab, max_length: usize) -> Self {
39        let num_classes = labels.iter().max().map_or(0, |&m| m + 1);
40        Self {
41            texts,
42            labels,
43            vocab,
44            max_length,
45            num_classes,
46        }
47    }
48
49    /// Creates a `TextDataset` from raw text samples with a tokenizer.
50    pub fn from_samples<T: Tokenizer>(
51        samples: &[(String, usize)],
52        tokenizer: &T,
53        min_freq: usize,
54        max_length: usize,
55    ) -> Self {
56        use std::collections::HashMap;
57
58        // Build vocabulary from tokenized text
59        let mut freq: HashMap<String, usize> = HashMap::new();
60        for (text, _) in samples {
61            for token in tokenizer.tokenize(text) {
62                *freq.entry(token).or_insert(0) += 1;
63            }
64        }
65
66        // Create vocabulary with tokens meeting min_freq threshold
67        let mut vocab = Vocab::with_special_tokens();
68        let mut tokens: Vec<_> = freq
69            .into_iter()
70            .filter(|(_, count)| *count >= min_freq)
71            .collect();
72        tokens.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
73        for (token, _) in tokens {
74            vocab.add_token(&token);
75        }
76
77        let texts: Vec<String> = samples.iter().map(|(t, _)| t.clone()).collect();
78        let labels: Vec<usize> = samples.iter().map(|(_, l)| *l).collect();
79
80        Self::new(texts, labels, vocab, max_length)
81    }
82
83    /// Returns the vocabulary.
84    #[must_use]
85    pub fn vocab(&self) -> &Vocab {
86        &self.vocab
87    }
88
89    /// Returns the number of classes.
90    #[must_use]
91    pub fn num_classes(&self) -> usize {
92        self.num_classes
93    }
94
95    /// Returns the maximum sequence length.
96    #[must_use]
97    pub fn max_length(&self) -> usize {
98        self.max_length
99    }
100
101    /// Encodes text to padded tensor.
102    fn encode_text(&self, text: &str) -> Tensor<f32> {
103        let tokens: Vec<&str> = text.split_whitespace().collect();
104        let mut indices: Vec<f32> = tokens
105            .iter()
106            .take(self.max_length)
107            .map(|t| self.vocab.token_to_index(t) as f32)
108            .collect();
109
110        // Pad to max_length
111        let pad_idx = self.vocab.pad_index().unwrap_or(0) as f32;
112        while indices.len() < self.max_length {
113            indices.push(pad_idx);
114        }
115
116        Tensor::from_vec(indices, &[self.max_length]).unwrap()
117    }
118}
119
120impl Dataset for TextDataset {
121    type Item = (Tensor<f32>, Tensor<f32>);
122
123    fn len(&self) -> usize {
124        self.texts.len()
125    }
126
127    fn get(&self, index: usize) -> Option<Self::Item> {
128        if index >= self.texts.len() {
129            return None;
130        }
131
132        let text = self.encode_text(&self.texts[index]);
133
134        // One-hot encode label
135        let mut label_vec = vec![0.0f32; self.num_classes];
136        label_vec[self.labels[index]] = 1.0;
137        let label = Tensor::from_vec(label_vec, &[self.num_classes]).unwrap();
138
139        Some((text, label))
140    }
141}
142
143// =============================================================================
144// LanguageModelDataset
145// =============================================================================
146
147/// A dataset for language modeling (next token prediction).
148pub struct LanguageModelDataset {
149    tokens: Vec<usize>,
150    sequence_length: usize,
151    vocab: Vocab,
152}
153
154impl LanguageModelDataset {
155    /// Creates a new `LanguageModelDataset`.
156    #[must_use]
157    pub fn new(text: &str, vocab: Vocab, sequence_length: usize) -> Self {
158        let tokens: Vec<usize> = text
159            .split_whitespace()
160            .map(|t| vocab.token_to_index(t))
161            .collect();
162
163        Self {
164            tokens,
165            sequence_length,
166            vocab,
167        }
168    }
169
170    /// Creates a dataset from text, building vocabulary automatically.
171    #[must_use]
172    pub fn from_text(text: &str, sequence_length: usize, min_freq: usize) -> Self {
173        let vocab = Vocab::from_text(text, min_freq);
174        Self::new(text, vocab, sequence_length)
175    }
176
177    /// Returns the vocabulary.
178    #[must_use]
179    pub fn vocab(&self) -> &Vocab {
180        &self.vocab
181    }
182}
183
184impl Dataset for LanguageModelDataset {
185    type Item = (Tensor<f32>, Tensor<f32>);
186
187    fn len(&self) -> usize {
188        if self.tokens.len() <= self.sequence_length {
189            0
190        } else {
191            self.tokens.len() - self.sequence_length
192        }
193    }
194
195    fn get(&self, index: usize) -> Option<Self::Item> {
196        if index >= self.len() {
197            return None;
198        }
199
200        // Input: tokens[index..index+sequence_length]
201        let input: Vec<f32> = self.tokens[index..index + self.sequence_length]
202            .iter()
203            .map(|&t| t as f32)
204            .collect();
205
206        // Target: tokens[index+1..index+sequence_length+1]
207        let target: Vec<f32> = self.tokens[(index + 1)..=(index + self.sequence_length)]
208            .iter()
209            .map(|&t| t as f32)
210            .collect();
211
212        Some((
213            Tensor::from_vec(input, &[self.sequence_length]).unwrap(),
214            Tensor::from_vec(target, &[self.sequence_length]).unwrap(),
215        ))
216    }
217}
218
219// =============================================================================
220// SyntheticSentimentDataset
221// =============================================================================
222
223/// A synthetic sentiment analysis dataset for testing.
224pub struct SyntheticSentimentDataset {
225    size: usize,
226    max_length: usize,
227    vocab_size: usize,
228}
229
230impl SyntheticSentimentDataset {
231    /// Creates a new synthetic sentiment dataset.
232    #[must_use]
233    pub fn new(size: usize, max_length: usize, vocab_size: usize) -> Self {
234        Self {
235            size,
236            max_length,
237            vocab_size,
238        }
239    }
240
241    /// Creates a small test dataset.
242    #[must_use]
243    pub fn small() -> Self {
244        Self::new(100, 32, 1000)
245    }
246
247    /// Creates a standard training dataset.
248    #[must_use]
249    pub fn train() -> Self {
250        Self::new(10000, 64, 10000)
251    }
252
253    /// Creates a standard test dataset.
254    #[must_use]
255    pub fn test() -> Self {
256        Self::new(2000, 64, 10000)
257    }
258}
259
260impl Dataset for SyntheticSentimentDataset {
261    type Item = (Tensor<f32>, Tensor<f32>);
262
263    fn len(&self) -> usize {
264        self.size
265    }
266
267    fn get(&self, index: usize) -> Option<Self::Item> {
268        if index >= self.size {
269            return None;
270        }
271
272        // Generate deterministic "random" sequence
273        let seed = index as u32;
274        let label = index % 2; // Binary sentiment
275
276        let mut text = Vec::with_capacity(self.max_length);
277        for i in 0..self.max_length {
278            let token_seed = seed.wrapping_mul(1103515245).wrapping_add(12345 + i as u32);
279            let token = (token_seed as usize) % self.vocab_size;
280            // Bias tokens based on sentiment
281            let biased_token = if label == 1 {
282                (token + self.vocab_size / 2) % self.vocab_size
283            } else {
284                token
285            };
286            text.push(biased_token as f32);
287        }
288
289        let text_tensor = Tensor::from_vec(text, &[self.max_length]).unwrap();
290
291        // One-hot label
292        let mut label_vec = vec![0.0f32; 2];
293        label_vec[label] = 1.0;
294        let label_tensor = Tensor::from_vec(label_vec, &[2]).unwrap();
295
296        Some((text_tensor, label_tensor))
297    }
298}
299
300// =============================================================================
301// SyntheticSequenceDataset
302// =============================================================================
303
304/// A synthetic sequence-to-sequence dataset for testing.
305pub struct SyntheticSeq2SeqDataset {
306    size: usize,
307    src_length: usize,
308    tgt_length: usize,
309    vocab_size: usize,
310}
311
312impl SyntheticSeq2SeqDataset {
313    /// Creates a new synthetic seq2seq dataset.
314    #[must_use]
315    pub fn new(size: usize, src_length: usize, tgt_length: usize, vocab_size: usize) -> Self {
316        Self {
317            size,
318            src_length,
319            tgt_length,
320            vocab_size,
321        }
322    }
323
324    /// Creates a copy task dataset (target = reversed source).
325    #[must_use]
326    pub fn copy_task(size: usize, length: usize, vocab_size: usize) -> Self {
327        Self::new(size, length, length, vocab_size)
328    }
329}
330
331impl Dataset for SyntheticSeq2SeqDataset {
332    type Item = (Tensor<f32>, Tensor<f32>);
333
334    fn len(&self) -> usize {
335        self.size
336    }
337
338    fn get(&self, index: usize) -> Option<Self::Item> {
339        if index >= self.size {
340            return None;
341        }
342
343        let seed = index as u32;
344
345        // Generate source sequence
346        let mut src = Vec::with_capacity(self.src_length);
347        for i in 0..self.src_length {
348            let token_seed = seed.wrapping_mul(1103515245).wrapping_add(12345 + i as u32);
349            let token = (token_seed as usize) % self.vocab_size;
350            src.push(token as f32);
351        }
352
353        // Target is reversed source (simple copy task)
354        let tgt: Vec<f32> = src.iter().rev().copied().collect();
355
356        Some((
357            Tensor::from_vec(src, &[self.src_length]).unwrap(),
358            Tensor::from_vec(tgt, &[self.tgt_length]).unwrap(),
359        ))
360    }
361}
362
363// =============================================================================
364// Tests
365// =============================================================================
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[test]
372    fn test_text_dataset() {
373        let vocab = Vocab::from_tokens(&["hello", "world", "good", "bad", "<pad>", "<unk>"]);
374        let texts = vec!["hello world".to_string(), "good bad".to_string()];
375        let labels = vec![0, 1];
376
377        let dataset = TextDataset::new(texts, labels, vocab, 10);
378
379        assert_eq!(dataset.len(), 2);
380        assert_eq!(dataset.num_classes(), 2);
381
382        let (text, label) = dataset.get(0).unwrap();
383        assert_eq!(text.shape(), &[10]);
384        assert_eq!(label.shape(), &[2]);
385    }
386
387    #[test]
388    fn test_language_model_dataset() {
389        let text = "the quick brown fox jumps over the lazy dog";
390        let dataset = LanguageModelDataset::from_text(text, 3, 1);
391
392        assert!(dataset.len() > 0);
393
394        let (input, target) = dataset.get(0).unwrap();
395        assert_eq!(input.shape(), &[3]);
396        assert_eq!(target.shape(), &[3]);
397    }
398
399    #[test]
400    fn test_synthetic_sentiment_dataset() {
401        let dataset = SyntheticSentimentDataset::small();
402
403        assert_eq!(dataset.len(), 100);
404
405        let (text, label) = dataset.get(0).unwrap();
406        assert_eq!(text.shape(), &[32]);
407        assert_eq!(label.shape(), &[2]);
408
409        // Check label is one-hot
410        let label_vec = label.to_vec();
411        let sum: f32 = label_vec.iter().sum();
412        assert!((sum - 1.0).abs() < 0.001);
413    }
414
415    #[test]
416    fn test_synthetic_sentiment_deterministic() {
417        let dataset = SyntheticSentimentDataset::small();
418
419        let (text1, label1) = dataset.get(5).unwrap();
420        let (text2, label2) = dataset.get(5).unwrap();
421
422        assert_eq!(text1.to_vec(), text2.to_vec());
423        assert_eq!(label1.to_vec(), label2.to_vec());
424    }
425
426    #[test]
427    fn test_synthetic_seq2seq_dataset() {
428        let dataset = SyntheticSeq2SeqDataset::copy_task(100, 10, 50);
429
430        assert_eq!(dataset.len(), 100);
431
432        let (src, tgt) = dataset.get(0).unwrap();
433        assert_eq!(src.shape(), &[10]);
434        assert_eq!(tgt.shape(), &[10]);
435
436        // Target should be reversed source
437        let src_vec = src.to_vec();
438        let tgt_vec = tgt.to_vec();
439        let reversed: Vec<f32> = src_vec.iter().rev().copied().collect();
440        assert_eq!(tgt_vec, reversed);
441    }
442
443    #[test]
444    fn test_text_dataset_padding() {
445        let vocab = Vocab::with_special_tokens();
446        let texts = vec!["a b".to_string()];
447        let labels = vec![0];
448
449        let dataset = TextDataset::new(texts, labels, vocab, 10);
450        let (text, _) = dataset.get(0).unwrap();
451
452        // Should be padded to length 10
453        assert_eq!(text.shape(), &[10]);
454    }
455}