tokenizers/processors/
mod.rs

1pub mod bert;
2pub mod roberta;
3pub mod sequence;
4pub mod template;
5
6// Re-export these as processors
7pub use super::pre_tokenizers::byte_level;
8
9use serde::{Deserialize, Serialize};
10
11use crate::pre_tokenizers::byte_level::ByteLevel;
12use crate::processors::bert::BertProcessing;
13use crate::processors::roberta::RobertaProcessing;
14use crate::processors::sequence::Sequence;
15use crate::processors::template::TemplateProcessing;
16use crate::{Encoding, PostProcessor, Result};
17
18#[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Eq)]
19#[serde(untagged)]
20pub enum PostProcessorWrapper {
21    // Roberta must be before Bert for deserialization (serde does not validate tags)
22    Roberta(RobertaProcessing),
23    Bert(BertProcessing),
24    ByteLevel(ByteLevel),
25    Template(TemplateProcessing),
26    Sequence(Sequence),
27}
28
29impl PostProcessor for PostProcessorWrapper {
30    fn added_tokens(&self, is_pair: bool) -> usize {
31        match self {
32            Self::Bert(bert) => bert.added_tokens(is_pair),
33            Self::ByteLevel(bl) => bl.added_tokens(is_pair),
34            Self::Roberta(roberta) => roberta.added_tokens(is_pair),
35            Self::Template(template) => template.added_tokens(is_pair),
36            Self::Sequence(bl) => bl.added_tokens(is_pair),
37        }
38    }
39
40    fn process_encodings(
41        &self,
42        encodings: Vec<Encoding>,
43        add_special_tokens: bool,
44    ) -> Result<Vec<Encoding>> {
45        match self {
46            Self::Bert(bert) => bert.process_encodings(encodings, add_special_tokens),
47            Self::ByteLevel(bl) => bl.process_encodings(encodings, add_special_tokens),
48            Self::Roberta(roberta) => roberta.process_encodings(encodings, add_special_tokens),
49            Self::Template(template) => template.process_encodings(encodings, add_special_tokens),
50            Self::Sequence(bl) => bl.process_encodings(encodings, add_special_tokens),
51        }
52    }
53}
54
55impl_enum_from!(BertProcessing, PostProcessorWrapper, Bert);
56impl_enum_from!(ByteLevel, PostProcessorWrapper, ByteLevel);
57impl_enum_from!(RobertaProcessing, PostProcessorWrapper, Roberta);
58impl_enum_from!(TemplateProcessing, PostProcessorWrapper, Template);
59impl_enum_from!(Sequence, PostProcessorWrapper, Sequence);
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64
65    #[test]
66    fn deserialize_bert_roberta_correctly() {
67        let roberta = RobertaProcessing::default();
68        let roberta_r = r#"{
69            "type":"RobertaProcessing",
70            "sep":["</s>",2],
71            "cls":["<s>",0],
72            "trim_offsets":true,
73            "add_prefix_space":true
74        }"#
75        .replace(char::is_whitespace, "");
76        assert_eq!(serde_json::to_string(&roberta).unwrap(), roberta_r);
77        assert_eq!(
78            serde_json::from_str::<PostProcessorWrapper>(&roberta_r).unwrap(),
79            PostProcessorWrapper::Roberta(roberta)
80        );
81
82        let bert = BertProcessing::default();
83        let bert_r = r#"{"type":"BertProcessing","sep":["[SEP]",102],"cls":["[CLS]",101]}"#;
84        assert_eq!(serde_json::to_string(&bert).unwrap(), bert_r);
85        assert_eq!(
86            serde_json::from_str::<PostProcessorWrapper>(bert_r).unwrap(),
87            PostProcessorWrapper::Bert(bert)
88        );
89    }
90
91    #[test]
92    fn post_processor_deserialization_no_type() {
93        let json = r#"{"add_prefix_space": true, "trim_offsets": false, "use_regex": false}"#;
94        let reconstructed = serde_json::from_str::<PostProcessorWrapper>(json);
95        match reconstructed {
96            Err(err) => assert_eq!(
97                err.to_string(),
98                "data did not match any variant of untagged enum PostProcessorWrapper"
99            ),
100            _ => panic!("Expected an error here"),
101        }
102
103        let json = r#"{"sep":["[SEP]",102],"cls":["[CLS]",101]}"#;
104        let reconstructed = serde_json::from_str::<PostProcessorWrapper>(json);
105        assert!(matches!(
106            reconstructed.unwrap(),
107            PostProcessorWrapper::Bert(_)
108        ));
109
110        let json =
111            r#"{"sep":["</s>",2], "cls":["<s>",0], "trim_offsets":true, "add_prefix_space":true}"#;
112        let reconstructed = serde_json::from_str::<PostProcessorWrapper>(json);
113        assert!(matches!(
114            reconstructed.unwrap(),
115            PostProcessorWrapper::Roberta(_)
116        ));
117
118        let json = r#"{"type":"RobertaProcessing", "sep":["</s>",2] }"#;
119        let reconstructed = serde_json::from_str::<PostProcessorWrapper>(json);
120        match reconstructed {
121            Err(err) => assert_eq!(
122                err.to_string(),
123                "data did not match any variant of untagged enum PostProcessorWrapper"
124            ),
125            _ => panic!("Expected an error here"),
126        }
127    }
128}