1use crate::tokenizer::Tokenizer;
18use crate::vocab::Vocab;
19use axonml_data::Dataset;
20use axonml_tensor::Tensor;
21
22pub struct TextDataset {
28 texts: Vec<String>,
29 labels: Vec<usize>,
30 vocab: Vocab,
31 max_length: usize,
32 num_classes: usize,
33}
34
35impl TextDataset {
36 #[must_use]
38 pub fn new(texts: Vec<String>, labels: Vec<usize>, vocab: Vocab, max_length: usize) -> Self {
39 let num_classes = labels.iter().max().map_or(0, |&m| m + 1);
40 Self {
41 texts,
42 labels,
43 vocab,
44 max_length,
45 num_classes,
46 }
47 }
48
49 pub fn from_samples<T: Tokenizer>(
51 samples: &[(String, usize)],
52 tokenizer: &T,
53 min_freq: usize,
54 max_length: usize,
55 ) -> Self {
56 use std::collections::HashMap;
57
58 let mut freq: HashMap<String, usize> = HashMap::new();
60 for (text, _) in samples {
61 for token in tokenizer.tokenize(text) {
62 *freq.entry(token).or_insert(0) += 1;
63 }
64 }
65
66 let mut vocab = Vocab::with_special_tokens();
68 let mut tokens: Vec<_> = freq
69 .into_iter()
70 .filter(|(_, count)| *count >= min_freq)
71 .collect();
72 tokens.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
73 for (token, _) in tokens {
74 vocab.add_token(&token);
75 }
76
77 let texts: Vec<String> = samples.iter().map(|(t, _)| t.clone()).collect();
78 let labels: Vec<usize> = samples.iter().map(|(_, l)| *l).collect();
79
80 Self::new(texts, labels, vocab, max_length)
81 }
82
83 #[must_use]
85 pub fn vocab(&self) -> &Vocab {
86 &self.vocab
87 }
88
89 #[must_use]
91 pub fn num_classes(&self) -> usize {
92 self.num_classes
93 }
94
95 #[must_use]
97 pub fn max_length(&self) -> usize {
98 self.max_length
99 }
100
101 fn encode_text(&self, text: &str) -> Tensor<f32> {
103 let tokens: Vec<&str> = text.split_whitespace().collect();
104 let mut indices: Vec<f32> = tokens
105 .iter()
106 .take(self.max_length)
107 .map(|t| self.vocab.token_to_index(t) as f32)
108 .collect();
109
110 let pad_idx = self.vocab.pad_index().unwrap_or(0) as f32;
112 while indices.len() < self.max_length {
113 indices.push(pad_idx);
114 }
115
116 Tensor::from_vec(indices, &[self.max_length]).unwrap()
117 }
118}
119
120impl Dataset for TextDataset {
121 type Item = (Tensor<f32>, Tensor<f32>);
122
123 fn len(&self) -> usize {
124 self.texts.len()
125 }
126
127 fn get(&self, index: usize) -> Option<Self::Item> {
128 if index >= self.texts.len() {
129 return None;
130 }
131
132 let text = self.encode_text(&self.texts[index]);
133
134 let mut label_vec = vec![0.0f32; self.num_classes];
136 label_vec[self.labels[index]] = 1.0;
137 let label = Tensor::from_vec(label_vec, &[self.num_classes]).unwrap();
138
139 Some((text, label))
140 }
141}
142
143pub struct LanguageModelDataset {
149 tokens: Vec<usize>,
150 sequence_length: usize,
151 vocab: Vocab,
152}
153
154impl LanguageModelDataset {
155 #[must_use]
157 pub fn new(text: &str, vocab: Vocab, sequence_length: usize) -> Self {
158 let tokens: Vec<usize> = text
159 .split_whitespace()
160 .map(|t| vocab.token_to_index(t))
161 .collect();
162
163 Self {
164 tokens,
165 sequence_length,
166 vocab,
167 }
168 }
169
170 #[must_use]
172 pub fn from_text(text: &str, sequence_length: usize, min_freq: usize) -> Self {
173 let vocab = Vocab::from_text(text, min_freq);
174 Self::new(text, vocab, sequence_length)
175 }
176
177 #[must_use]
179 pub fn vocab(&self) -> &Vocab {
180 &self.vocab
181 }
182}
183
184impl Dataset for LanguageModelDataset {
185 type Item = (Tensor<f32>, Tensor<f32>);
186
187 fn len(&self) -> usize {
188 if self.tokens.len() <= self.sequence_length {
189 0
190 } else {
191 self.tokens.len() - self.sequence_length
192 }
193 }
194
195 fn get(&self, index: usize) -> Option<Self::Item> {
196 if index >= self.len() {
197 return None;
198 }
199
200 let input: Vec<f32> = self.tokens[index..index + self.sequence_length]
202 .iter()
203 .map(|&t| t as f32)
204 .collect();
205
206 let target: Vec<f32> = self.tokens[(index + 1)..=(index + self.sequence_length)]
208 .iter()
209 .map(|&t| t as f32)
210 .collect();
211
212 Some((
213 Tensor::from_vec(input, &[self.sequence_length]).unwrap(),
214 Tensor::from_vec(target, &[self.sequence_length]).unwrap(),
215 ))
216 }
217}
218
219pub struct SyntheticSentimentDataset {
225 size: usize,
226 max_length: usize,
227 vocab_size: usize,
228}
229
230impl SyntheticSentimentDataset {
231 #[must_use]
233 pub fn new(size: usize, max_length: usize, vocab_size: usize) -> Self {
234 Self {
235 size,
236 max_length,
237 vocab_size,
238 }
239 }
240
241 #[must_use]
243 pub fn small() -> Self {
244 Self::new(100, 32, 1000)
245 }
246
247 #[must_use]
249 pub fn train() -> Self {
250 Self::new(10000, 64, 10000)
251 }
252
253 #[must_use]
255 pub fn test() -> Self {
256 Self::new(2000, 64, 10000)
257 }
258}
259
260impl Dataset for SyntheticSentimentDataset {
261 type Item = (Tensor<f32>, Tensor<f32>);
262
263 fn len(&self) -> usize {
264 self.size
265 }
266
267 fn get(&self, index: usize) -> Option<Self::Item> {
268 if index >= self.size {
269 return None;
270 }
271
272 let seed = index as u32;
274 let label = index % 2; let mut text = Vec::with_capacity(self.max_length);
277 for i in 0..self.max_length {
278 let token_seed = seed.wrapping_mul(1103515245).wrapping_add(12345 + i as u32);
279 let token = (token_seed as usize) % self.vocab_size;
280 let biased_token = if label == 1 {
282 (token + self.vocab_size / 2) % self.vocab_size
283 } else {
284 token
285 };
286 text.push(biased_token as f32);
287 }
288
289 let text_tensor = Tensor::from_vec(text, &[self.max_length]).unwrap();
290
291 let mut label_vec = vec![0.0f32; 2];
293 label_vec[label] = 1.0;
294 let label_tensor = Tensor::from_vec(label_vec, &[2]).unwrap();
295
296 Some((text_tensor, label_tensor))
297 }
298}
299
300pub struct SyntheticSeq2SeqDataset {
306 size: usize,
307 src_length: usize,
308 tgt_length: usize,
309 vocab_size: usize,
310}
311
312impl SyntheticSeq2SeqDataset {
313 #[must_use]
315 pub fn new(size: usize, src_length: usize, tgt_length: usize, vocab_size: usize) -> Self {
316 Self {
317 size,
318 src_length,
319 tgt_length,
320 vocab_size,
321 }
322 }
323
324 #[must_use]
326 pub fn copy_task(size: usize, length: usize, vocab_size: usize) -> Self {
327 Self::new(size, length, length, vocab_size)
328 }
329}
330
331impl Dataset for SyntheticSeq2SeqDataset {
332 type Item = (Tensor<f32>, Tensor<f32>);
333
334 fn len(&self) -> usize {
335 self.size
336 }
337
338 fn get(&self, index: usize) -> Option<Self::Item> {
339 if index >= self.size {
340 return None;
341 }
342
343 let seed = index as u32;
344
345 let mut src = Vec::with_capacity(self.src_length);
347 for i in 0..self.src_length {
348 let token_seed = seed.wrapping_mul(1103515245).wrapping_add(12345 + i as u32);
349 let token = (token_seed as usize) % self.vocab_size;
350 src.push(token as f32);
351 }
352
353 let tgt: Vec<f32> = src.iter().rev().copied().collect();
355
356 Some((
357 Tensor::from_vec(src, &[self.src_length]).unwrap(),
358 Tensor::from_vec(tgt, &[self.tgt_length]).unwrap(),
359 ))
360 }
361}
362
363#[cfg(test)]
368mod tests {
369 use super::*;
370
371 #[test]
372 fn test_text_dataset() {
373 let vocab = Vocab::from_tokens(&["hello", "world", "good", "bad", "<pad>", "<unk>"]);
374 let texts = vec!["hello world".to_string(), "good bad".to_string()];
375 let labels = vec![0, 1];
376
377 let dataset = TextDataset::new(texts, labels, vocab, 10);
378
379 assert_eq!(dataset.len(), 2);
380 assert_eq!(dataset.num_classes(), 2);
381
382 let (text, label) = dataset.get(0).unwrap();
383 assert_eq!(text.shape(), &[10]);
384 assert_eq!(label.shape(), &[2]);
385 }
386
387 #[test]
388 fn test_language_model_dataset() {
389 let text = "the quick brown fox jumps over the lazy dog";
390 let dataset = LanguageModelDataset::from_text(text, 3, 1);
391
392 assert!(dataset.len() > 0);
393
394 let (input, target) = dataset.get(0).unwrap();
395 assert_eq!(input.shape(), &[3]);
396 assert_eq!(target.shape(), &[3]);
397 }
398
399 #[test]
400 fn test_synthetic_sentiment_dataset() {
401 let dataset = SyntheticSentimentDataset::small();
402
403 assert_eq!(dataset.len(), 100);
404
405 let (text, label) = dataset.get(0).unwrap();
406 assert_eq!(text.shape(), &[32]);
407 assert_eq!(label.shape(), &[2]);
408
409 let label_vec = label.to_vec();
411 let sum: f32 = label_vec.iter().sum();
412 assert!((sum - 1.0).abs() < 0.001);
413 }
414
415 #[test]
416 fn test_synthetic_sentiment_deterministic() {
417 let dataset = SyntheticSentimentDataset::small();
418
419 let (text1, label1) = dataset.get(5).unwrap();
420 let (text2, label2) = dataset.get(5).unwrap();
421
422 assert_eq!(text1.to_vec(), text2.to_vec());
423 assert_eq!(label1.to_vec(), label2.to_vec());
424 }
425
426 #[test]
427 fn test_synthetic_seq2seq_dataset() {
428 let dataset = SyntheticSeq2SeqDataset::copy_task(100, 10, 50);
429
430 assert_eq!(dataset.len(), 100);
431
432 let (src, tgt) = dataset.get(0).unwrap();
433 assert_eq!(src.shape(), &[10]);
434 assert_eq!(tgt.shape(), &[10]);
435
436 let src_vec = src.to_vec();
438 let tgt_vec = tgt.to_vec();
439 let reversed: Vec<f32> = src_vec.iter().rev().copied().collect();
440 assert_eq!(tgt_vec, reversed);
441 }
442
443 #[test]
444 fn test_text_dataset_padding() {
445 let vocab = Vocab::with_special_tokens();
446 let texts = vec!["a b".to_string()];
447 let labels = vec![0];
448
449 let dataset = TextDataset::new(texts, labels, vocab, 10);
450 let (text, _) = dataset.get(0).unwrap();
451
452 assert_eq!(text.shape(), &[10]);
454 }
455}