burn_dragon_language 0.4.0

Language modeling components for burn_dragon
Documentation
pub mod byte;
pub mod char_vocab;

use std::sync::Arc;

#[cfg(feature = "train")]
use std::path::{Path, PathBuf};

#[cfg(feature = "train")]
use anyhow::{Result, anyhow};
#[cfg(feature = "train")]
use byte::ByteTokenizer;
#[cfg(feature = "train")]
use char_vocab::CharVocab;
#[cfg(feature = "train")]
use serde::{Deserialize, Serialize};

pub trait Tokenizer: Send + Sync {
    fn encode(&self, text: &str, add_bos: bool, add_eos: bool) -> Vec<u32>;
    fn decode(&self, ids: &[u32]) -> String;
    fn len(&self) -> usize;
    fn is_empty(&self) -> bool;
    fn bos_id(&self) -> Option<u32>;
    fn eos_id(&self) -> Option<u32>;
    fn pad_id(&self) -> Option<u32>;
    fn unk_id(&self) -> Option<u32>;
    fn as_any(&self) -> &dyn std::any::Any;
}

pub type SharedTokenizer = Arc<dyn Tokenizer>;

#[cfg(feature = "train")]
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct TokenizerConfig {
    #[serde(default)]
    pub vocab_path: Option<PathBuf>,
    #[serde(flatten)]
    pub kind: TokenizerKind,
}

#[cfg(feature = "train")]
impl Default for TokenizerConfig {
    fn default() -> Self {
        Self {
            vocab_path: None,
            kind: TokenizerKind::Char(CharTokenizerConfig::default()),
        }
    }
}

#[cfg(feature = "train")]
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum TokenizerKind {
    Char(CharTokenizerConfig),
    Byte(ByteTokenizerConfig),
}

#[cfg(feature = "train")]
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct CharTokenizerConfig {
    #[serde(default = "default_true")]
    pub include_unknown: bool,
}

#[cfg(feature = "train")]
impl Default for CharTokenizerConfig {
    fn default() -> Self {
        Self {
            include_unknown: true,
        }
    }
}

#[cfg(feature = "train")]
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct ByteTokenizerConfig {
    #[serde(default = "default_true")]
    pub add_special_tokens: bool,
}

#[cfg(feature = "train")]
impl Default for ByteTokenizerConfig {
    fn default() -> Self {
        Self {
            add_special_tokens: true,
        }
    }
}

#[cfg(feature = "train")]
impl TokenizerConfig {
    pub fn storage_path(&self, cache_dir: &Path) -> Option<PathBuf> {
        match &self.kind {
            TokenizerKind::Char(_) => Some(match &self.vocab_path {
                Some(path) if path.is_absolute() => path.clone(),
                Some(path) => cache_dir.join(path),
                None => cache_dir.join("vocab.json"),
            }),
            TokenizerKind::Byte(_) => None,
        }
    }

    pub fn load(&self, path: &Path) -> Result<SharedTokenizer> {
        match &self.kind {
            TokenizerKind::Char(_) => {
                let vocab = CharVocab::load(path)?;
                Ok(Arc::new(vocab) as SharedTokenizer)
            }
            TokenizerKind::Byte(config) => {
                Ok(Arc::new(ByteTokenizer::new(config.add_special_tokens)) as SharedTokenizer)
            }
        }
    }

    pub fn fit<'a, I>(&self, texts: I) -> Result<SharedTokenizer>
    where
        I: Iterator<Item = &'a str>,
    {
        match &self.kind {
            TokenizerKind::Char(config) => {
                let vocab = CharVocab::fit(texts, config.include_unknown)?;
                Ok(Arc::new(vocab) as SharedTokenizer)
            }
            TokenizerKind::Byte(config) => {
                Ok(Arc::new(ByteTokenizer::new(config.add_special_tokens)) as SharedTokenizer)
            }
        }
    }

    pub fn save(&self, tokenizer: &dyn Tokenizer, path: &Path) -> Result<()> {
        match &self.kind {
            TokenizerKind::Char(_) => {
                let vocab = tokenizer
                    .as_any()
                    .downcast_ref::<CharVocab>()
                    .ok_or_else(|| anyhow!("expected char tokenizer"))?;
                vocab.save(path)
            }
            TokenizerKind::Byte(_) => Ok(()),
        }
    }

    pub fn requires_strict_coverage(&self) -> bool {
        matches!(&self.kind, TokenizerKind::Char(config) if !config.include_unknown)
    }

    pub fn validate_corpus(&self, tokenizer: &dyn Tokenizer, text: &str) -> Result<()> {
        match &self.kind {
            TokenizerKind::Char(config) if !config.include_unknown => {
                let vocab = tokenizer
                    .as_any()
                    .downcast_ref::<CharVocab>()
                    .ok_or_else(|| anyhow!("expected char tokenizer"))?;
                for ch in text.chars() {
                    if !vocab.contains(ch) {
                        return Err(anyhow!(
                            "vocabulary missing character {ch:?} found in dataset"
                        ));
                    }
                }
                Ok(())
            }
            TokenizerKind::Byte(_) => Ok(()),
            _ => Ok(()),
        }
    }

    pub fn kind_name(&self) -> &'static str {
        match &self.kind {
            TokenizerKind::Char(_) => "char",
            TokenizerKind::Byte(_) => "byte",
        }
    }
}

#[cfg(feature = "train")]
fn default_true() -> bool {
    true
}