1use crate::tokenizer::Tokenizer;
9use crate::vocab::Vocab;
10use axonml_data::Dataset;
11use axonml_tensor::Tensor;
12
13pub struct TextDataset {
19 texts: Vec<String>,
20 labels: Vec<usize>,
21 vocab: Vocab,
22 max_length: usize,
23 num_classes: usize,
24}
25
26impl TextDataset {
27 #[must_use]
29 pub fn new(texts: Vec<String>, labels: Vec<usize>, vocab: Vocab, max_length: usize) -> Self {
30 let num_classes = labels.iter().max().map_or(0, |&m| m + 1);
31 Self {
32 texts,
33 labels,
34 vocab,
35 max_length,
36 num_classes,
37 }
38 }
39
40 pub fn from_samples<T: Tokenizer>(
42 samples: &[(String, usize)],
43 tokenizer: &T,
44 min_freq: usize,
45 max_length: usize,
46 ) -> Self {
47 use std::collections::HashMap;
48
49 let mut freq: HashMap<String, usize> = HashMap::new();
51 for (text, _) in samples {
52 for token in tokenizer.tokenize(text) {
53 *freq.entry(token).or_insert(0) += 1;
54 }
55 }
56
57 let mut vocab = Vocab::with_special_tokens();
59 let mut tokens: Vec<_> = freq
60 .into_iter()
61 .filter(|(_, count)| *count >= min_freq)
62 .collect();
63 tokens.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
64 for (token, _) in tokens {
65 vocab.add_token(&token);
66 }
67
68 let texts: Vec<String> = samples.iter().map(|(t, _)| t.clone()).collect();
69 let labels: Vec<usize> = samples.iter().map(|(_, l)| *l).collect();
70
71 Self::new(texts, labels, vocab, max_length)
72 }
73
74 #[must_use]
76 pub fn vocab(&self) -> &Vocab {
77 &self.vocab
78 }
79
80 #[must_use]
82 pub fn num_classes(&self) -> usize {
83 self.num_classes
84 }
85
86 #[must_use]
88 pub fn max_length(&self) -> usize {
89 self.max_length
90 }
91
92 fn encode_text(&self, text: &str) -> Tensor<f32> {
94 let tokens: Vec<&str> = text.split_whitespace().collect();
95 let mut indices: Vec<f32> = tokens
96 .iter()
97 .take(self.max_length)
98 .map(|t| self.vocab.token_to_index(t) as f32)
99 .collect();
100
101 let pad_idx = self.vocab.pad_index().unwrap_or(0) as f32;
103 while indices.len() < self.max_length {
104 indices.push(pad_idx);
105 }
106
107 Tensor::from_vec(indices, &[self.max_length]).unwrap()
108 }
109}
110
111impl Dataset for TextDataset {
112 type Item = (Tensor<f32>, Tensor<f32>);
113
114 fn len(&self) -> usize {
115 self.texts.len()
116 }
117
118 fn get(&self, index: usize) -> Option<Self::Item> {
119 if index >= self.texts.len() {
120 return None;
121 }
122
123 let text = self.encode_text(&self.texts[index]);
124
125 let mut label_vec = vec![0.0f32; self.num_classes];
127 label_vec[self.labels[index]] = 1.0;
128 let label = Tensor::from_vec(label_vec, &[self.num_classes]).unwrap();
129
130 Some((text, label))
131 }
132}
133
134pub struct LanguageModelDataset {
140 tokens: Vec<usize>,
141 sequence_length: usize,
142 vocab: Vocab,
143}
144
145impl LanguageModelDataset {
146 #[must_use]
148 pub fn new(text: &str, vocab: Vocab, sequence_length: usize) -> Self {
149 let tokens: Vec<usize> = text
150 .split_whitespace()
151 .map(|t| vocab.token_to_index(t))
152 .collect();
153
154 Self {
155 tokens,
156 sequence_length,
157 vocab,
158 }
159 }
160
161 #[must_use]
163 pub fn from_text(text: &str, sequence_length: usize, min_freq: usize) -> Self {
164 let vocab = Vocab::from_text(text, min_freq);
165 Self::new(text, vocab, sequence_length)
166 }
167
168 #[must_use]
170 pub fn vocab(&self) -> &Vocab {
171 &self.vocab
172 }
173}
174
175impl Dataset for LanguageModelDataset {
176 type Item = (Tensor<f32>, Tensor<f32>);
177
178 fn len(&self) -> usize {
179 if self.tokens.len() <= self.sequence_length {
180 0
181 } else {
182 self.tokens.len() - self.sequence_length
183 }
184 }
185
186 fn get(&self, index: usize) -> Option<Self::Item> {
187 if index >= self.len() {
188 return None;
189 }
190
191 let input: Vec<f32> = self.tokens[index..index + self.sequence_length]
193 .iter()
194 .map(|&t| t as f32)
195 .collect();
196
197 let target: Vec<f32> = self.tokens[(index + 1)..=(index + self.sequence_length)]
199 .iter()
200 .map(|&t| t as f32)
201 .collect();
202
203 Some((
204 Tensor::from_vec(input, &[self.sequence_length]).unwrap(),
205 Tensor::from_vec(target, &[self.sequence_length]).unwrap(),
206 ))
207 }
208}
209
210pub struct SyntheticSentimentDataset {
216 size: usize,
217 max_length: usize,
218 vocab_size: usize,
219}
220
221impl SyntheticSentimentDataset {
222 #[must_use]
224 pub fn new(size: usize, max_length: usize, vocab_size: usize) -> Self {
225 Self {
226 size,
227 max_length,
228 vocab_size,
229 }
230 }
231
232 #[must_use]
234 pub fn small() -> Self {
235 Self::new(100, 32, 1000)
236 }
237
238 #[must_use]
240 pub fn train() -> Self {
241 Self::new(10000, 64, 10000)
242 }
243
244 #[must_use]
246 pub fn test() -> Self {
247 Self::new(2000, 64, 10000)
248 }
249}
250
251impl Dataset for SyntheticSentimentDataset {
252 type Item = (Tensor<f32>, Tensor<f32>);
253
254 fn len(&self) -> usize {
255 self.size
256 }
257
258 fn get(&self, index: usize) -> Option<Self::Item> {
259 if index >= self.size {
260 return None;
261 }
262
263 let seed = index as u32;
265 let label = index % 2; let mut text = Vec::with_capacity(self.max_length);
268 for i in 0..self.max_length {
269 let token_seed = seed.wrapping_mul(1103515245).wrapping_add(12345 + i as u32);
270 let token = (token_seed as usize) % self.vocab_size;
271 let biased_token = if label == 1 {
273 (token + self.vocab_size / 2) % self.vocab_size
274 } else {
275 token
276 };
277 text.push(biased_token as f32);
278 }
279
280 let text_tensor = Tensor::from_vec(text, &[self.max_length]).unwrap();
281
282 let mut label_vec = vec![0.0f32; 2];
284 label_vec[label] = 1.0;
285 let label_tensor = Tensor::from_vec(label_vec, &[2]).unwrap();
286
287 Some((text_tensor, label_tensor))
288 }
289}
290
291pub struct SyntheticSeq2SeqDataset {
297 size: usize,
298 src_length: usize,
299 tgt_length: usize,
300 vocab_size: usize,
301}
302
303impl SyntheticSeq2SeqDataset {
304 #[must_use]
306 pub fn new(size: usize, src_length: usize, tgt_length: usize, vocab_size: usize) -> Self {
307 Self {
308 size,
309 src_length,
310 tgt_length,
311 vocab_size,
312 }
313 }
314
315 #[must_use]
317 pub fn copy_task(size: usize, length: usize, vocab_size: usize) -> Self {
318 Self::new(size, length, length, vocab_size)
319 }
320}
321
322impl Dataset for SyntheticSeq2SeqDataset {
323 type Item = (Tensor<f32>, Tensor<f32>);
324
325 fn len(&self) -> usize {
326 self.size
327 }
328
329 fn get(&self, index: usize) -> Option<Self::Item> {
330 if index >= self.size {
331 return None;
332 }
333
334 let seed = index as u32;
335
336 let mut src = Vec::with_capacity(self.src_length);
338 for i in 0..self.src_length {
339 let token_seed = seed.wrapping_mul(1103515245).wrapping_add(12345 + i as u32);
340 let token = (token_seed as usize) % self.vocab_size;
341 src.push(token as f32);
342 }
343
344 let tgt: Vec<f32> = src.iter().rev().copied().collect();
346
347 Some((
348 Tensor::from_vec(src, &[self.src_length]).unwrap(),
349 Tensor::from_vec(tgt, &[self.tgt_length]).unwrap(),
350 ))
351 }
352}
353
354#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
363 fn test_text_dataset() {
364 let vocab = Vocab::from_tokens(&["hello", "world", "good", "bad", "<pad>", "<unk>"]);
365 let texts = vec!["hello world".to_string(), "good bad".to_string()];
366 let labels = vec![0, 1];
367
368 let dataset = TextDataset::new(texts, labels, vocab, 10);
369
370 assert_eq!(dataset.len(), 2);
371 assert_eq!(dataset.num_classes(), 2);
372
373 let (text, label) = dataset.get(0).unwrap();
374 assert_eq!(text.shape(), &[10]);
375 assert_eq!(label.shape(), &[2]);
376 }
377
378 #[test]
379 fn test_language_model_dataset() {
380 let text = "the quick brown fox jumps over the lazy dog";
381 let dataset = LanguageModelDataset::from_text(text, 3, 1);
382
383 assert!(dataset.len() > 0);
384
385 let (input, target) = dataset.get(0).unwrap();
386 assert_eq!(input.shape(), &[3]);
387 assert_eq!(target.shape(), &[3]);
388 }
389
390 #[test]
391 fn test_synthetic_sentiment_dataset() {
392 let dataset = SyntheticSentimentDataset::small();
393
394 assert_eq!(dataset.len(), 100);
395
396 let (text, label) = dataset.get(0).unwrap();
397 assert_eq!(text.shape(), &[32]);
398 assert_eq!(label.shape(), &[2]);
399
400 let label_vec = label.to_vec();
402 let sum: f32 = label_vec.iter().sum();
403 assert!((sum - 1.0).abs() < 0.001);
404 }
405
406 #[test]
407 fn test_synthetic_sentiment_deterministic() {
408 let dataset = SyntheticSentimentDataset::small();
409
410 let (text1, label1) = dataset.get(5).unwrap();
411 let (text2, label2) = dataset.get(5).unwrap();
412
413 assert_eq!(text1.to_vec(), text2.to_vec());
414 assert_eq!(label1.to_vec(), label2.to_vec());
415 }
416
417 #[test]
418 fn test_synthetic_seq2seq_dataset() {
419 let dataset = SyntheticSeq2SeqDataset::copy_task(100, 10, 50);
420
421 assert_eq!(dataset.len(), 100);
422
423 let (src, tgt) = dataset.get(0).unwrap();
424 assert_eq!(src.shape(), &[10]);
425 assert_eq!(tgt.shape(), &[10]);
426
427 let src_vec = src.to_vec();
429 let tgt_vec = tgt.to_vec();
430 let reversed: Vec<f32> = src_vec.iter().rev().copied().collect();
431 assert_eq!(tgt_vec, reversed);
432 }
433
434 #[test]
435 fn test_text_dataset_padding() {
436 let vocab = Vocab::with_special_tokens();
437 let texts = vec!["a b".to_string()];
438 let labels = vec![0];
439
440 let dataset = TextDataset::new(texts, labels, vocab, 10);
441 let (text, _) = dataset.get(0).unwrap();
442
443 assert_eq!(text.shape(), &[10]);
445 }
446}