burn_dragon_language 0.5.0

Language modeling components for burn_dragon
Documentation
pub mod byte;
pub mod char_vocab;
pub mod pretokenized;
pub mod rust_bpe;

use std::path::{Path, PathBuf};
use std::sync::Arc;

use anyhow::{Result, anyhow};
use byte::ByteTokenizer;
use char_vocab::CharVocab;
use pretokenized::PretokenizedTokenizer;
use rust_bpe::RustBpeTokenizer;
use serde::{Deserialize, Serialize};

pub trait Tokenizer: Send + Sync {
    fn encode(&self, text: &str, add_bos: bool, add_eos: bool) -> Vec<u32>;
    fn decode(&self, ids: &[u32]) -> String;
    fn decode_with_options(&self, ids: &[u32], stop_at_eos: bool) -> String {
        if stop_at_eos {
            self.decode(ids)
        } else {
            self.decode(ids)
        }
    }
    fn len(&self) -> usize;
    fn is_empty(&self) -> bool;
    fn bos_id(&self) -> Option<u32>;
    fn eos_id(&self) -> Option<u32>;
    fn pad_id(&self) -> Option<u32>;
    fn unk_id(&self) -> Option<u32>;
    fn as_any(&self) -> &dyn std::any::Any;
}

pub type SharedTokenizer = Arc<dyn Tokenizer>;

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct TokenizerConfig {
    #[serde(default)]
    pub vocab_path: Option<PathBuf>,
    #[serde(flatten)]
    pub kind: TokenizerKind,
}

impl Default for TokenizerConfig {
    fn default() -> Self {
        Self {
            vocab_path: None,
            kind: TokenizerKind::Char(CharTokenizerConfig::default()),
        }
    }
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum TokenizerKind {
    Char(CharTokenizerConfig),
    Byte(ByteTokenizerConfig),
    Pretokenized(PretokenizedTokenizerConfig),
    RustBpe(RustBpeTokenizerConfig),
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct CharTokenizerConfig {
    #[serde(default = "default_true")]
    pub include_unknown: bool,
}

impl Default for CharTokenizerConfig {
    fn default() -> Self {
        Self {
            include_unknown: true,
        }
    }
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct ByteTokenizerConfig {
    #[serde(default = "default_true")]
    pub add_special_tokens: bool,
}

impl Default for ByteTokenizerConfig {
    fn default() -> Self {
        Self {
            add_special_tokens: true,
        }
    }
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct PretokenizedTokenizerConfig {
    pub vocab_size: usize,
    #[serde(default)]
    pub bos_id: Option<u32>,
    #[serde(default)]
    pub eos_id: Option<u32>,
    #[serde(default)]
    pub pad_id: Option<u32>,
    #[serde(default)]
    pub unk_id: Option<u32>,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct RustBpeTokenizerConfig {
    pub mergeable_vocab_size: usize,
    #[serde(default)]
    pub pattern: Option<String>,
    #[serde(default)]
    pub bos_id: Option<u32>,
    #[serde(default)]
    pub eos_id: Option<u32>,
    #[serde(default)]
    pub pad_id: Option<u32>,
    #[serde(default)]
    pub unk_id: Option<u32>,
}

impl TokenizerConfig {
    pub fn storage_path(&self, cache_dir: &Path) -> Option<PathBuf> {
        match &self.kind {
            TokenizerKind::Char(_) => Some(match &self.vocab_path {
                Some(path) if path.is_absolute() => path.clone(),
                Some(path) => cache_dir.join(path),
                None => cache_dir.join("vocab.json"),
            }),
            TokenizerKind::RustBpe(_) => Some(match &self.vocab_path {
                Some(path) if path.is_absolute() => path.clone(),
                Some(path) => cache_dir.join(path),
                None => cache_dir.join("tokenizer.rustbpe.json"),
            }),
            TokenizerKind::Byte(_) | TokenizerKind::Pretokenized(_) => None,
        }
    }

    pub fn load(&self, path: &Path) -> Result<SharedTokenizer> {
        match &self.kind {
            TokenizerKind::Char(_) => {
                let vocab = CharVocab::load(path)?;
                Ok(Arc::new(vocab) as SharedTokenizer)
            }
            TokenizerKind::Byte(config) => {
                Ok(Arc::new(ByteTokenizer::new(config.add_special_tokens)) as SharedTokenizer)
            }
            TokenizerKind::Pretokenized(config) => Ok(Arc::new(PretokenizedTokenizer::new(
                config.vocab_size,
                config.bos_id,
                config.eos_id,
                config.pad_id,
                config.unk_id,
            )) as SharedTokenizer),
            TokenizerKind::RustBpe(config) => Ok(Arc::new(RustBpeTokenizer::load(
                path,
                config.mergeable_vocab_size,
                config.bos_id,
                config.eos_id,
                config.pad_id,
                config.unk_id,
            )?) as SharedTokenizer),
        }
    }

    pub fn fit<'a, I>(&self, texts: I) -> Result<SharedTokenizer>
    where
        I: Iterator<Item = &'a str>,
    {
        match &self.kind {
            TokenizerKind::Char(config) => {
                let vocab = CharVocab::fit(texts, config.include_unknown)?;
                Ok(Arc::new(vocab) as SharedTokenizer)
            }
            TokenizerKind::Byte(config) => {
                Ok(Arc::new(ByteTokenizer::new(config.add_special_tokens)) as SharedTokenizer)
            }
            TokenizerKind::Pretokenized(config) => Ok(Arc::new(PretokenizedTokenizer::new(
                config.vocab_size,
                config.bos_id,
                config.eos_id,
                config.pad_id,
                config.unk_id,
            )) as SharedTokenizer),
            TokenizerKind::RustBpe(config) => {
                let mut tokenizer = RustBpeTokenizer::new_untrained(
                    config.mergeable_vocab_size,
                    config.pattern.as_deref(),
                    config.bos_id,
                    config.eos_id,
                    config.pad_id,
                    config.unk_id,
                )?;
                tokenizer.train_from_texts(texts)?;
                Ok(Arc::new(tokenizer) as SharedTokenizer)
            }
        }
    }

    pub fn save(&self, tokenizer: &dyn Tokenizer, path: &Path) -> Result<()> {
        match &self.kind {
            TokenizerKind::Char(_) => {
                let vocab = tokenizer
                    .as_any()
                    .downcast_ref::<CharVocab>()
                    .ok_or_else(|| anyhow!("expected char tokenizer"))?;
                vocab.save(path)
            }
            TokenizerKind::RustBpe(_) => {
                let tokenizer = tokenizer
                    .as_any()
                    .downcast_ref::<RustBpeTokenizer>()
                    .ok_or_else(|| anyhow!("expected rust_bpe tokenizer"))?;
                tokenizer.save(path)
            }
            TokenizerKind::Byte(_) | TokenizerKind::Pretokenized(_) => Ok(()),
        }
    }

    pub fn requires_strict_coverage(&self) -> bool {
        matches!(&self.kind, TokenizerKind::Char(config) if !config.include_unknown)
    }

    pub fn validate_corpus(&self, tokenizer: &dyn Tokenizer, text: &str) -> Result<()> {
        match &self.kind {
            TokenizerKind::Char(config) if !config.include_unknown => {
                let vocab = tokenizer
                    .as_any()
                    .downcast_ref::<CharVocab>()
                    .ok_or_else(|| anyhow!("expected char tokenizer"))?;
                for ch in text.chars() {
                    if !vocab.contains(ch) {
                        return Err(anyhow!(
                            "vocabulary missing character {ch:?} found in dataset"
                        ));
                    }
                }
                Ok(())
            }
            TokenizerKind::Byte(_) | TokenizerKind::Pretokenized(_) => Ok(()),
            _ => Ok(()),
        }
    }

    pub fn kind_name(&self) -> &'static str {
        match &self.kind {
            TokenizerKind::Char(_) => "char",
            TokenizerKind::Byte(_) => "byte",
            TokenizerKind::Pretokenized(_) => "pretokenized",
            TokenizerKind::RustBpe(_) => "rust_bpe",
        }
    }
}

fn default_true() -> bool {
    true
}