burn_dragon_language 0.5.0

Language modeling components for burn_dragon
Documentation
mod factory;
mod huggingface;
mod local_text;
pub mod scheduler;
mod shakespeare;
mod universality;

use crate::tokenizer::SharedTokenizer;

pub use factory::build_dataset;
pub use huggingface::HuggingFaceDataset;
pub use local_text::LocalTextDataset;
pub use scheduler::{
    RandomDataLoader, SequenceBatch, StreamingDataLoader, TokenSequenceDataset,
    sample_batch_with_shape,
};
pub use shakespeare::ShakespeareDataset;
pub use universality::UniversalityDataset;

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum DatasetSplit {
    Train,
    Val,
}

#[derive(Clone)]
pub enum Dataset {
    Shakespeare(ShakespeareDataset),
    LocalText(LocalTextDataset),
    HuggingFace(HuggingFaceDataset),
    Universality(UniversalityDataset),
}

impl Dataset {
    pub fn from_shakespeare(dataset: ShakespeareDataset) -> Self {
        Self::Shakespeare(dataset)
    }

    pub fn from_huggingface(dataset: HuggingFaceDataset) -> Self {
        Self::HuggingFace(dataset)
    }

    pub fn from_local_text(dataset: LocalTextDataset) -> Self {
        Self::LocalText(dataset)
    }

    pub fn from_universality(dataset: UniversalityDataset) -> Self {
        Self::Universality(dataset)
    }

    pub fn tokenizer(&self) -> SharedTokenizer {
        TokenSequenceDataset::tokenizer(self)
    }

    pub fn train_split_ratio(&self) -> f32 {
        TokenSequenceDataset::train_split_ratio(self)
    }

    pub fn batch_size(&self) -> usize {
        TokenSequenceDataset::batch_size(self)
    }

    pub fn steps_per_epoch(&self, split: DatasetSplit) -> usize {
        TokenSequenceDataset::steps_per_epoch(self, split)
    }
}

impl TokenSequenceDataset for Dataset {
    fn tokenizer(&self) -> SharedTokenizer {
        match self {
            Dataset::Shakespeare(dataset) => dataset.tokenizer(),
            Dataset::LocalText(dataset) => dataset.tokenizer(),
            Dataset::HuggingFace(dataset) => dataset.tokenizer(),
            Dataset::Universality(dataset) => dataset.tokenizer(),
        }
    }

    fn token_count(&self) -> usize {
        match self {
            Dataset::Shakespeare(dataset) => dataset.token_count(),
            Dataset::LocalText(dataset) => dataset.token_count(),
            Dataset::HuggingFace(dataset) => dataset.token_count(),
            Dataset::Universality(dataset) => dataset.token_count(),
        }
    }

    fn copy_token_range(&self, start: usize, dst: &mut [u32]) {
        match self {
            Dataset::Shakespeare(dataset) => dataset.copy_token_range(start, dst),
            Dataset::LocalText(dataset) => dataset.copy_token_range(start, dst),
            Dataset::HuggingFace(dataset) => dataset.copy_token_range(start, dst),
            Dataset::Universality(dataset) => dataset.copy_token_range(start, dst),
        }
    }

    fn train_len(&self) -> usize {
        match self {
            Dataset::Shakespeare(dataset) => dataset.train_len(),
            Dataset::LocalText(dataset) => dataset.train_len(),
            Dataset::HuggingFace(dataset) => dataset.train_len(),
            Dataset::Universality(dataset) => dataset.train_len(),
        }
    }

    fn block_size(&self) -> usize {
        match self {
            Dataset::Shakespeare(dataset) => dataset.block_size(),
            Dataset::LocalText(dataset) => dataset.block_size(),
            Dataset::HuggingFace(dataset) => dataset.block_size(),
            Dataset::Universality(dataset) => dataset.block_size(),
        }
    }

    fn batch_size(&self) -> usize {
        match self {
            Dataset::Shakespeare(dataset) => dataset.batch_size(),
            Dataset::LocalText(dataset) => dataset.batch_size(),
            Dataset::HuggingFace(dataset) => dataset.batch_size(),
            Dataset::Universality(dataset) => dataset.batch_size(),
        }
    }

    fn train_split_ratio(&self) -> f32 {
        match self {
            Dataset::Shakespeare(dataset) => dataset.train_split_ratio(),
            Dataset::LocalText(dataset) => dataset.train_split_ratio(),
            Dataset::HuggingFace(dataset) => dataset.train_split_ratio(),
            Dataset::Universality(dataset) => dataset.train_split_ratio(),
        }
    }

    fn preferred_logical_document_tokens(&self, split: DatasetSplit) -> Option<usize> {
        match self {
            Dataset::Shakespeare(dataset) => dataset.preferred_logical_document_tokens(split),
            Dataset::LocalText(dataset) => dataset.preferred_logical_document_tokens(split),
            Dataset::HuggingFace(dataset) => dataset.preferred_logical_document_tokens(split),
            Dataset::Universality(dataset) => dataset.preferred_logical_document_tokens(split),
        }
    }

    fn split_offset_and_span(&self, split: DatasetSplit) -> (usize, usize) {
        match self {
            Dataset::Shakespeare(dataset) => {
                TokenSequenceDataset::split_offset_and_span(dataset, split)
            }
            Dataset::LocalText(dataset) => {
                TokenSequenceDataset::split_offset_and_span(dataset, split)
            }
            Dataset::HuggingFace(dataset) => {
                TokenSequenceDataset::split_offset_and_span(dataset, split)
            }
            Dataset::Universality(dataset) => {
                TokenSequenceDataset::split_offset_and_span(dataset, split)
            }
        }
    }

    fn steps_per_epoch(&self, split: DatasetSplit) -> usize {
        match self {
            Dataset::Shakespeare(dataset) => TokenSequenceDataset::steps_per_epoch(dataset, split),
            Dataset::LocalText(dataset) => TokenSequenceDataset::steps_per_epoch(dataset, split),
            Dataset::HuggingFace(dataset) => TokenSequenceDataset::steps_per_epoch(dataset, split),
            Dataset::Universality(dataset) => TokenSequenceDataset::steps_per_epoch(dataset, split),
        }
    }

    fn decode(&self, tokens: &[i64]) -> String {
        match self {
            Dataset::Shakespeare(dataset) => TokenSequenceDataset::decode(dataset, tokens),
            Dataset::LocalText(dataset) => TokenSequenceDataset::decode(dataset, tokens),
            Dataset::HuggingFace(dataset) => TokenSequenceDataset::decode(dataset, tokens),
            Dataset::Universality(dataset) => TokenSequenceDataset::decode(dataset, tokens),
        }
    }
}

pub type ShakespeareSplit = DatasetSplit;
pub type ShakespeareBatch<B> = SequenceBatch<B>;
pub type ShakespeareRandomDataLoader<B> = RandomDataLoader<B>;