burn_dragon_language 0.4.0

Language modeling components for burn_dragon
Documentation
use super::Tokenizer;

const BYTE_VOCAB_SIZE: usize = 256;

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ByteTokenizer {
    add_special_tokens: bool,
    bos: Option<u32>,
    eos: Option<u32>,
    pad: Option<u32>,
    vocab_size: usize,
}

impl ByteTokenizer {
    pub fn new(add_special_tokens: bool) -> Self {
        let mut vocab_size = BYTE_VOCAB_SIZE;
        let mut bos = None;
        let mut eos = None;
        let mut pad = None;

        if add_special_tokens {
            bos = Some(vocab_size as u32);
            vocab_size += 1;
            eos = Some(vocab_size as u32);
            vocab_size += 1;
            pad = Some(vocab_size as u32);
            vocab_size += 1;
        }

        Self {
            add_special_tokens,
            bos,
            eos,
            pad,
            vocab_size,
        }
    }
}

impl Tokenizer for ByteTokenizer {
    fn encode(&self, text: &str, add_bos: bool, add_eos: bool) -> Vec<u32> {
        let mut tokens = Vec::with_capacity(text.len() + 2);
        if add_bos && let Some(bos) = self.bos {
            tokens.push(bos);
        }

        for byte in text.as_bytes() {
            tokens.push(*byte as u32);
        }

        if add_eos && let Some(eos) = self.eos {
            tokens.push(eos);
        }

        tokens
    }

    fn decode(&self, ids: &[u32]) -> String {
        let mut bytes = Vec::with_capacity(ids.len());
        for &id in ids {
            if Some(id) == self.pad || Some(id) == self.bos {
                continue;
            }
            if Some(id) == self.eos {
                break;
            }
            if (id as usize) < BYTE_VOCAB_SIZE {
                bytes.push(id as u8);
            }
        }
        String::from_utf8_lossy(&bytes).to_string()
    }

    fn len(&self) -> usize {
        self.vocab_size
    }

    fn is_empty(&self) -> bool {
        self.vocab_size == 0
    }

    fn bos_id(&self) -> Option<u32> {
        self.bos
    }

    fn eos_id(&self) -> Option<u32> {
        self.eos
    }

    fn pad_id(&self) -> Option<u32> {
        self.pad
    }

    fn unk_id(&self) -> Option<u32> {
        None
    }

    fn as_any(&self) -> &dyn std::any::Any {
        self
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn encode_decode_round_trip() {
        let tokenizer = ByteTokenizer::new(true);
        let encoded = tokenizer.encode("hello", true, true);
        assert_eq!(encoded.first().copied(), tokenizer.bos_id());
        assert_eq!(encoded.last().copied(), tokenizer.eos_id());
        let decoded = tokenizer.decode(&encoded);
        assert_eq!(decoded, "hello");
    }

    #[test]
    fn decode_truncates_at_eos() {
        let tokenizer = ByteTokenizer::new(true);
        let mut ids = tokenizer.encode("abc", false, false);
        if let Some(eos) = tokenizer.eos_id() {
            ids.push(eos);
        }
        ids.extend([b'd' as u32, b'e' as u32]);
        assert_eq!(tokenizer.decode(&ids), "abc");
    }
}