1use crate::tokenizer::Tokenizer;
9use crate::vocab::Vocab;
10use axonml_data::Dataset;
11use axonml_tensor::Tensor;
12
13pub struct TextDataset {
19 texts: Vec<String>,
20 labels: Vec<usize>,
21 vocab: Vocab,
22 max_length: usize,
23 num_classes: usize,
24}
25
26impl TextDataset {
27 #[must_use] pub fn new(texts: Vec<String>, labels: Vec<usize>, vocab: Vocab, max_length: usize) -> Self {
29 let num_classes = labels.iter().max().map_or(0, |&m| m + 1);
30 Self {
31 texts,
32 labels,
33 vocab,
34 max_length,
35 num_classes,
36 }
37 }
38
39 pub fn from_samples<T: Tokenizer>(
41 samples: &[(String, usize)],
42 tokenizer: &T,
43 min_freq: usize,
44 max_length: usize,
45 ) -> Self {
46 use std::collections::HashMap;
47
48 let mut freq: HashMap<String, usize> = HashMap::new();
50 for (text, _) in samples {
51 for token in tokenizer.tokenize(text) {
52 *freq.entry(token).or_insert(0) += 1;
53 }
54 }
55
56 let mut vocab = Vocab::with_special_tokens();
58 let mut tokens: Vec<_> = freq
59 .into_iter()
60 .filter(|(_, count)| *count >= min_freq)
61 .collect();
62 tokens.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
63 for (token, _) in tokens {
64 vocab.add_token(&token);
65 }
66
67 let texts: Vec<String> = samples.iter().map(|(t, _)| t.clone()).collect();
68 let labels: Vec<usize> = samples.iter().map(|(_, l)| *l).collect();
69
70 Self::new(texts, labels, vocab, max_length)
71 }
72
73 #[must_use] pub fn vocab(&self) -> &Vocab {
75 &self.vocab
76 }
77
78 #[must_use] pub fn num_classes(&self) -> usize {
80 self.num_classes
81 }
82
83 #[must_use] pub fn max_length(&self) -> usize {
85 self.max_length
86 }
87
88 fn encode_text(&self, text: &str) -> Tensor<f32> {
90 let tokens: Vec<&str> = text.split_whitespace().collect();
91 let mut indices: Vec<f32> = tokens
92 .iter()
93 .take(self.max_length)
94 .map(|t| self.vocab.token_to_index(t) as f32)
95 .collect();
96
97 let pad_idx = self.vocab.pad_index().unwrap_or(0) as f32;
99 while indices.len() < self.max_length {
100 indices.push(pad_idx);
101 }
102
103 Tensor::from_vec(indices, &[self.max_length]).unwrap()
104 }
105}
106
107impl Dataset for TextDataset {
108 type Item = (Tensor<f32>, Tensor<f32>);
109
110 fn len(&self) -> usize {
111 self.texts.len()
112 }
113
114 fn get(&self, index: usize) -> Option<Self::Item> {
115 if index >= self.texts.len() {
116 return None;
117 }
118
119 let text = self.encode_text(&self.texts[index]);
120
121 let mut label_vec = vec![0.0f32; self.num_classes];
123 label_vec[self.labels[index]] = 1.0;
124 let label = Tensor::from_vec(label_vec, &[self.num_classes]).unwrap();
125
126 Some((text, label))
127 }
128}
129
130pub struct LanguageModelDataset {
136 tokens: Vec<usize>,
137 sequence_length: usize,
138 vocab: Vocab,
139}
140
141impl LanguageModelDataset {
142 #[must_use] pub fn new(text: &str, vocab: Vocab, sequence_length: usize) -> Self {
144 let tokens: Vec<usize> = text
145 .split_whitespace()
146 .map(|t| vocab.token_to_index(t))
147 .collect();
148
149 Self {
150 tokens,
151 sequence_length,
152 vocab,
153 }
154 }
155
156 #[must_use] pub fn from_text(text: &str, sequence_length: usize, min_freq: usize) -> Self {
158 let vocab = Vocab::from_text(text, min_freq);
159 Self::new(text, vocab, sequence_length)
160 }
161
162 #[must_use] pub fn vocab(&self) -> &Vocab {
164 &self.vocab
165 }
166}
167
168impl Dataset for LanguageModelDataset {
169 type Item = (Tensor<f32>, Tensor<f32>);
170
171 fn len(&self) -> usize {
172 if self.tokens.len() <= self.sequence_length {
173 0
174 } else {
175 self.tokens.len() - self.sequence_length
176 }
177 }
178
179 fn get(&self, index: usize) -> Option<Self::Item> {
180 if index >= self.len() {
181 return None;
182 }
183
184 let input: Vec<f32> = self.tokens[index..index + self.sequence_length]
186 .iter()
187 .map(|&t| t as f32)
188 .collect();
189
190 let target: Vec<f32> = self.tokens[(index + 1)..=(index + self.sequence_length)]
192 .iter()
193 .map(|&t| t as f32)
194 .collect();
195
196 Some((
197 Tensor::from_vec(input, &[self.sequence_length]).unwrap(),
198 Tensor::from_vec(target, &[self.sequence_length]).unwrap(),
199 ))
200 }
201}
202
203pub struct SyntheticSentimentDataset {
209 size: usize,
210 max_length: usize,
211 vocab_size: usize,
212}
213
214impl SyntheticSentimentDataset {
215 #[must_use] pub fn new(size: usize, max_length: usize, vocab_size: usize) -> Self {
217 Self {
218 size,
219 max_length,
220 vocab_size,
221 }
222 }
223
224 #[must_use] pub fn small() -> Self {
226 Self::new(100, 32, 1000)
227 }
228
229 #[must_use] pub fn train() -> Self {
231 Self::new(10000, 64, 10000)
232 }
233
234 #[must_use] pub fn test() -> Self {
236 Self::new(2000, 64, 10000)
237 }
238}
239
240impl Dataset for SyntheticSentimentDataset {
241 type Item = (Tensor<f32>, Tensor<f32>);
242
243 fn len(&self) -> usize {
244 self.size
245 }
246
247 fn get(&self, index: usize) -> Option<Self::Item> {
248 if index >= self.size {
249 return None;
250 }
251
252 let seed = index as u32;
254 let label = index % 2; let mut text = Vec::with_capacity(self.max_length);
257 for i in 0..self.max_length {
258 let token_seed = seed.wrapping_mul(1103515245).wrapping_add(12345 + i as u32);
259 let token = (token_seed as usize) % self.vocab_size;
260 let biased_token = if label == 1 {
262 (token + self.vocab_size / 2) % self.vocab_size
263 } else {
264 token
265 };
266 text.push(biased_token as f32);
267 }
268
269 let text_tensor = Tensor::from_vec(text, &[self.max_length]).unwrap();
270
271 let mut label_vec = vec![0.0f32; 2];
273 label_vec[label] = 1.0;
274 let label_tensor = Tensor::from_vec(label_vec, &[2]).unwrap();
275
276 Some((text_tensor, label_tensor))
277 }
278}
279
280pub struct SyntheticSeq2SeqDataset {
286 size: usize,
287 src_length: usize,
288 tgt_length: usize,
289 vocab_size: usize,
290}
291
292impl SyntheticSeq2SeqDataset {
293 #[must_use] pub fn new(size: usize, src_length: usize, tgt_length: usize, vocab_size: usize) -> Self {
295 Self {
296 size,
297 src_length,
298 tgt_length,
299 vocab_size,
300 }
301 }
302
303 #[must_use] pub fn copy_task(size: usize, length: usize, vocab_size: usize) -> Self {
305 Self::new(size, length, length, vocab_size)
306 }
307}
308
309impl Dataset for SyntheticSeq2SeqDataset {
310 type Item = (Tensor<f32>, Tensor<f32>);
311
312 fn len(&self) -> usize {
313 self.size
314 }
315
316 fn get(&self, index: usize) -> Option<Self::Item> {
317 if index >= self.size {
318 return None;
319 }
320
321 let seed = index as u32;
322
323 let mut src = Vec::with_capacity(self.src_length);
325 for i in 0..self.src_length {
326 let token_seed = seed.wrapping_mul(1103515245).wrapping_add(12345 + i as u32);
327 let token = (token_seed as usize) % self.vocab_size;
328 src.push(token as f32);
329 }
330
331 let tgt: Vec<f32> = src.iter().rev().copied().collect();
333
334 Some((
335 Tensor::from_vec(src, &[self.src_length]).unwrap(),
336 Tensor::from_vec(tgt, &[self.tgt_length]).unwrap(),
337 ))
338 }
339}
340
341#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[test]
350 fn test_text_dataset() {
351 let vocab = Vocab::from_tokens(&["hello", "world", "good", "bad", "<pad>", "<unk>"]);
352 let texts = vec!["hello world".to_string(), "good bad".to_string()];
353 let labels = vec![0, 1];
354
355 let dataset = TextDataset::new(texts, labels, vocab, 10);
356
357 assert_eq!(dataset.len(), 2);
358 assert_eq!(dataset.num_classes(), 2);
359
360 let (text, label) = dataset.get(0).unwrap();
361 assert_eq!(text.shape(), &[10]);
362 assert_eq!(label.shape(), &[2]);
363 }
364
365 #[test]
366 fn test_language_model_dataset() {
367 let text = "the quick brown fox jumps over the lazy dog";
368 let dataset = LanguageModelDataset::from_text(text, 3, 1);
369
370 assert!(dataset.len() > 0);
371
372 let (input, target) = dataset.get(0).unwrap();
373 assert_eq!(input.shape(), &[3]);
374 assert_eq!(target.shape(), &[3]);
375 }
376
377 #[test]
378 fn test_synthetic_sentiment_dataset() {
379 let dataset = SyntheticSentimentDataset::small();
380
381 assert_eq!(dataset.len(), 100);
382
383 let (text, label) = dataset.get(0).unwrap();
384 assert_eq!(text.shape(), &[32]);
385 assert_eq!(label.shape(), &[2]);
386
387 let label_vec = label.to_vec();
389 let sum: f32 = label_vec.iter().sum();
390 assert!((sum - 1.0).abs() < 0.001);
391 }
392
393 #[test]
394 fn test_synthetic_sentiment_deterministic() {
395 let dataset = SyntheticSentimentDataset::small();
396
397 let (text1, label1) = dataset.get(5).unwrap();
398 let (text2, label2) = dataset.get(5).unwrap();
399
400 assert_eq!(text1.to_vec(), text2.to_vec());
401 assert_eq!(label1.to_vec(), label2.to_vec());
402 }
403
404 #[test]
405 fn test_synthetic_seq2seq_dataset() {
406 let dataset = SyntheticSeq2SeqDataset::copy_task(100, 10, 50);
407
408 assert_eq!(dataset.len(), 100);
409
410 let (src, tgt) = dataset.get(0).unwrap();
411 assert_eq!(src.shape(), &[10]);
412 assert_eq!(tgt.shape(), &[10]);
413
414 let src_vec = src.to_vec();
416 let tgt_vec = tgt.to_vec();
417 let reversed: Vec<f32> = src_vec.iter().rev().copied().collect();
418 assert_eq!(tgt_vec, reversed);
419 }
420
421 #[test]
422 fn test_text_dataset_padding() {
423 let vocab = Vocab::with_special_tokens();
424 let texts = vec!["a b".to_string()];
425 let labels = vec![0];
426
427 let dataset = TextDataset::new(texts, labels, vocab, 10);
428 let (text, _) = dataset.get(0).unwrap();
429
430 assert_eq!(text.shape(), &[10]);
432 }
433}