tokenizers/processors/
mod.rs1pub mod bert;
2pub mod roberta;
3pub mod sequence;
4pub mod template;
5
6pub use super::pre_tokenizers::byte_level;
8
9use serde::{Deserialize, Serialize};
10
11use crate::pre_tokenizers::byte_level::ByteLevel;
12use crate::processors::bert::BertProcessing;
13use crate::processors::roberta::RobertaProcessing;
14use crate::processors::sequence::Sequence;
15use crate::processors::template::TemplateProcessing;
16use crate::{Encoding, PostProcessor, Result};
17
18#[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Eq)]
19#[serde(untagged)]
20pub enum PostProcessorWrapper {
21 Roberta(RobertaProcessing),
23 Bert(BertProcessing),
24 ByteLevel(ByteLevel),
25 Template(TemplateProcessing),
26 Sequence(Sequence),
27}
28
29impl PostProcessor for PostProcessorWrapper {
30 fn added_tokens(&self, is_pair: bool) -> usize {
31 match self {
32 Self::Bert(bert) => bert.added_tokens(is_pair),
33 Self::ByteLevel(bl) => bl.added_tokens(is_pair),
34 Self::Roberta(roberta) => roberta.added_tokens(is_pair),
35 Self::Template(template) => template.added_tokens(is_pair),
36 Self::Sequence(bl) => bl.added_tokens(is_pair),
37 }
38 }
39
40 fn process_encodings(
41 &self,
42 encodings: Vec<Encoding>,
43 add_special_tokens: bool,
44 ) -> Result<Vec<Encoding>> {
45 match self {
46 Self::Bert(bert) => bert.process_encodings(encodings, add_special_tokens),
47 Self::ByteLevel(bl) => bl.process_encodings(encodings, add_special_tokens),
48 Self::Roberta(roberta) => roberta.process_encodings(encodings, add_special_tokens),
49 Self::Template(template) => template.process_encodings(encodings, add_special_tokens),
50 Self::Sequence(bl) => bl.process_encodings(encodings, add_special_tokens),
51 }
52 }
53}
54
55impl_enum_from!(BertProcessing, PostProcessorWrapper, Bert);
56impl_enum_from!(ByteLevel, PostProcessorWrapper, ByteLevel);
57impl_enum_from!(RobertaProcessing, PostProcessorWrapper, Roberta);
58impl_enum_from!(TemplateProcessing, PostProcessorWrapper, Template);
59impl_enum_from!(Sequence, PostProcessorWrapper, Sequence);
60
61#[cfg(test)]
62mod tests {
63 use super::*;
64
65 #[test]
66 fn deserialize_bert_roberta_correctly() {
67 let roberta = RobertaProcessing::default();
68 let roberta_r = r#"{
69 "type":"RobertaProcessing",
70 "sep":["</s>",2],
71 "cls":["<s>",0],
72 "trim_offsets":true,
73 "add_prefix_space":true
74 }"#
75 .replace(char::is_whitespace, "");
76 assert_eq!(serde_json::to_string(&roberta).unwrap(), roberta_r);
77 assert_eq!(
78 serde_json::from_str::<PostProcessorWrapper>(&roberta_r).unwrap(),
79 PostProcessorWrapper::Roberta(roberta)
80 );
81
82 let bert = BertProcessing::default();
83 let bert_r = r#"{"type":"BertProcessing","sep":["[SEP]",102],"cls":["[CLS]",101]}"#;
84 assert_eq!(serde_json::to_string(&bert).unwrap(), bert_r);
85 assert_eq!(
86 serde_json::from_str::<PostProcessorWrapper>(bert_r).unwrap(),
87 PostProcessorWrapper::Bert(bert)
88 );
89 }
90
91 #[test]
92 fn post_processor_deserialization_no_type() {
93 let json = r#"{"add_prefix_space": true, "trim_offsets": false, "use_regex": false}"#;
94 let reconstructed = serde_json::from_str::<PostProcessorWrapper>(json);
95 match reconstructed {
96 Err(err) => assert_eq!(
97 err.to_string(),
98 "data did not match any variant of untagged enum PostProcessorWrapper"
99 ),
100 _ => panic!("Expected an error here"),
101 }
102
103 let json = r#"{"sep":["[SEP]",102],"cls":["[CLS]",101]}"#;
104 let reconstructed = serde_json::from_str::<PostProcessorWrapper>(json);
105 assert!(matches!(
106 reconstructed.unwrap(),
107 PostProcessorWrapper::Bert(_)
108 ));
109
110 let json =
111 r#"{"sep":["</s>",2], "cls":["<s>",0], "trim_offsets":true, "add_prefix_space":true}"#;
112 let reconstructed = serde_json::from_str::<PostProcessorWrapper>(json);
113 assert!(matches!(
114 reconstructed.unwrap(),
115 PostProcessorWrapper::Roberta(_)
116 ));
117
118 let json = r#"{"type":"RobertaProcessing", "sep":["</s>",2] }"#;
119 let reconstructed = serde_json::from_str::<PostProcessorWrapper>(json);
120 match reconstructed {
121 Err(err) => assert_eq!(
122 err.to_string(),
123 "data did not match any variant of untagged enum PostProcessorWrapper"
124 ),
125 _ => panic!("Expected an error here"),
126 }
127 }
128}