Skip to main content

axonml_text/datasets/
mod.rs

1//! Text Datasets
2//!
3//! # File
4//! `crates/axonml-text/src/datasets/mod.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18use crate::tokenizer::Tokenizer;
19use crate::vocab::Vocab;
20use axonml_data::Dataset;
21use axonml_tensor::Tensor;
22
23// =============================================================================
24// TextDataset
25// =============================================================================
26
27/// A dataset of text samples with labels.
28pub struct TextDataset {
29    texts: Vec<String>,
30    labels: Vec<usize>,
31    vocab: Vocab,
32    max_length: usize,
33    num_classes: usize,
34    /// Tokenizer used for encoding text (must match how vocab was built).
35    tokenizer: Box<dyn Tokenizer>,
36}
37
38impl TextDataset {
39    /// Creates a new `TextDataset` with a whitespace tokenizer.
40    #[must_use]
41    pub fn new(texts: Vec<String>, labels: Vec<usize>, vocab: Vocab, max_length: usize) -> Self {
42        Self::with_tokenizer(
43            texts,
44            labels,
45            vocab,
46            max_length,
47            crate::tokenizer::WhitespaceTokenizer::new(),
48        )
49    }
50
51    /// Creates a new `TextDataset` with a specified tokenizer.
52    pub fn with_tokenizer<T: Tokenizer + 'static>(
53        texts: Vec<String>,
54        labels: Vec<usize>,
55        vocab: Vocab,
56        max_length: usize,
57        tokenizer: T,
58    ) -> Self {
59        let num_classes = labels.iter().max().map_or(0, |&m| m + 1);
60        Self {
61            texts,
62            labels,
63            vocab,
64            max_length,
65            num_classes,
66            tokenizer: Box::new(tokenizer),
67        }
68    }
69
70    /// Creates a `TextDataset` from raw text samples with a tokenizer.
71    ///
72    /// The same tokenizer is stored and used for both vocabulary building
73    /// and encoding at access time.
74    pub fn from_samples<T: Tokenizer + Clone + 'static>(
75        samples: &[(String, usize)],
76        tokenizer: &T,
77        min_freq: usize,
78        max_length: usize,
79    ) -> Self {
80        use std::collections::HashMap;
81
82        // Build vocabulary from tokenized text
83        let mut freq: HashMap<String, usize> = HashMap::new();
84        for (text, _) in samples {
85            for token in tokenizer.tokenize(text) {
86                *freq.entry(token).or_insert(0) += 1;
87            }
88        }
89
90        // Create vocabulary with tokens meeting min_freq threshold
91        let mut vocab = Vocab::with_special_tokens();
92        let mut tokens: Vec<_> = freq
93            .into_iter()
94            .filter(|(_, count)| *count >= min_freq)
95            .collect();
96        tokens.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
97        for (token, _) in tokens {
98            vocab.add_token(&token);
99        }
100
101        let texts: Vec<String> = samples.iter().map(|(t, _)| t.clone()).collect();
102        let labels: Vec<usize> = samples.iter().map(|(_, l)| *l).collect();
103
104        Self::with_tokenizer(texts, labels, vocab, max_length, tokenizer.clone())
105    }
106
107    /// Returns the vocabulary.
108    #[must_use]
109    pub fn vocab(&self) -> &Vocab {
110        &self.vocab
111    }
112
113    /// Returns the number of classes.
114    #[must_use]
115    pub fn num_classes(&self) -> usize {
116        self.num_classes
117    }
118
119    /// Returns the maximum sequence length.
120    #[must_use]
121    pub fn max_length(&self) -> usize {
122        self.max_length
123    }
124
125    /// Encodes text to padded tensor using the stored tokenizer.
126    fn encode_text(&self, text: &str) -> Tensor<f32> {
127        let tokens = self.tokenizer.tokenize(text);
128        let mut indices: Vec<f32> = tokens
129            .iter()
130            .take(self.max_length)
131            .map(|t| self.vocab.token_to_index(t) as f32)
132            .collect();
133
134        // Pad to max_length
135        let pad_idx = self.vocab.pad_index().unwrap_or(0) as f32;
136        while indices.len() < self.max_length {
137            indices.push(pad_idx);
138        }
139
140        Tensor::from_vec(indices, &[self.max_length]).unwrap()
141    }
142}
143
144impl Dataset for TextDataset {
145    type Item = (Tensor<f32>, Tensor<f32>);
146
147    fn len(&self) -> usize {
148        self.texts.len()
149    }
150
151    fn get(&self, index: usize) -> Option<Self::Item> {
152        if index >= self.texts.len() {
153            return None;
154        }
155
156        let text = self.encode_text(&self.texts[index]);
157
158        // Class index label (compatible with CrossEntropyLoss)
159        let label = Tensor::from_vec(vec![self.labels[index] as f32], &[1]).unwrap();
160
161        Some((text, label))
162    }
163}
164
165// =============================================================================
166// LanguageModelDataset
167// =============================================================================
168
169/// A dataset for language modeling (next token prediction).
170pub struct LanguageModelDataset {
171    tokens: Vec<usize>,
172    sequence_length: usize,
173    vocab: Vocab,
174}
175
176impl LanguageModelDataset {
177    /// Creates a new `LanguageModelDataset`.
178    #[must_use]
179    pub fn new(text: &str, vocab: Vocab, sequence_length: usize) -> Self {
180        let tokens: Vec<usize> = text
181            .split_whitespace()
182            .map(|t| vocab.token_to_index(t))
183            .collect();
184
185        Self {
186            tokens,
187            sequence_length,
188            vocab,
189        }
190    }
191
192    /// Creates a dataset from text, building vocabulary automatically.
193    #[must_use]
194    pub fn from_text(text: &str, sequence_length: usize, min_freq: usize) -> Self {
195        let vocab = Vocab::from_text(text, min_freq);
196        Self::new(text, vocab, sequence_length)
197    }
198
199    /// Returns the vocabulary.
200    #[must_use]
201    pub fn vocab(&self) -> &Vocab {
202        &self.vocab
203    }
204}
205
206impl Dataset for LanguageModelDataset {
207    type Item = (Tensor<f32>, Tensor<f32>);
208
209    fn len(&self) -> usize {
210        if self.tokens.len() <= self.sequence_length {
211            0
212        } else {
213            self.tokens.len() - self.sequence_length
214        }
215    }
216
217    fn get(&self, index: usize) -> Option<Self::Item> {
218        if index >= self.len() {
219            return None;
220        }
221
222        // Input: tokens[index..index+sequence_length]
223        let input: Vec<f32> = self.tokens[index..index + self.sequence_length]
224            .iter()
225            .map(|&t| t as f32)
226            .collect();
227
228        // Target: tokens[index+1..index+sequence_length+1]
229        let target: Vec<f32> = self.tokens[(index + 1)..=(index + self.sequence_length)]
230            .iter()
231            .map(|&t| t as f32)
232            .collect();
233
234        Some((
235            Tensor::from_vec(input, &[self.sequence_length]).unwrap(),
236            Tensor::from_vec(target, &[self.sequence_length]).unwrap(),
237        ))
238    }
239}
240
241// =============================================================================
242// SyntheticSentimentDataset
243// =============================================================================
244
245/// A synthetic sentiment analysis dataset for testing.
246pub struct SyntheticSentimentDataset {
247    size: usize,
248    max_length: usize,
249    vocab_size: usize,
250}
251
252impl SyntheticSentimentDataset {
253    /// Creates a new synthetic sentiment dataset.
254    #[must_use]
255    pub fn new(size: usize, max_length: usize, vocab_size: usize) -> Self {
256        Self {
257            size,
258            max_length,
259            vocab_size,
260        }
261    }
262
263    /// Creates a small test dataset.
264    #[must_use]
265    pub fn small() -> Self {
266        Self::new(100, 32, 1000)
267    }
268
269    /// Creates a standard training dataset.
270    #[must_use]
271    pub fn train() -> Self {
272        Self::new(10000, 64, 10000)
273    }
274
275    /// Creates a standard test dataset.
276    #[must_use]
277    pub fn test() -> Self {
278        Self::new(2000, 64, 10000)
279    }
280}
281
282impl Dataset for SyntheticSentimentDataset {
283    type Item = (Tensor<f32>, Tensor<f32>);
284
285    fn len(&self) -> usize {
286        self.size
287    }
288
289    fn get(&self, index: usize) -> Option<Self::Item> {
290        if index >= self.size {
291            return None;
292        }
293
294        // Generate deterministic "random" sequence
295        let seed = index as u32;
296        let label = index % 2; // Binary sentiment
297
298        let mut text = Vec::with_capacity(self.max_length);
299        for i in 0..self.max_length {
300            let token_seed = seed.wrapping_mul(1103515245).wrapping_add(12345 + i as u32);
301            let token = (token_seed as usize) % self.vocab_size;
302            // Bias tokens based on sentiment
303            let biased_token = if label == 1 {
304                (token + self.vocab_size / 2) % self.vocab_size
305            } else {
306                token
307            };
308            text.push(biased_token as f32);
309        }
310
311        let text_tensor = Tensor::from_vec(text, &[self.max_length]).unwrap();
312
313        // Class index label (compatible with CrossEntropyLoss)
314        let label_tensor = Tensor::from_vec(vec![label as f32], &[1]).unwrap();
315
316        Some((text_tensor, label_tensor))
317    }
318}
319
320// =============================================================================
321// SyntheticSequenceDataset
322// =============================================================================
323
324/// A synthetic sequence-to-sequence dataset for testing.
325pub struct SyntheticSeq2SeqDataset {
326    size: usize,
327    src_length: usize,
328    tgt_length: usize,
329    vocab_size: usize,
330}
331
332impl SyntheticSeq2SeqDataset {
333    /// Creates a new synthetic seq2seq dataset.
334    #[must_use]
335    pub fn new(size: usize, src_length: usize, tgt_length: usize, vocab_size: usize) -> Self {
336        Self {
337            size,
338            src_length,
339            tgt_length,
340            vocab_size,
341        }
342    }
343
344    /// Creates a copy task dataset (target = reversed source).
345    #[must_use]
346    pub fn copy_task(size: usize, length: usize, vocab_size: usize) -> Self {
347        Self::new(size, length, length, vocab_size)
348    }
349}
350
351impl Dataset for SyntheticSeq2SeqDataset {
352    type Item = (Tensor<f32>, Tensor<f32>);
353
354    fn len(&self) -> usize {
355        self.size
356    }
357
358    fn get(&self, index: usize) -> Option<Self::Item> {
359        if index >= self.size {
360            return None;
361        }
362
363        let seed = index as u32;
364
365        // Generate source sequence
366        let mut src = Vec::with_capacity(self.src_length);
367        for i in 0..self.src_length {
368            let token_seed = seed.wrapping_mul(1103515245).wrapping_add(12345 + i as u32);
369            let token = (token_seed as usize) % self.vocab_size;
370            src.push(token as f32);
371        }
372
373        // Target is reversed source (simple copy task)
374        let tgt: Vec<f32> = src.iter().rev().copied().collect();
375
376        Some((
377            Tensor::from_vec(src, &[self.src_length]).unwrap(),
378            Tensor::from_vec(tgt, &[self.tgt_length]).unwrap(),
379        ))
380    }
381}
382
383// =============================================================================
384// Tests
385// =============================================================================
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390
391    #[test]
392    fn test_text_dataset() {
393        let vocab = Vocab::from_tokens(&["hello", "world", "good", "bad", "<pad>", "<unk>"]);
394        let texts = vec!["hello world".to_string(), "good bad".to_string()];
395        let labels = vec![0, 1];
396
397        let dataset = TextDataset::new(texts, labels, vocab, 10);
398
399        assert_eq!(dataset.len(), 2);
400        assert_eq!(dataset.num_classes(), 2);
401
402        let (text, label) = dataset.get(0).unwrap();
403        assert_eq!(text.shape(), &[10]);
404        // Class index label (scalar [1])
405        assert_eq!(label.shape(), &[1]);
406        assert_eq!(label.to_vec()[0], 0.0);
407    }
408
409    #[test]
410    fn test_language_model_dataset() {
411        let text = "the quick brown fox jumps over the lazy dog";
412        let dataset = LanguageModelDataset::from_text(text, 3, 1);
413
414        assert!(dataset.len() > 0);
415
416        let (input, target) = dataset.get(0).unwrap();
417        assert_eq!(input.shape(), &[3]);
418        assert_eq!(target.shape(), &[3]);
419    }
420
421    #[test]
422    fn test_synthetic_sentiment_dataset() {
423        let dataset = SyntheticSentimentDataset::small();
424
425        assert_eq!(dataset.len(), 100);
426
427        let (text, label) = dataset.get(0).unwrap();
428        assert_eq!(text.shape(), &[32]);
429        // Class index label
430        assert_eq!(label.shape(), &[1]);
431        let label_val = label.to_vec()[0];
432        assert!(label_val == 0.0 || label_val == 1.0);
433    }
434
435    #[test]
436    fn test_synthetic_sentiment_deterministic() {
437        let dataset = SyntheticSentimentDataset::small();
438
439        let (text1, label1) = dataset.get(5).unwrap();
440        let (text2, label2) = dataset.get(5).unwrap();
441
442        assert_eq!(text1.to_vec(), text2.to_vec());
443        assert_eq!(label1.to_vec(), label2.to_vec());
444        // Label should be class index
445        assert_eq!(label1.shape(), &[1]);
446    }
447
448    #[test]
449    fn test_synthetic_seq2seq_dataset() {
450        let dataset = SyntheticSeq2SeqDataset::copy_task(100, 10, 50);
451
452        assert_eq!(dataset.len(), 100);
453
454        let (src, tgt) = dataset.get(0).unwrap();
455        assert_eq!(src.shape(), &[10]);
456        assert_eq!(tgt.shape(), &[10]);
457
458        // Target should be reversed source
459        let src_vec = src.to_vec();
460        let tgt_vec = tgt.to_vec();
461        let reversed: Vec<f32> = src_vec.iter().rev().copied().collect();
462        assert_eq!(tgt_vec, reversed);
463    }
464
465    #[test]
466    fn test_text_dataset_padding() {
467        let vocab = Vocab::with_special_tokens();
468        let texts = vec!["a b".to_string()];
469        let labels = vec![0];
470
471        let dataset = TextDataset::new(texts, labels, vocab, 10);
472        let (text, _) = dataset.get(0).unwrap();
473
474        // Should be padded to length 10
475        assert_eq!(text.shape(), &[10]);
476    }
477}