1use crate::tokenizer::Tokenizer;
19use crate::vocab::Vocab;
20use axonml_data::Dataset;
21use axonml_tensor::Tensor;
22
23pub struct TextDataset {
29 texts: Vec<String>,
30 labels: Vec<usize>,
31 vocab: Vocab,
32 max_length: usize,
33 num_classes: usize,
34 tokenizer: Box<dyn Tokenizer>,
36}
37
38impl TextDataset {
39 #[must_use]
41 pub fn new(texts: Vec<String>, labels: Vec<usize>, vocab: Vocab, max_length: usize) -> Self {
42 Self::with_tokenizer(
43 texts,
44 labels,
45 vocab,
46 max_length,
47 crate::tokenizer::WhitespaceTokenizer::new(),
48 )
49 }
50
51 pub fn with_tokenizer<T: Tokenizer + 'static>(
53 texts: Vec<String>,
54 labels: Vec<usize>,
55 vocab: Vocab,
56 max_length: usize,
57 tokenizer: T,
58 ) -> Self {
59 let num_classes = labels.iter().max().map_or(0, |&m| m + 1);
60 Self {
61 texts,
62 labels,
63 vocab,
64 max_length,
65 num_classes,
66 tokenizer: Box::new(tokenizer),
67 }
68 }
69
70 pub fn from_samples<T: Tokenizer + Clone + 'static>(
75 samples: &[(String, usize)],
76 tokenizer: &T,
77 min_freq: usize,
78 max_length: usize,
79 ) -> Self {
80 use std::collections::HashMap;
81
82 let mut freq: HashMap<String, usize> = HashMap::new();
84 for (text, _) in samples {
85 for token in tokenizer.tokenize(text) {
86 *freq.entry(token).or_insert(0) += 1;
87 }
88 }
89
90 let mut vocab = Vocab::with_special_tokens();
92 let mut tokens: Vec<_> = freq
93 .into_iter()
94 .filter(|(_, count)| *count >= min_freq)
95 .collect();
96 tokens.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
97 for (token, _) in tokens {
98 vocab.add_token(&token);
99 }
100
101 let texts: Vec<String> = samples.iter().map(|(t, _)| t.clone()).collect();
102 let labels: Vec<usize> = samples.iter().map(|(_, l)| *l).collect();
103
104 Self::with_tokenizer(texts, labels, vocab, max_length, tokenizer.clone())
105 }
106
107 #[must_use]
109 pub fn vocab(&self) -> &Vocab {
110 &self.vocab
111 }
112
113 #[must_use]
115 pub fn num_classes(&self) -> usize {
116 self.num_classes
117 }
118
119 #[must_use]
121 pub fn max_length(&self) -> usize {
122 self.max_length
123 }
124
125 fn encode_text(&self, text: &str) -> Tensor<f32> {
127 let tokens = self.tokenizer.tokenize(text);
128 let mut indices: Vec<f32> = tokens
129 .iter()
130 .take(self.max_length)
131 .map(|t| self.vocab.token_to_index(t) as f32)
132 .collect();
133
134 let pad_idx = self.vocab.pad_index().unwrap_or(0) as f32;
136 while indices.len() < self.max_length {
137 indices.push(pad_idx);
138 }
139
140 Tensor::from_vec(indices, &[self.max_length]).unwrap()
141 }
142}
143
144impl Dataset for TextDataset {
145 type Item = (Tensor<f32>, Tensor<f32>);
146
147 fn len(&self) -> usize {
148 self.texts.len()
149 }
150
151 fn get(&self, index: usize) -> Option<Self::Item> {
152 if index >= self.texts.len() {
153 return None;
154 }
155
156 let text = self.encode_text(&self.texts[index]);
157
158 let label = Tensor::from_vec(vec![self.labels[index] as f32], &[1]).unwrap();
160
161 Some((text, label))
162 }
163}
164
165pub struct LanguageModelDataset {
171 tokens: Vec<usize>,
172 sequence_length: usize,
173 vocab: Vocab,
174}
175
176impl LanguageModelDataset {
177 #[must_use]
179 pub fn new(text: &str, vocab: Vocab, sequence_length: usize) -> Self {
180 let tokens: Vec<usize> = text
181 .split_whitespace()
182 .map(|t| vocab.token_to_index(t))
183 .collect();
184
185 Self {
186 tokens,
187 sequence_length,
188 vocab,
189 }
190 }
191
192 #[must_use]
194 pub fn from_text(text: &str, sequence_length: usize, min_freq: usize) -> Self {
195 let vocab = Vocab::from_text(text, min_freq);
196 Self::new(text, vocab, sequence_length)
197 }
198
199 #[must_use]
201 pub fn vocab(&self) -> &Vocab {
202 &self.vocab
203 }
204}
205
206impl Dataset for LanguageModelDataset {
207 type Item = (Tensor<f32>, Tensor<f32>);
208
209 fn len(&self) -> usize {
210 if self.tokens.len() <= self.sequence_length {
211 0
212 } else {
213 self.tokens.len() - self.sequence_length
214 }
215 }
216
217 fn get(&self, index: usize) -> Option<Self::Item> {
218 if index >= self.len() {
219 return None;
220 }
221
222 let input: Vec<f32> = self.tokens[index..index + self.sequence_length]
224 .iter()
225 .map(|&t| t as f32)
226 .collect();
227
228 let target: Vec<f32> = self.tokens[(index + 1)..=(index + self.sequence_length)]
230 .iter()
231 .map(|&t| t as f32)
232 .collect();
233
234 Some((
235 Tensor::from_vec(input, &[self.sequence_length]).unwrap(),
236 Tensor::from_vec(target, &[self.sequence_length]).unwrap(),
237 ))
238 }
239}
240
241pub struct SyntheticSentimentDataset {
247 size: usize,
248 max_length: usize,
249 vocab_size: usize,
250}
251
252impl SyntheticSentimentDataset {
253 #[must_use]
255 pub fn new(size: usize, max_length: usize, vocab_size: usize) -> Self {
256 Self {
257 size,
258 max_length,
259 vocab_size,
260 }
261 }
262
263 #[must_use]
265 pub fn small() -> Self {
266 Self::new(100, 32, 1000)
267 }
268
269 #[must_use]
271 pub fn train() -> Self {
272 Self::new(10000, 64, 10000)
273 }
274
275 #[must_use]
277 pub fn test() -> Self {
278 Self::new(2000, 64, 10000)
279 }
280}
281
282impl Dataset for SyntheticSentimentDataset {
283 type Item = (Tensor<f32>, Tensor<f32>);
284
285 fn len(&self) -> usize {
286 self.size
287 }
288
289 fn get(&self, index: usize) -> Option<Self::Item> {
290 if index >= self.size {
291 return None;
292 }
293
294 let seed = index as u32;
296 let label = index % 2; let mut text = Vec::with_capacity(self.max_length);
299 for i in 0..self.max_length {
300 let token_seed = seed.wrapping_mul(1103515245).wrapping_add(12345 + i as u32);
301 let token = (token_seed as usize) % self.vocab_size;
302 let biased_token = if label == 1 {
304 (token + self.vocab_size / 2) % self.vocab_size
305 } else {
306 token
307 };
308 text.push(biased_token as f32);
309 }
310
311 let text_tensor = Tensor::from_vec(text, &[self.max_length]).unwrap();
312
313 let label_tensor = Tensor::from_vec(vec![label as f32], &[1]).unwrap();
315
316 Some((text_tensor, label_tensor))
317 }
318}
319
320pub struct SyntheticSeq2SeqDataset {
326 size: usize,
327 src_length: usize,
328 tgt_length: usize,
329 vocab_size: usize,
330}
331
332impl SyntheticSeq2SeqDataset {
333 #[must_use]
335 pub fn new(size: usize, src_length: usize, tgt_length: usize, vocab_size: usize) -> Self {
336 Self {
337 size,
338 src_length,
339 tgt_length,
340 vocab_size,
341 }
342 }
343
344 #[must_use]
346 pub fn copy_task(size: usize, length: usize, vocab_size: usize) -> Self {
347 Self::new(size, length, length, vocab_size)
348 }
349}
350
351impl Dataset for SyntheticSeq2SeqDataset {
352 type Item = (Tensor<f32>, Tensor<f32>);
353
354 fn len(&self) -> usize {
355 self.size
356 }
357
358 fn get(&self, index: usize) -> Option<Self::Item> {
359 if index >= self.size {
360 return None;
361 }
362
363 let seed = index as u32;
364
365 let mut src = Vec::with_capacity(self.src_length);
367 for i in 0..self.src_length {
368 let token_seed = seed.wrapping_mul(1103515245).wrapping_add(12345 + i as u32);
369 let token = (token_seed as usize) % self.vocab_size;
370 src.push(token as f32);
371 }
372
373 let tgt: Vec<f32> = src.iter().rev().copied().collect();
375
376 Some((
377 Tensor::from_vec(src, &[self.src_length]).unwrap(),
378 Tensor::from_vec(tgt, &[self.tgt_length]).unwrap(),
379 ))
380 }
381}
382
383#[cfg(test)]
388mod tests {
389 use super::*;
390
391 #[test]
392 fn test_text_dataset() {
393 let vocab = Vocab::from_tokens(&["hello", "world", "good", "bad", "<pad>", "<unk>"]);
394 let texts = vec!["hello world".to_string(), "good bad".to_string()];
395 let labels = vec![0, 1];
396
397 let dataset = TextDataset::new(texts, labels, vocab, 10);
398
399 assert_eq!(dataset.len(), 2);
400 assert_eq!(dataset.num_classes(), 2);
401
402 let (text, label) = dataset.get(0).unwrap();
403 assert_eq!(text.shape(), &[10]);
404 assert_eq!(label.shape(), &[1]);
406 assert_eq!(label.to_vec()[0], 0.0);
407 }
408
409 #[test]
410 fn test_language_model_dataset() {
411 let text = "the quick brown fox jumps over the lazy dog";
412 let dataset = LanguageModelDataset::from_text(text, 3, 1);
413
414 assert!(dataset.len() > 0);
415
416 let (input, target) = dataset.get(0).unwrap();
417 assert_eq!(input.shape(), &[3]);
418 assert_eq!(target.shape(), &[3]);
419 }
420
421 #[test]
422 fn test_synthetic_sentiment_dataset() {
423 let dataset = SyntheticSentimentDataset::small();
424
425 assert_eq!(dataset.len(), 100);
426
427 let (text, label) = dataset.get(0).unwrap();
428 assert_eq!(text.shape(), &[32]);
429 assert_eq!(label.shape(), &[1]);
431 let label_val = label.to_vec()[0];
432 assert!(label_val == 0.0 || label_val == 1.0);
433 }
434
435 #[test]
436 fn test_synthetic_sentiment_deterministic() {
437 let dataset = SyntheticSentimentDataset::small();
438
439 let (text1, label1) = dataset.get(5).unwrap();
440 let (text2, label2) = dataset.get(5).unwrap();
441
442 assert_eq!(text1.to_vec(), text2.to_vec());
443 assert_eq!(label1.to_vec(), label2.to_vec());
444 assert_eq!(label1.shape(), &[1]);
446 }
447
448 #[test]
449 fn test_synthetic_seq2seq_dataset() {
450 let dataset = SyntheticSeq2SeqDataset::copy_task(100, 10, 50);
451
452 assert_eq!(dataset.len(), 100);
453
454 let (src, tgt) = dataset.get(0).unwrap();
455 assert_eq!(src.shape(), &[10]);
456 assert_eq!(tgt.shape(), &[10]);
457
458 let src_vec = src.to_vec();
460 let tgt_vec = tgt.to_vec();
461 let reversed: Vec<f32> = src_vec.iter().rev().copied().collect();
462 assert_eq!(tgt_vec, reversed);
463 }
464
465 #[test]
466 fn test_text_dataset_padding() {
467 let vocab = Vocab::with_special_tokens();
468 let texts = vec!["a b".to_string()];
469 let labels = vec![0];
470
471 let dataset = TextDataset::new(texts, labels, vocab, 10);
472 let (text, _) = dataset.get(0).unwrap();
473
474 assert_eq!(text.shape(), &[10]);
476 }
477}