burn_dragon_language 0.4.0

Language modeling components for burn_dragon
Documentation
mod factory;
mod huggingface;
pub mod scheduler;
mod shakespeare;

use crate::tokenizer::SharedTokenizer;

pub use factory::build_dataset;
pub use huggingface::HuggingFaceDataset;
pub use scheduler::{RandomDataLoader, SequenceBatch, TokenSequenceDataset};
pub use shakespeare::ShakespeareDataset;

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum DatasetSplit {
    Train,
    Val,
}

#[derive(Clone)]
pub enum Dataset {
    Shakespeare(ShakespeareDataset),
    HuggingFace(HuggingFaceDataset),
}

impl Dataset {
    pub fn from_shakespeare(dataset: ShakespeareDataset) -> Self {
        Self::Shakespeare(dataset)
    }

    pub fn from_huggingface(dataset: HuggingFaceDataset) -> Self {
        Self::HuggingFace(dataset)
    }

    pub fn tokenizer(&self) -> SharedTokenizer {
        TokenSequenceDataset::tokenizer(self)
    }

    pub fn train_split_ratio(&self) -> f32 {
        TokenSequenceDataset::train_split_ratio(self)
    }

    pub fn batch_size(&self) -> usize {
        TokenSequenceDataset::batch_size(self)
    }

    pub fn steps_per_epoch(&self, split: DatasetSplit) -> usize {
        TokenSequenceDataset::steps_per_epoch(self, split)
    }
}

impl TokenSequenceDataset for Dataset {
    fn tokenizer(&self) -> SharedTokenizer {
        match self {
            Dataset::Shakespeare(dataset) => dataset.tokenizer(),
            Dataset::HuggingFace(dataset) => dataset.tokenizer(),
        }
    }

    fn tokens(&self) -> &[u32] {
        match self {
            Dataset::Shakespeare(dataset) => dataset.tokens(),
            Dataset::HuggingFace(dataset) => dataset.tokens(),
        }
    }

    fn train_len(&self) -> usize {
        match self {
            Dataset::Shakespeare(dataset) => dataset.train_len(),
            Dataset::HuggingFace(dataset) => dataset.train_len(),
        }
    }

    fn block_size(&self) -> usize {
        match self {
            Dataset::Shakespeare(dataset) => dataset.block_size(),
            Dataset::HuggingFace(dataset) => dataset.block_size(),
        }
    }

    fn batch_size(&self) -> usize {
        match self {
            Dataset::Shakespeare(dataset) => dataset.batch_size(),
            Dataset::HuggingFace(dataset) => dataset.batch_size(),
        }
    }

    fn train_split_ratio(&self) -> f32 {
        match self {
            Dataset::Shakespeare(dataset) => dataset.train_split_ratio(),
            Dataset::HuggingFace(dataset) => dataset.train_split_ratio(),
        }
    }

    fn split_offset_and_span(&self, split: DatasetSplit) -> (usize, usize) {
        match self {
            Dataset::Shakespeare(dataset) => {
                TokenSequenceDataset::split_offset_and_span(dataset, split)
            }
            Dataset::HuggingFace(dataset) => {
                TokenSequenceDataset::split_offset_and_span(dataset, split)
            }
        }
    }

    fn steps_per_epoch(&self, split: DatasetSplit) -> usize {
        match self {
            Dataset::Shakespeare(dataset) => TokenSequenceDataset::steps_per_epoch(dataset, split),
            Dataset::HuggingFace(dataset) => TokenSequenceDataset::steps_per_epoch(dataset, split),
        }
    }

    fn decode(&self, tokens: &[i64]) -> String {
        match self {
            Dataset::Shakespeare(dataset) => TokenSequenceDataset::decode(dataset, tokens),
            Dataset::HuggingFace(dataset) => TokenSequenceDataset::decode(dataset, tokens),
        }
    }
}

pub type ShakespeareSplit = DatasetSplit;
pub type ShakespeareBatch<B> = SequenceBatch<B>;
pub type ShakespeareRandomDataLoader<B> = RandomDataLoader<B>;