flodl/data/datasets/
shakespeare.rs1use crate::data::BatchDataSet;
17use crate::tensor::{Device, Result, Tensor, TensorError};
18
19pub struct Shakespeare {
21 pub data: Tensor,
23 pub targets: Tensor,
25 pub vocab_size: usize,
27 pub char_to_idx: Vec<(char, usize)>,
29 pub idx_to_char: Vec<char>,
31}
32
33impl Shakespeare {
34 pub fn parse(text: &str, seq_len: usize) -> Result<Self> {
39 if text.len() < seq_len + 1 {
40 return Err(TensorError::new(&format!(
41 "Shakespeare: text length {} too short for seq_len {}",
42 text.len(), seq_len
43 )));
44 }
45
46 let chars: Vec<char> = text.chars().collect();
48 let mut vocab: Vec<char> = chars.clone();
49 vocab.sort();
50 vocab.dedup();
51 let vocab_size = vocab.len();
52
53 let mut char_to_idx = Vec::with_capacity(vocab_size);
55 let mut lookup = [0usize; 256]; for (idx, &ch) in vocab.iter().enumerate() {
57 char_to_idx.push((ch, idx));
58 if (ch as u32) < 256 {
59 lookup[ch as usize] = idx;
60 }
61 }
62
63 let encoded: Vec<i64> = chars.iter().map(|&ch| {
65 if (ch as u32) < 256 {
66 lookup[ch as usize] as i64
67 } else {
68 char_to_idx.iter()
70 .find(|(c, _)| *c == ch)
71 .map(|(_, i)| *i as i64)
72 .unwrap_or(0)
73 }
74 }).collect();
75
76 let n_sequences = (encoded.len() - 1) / seq_len;
78 if n_sequences == 0 {
79 return Err(TensorError::new("Shakespeare: not enough text for even one sequence"));
80 }
81
82 let mut input_data = Vec::with_capacity(n_sequences * seq_len);
83 let mut target_data = Vec::with_capacity(n_sequences * seq_len);
84
85 for i in 0..n_sequences {
86 let start = i * seq_len;
87 input_data.extend_from_slice(&encoded[start..start + seq_len]);
88 target_data.extend_from_slice(&encoded[start + 1..start + seq_len + 1]);
89 }
90
91 let data = Tensor::from_i64(&input_data, &[n_sequences as i64, seq_len as i64], Device::CPU)?;
92 let targets = Tensor::from_i64(&target_data, &[n_sequences as i64, seq_len as i64], Device::CPU)?;
93
94 Ok(Shakespeare {
95 data,
96 targets,
97 vocab_size,
98 char_to_idx,
99 idx_to_char: vocab,
100 })
101 }
102
103 pub fn len(&self) -> usize {
105 self.data.shape()[0] as usize
106 }
107
108 pub fn is_empty(&self) -> bool {
110 self.len() == 0
111 }
112
113 pub fn decode(&self, indices: &[i64]) -> String {
115 indices.iter()
116 .map(|&i| {
117 self.idx_to_char.get(i as usize).copied().unwrap_or('?')
118 })
119 .collect()
120 }
121}
122
123impl BatchDataSet for Shakespeare {
124 fn len(&self) -> usize {
125 self.data.shape()[0] as usize
126 }
127
128 fn get_batch(&self, indices: &[usize]) -> Result<Vec<Tensor>> {
129 let idx: Vec<i64> = indices.iter().map(|&i| (i % self.len()) as i64).collect();
130 let idx_tensor = Tensor::from_i64(&idx, &[idx.len() as i64], Device::CPU)?;
131 let data = self.data.index_select(0, &idx_tensor)?;
132 let targets = self.targets.index_select(0, &idx_tensor)?;
133 Ok(vec![data, targets])
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn parse_simple_text() {
143 let text = "abcdefghijklmnop"; let data = Shakespeare::parse(text, 4).unwrap();
145
146 assert_eq!(data.data.shape(), &[3, 4]);
148 assert_eq!(data.targets.shape(), &[3, 4]);
149 assert!(data.vocab_size <= 16);
150 }
151
152 #[test]
153 fn target_is_shifted_by_one() {
154 let text = "abcdefghij"; let data = Shakespeare::parse(text, 3).unwrap();
156
157 let _input_0 = data.data.select(0, 0).unwrap()
159 .select(0, 0).unwrap().to_i64_vec().unwrap()[0];
160 let target_0 = data.targets.select(0, 0).unwrap()
161 .select(0, 0).unwrap().to_i64_vec().unwrap()[0];
162
163 let input_1 = data.data.select(0, 0).unwrap()
165 .select(0, 1).unwrap().to_i64_vec().unwrap()[0];
166 assert_eq!(target_0, input_1);
167 }
168
169 #[test]
170 fn vocab_is_sorted() {
171 let text = "zyxwvutsrqponmlkjihgfedcba";
172 let data = Shakespeare::parse(text, 3).unwrap();
173
174 for i in 1..data.idx_to_char.len() {
176 assert!(data.idx_to_char[i] > data.idx_to_char[i - 1]);
177 }
178 }
179
180 #[test]
181 fn decode_roundtrip() {
182 let text = "hello world";
183 let data = Shakespeare::parse(text, 4).unwrap();
184
185 let seq: Vec<i64> = (0..4)
187 .map(|j| data.data.select(0, 0).unwrap()
188 .select(0, j).unwrap().to_i64_vec().unwrap()[0])
189 .collect();
190 let decoded = data.decode(&seq);
191 assert_eq!(decoded, "hell");
192 }
193
194 #[test]
195 fn text_too_short() {
196 assert!(Shakespeare::parse("ab", 5).is_err());
197 }
198}