tokengeex 1.1.0

TokenGeeX is an efficient tokenizer for code based on UnigramLM and TokenMonster.
Documentation
use serde::{ser::SerializeStruct, Deserialize, Serialize};
use unicode_normalization::UnicodeNormalization;

/// A processor is a step of the tokenization pipeline. It can be used to
/// transform input sequences before they are fed to the model and to transform
/// the output sequences after they are generated by the model.
pub trait Processor {
    fn preprocess(&self, s: &str) -> String;

    fn postprocess(&self, s: &str) -> String;
}

#[derive(Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ProcessorWrapper {
    Crlf(CrlfProcessor),
    Unicode(UnicodeProcessor),
}

impl Processor for ProcessorWrapper {
    fn preprocess(&self, s: &str) -> String {
        match self {
            ProcessorWrapper::Crlf(processor) => processor.preprocess(s),
            ProcessorWrapper::Unicode(processor) => processor.preprocess(s),
        }
    }

    fn postprocess(&self, s: &str) -> String {
        match self {
            ProcessorWrapper::Crlf(processor) => processor.postprocess(s),
            ProcessorWrapper::Unicode(processor) => processor.postprocess(s),
        }
    }
}

/// Replaces occurences of \r\n by \n.
#[derive(Clone)]
pub struct CrlfProcessor;

impl From<CrlfProcessor> for ProcessorWrapper {
    fn from(val: CrlfProcessor) -> Self {
        ProcessorWrapper::Crlf(val)
    }
}

impl Processor for CrlfProcessor {
    fn preprocess(&self, s: &str) -> String {
        s.replace("\r\n", "\n")
    }

    fn postprocess(&self, s: &str) -> String {
        s.into()
    }
}

impl Serialize for CrlfProcessor {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        let mut processor = serializer.serialize_struct("CrlfProcessor", 1)?;

        processor.serialize_field("type", "crlf")?;

        processor.end()
    }
}

impl<'de> serde::Deserialize<'de> for CrlfProcessor {
    fn deserialize<D>(deserializer: D) -> Result<CrlfProcessor, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        struct CrlfProcessorVisitor;

        impl<'de> serde::de::Visitor<'de> for CrlfProcessorVisitor {
            type Value = CrlfProcessor;

            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
                formatter.write_str("struct CrlfProcessor")
            }

            fn visit_map<A>(self, mut map: A) -> Result<CrlfProcessor, A::Error>
            where
                A: serde::de::MapAccess<'de>,
            {
                while let Some(key) = map.next_key::<&str>()? {
                    match key {
                        "type" => {
                            let value = map.next_value::<String>()?;
                            if value != "crlf" {
                                return Err(serde::de::Error::unknown_variant(&value, &["crlf"]));
                            }
                        }
                        _ => {
                            let _: serde::de::IgnoredAny = map.next_value()?;
                        }
                    }
                }

                Ok(CrlfProcessor)
            }
        }

        deserializer.deserialize_struct("CrlfProcessor", &["type"], CrlfProcessorVisitor)
    }
}

/// Unicode normalizer.
#[derive(Clone)]
pub enum UnicodeProcessor {
    Nfc,
    Nfd,
    Nfkc,
    Nfkd,
}

impl From<UnicodeProcessor> for ProcessorWrapper {
    fn from(val: UnicodeProcessor) -> Self {
        ProcessorWrapper::Unicode(val)
    }
}

impl Processor for UnicodeProcessor {
    fn preprocess(&self, s: &str) -> String {
        match self {
            UnicodeProcessor::Nfc => s.nfc().collect::<String>(),
            UnicodeProcessor::Nfd => s.nfd().collect::<String>(),
            UnicodeProcessor::Nfkc => s.nfkc().collect::<String>(),
            UnicodeProcessor::Nfkd => s.nfkd().collect::<String>(),
        }
    }

    fn postprocess(&self, s: &str) -> String {
        s.into()
    }
}

impl Serialize for UnicodeProcessor {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        let mut processor = serializer.serialize_struct("UnicodeProcessor", 2)?;

        processor.serialize_field("type", "unicode")?;
        processor.serialize_field(
            "form",
            match self {
                UnicodeProcessor::Nfc => "nfc",
                UnicodeProcessor::Nfd => "nfd",
                UnicodeProcessor::Nfkc => "nfkc",
                UnicodeProcessor::Nfkd => "nfkd",
            },
        )?;

        processor.end()
    }
}

impl<'de> serde::Deserialize<'de> for UnicodeProcessor {
    fn deserialize<D>(deserializer: D) -> Result<UnicodeProcessor, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        struct UnicodeProcessorVisitor;

        impl<'de> serde::de::Visitor<'de> for UnicodeProcessorVisitor {
            type Value = UnicodeProcessor;

            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
                formatter.write_str("struct UnicodeProcessor")
            }

            fn visit_map<A>(self, mut map: A) -> Result<UnicodeProcessor, A::Error>
            where
                A: serde::de::MapAccess<'de>,
            {
                let mut form = None;

                while let Some(key) = map.next_key::<&str>()? {
                    match key {
                        "form" => {
                            form = Some(map.next_value::<String>()?);
                        }
                        _ => {
                            let _: serde::de::IgnoredAny = map.next_value()?;
                        }
                    }
                }

                let form = form.ok_or_else(|| serde::de::Error::missing_field("form"))?;

                Ok(match form.as_str() {
                    "nfc" => UnicodeProcessor::Nfc,
                    "nfd" => UnicodeProcessor::Nfd,
                    "nfkc" => UnicodeProcessor::Nfkc,
                    "nfkd" => UnicodeProcessor::Nfkd,
                    _ => {
                        return Err(serde::de::Error::unknown_variant(
                            &form,
                            &["nfc", "nfd", "nfkc", "nfkd"],
                        ))
                    }
                })
            }
        }

        deserializer.deserialize_struct(
            "UnicodeProcessor",
            &["type", "form"],
            UnicodeProcessorVisitor,
        )
    }
}