Skip to main content

flodl/data/datasets/
shakespeare.rs

1//! Shakespeare character-level language modeling dataset.
2//!
3//! Tokenizes text into character indices and creates input/target pairs
4//! for next-character prediction (shifted by one position).
5//!
6//! # Example
7//!
8//! ```ignore
9//! let text = std::fs::read_to_string("input.txt")?;
10//! let data = Shakespeare::parse(&text, 128)?;
11//! // data.data:    [N, 128] Int64 -- input sequences
12//! // data.targets: [N, 128] Int64 -- targets (shifted by 1)
13//! // data.vocab_size: ~65 unique characters
14//! ```
15
16use crate::data::BatchDataSet;
17use crate::tensor::{Device, Result, Tensor, TensorError};
18
19/// Parsed Shakespeare character-level dataset.
20pub struct Shakespeare {
21    /// Input sequences as `[N, seq_len]` Int64 (character indices).
22    pub data: Tensor,
23    /// Target sequences as `[N, seq_len]` Int64 (shifted by 1).
24    pub targets: Tensor,
25    /// Number of unique characters in the vocabulary.
26    pub vocab_size: usize,
27    /// Character-to-index mapping (sorted by char value).
28    pub char_to_idx: Vec<(char, usize)>,
29    /// Index-to-character mapping.
30    pub idx_to_char: Vec<char>,
31}
32
33impl Shakespeare {
34    /// Parse raw text into character-level sequences.
35    ///
36    /// Creates non-overlapping windows of `seq_len` characters.
37    /// Input is `text[i..i+seq_len]`, target is `text[i+1..i+seq_len+1]`.
38    pub fn parse(text: &str, seq_len: usize) -> Result<Self> {
39        if text.len() < seq_len + 1 {
40            return Err(TensorError::new(&format!(
41                "Shakespeare: text length {} too short for seq_len {}",
42                text.len(), seq_len
43            )));
44        }
45
46        // Build vocabulary from sorted unique characters
47        let chars: Vec<char> = text.chars().collect();
48        let mut vocab: Vec<char> = chars.clone();
49        vocab.sort();
50        vocab.dedup();
51        let vocab_size = vocab.len();
52
53        // Build lookup table (char -> index)
54        let mut char_to_idx = Vec::with_capacity(vocab_size);
55        let mut lookup = [0usize; 256]; // ASCII fast path
56        for (idx, &ch) in vocab.iter().enumerate() {
57            char_to_idx.push((ch, idx));
58            if (ch as u32) < 256 {
59                lookup[ch as usize] = idx;
60            }
61        }
62
63        // Encode entire text to indices
64        let encoded: Vec<i64> = chars.iter().map(|&ch| {
65            if (ch as u32) < 256 {
66                lookup[ch as usize] as i64
67            } else {
68                // Fallback for non-ASCII (shouldn't happen in Shakespeare)
69                char_to_idx.iter()
70                    .find(|(c, _)| *c == ch)
71                    .map(|(_, i)| *i as i64)
72                    .unwrap_or(0)
73            }
74        }).collect();
75
76        // Create non-overlapping sequences
77        let n_sequences = (encoded.len() - 1) / seq_len;
78        if n_sequences == 0 {
79            return Err(TensorError::new("Shakespeare: not enough text for even one sequence"));
80        }
81
82        let mut input_data = Vec::with_capacity(n_sequences * seq_len);
83        let mut target_data = Vec::with_capacity(n_sequences * seq_len);
84
85        for i in 0..n_sequences {
86            let start = i * seq_len;
87            input_data.extend_from_slice(&encoded[start..start + seq_len]);
88            target_data.extend_from_slice(&encoded[start + 1..start + seq_len + 1]);
89        }
90
91        let data = Tensor::from_i64(&input_data, &[n_sequences as i64, seq_len as i64], Device::CPU)?;
92        let targets = Tensor::from_i64(&target_data, &[n_sequences as i64, seq_len as i64], Device::CPU)?;
93
94        Ok(Shakespeare {
95            data,
96            targets,
97            vocab_size,
98            char_to_idx,
99            idx_to_char: vocab,
100        })
101    }
102
103    /// Number of sequences.
104    pub fn len(&self) -> usize {
105        self.data.shape()[0] as usize
106    }
107
108    /// True if the dataset is empty.
109    pub fn is_empty(&self) -> bool {
110        self.len() == 0
111    }
112
113    /// Decode a sequence of indices back to a string.
114    pub fn decode(&self, indices: &[i64]) -> String {
115        indices.iter()
116            .map(|&i| {
117                self.idx_to_char.get(i as usize).copied().unwrap_or('?')
118            })
119            .collect()
120    }
121}
122
123impl BatchDataSet for Shakespeare {
124    fn len(&self) -> usize {
125        self.data.shape()[0] as usize
126    }
127
128    fn get_batch(&self, indices: &[usize]) -> Result<Vec<Tensor>> {
129        let idx: Vec<i64> = indices.iter().map(|&i| (i % self.len()) as i64).collect();
130        let idx_tensor = Tensor::from_i64(&idx, &[idx.len() as i64], Device::CPU)?;
131        let data = self.data.index_select(0, &idx_tensor)?;
132        let targets = self.targets.index_select(0, &idx_tensor)?;
133        Ok(vec![data, targets])
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn parse_simple_text() {
143        let text = "abcdefghijklmnop"; // 16 chars
144        let data = Shakespeare::parse(text, 4).unwrap();
145
146        // 15 usable chars (need +1 for target), 15/4 = 3 sequences
147        assert_eq!(data.data.shape(), &[3, 4]);
148        assert_eq!(data.targets.shape(), &[3, 4]);
149        assert!(data.vocab_size <= 16);
150    }
151
152    #[test]
153    fn target_is_shifted_by_one() {
154        let text = "abcdefghij"; // 10 chars
155        let data = Shakespeare::parse(text, 3).unwrap();
156
157        // Sequence 0: input="abc", target="bcd"
158        let _input_0 = data.data.select(0, 0).unwrap()
159            .select(0, 0).unwrap().to_i64_vec().unwrap()[0];
160        let target_0 = data.targets.select(0, 0).unwrap()
161            .select(0, 0).unwrap().to_i64_vec().unwrap()[0];
162
163        // target[0] should equal input[1] (both are 'b')
164        let input_1 = data.data.select(0, 0).unwrap()
165            .select(0, 1).unwrap().to_i64_vec().unwrap()[0];
166        assert_eq!(target_0, input_1);
167    }
168
169    #[test]
170    fn vocab_is_sorted() {
171        let text = "zyxwvutsrqponmlkjihgfedcba";
172        let data = Shakespeare::parse(text, 3).unwrap();
173
174        // idx_to_char should be sorted
175        for i in 1..data.idx_to_char.len() {
176            assert!(data.idx_to_char[i] > data.idx_to_char[i - 1]);
177        }
178    }
179
180    #[test]
181    fn decode_roundtrip() {
182        let text = "hello world";
183        let data = Shakespeare::parse(text, 4).unwrap();
184
185        // Encode then decode first sequence
186        let seq: Vec<i64> = (0..4)
187            .map(|j| data.data.select(0, 0).unwrap()
188                .select(0, j).unwrap().to_i64_vec().unwrap()[0])
189            .collect();
190        let decoded = data.decode(&seq);
191        assert_eq!(decoded, "hell");
192    }
193
194    #[test]
195    fn text_too_short() {
196        assert!(Shakespeare::parse("ab", 5).is_err());
197    }
198}