burn_dragon_language 0.5.0

Language modeling components for burn_dragon
Documentation
use std::fs;
use std::io;
use std::path::{Path, PathBuf};

use burn::tensor::backend::Backend;

use super::DatasetSplit;
use super::scheduler::{SequenceBatch, TokenSequenceDataset};
use crate::tokenizer::{SharedTokenizer, TokenizerConfig};

#[derive(Clone)]
pub struct LocalTextDataset {
    tokens: Vec<u32>,
    train_len: usize,
    block_size: usize,
    batch_size: usize,
    train_split_ratio: f32,
    tokenizer: SharedTokenizer,
    source_path: PathBuf,
    document_count: usize,
    preferred_logical_document_tokens: Option<usize>,
}

impl LocalTextDataset {
    pub fn new(
        cache_dir: impl AsRef<Path>,
        source_path: impl AsRef<Path>,
        block_size: usize,
        batch_size: usize,
        train_split_ratio: f32,
        tokenizer_cfg: &TokenizerConfig,
    ) -> io::Result<Self> {
        let cache_dir = cache_dir.as_ref();
        fs::create_dir_all(cache_dir)?;

        let source_path = source_path.as_ref().to_path_buf();
        let raw = fs::read_to_string(&source_path)?;
        let documents = raw
            .lines()
            .map(str::trim)
            .filter(|line| !line.is_empty())
            .map(ToOwned::to_owned)
            .collect::<Vec<_>>();
        if documents.is_empty() {
            return Err(io::Error::new(
                io::ErrorKind::InvalidData,
                format!(
                    "local text dataset {} did not contain any non-empty lines",
                    source_path.display()
                ),
            ));
        }

        let tokenizer_path = tokenizer_cfg.storage_path(cache_dir);
        let tokenizer = if let Some(path) = tokenizer_path {
            if path.is_file() {
                tokenizer_cfg.load(&path).map_err(io::Error::other)?
            } else {
                let tokenizer = tokenizer_cfg
                    .fit(documents.iter().map(String::as_str))
                    .map_err(io::Error::other)?;
                tokenizer_cfg
                    .save(&*tokenizer, &path)
                    .map_err(io::Error::other)?;
                tokenizer
            }
        } else {
            tokenizer_cfg
                .fit(documents.iter().map(String::as_str))
                .map_err(io::Error::other)?
        };

        let mut tokens = Vec::new();
        let mut logical_lengths = Vec::with_capacity(documents.len());
        let mut document_spans = Vec::with_capacity(documents.len());
        for document in &documents {
            tokenizer_cfg
                .validate_corpus(&*tokenizer, document.as_str())
                .map_err(io::Error::other)?;
            let encoded = tokenizer.encode(document, false, true);
            if encoded.len() <= 1 {
                return Err(io::Error::new(
                    io::ErrorKind::InvalidData,
                    format!(
                        "local text document from {} encoded to {} tokens; expected at least 2 with EOS",
                        source_path.display(),
                        encoded.len()
                    ),
                ));
            }
            logical_lengths.push(encoded.len() - 1);
            document_spans.push(encoded.len());
            tokens.extend(encoded);
        }

        if tokens.len() <= block_size + 1 {
            return Err(io::Error::new(
                io::ErrorKind::InvalidData,
                format!(
                    "encoded local text dataset {} is smaller than block size {}",
                    source_path.display(),
                    block_size
                ),
            ));
        }

        let split_ratio = train_split_ratio.clamp(0.0, 1.0);
        let mut train_document_count = ((documents.len() as f32) * split_ratio) as usize;
        train_document_count = train_document_count.clamp(1, documents.len());
        let minimum_train_tokens = block_size + 1;
        let mut train_len = document_spans
            .iter()
            .take(train_document_count)
            .copied()
            .sum::<usize>();
        while train_len < minimum_train_tokens && train_document_count < documents.len() {
            train_len += document_spans[train_document_count];
            train_document_count += 1;
        }
        if train_len < minimum_train_tokens {
            return Err(io::Error::new(
                io::ErrorKind::InvalidData,
                format!(
                    "local text dataset {} cannot satisfy block size {} with {} documents",
                    source_path.display(),
                    block_size,
                    documents.len()
                ),
            ));
        }

        let preferred_logical_document_tokens = logical_lengths
            .first()
            .copied()
            .filter(|length| logical_lengths.iter().all(|candidate| candidate == length));

        Ok(Self {
            tokens,
            train_len,
            block_size,
            batch_size,
            train_split_ratio: split_ratio,
            tokenizer,
            source_path,
            document_count: documents.len(),
            preferred_logical_document_tokens,
        })
    }

    pub fn source_path(&self) -> &Path {
        &self.source_path
    }

    pub fn document_count(&self) -> usize {
        self.document_count
    }

    pub fn tokenizer(&self) -> SharedTokenizer {
        self.tokenizer.clone()
    }

    pub fn train_split_ratio(&self) -> f32 {
        self.train_split_ratio
    }

    pub fn batch_size(&self) -> usize {
        self.batch_size
    }

    pub fn block_size(&self) -> usize {
        self.block_size
    }

    pub fn token_count(&self) -> usize {
        self.tokens.len()
    }

    pub fn copy_token_range(&self, start: usize, dst: &mut [u32]) {
        let end = start + dst.len();
        dst.copy_from_slice(&self.tokens[start..end]);
    }

    pub fn train_len(&self) -> usize {
        self.train_len
    }

    pub fn steps_per_epoch(&self, split: DatasetSplit) -> usize {
        TokenSequenceDataset::steps_per_epoch(self, split)
    }

    pub fn sample_batch<B: Backend>(
        &self,
        split: DatasetSplit,
        device: &B::Device,
    ) -> SequenceBatch<B> {
        super::scheduler::sample_batch(self, split, device)
    }

    pub fn decode(&self, tokens: &[i64]) -> String {
        TokenSequenceDataset::decode(self, tokens)
    }
}

impl TokenSequenceDataset for LocalTextDataset {
    fn tokenizer(&self) -> SharedTokenizer {
        self.tokenizer.clone()
    }

    fn token_count(&self) -> usize {
        self.tokens.len()
    }

    fn copy_token_range(&self, start: usize, dst: &mut [u32]) {
        let end = start + dst.len();
        dst.copy_from_slice(&self.tokens[start..end]);
    }

    fn train_len(&self) -> usize {
        self.train_len
    }

    fn block_size(&self) -> usize {
        self.block_size
    }

    fn batch_size(&self) -> usize {
        self.batch_size
    }

    fn train_split_ratio(&self) -> f32 {
        self.train_split_ratio
    }

    fn preferred_logical_document_tokens(&self, _split: DatasetSplit) -> Option<usize> {
        self.preferred_logical_document_tokens
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::tokenizer::{TokenizerConfig, TokenizerKind};
    use tempfile::tempdir;

    #[test]
    fn local_text_dataset_uses_non_empty_lines_as_documents() {
        let dir = tempdir().expect("tempdir");
        let source = dir.path().join("docs.txt");
        fs::write(&source, "alpha|a=1\n\nbeta_|a=2\n gamma|a=3 \n").expect("write source");

        let dataset = LocalTextDataset::new(
            dir.path().join("cache"),
            &source,
            8,
            2,
            1.0,
            &TokenizerConfig {
                vocab_path: None,
                kind: TokenizerKind::Byte(Default::default()),
            },
        )
        .expect("dataset");

        assert_eq!(dataset.document_count(), 3);
        assert_eq!(dataset.preferred_logical_document_tokens, Some(9));
        assert!(dataset.token_count() > dataset.block_size());
    }
}