dreamwell_intelligence/
dataset.rs1use crate::tokenizer::{BpeTokenizer, CharTokenizer, MiTokenizer};
7
8pub const SHAKESPEARE: &str = "First Citizen:\n\
10Before we proceed any further, hear me speak.\n\
11\n\
12All:\n\
13Speak, speak.\n\
14\n\
15First Citizen:\n\
16You are all resolved rather to die than to famish?\n\
17\n\
18All:\n\
19Resolved. resolved.\n\
20\n\
21First Citizen:\n\
22First, you know Caius Marcius is chief enemy to the people.\n\
23\n\
24All:\n\
25We know't, we know't.\n\
26\n\
27First Citizen:\n\
28Let us kill him, and we'll have corn at our own price.\n\
29Is't a verdict?\n\
30\n\
31All:\n\
32No more talking on't; let it be done: away, away!\n\
33\n\
34Second Citizen:\n\
35One word, good citizens.\n\
36\n\
37First Citizen:\n\
38We are accounted poor citizens, the patricians good.\n\
39What authority surfeits on would relieve us: if they\n\
40would yield us but the superfluity, while it were\n\
41wholesome, we might guess they relieved us humanely;\n\
42but they think we are too dear: the leanness that\n\
43afflicts us, the object of our misery, is as an\n\
44inventory to particularise their abundance; our\n\
45sufferance is a gain to them Let us revenge this with\n\
46our pikes, ere we become rakes: for the gods know I\n\
47speak this in hunger for bread, not in thirst for revenge.\n";
48
49pub struct TokenDataset {
51 pub tokens: Vec<usize>,
52 pub tokenizer: CharTokenizer,
53}
54
55impl TokenDataset {
56 pub fn from_text(text: &str) -> Self {
58 let tokenizer = CharTokenizer::from_text(text);
59 let tokens = tokenizer.encode(text);
60 Self { tokens, tokenizer }
61 }
62
63 pub fn shakespeare() -> Self {
65 Self::from_text(SHAKESPEARE)
66 }
67
68 pub fn vocab_size(&self) -> usize {
70 self.tokenizer.vocab_size
71 }
72
73 pub fn len(&self) -> usize {
75 self.tokens.len()
76 }
77
78 pub fn sample_window(&self, length: usize, seed: u64) -> &[usize] {
81 let max_start = self.tokens.len().saturating_sub(length);
82 if max_start == 0 {
83 return &self.tokens;
84 }
85 let start = (seed as usize) % max_start;
86 let end = (start + length).min(self.tokens.len());
87 &self.tokens[start..end]
88 }
89
90 pub fn num_windows(&self, length: usize) -> usize {
92 self.tokens.len().saturating_sub(length)
93 }
94
95 pub fn decode(&self, tokens: &[usize]) -> String {
97 self.tokenizer.decode(tokens)
98 }
99
100 pub fn from_jsonl(path: &std::path::Path) -> std::io::Result<Self> {
103 let content = std::fs::read_to_string(path)?;
104 let mut text = String::new();
105 for line in content.lines() {
106 let trimmed = line.trim();
107 if trimmed.is_empty() {
108 continue;
109 }
110 let unquoted = if trimmed.starts_with('"') && trimmed.ends_with('"') {
112 &trimmed[1..trimmed.len() - 1]
113 } else {
114 trimmed
115 };
116 let unescaped = unquoted.replace("\\n", "\n").replace("\\t", "\t");
118 if !text.is_empty() {
119 text.push('\n');
120 }
121 text.push_str(&unescaped);
122 }
123 Ok(Self::from_text(&text))
124 }
125
126 pub fn from_file(path: &std::path::Path) -> std::io::Result<Self> {
128 let text = std::fs::read_to_string(path)?;
129 Ok(Self::from_text(&text))
130 }
131
132 pub fn from_dir(path: &std::path::Path) -> std::io::Result<Self> {
134 let mut text = String::new();
135 let mut entries: Vec<_> = std::fs::read_dir(path)?
136 .filter_map(|e| e.ok())
137 .filter(|e| e.path().extension().map(|x| x == "txt").unwrap_or(false))
138 .collect();
139 entries.sort_by_key(|e| e.file_name());
140 for entry in entries {
141 let content = std::fs::read_to_string(entry.path())?;
142 if !text.is_empty() {
143 text.push('\n');
144 }
145 text.push_str(&content);
146 }
147 if text.is_empty() {
148 return Err(std::io::Error::new(
149 std::io::ErrorKind::NotFound,
150 "no .txt files in directory",
151 ));
152 }
153 Ok(Self::from_text(&text))
154 }
155
156 pub fn from_path(path: &std::path::Path) -> std::io::Result<Self> {
158 if path.is_dir() {
159 Self::from_dir(path)
160 } else if path.extension().map(|x| x == "jsonl").unwrap_or(false) {
161 Self::from_jsonl(path)
162 } else {
163 Self::from_file(path)
164 }
165 }
166}
167
168pub struct BpeDataset {
171 pub tokens: Vec<usize>,
172 pub tokenizer: BpeTokenizer,
173}
174
175impl BpeDataset {
176 pub fn from_text(text: &str, target_vocab: usize) -> Self {
179 let tokenizer = BpeTokenizer::train(text, target_vocab);
180 let tokens = tokenizer.encode(text);
181 Self { tokens, tokenizer }
182 }
183
184 pub fn from_jsonl(path: &std::path::Path, target_vocab: usize) -> std::io::Result<Self> {
186 let content = std::fs::read_to_string(path)?;
187 let mut text = String::new();
188 for line in content.lines() {
189 let trimmed = line.trim();
190 if trimmed.is_empty() {
191 continue;
192 }
193 let unquoted = if trimmed.starts_with('"') && trimmed.ends_with('"') {
194 &trimmed[1..trimmed.len() - 1]
195 } else {
196 trimmed
197 };
198 let unescaped = unquoted.replace("\\n", "\n").replace("\\t", "\t");
199 if !text.is_empty() {
200 text.push('\n');
201 }
202 text.push_str(&unescaped);
203 }
204 Ok(Self::from_text(&text, target_vocab))
205 }
206
207 pub fn from_file(path: &std::path::Path, target_vocab: usize) -> std::io::Result<Self> {
209 let text = std::fs::read_to_string(path)?;
210 Ok(Self::from_text(&text, target_vocab))
211 }
212
213 pub fn vocab_size(&self) -> usize {
214 self.tokenizer.vocab_size
215 }
216 pub fn len(&self) -> usize {
217 self.tokens.len()
218 }
219
220 pub fn decode(&self, tokens: &[usize]) -> String {
221 self.tokenizer.decode(tokens)
222 }
223}
224
225pub struct MiDataset {
228 pub tokens: Vec<usize>,
229 pub tokenizer: MiTokenizer,
230}
231
232impl MiDataset {
233 pub fn from_text(text: &str, target_vocab: usize) -> Self {
234 let tokenizer = MiTokenizer::train(text, target_vocab);
235 let tokens = tokenizer.encode(text);
236 Self { tokens, tokenizer }
237 }
238
239 pub fn from_jsonl(path: &std::path::Path, target_vocab: usize) -> std::io::Result<Self> {
240 let content = std::fs::read_to_string(path)?;
241 let mut text = String::new();
242 for line in content.lines() {
243 let trimmed = line.trim();
244 if trimmed.is_empty() {
245 continue;
246 }
247 let unquoted = if trimmed.starts_with('"') && trimmed.ends_with('"') {
248 &trimmed[1..trimmed.len() - 1]
249 } else {
250 trimmed
251 };
252 let unescaped = unquoted.replace("\\n", "\n").replace("\\t", "\t");
253 if !text.is_empty() {
254 text.push('\n');
255 }
256 text.push_str(&unescaped);
257 }
258 Ok(Self::from_text(&text, target_vocab))
259 }
260
261 pub fn from_file(path: &std::path::Path, target_vocab: usize) -> std::io::Result<Self> {
262 let text = std::fs::read_to_string(path)?;
263 Ok(Self::from_text(&text, target_vocab))
264 }
265
266 pub fn from_dir(path: &std::path::Path, target_vocab: usize) -> std::io::Result<Self> {
268 let mut text = String::new();
269 let mut entries: Vec<_> = std::fs::read_dir(path)?
270 .filter_map(|e| e.ok())
271 .filter(|e| e.path().extension().map(|x| x == "txt").unwrap_or(false))
272 .collect();
273 entries.sort_by_key(|e| e.file_name());
274 for entry in entries {
275 let content = std::fs::read_to_string(entry.path())?;
276 if !text.is_empty() {
277 text.push('\n');
278 }
279 text.push_str(&content);
280 }
281 if text.is_empty() {
282 return Err(std::io::Error::new(
283 std::io::ErrorKind::NotFound,
284 "no .txt files in directory",
285 ));
286 }
287 Ok(Self::from_text(&text, target_vocab))
288 }
289
290 pub fn from_path(path: &std::path::Path, target_vocab: usize) -> std::io::Result<Self> {
292 if path.is_dir() {
293 Self::from_dir(path, target_vocab)
294 } else if path.extension().map(|x| x == "jsonl").unwrap_or(false) {
295 Self::from_jsonl(path, target_vocab)
296 } else {
297 Self::from_file(path, target_vocab)
298 }
299 }
300
301 pub fn vocab_size(&self) -> usize {
302 self.tokenizer.vocab_size
303 }
304 pub fn len(&self) -> usize {
305 self.tokens.len()
306 }
307 pub fn decode(&self, tokens: &[usize]) -> String {
308 self.tokenizer.decode(tokens)
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 #[test]
317 fn shakespeare_loads() {
318 let ds = TokenDataset::shakespeare();
319 assert!(ds.len() > 900, "Shakespeare should have >900 tokens, got {}", ds.len());
320 assert!(ds.vocab_size() > 30, "vocab should be >30, got {}", ds.vocab_size());
321 }
322
323 #[test]
324 fn window_sampling_deterministic() {
325 let ds = TokenDataset::shakespeare();
326 let w1 = ds.sample_window(32, 42);
327 let w2 = ds.sample_window(32, 42);
328 assert_eq!(w1, w2, "same seed should produce same window");
329 }
330
331 #[test]
332 fn roundtrip_decode() {
333 let ds = TokenDataset::shakespeare();
334 let window = ds.sample_window(20, 0);
335 let text = ds.decode(window);
336 assert_eq!(text.len(), 20, "decoded text should match window length");
337 }
338}