use crate::tokenizer::{BpeTokenizer, CharTokenizer, MiTokenizer};
pub const SHAKESPEARE: &str = "First Citizen:\n\
Before we proceed any further, hear me speak.\n\
\n\
All:\n\
Speak, speak.\n\
\n\
First Citizen:\n\
You are all resolved rather to die than to famish?\n\
\n\
All:\n\
Resolved. resolved.\n\
\n\
First Citizen:\n\
First, you know Caius Marcius is chief enemy to the people.\n\
\n\
All:\n\
We know't, we know't.\n\
\n\
First Citizen:\n\
Let us kill him, and we'll have corn at our own price.\n\
Is't a verdict?\n\
\n\
All:\n\
No more talking on't; let it be done: away, away!\n\
\n\
Second Citizen:\n\
One word, good citizens.\n\
\n\
First Citizen:\n\
We are accounted poor citizens, the patricians good.\n\
What authority surfeits on would relieve us: if they\n\
would yield us but the superfluity, while it were\n\
wholesome, we might guess they relieved us humanely;\n\
but they think we are too dear: the leanness that\n\
afflicts us, the object of our misery, is as an\n\
inventory to particularise their abundance; our\n\
sufferance is a gain to them Let us revenge this with\n\
our pikes, ere we become rakes: for the gods know I\n\
speak this in hunger for bread, not in thirst for revenge.\n";
pub struct TokenDataset {
pub tokens: Vec<usize>,
pub tokenizer: CharTokenizer,
}
impl TokenDataset {
pub fn from_text(text: &str) -> Self {
let tokenizer = CharTokenizer::from_text(text);
let tokens = tokenizer.encode(text);
Self { tokens, tokenizer }
}
pub fn shakespeare() -> Self {
Self::from_text(SHAKESPEARE)
}
pub fn vocab_size(&self) -> usize {
self.tokenizer.vocab_size
}
pub fn len(&self) -> usize {
self.tokens.len()
}
pub fn sample_window(&self, length: usize, seed: u64) -> &[usize] {
let max_start = self.tokens.len().saturating_sub(length);
if max_start == 0 {
return &self.tokens;
}
let start = (seed as usize) % max_start;
let end = (start + length).min(self.tokens.len());
&self.tokens[start..end]
}
pub fn num_windows(&self, length: usize) -> usize {
self.tokens.len().saturating_sub(length)
}
pub fn decode(&self, tokens: &[usize]) -> String {
self.tokenizer.decode(tokens)
}
pub fn from_jsonl(path: &std::path::Path) -> std::io::Result<Self> {
let content = std::fs::read_to_string(path)?;
let mut text = String::new();
for line in content.lines() {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let unquoted = if trimmed.starts_with('"') && trimmed.ends_with('"') {
&trimmed[1..trimmed.len() - 1]
} else {
trimmed
};
let unescaped = unquoted.replace("\\n", "\n").replace("\\t", "\t");
if !text.is_empty() {
text.push('\n');
}
text.push_str(&unescaped);
}
Ok(Self::from_text(&text))
}
pub fn from_file(path: &std::path::Path) -> std::io::Result<Self> {
let text = std::fs::read_to_string(path)?;
Ok(Self::from_text(&text))
}
pub fn from_dir(path: &std::path::Path) -> std::io::Result<Self> {
let mut text = String::new();
let mut entries: Vec<_> = std::fs::read_dir(path)?
.filter_map(|e| e.ok())
.filter(|e| e.path().extension().map(|x| x == "txt").unwrap_or(false))
.collect();
entries.sort_by_key(|e| e.file_name());
for entry in entries {
let content = std::fs::read_to_string(entry.path())?;
if !text.is_empty() {
text.push('\n');
}
text.push_str(&content);
}
if text.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
"no .txt files in directory",
));
}
Ok(Self::from_text(&text))
}
pub fn from_path(path: &std::path::Path) -> std::io::Result<Self> {
if path.is_dir() {
Self::from_dir(path)
} else if path.extension().map(|x| x == "jsonl").unwrap_or(false) {
Self::from_jsonl(path)
} else {
Self::from_file(path)
}
}
}
pub struct BpeDataset {
pub tokens: Vec<usize>,
pub tokenizer: BpeTokenizer,
}
impl BpeDataset {
pub fn from_text(text: &str, target_vocab: usize) -> Self {
let tokenizer = BpeTokenizer::train(text, target_vocab);
let tokens = tokenizer.encode(text);
Self { tokens, tokenizer }
}
pub fn from_jsonl(path: &std::path::Path, target_vocab: usize) -> std::io::Result<Self> {
let content = std::fs::read_to_string(path)?;
let mut text = String::new();
for line in content.lines() {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let unquoted = if trimmed.starts_with('"') && trimmed.ends_with('"') {
&trimmed[1..trimmed.len() - 1]
} else {
trimmed
};
let unescaped = unquoted.replace("\\n", "\n").replace("\\t", "\t");
if !text.is_empty() {
text.push('\n');
}
text.push_str(&unescaped);
}
Ok(Self::from_text(&text, target_vocab))
}
pub fn from_file(path: &std::path::Path, target_vocab: usize) -> std::io::Result<Self> {
let text = std::fs::read_to_string(path)?;
Ok(Self::from_text(&text, target_vocab))
}
pub fn vocab_size(&self) -> usize {
self.tokenizer.vocab_size
}
pub fn len(&self) -> usize {
self.tokens.len()
}
pub fn decode(&self, tokens: &[usize]) -> String {
self.tokenizer.decode(tokens)
}
}
pub struct MiDataset {
pub tokens: Vec<usize>,
pub tokenizer: MiTokenizer,
}
impl MiDataset {
pub fn from_text(text: &str, target_vocab: usize) -> Self {
let tokenizer = MiTokenizer::train(text, target_vocab);
let tokens = tokenizer.encode(text);
Self { tokens, tokenizer }
}
pub fn from_jsonl(path: &std::path::Path, target_vocab: usize) -> std::io::Result<Self> {
let content = std::fs::read_to_string(path)?;
let mut text = String::new();
for line in content.lines() {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let unquoted = if trimmed.starts_with('"') && trimmed.ends_with('"') {
&trimmed[1..trimmed.len() - 1]
} else {
trimmed
};
let unescaped = unquoted.replace("\\n", "\n").replace("\\t", "\t");
if !text.is_empty() {
text.push('\n');
}
text.push_str(&unescaped);
}
Ok(Self::from_text(&text, target_vocab))
}
pub fn from_file(path: &std::path::Path, target_vocab: usize) -> std::io::Result<Self> {
let text = std::fs::read_to_string(path)?;
Ok(Self::from_text(&text, target_vocab))
}
pub fn from_dir(path: &std::path::Path, target_vocab: usize) -> std::io::Result<Self> {
let mut text = String::new();
let mut entries: Vec<_> = std::fs::read_dir(path)?
.filter_map(|e| e.ok())
.filter(|e| e.path().extension().map(|x| x == "txt").unwrap_or(false))
.collect();
entries.sort_by_key(|e| e.file_name());
for entry in entries {
let content = std::fs::read_to_string(entry.path())?;
if !text.is_empty() {
text.push('\n');
}
text.push_str(&content);
}
if text.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
"no .txt files in directory",
));
}
Ok(Self::from_text(&text, target_vocab))
}
pub fn from_path(path: &std::path::Path, target_vocab: usize) -> std::io::Result<Self> {
if path.is_dir() {
Self::from_dir(path, target_vocab)
} else if path.extension().map(|x| x == "jsonl").unwrap_or(false) {
Self::from_jsonl(path, target_vocab)
} else {
Self::from_file(path, target_vocab)
}
}
pub fn vocab_size(&self) -> usize {
self.tokenizer.vocab_size
}
pub fn len(&self) -> usize {
self.tokens.len()
}
pub fn decode(&self, tokens: &[usize]) -> String {
self.tokenizer.decode(tokens)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn shakespeare_loads() {
let ds = TokenDataset::shakespeare();
assert!(ds.len() > 900, "Shakespeare should have >900 tokens, got {}", ds.len());
assert!(ds.vocab_size() > 30, "vocab should be >30, got {}", ds.vocab_size());
}
#[test]
fn window_sampling_deterministic() {
let ds = TokenDataset::shakespeare();
let w1 = ds.sample_window(32, 42);
let w2 = ds.sample_window(32, 42);
assert_eq!(w1, w2, "same seed should produce same window");
}
#[test]
fn roundtrip_decode() {
let ds = TokenDataset::shakespeare();
let window = ds.sample_window(20, 0);
let text = ds.decode(window);
assert_eq!(text.len(), 20, "decoded text should match window length");
}
}