tokenizers 0.22.2

Provides an implementation of today's most used tokenizers, with a focus on performances and versatility.
Documentation
use crate::processors::PostProcessorWrapper;
use crate::tokenizer::{Encoding, PostProcessor, Result};
use crate::utils::macro_rules_attribute;
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, PartialEq, Eq)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct Sequence {
    processors: Vec<PostProcessorWrapper>,
}

impl Sequence {
    pub fn new(processors: Vec<PostProcessorWrapper>) -> Self {
        Self { processors }
    }

    pub fn get(&self, index: usize) -> Option<&PostProcessorWrapper> {
        self.processors.get(index)
    }

    pub fn get_mut(&mut self, index: usize) -> Option<&mut PostProcessorWrapper> {
        self.processors.get_mut(index)
    }

    pub fn set_mut(&mut self, index: usize, post_proc: PostProcessorWrapper) {
        self.processors[index] = post_proc;
    }
}

impl AsRef<[PostProcessorWrapper]> for Sequence {
    fn as_ref(&self) -> &[PostProcessorWrapper] {
        &self.processors
    }
}

impl AsMut<[PostProcessorWrapper]> for Sequence {
    fn as_mut(&mut self) -> &mut [PostProcessorWrapper] {
        &mut self.processors
    }
}

impl IntoIterator for Sequence {
    type Item = PostProcessorWrapper;
    type IntoIter = std::vec::IntoIter<Self::Item>;

    fn into_iter(self) -> Self::IntoIter {
        self.processors.into_iter()
    }
}

impl PostProcessor for Sequence {
    fn added_tokens(&self, is_pair: bool) -> usize {
        self.processors
            .iter()
            .map(|p| p.added_tokens(is_pair))
            .sum::<usize>()
    }

    fn process_encodings(
        &self,
        mut encodings: Vec<Encoding>,
        add_special_tokens: bool,
    ) -> Result<Vec<Encoding>> {
        for processor in &self.processors {
            encodings = processor.process_encodings(encodings, add_special_tokens)?;
        }
        Ok(encodings)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::processors::{ByteLevel, PostProcessorWrapper};
    use crate::tokenizer::{Encoding, PostProcessor};
    use ahash::AHashMap;
    use std::iter::FromIterator;

    #[test]
    fn process_chain() {
        let start = Encoding::new(
            vec![0; 5],
            vec![0; 5],
            vec![
                "Ġ".into(),
                "ĠĠĠĠHelloĠĠ".into(),
                "ĠĠHello".into(),
                "HelloĠĠ".into(),
                "ĠĠĠĠ".into(),
            ],
            vec![],
            vec![(0, 1), (0, 11), (11, 18), (18, 25), (25, 29)],
            vec![],
            vec![],
            vec![],
            AHashMap::new(),
        );

        let bytelevel = ByteLevel::default().trim_offsets(true);
        let sequence = Sequence::new(vec![PostProcessorWrapper::ByteLevel(bytelevel)]);
        let expected = Encoding::new(
            vec![0; 5],
            vec![0; 5],
            vec![
                "Ġ".into(),
                "ĠĠĠĠHelloĠĠ".into(),
                "ĠĠHello".into(),
                "HelloĠĠ".into(),
                "ĠĠĠĠ".into(),
            ],
            vec![],
            vec![(0, 0), (4, 9), (13, 18), (18, 23), (29, 29)],
            vec![],
            vec![],
            vec![],
            AHashMap::from_iter(vec![(0, 0..5)]),
        );

        assert_eq!(
            expected,
            bytelevel.process(start.clone(), None, false).unwrap()
        );
        assert_eq!(
            expected,
            sequence.process(start.clone(), None, false).unwrap()
        );

        let pair_expected = Encoding::new(
            vec![0; 10],
            vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
            vec![
                "Ġ".into(),
                "ĠĠĠĠHelloĠĠ".into(),
                "ĠĠHello".into(),
                "HelloĠĠ".into(),
                "ĠĠĠĠ".into(),
                "Ġ".into(),
                "ĠĠĠĠHelloĠĠ".into(),
                "ĠĠHello".into(),
                "HelloĠĠ".into(),
                "ĠĠĠĠ".into(),
            ],
            vec![],
            vec![
                (0, 0),
                (4, 9),
                (13, 18),
                (18, 23),
                (29, 29),
                (0, 0),
                (4, 9),
                (13, 18),
                (18, 23),
                (29, 29),
            ],
            vec![],
            vec![],
            vec![],
            AHashMap::from_iter(vec![(0, 0..5), (1, 5..10)]),
        );
        assert_eq!(
            pair_expected,
            bytelevel
                .process(start.clone(), Some(start.clone()), false)
                .unwrap()
        );
        assert_eq!(
            pair_expected,
            sequence.process(start.clone(), Some(start), false).unwrap()
        );
    }
}