burn_dragon_language 0.4.0

Language modeling components for burn_dragon
Documentation
use anyhow::{Context, Result};

use crate::config::{
    DatasetConfig, DatasetSourceConfig, HuggingFaceDatasetConfig, HuggingFaceRecordFormat,
    TrainingHyperparameters,
};

use super::{Dataset, HuggingFaceDataset, ShakespeareDataset};

pub fn build_dataset(
    cfg: &DatasetConfig,
    training: &TrainingHyperparameters,
) -> Result<(Dataset, String)> {
    let dataset = match &cfg.source {
        DatasetSourceConfig::Shakespeare { url } => Dataset::from_shakespeare(
            ShakespeareDataset::new_with_source(
                &cfg.cache_dir,
                training.block_size,
                training.batch_size,
                cfg.train_split_ratio,
                &cfg.tokenizer,
                url.as_deref(),
            )
            .with_context(|| "failed to prepare Shakespeare dataset")?,
        ),
        DatasetSourceConfig::HuggingFace(hf_cfg) => Dataset::from_huggingface(
            HuggingFaceDataset::new(
                &cfg.cache_dir,
                training.block_size,
                training.batch_size,
                cfg.train_split_ratio,
                &cfg.tokenizer,
                hf_cfg,
            )
            .with_context(|| {
                format!("failed to prepare Hugging Face dataset {}", hf_cfg.repo_id)
            })?,
        ),
        DatasetSourceConfig::DeepMath {
            revision,
            max_records,
        } => {
            let config = deepmath_config(revision, *max_records);
            Dataset::from_huggingface(
                HuggingFaceDataset::new(
                    &cfg.cache_dir,
                    training.block_size,
                    training.batch_size,
                    cfg.train_split_ratio,
                    &cfg.tokenizer,
                    &config,
                )
                .with_context(|| "failed to prepare DeepMath-103K dataset")?,
            )
        }
        DatasetSourceConfig::TinyChat {
            revision,
            max_records,
        } => {
            let config = tinychat_config(revision, *max_records);
            Dataset::from_huggingface(
                HuggingFaceDataset::new(
                    &cfg.cache_dir,
                    training.block_size,
                    training.batch_size,
                    cfg.train_split_ratio,
                    &cfg.tokenizer,
                    &config,
                )
                .with_context(|| "failed to prepare TinyChat dataset")?,
            )
        }
        DatasetSourceConfig::WebscaleRl {
            revision,
            max_records,
        } => {
            let config = webscale_rl_config(revision, *max_records);
            Dataset::from_huggingface(
                HuggingFaceDataset::new(
                    &cfg.cache_dir,
                    training.block_size,
                    training.batch_size,
                    cfg.train_split_ratio,
                    &cfg.tokenizer,
                    &config,
                )
                .with_context(|| "failed to prepare Webscale-RL dataset")?,
            )
        }
        DatasetSourceConfig::PoetryFoundation {
            revision,
            max_records,
        } => {
            let config = poetry_foundation_config(revision, *max_records);
            Dataset::from_huggingface(
                HuggingFaceDataset::new(
                    &cfg.cache_dir,
                    training.block_size,
                    training.batch_size,
                    cfg.train_split_ratio,
                    &cfg.tokenizer,
                    &config,
                )
                .with_context(|| "failed to prepare Poetry Foundation Poems dataset")?,
            )
        }
    };

    let description = match &dataset {
        Dataset::Shakespeare(ds) => format!(
            "Prepared Shakespeare dataset with batch_size={}, block_size={}, split_ratio={}",
            ds.batch_size(),
            ds.block_size(),
            ds.train_split_ratio()
        ),
        Dataset::HuggingFace(ds) => format!(
            "Prepared Hugging Face dataset {} (rev: {}) with batch_size={}, block_size={}, split_ratio={}",
            ds.repo_id(),
            ds.revision().unwrap_or("main"),
            ds.batch_size(),
            ds.block_size(),
            ds.train_split_ratio()
        ),
    };

    Ok((dataset, description))
}

fn deepmath_config(
    revision: &Option<String>,
    max_records: Option<usize>,
) -> HuggingFaceDatasetConfig {
    let train_files = (0..10)
        .map(|idx| format!("data/train-{idx:05}-of-00010.parquet"))
        .collect();

    HuggingFaceDatasetConfig {
        repo_id: "zwhe99/DeepMath-103K".to_string(),
        token: None,
        revision: revision.clone(),
        format: HuggingFaceRecordFormat::Parquet,
        train_files,
        validation_files: Vec::new(),
        text_fields: vec!["question".to_string(), "final_answer".to_string()],
        field_separator: "\n\n".to_string(),
        template: Some("Question:\n{question}\n\nAnswer:\n{final_answer}".to_string()),
        max_records,
    }
}

fn tinychat_config(
    revision: &Option<String>,
    max_records: Option<usize>,
) -> HuggingFaceDatasetConfig {
    HuggingFaceDatasetConfig {
        repo_id: "starhopp3r/TinyChat".to_string(),
        token: None,
        revision: revision.clone(),
        format: HuggingFaceRecordFormat::Text,
        train_files: vec!["tinychat.txt".to_string()],
        validation_files: Vec::new(),
        text_fields: vec!["text".to_string()],
        field_separator: "\n\n".to_string(),
        template: None,
        max_records,
    }
}

fn webscale_rl_config(
    revision: &Option<String>,
    max_records: Option<usize>,
) -> HuggingFaceDatasetConfig {
    let mut train_files = Vec::with_capacity(12);
    for idx in 0..12 {
        train_files.push(format!("data/part-{idx}.parquet"));
    }

    HuggingFaceDatasetConfig {
        repo_id: "Salesforce/Webscale-RL".to_string(),
        token: None,
        revision: revision.clone(),
        format: HuggingFaceRecordFormat::Parquet,
        train_files,
        validation_files: Vec::new(),
        text_fields: vec![
            "pretrain_text".to_string(),
            "question".to_string(),
            "answer".to_string(),
            "domain".to_string(),
            "persona".to_string(),
        ],
        field_separator: "\n\n".to_string(),
        template: Some(
            "Context:\n{pretrain_text}\n\nDomain: {domain}\nPersona: {persona}\n\nQuestion: \
             {question}\nAnswer: {answer}"
                .to_string(),
        ),
        max_records,
    }
}

fn poetry_foundation_config(
    revision: &Option<String>,
    max_records: Option<usize>,
) -> HuggingFaceDatasetConfig {
    HuggingFaceDatasetConfig {
        repo_id: "suayptalha/Poetry-Foundation-Poems".to_string(),
        token: None,
        revision: revision.clone(),
        format: HuggingFaceRecordFormat::Csv,
        train_files: vec!["PoetryFoundationData.csv".to_string()],
        validation_files: Vec::new(),
        text_fields: vec!["Title".to_string(), "Poem".to_string()],
        field_separator: "\n\n\n".to_string(),
        template: Some("{Title}\n\n\n{Poem}\n\n\n\n\n\n".to_string()),
        max_records,
    }
}