1pub mod bpe;
2pub mod byte_fallback;
3pub mod ctc;
4pub mod fuse;
5pub mod sequence;
6pub mod strip;
7pub mod wordpiece;
8
9pub use super::pre_tokenizers::byte_level;
11pub use super::pre_tokenizers::metaspace;
12
13use serde::{Deserialize, Deserializer, Serialize};
14
15use crate::decoders::bpe::BPEDecoder;
16use crate::decoders::byte_fallback::ByteFallback;
17use crate::decoders::ctc::CTC;
18use crate::decoders::fuse::Fuse;
19use crate::decoders::sequence::Sequence;
20use crate::decoders::strip::Strip;
21use crate::decoders::wordpiece::WordPiece;
22use crate::normalizers::replace::Replace;
23use crate::pre_tokenizers::byte_level::ByteLevel;
24use crate::pre_tokenizers::metaspace::Metaspace;
25use crate::{Decoder, Result};
26
27#[derive(Serialize, Clone, Debug)]
28#[serde(untagged)]
29pub enum DecoderWrapper {
30 BPE(BPEDecoder),
31 ByteLevel(ByteLevel),
32 WordPiece(WordPiece),
33 Metaspace(Metaspace),
34 CTC(CTC),
35 Sequence(Sequence),
36 Replace(Replace),
37 Fuse(Fuse),
38 Strip(Strip),
39 ByteFallback(ByteFallback),
40}
41
42impl<'de> Deserialize<'de> for DecoderWrapper {
43 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
44 where
45 D: Deserializer<'de>,
46 {
47 #[derive(Deserialize)]
48 pub struct Tagged {
49 #[serde(rename = "type")]
50 variant: EnumType,
51 #[serde(flatten)]
52 rest: serde_json::Value,
53 }
54 #[derive(Serialize, Deserialize)]
55 pub enum EnumType {
56 BPEDecoder,
57 ByteLevel,
58 WordPiece,
59 Metaspace,
60 CTC,
61 Sequence,
62 Replace,
63 Fuse,
64 Strip,
65 ByteFallback,
66 }
67
68 #[derive(Deserialize)]
69 #[serde(untagged)]
70 pub enum DecoderHelper {
71 Tagged(Tagged),
72 Legacy(serde_json::Value),
73 }
74
75 #[derive(Deserialize)]
76 #[serde(untagged)]
77 pub enum DecoderUntagged {
78 BPE(BPEDecoder),
79 ByteLevel(ByteLevel),
80 WordPiece(WordPiece),
81 Metaspace(Metaspace),
82 CTC(CTC),
83 Sequence(Sequence),
84 Replace(Replace),
85 Fuse(Fuse),
86 Strip(Strip),
87 ByteFallback(ByteFallback),
88 }
89
90 let helper = DecoderHelper::deserialize(deserializer).expect("Helper");
91 Ok(match helper {
92 DecoderHelper::Tagged(model) => {
93 let mut values: serde_json::Map<String, serde_json::Value> =
94 serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?;
95 values.insert(
96 "type".to_string(),
97 serde_json::to_value(&model.variant).map_err(serde::de::Error::custom)?,
98 );
99 let values = serde_json::Value::Object(values);
100 match model.variant {
101 EnumType::BPEDecoder => DecoderWrapper::BPE(
102 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
103 ),
104 EnumType::ByteLevel => DecoderWrapper::ByteLevel(
105 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
106 ),
107 EnumType::WordPiece => DecoderWrapper::WordPiece(
108 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
109 ),
110 EnumType::Metaspace => DecoderWrapper::Metaspace(
111 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
112 ),
113 EnumType::CTC => DecoderWrapper::CTC(
114 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
115 ),
116 EnumType::Sequence => DecoderWrapper::Sequence(
117 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
118 ),
119 EnumType::Replace => DecoderWrapper::Replace(
120 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
121 ),
122 EnumType::Fuse => DecoderWrapper::Fuse(
123 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
124 ),
125 EnumType::Strip => DecoderWrapper::Strip(
126 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
127 ),
128 EnumType::ByteFallback => DecoderWrapper::ByteFallback(
129 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
130 ),
131 }
132 }
133 DecoderHelper::Legacy(value) => {
134 let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
135 match untagged {
136 DecoderUntagged::BPE(dec) => DecoderWrapper::BPE(dec),
137 DecoderUntagged::ByteLevel(dec) => DecoderWrapper::ByteLevel(dec),
138 DecoderUntagged::WordPiece(dec) => DecoderWrapper::WordPiece(dec),
139 DecoderUntagged::Metaspace(dec) => DecoderWrapper::Metaspace(dec),
140 DecoderUntagged::CTC(dec) => DecoderWrapper::CTC(dec),
141 DecoderUntagged::Sequence(dec) => DecoderWrapper::Sequence(dec),
142 DecoderUntagged::Replace(dec) => DecoderWrapper::Replace(dec),
143 DecoderUntagged::Fuse(dec) => DecoderWrapper::Fuse(dec),
144 DecoderUntagged::Strip(dec) => DecoderWrapper::Strip(dec),
145 DecoderUntagged::ByteFallback(dec) => DecoderWrapper::ByteFallback(dec),
146 }
147 }
148 })
149 }
150}
151
152impl Decoder for DecoderWrapper {
153 fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
154 match self {
155 Self::BPE(bpe) => bpe.decode_chain(tokens),
156 Self::ByteLevel(bl) => bl.decode_chain(tokens),
157 Self::Metaspace(ms) => ms.decode_chain(tokens),
158 Self::WordPiece(wp) => wp.decode_chain(tokens),
159 Self::CTC(ctc) => ctc.decode_chain(tokens),
160 Self::Sequence(seq) => seq.decode_chain(tokens),
161 Self::Replace(seq) => seq.decode_chain(tokens),
162 Self::ByteFallback(bf) => bf.decode_chain(tokens),
163 Self::Strip(bf) => bf.decode_chain(tokens),
164 Self::Fuse(bf) => bf.decode_chain(tokens),
165 }
166 }
167}
168
169impl_enum_from!(BPEDecoder, DecoderWrapper, BPE);
170impl_enum_from!(ByteLevel, DecoderWrapper, ByteLevel);
171impl_enum_from!(ByteFallback, DecoderWrapper, ByteFallback);
172impl_enum_from!(Fuse, DecoderWrapper, Fuse);
173impl_enum_from!(Strip, DecoderWrapper, Strip);
174impl_enum_from!(Metaspace, DecoderWrapper, Metaspace);
175impl_enum_from!(WordPiece, DecoderWrapper, WordPiece);
176impl_enum_from!(CTC, DecoderWrapper, CTC);
177impl_enum_from!(Sequence, DecoderWrapper, Sequence);
178impl_enum_from!(Replace, DecoderWrapper, Replace);
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn decoder_serialization() {
186 let oldjson = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
187 let olddecoder: DecoderWrapper = serde_json::from_str(oldjson).unwrap();
188 let oldserialized = serde_json::to_string(&olddecoder).unwrap();
189 let json = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","prepend_scheme":"always","split":true}]}"#;
190 assert_eq!(oldserialized, json);
191
192 let decoder: DecoderWrapper = serde_json::from_str(json).unwrap();
193 let serialized = serde_json::to_string(&decoder).unwrap();
194 assert_eq!(serialized, json);
195 }
196 #[test]
197 fn decoder_serialization_other_no_arg() {
198 let json = r#"{"type":"Sequence","decoders":[{"type":"Fuse"},{"type":"Metaspace","replacement":"▁","prepend_scheme":"always","split":true}]}"#;
199 let decoder: DecoderWrapper = serde_json::from_str(json).unwrap();
200 let serialized = serde_json::to_string(&decoder).unwrap();
201 assert_eq!(serialized, json);
202 }
203
204 #[test]
205 fn decoder_serialization_no_decode() {
206 let json = r#"{"type":"Sequence","decoders":[{},{"type":"Metaspace","replacement":"▁","prepend_scheme":"always"}]}"#;
207 let parse = serde_json::from_str::<DecoderWrapper>(json);
208 match parse {
209 Err(err) => assert_eq!(
210 format!("{err}"),
211 "data did not match any variant of untagged enum DecoderUntagged"
212 ),
213 _ => panic!("Expected error"),
214 }
215
216 let json = r#"{"replacement":"▁","prepend_scheme":"always"}"#;
217 let parse = serde_json::from_str::<DecoderWrapper>(json);
218 match parse {
219 Err(err) => assert_eq!(
220 format!("{err}"),
221 "data did not match any variant of untagged enum DecoderUntagged"
222 ),
223 _ => panic!("Expected error"),
224 }
225
226 let json = r#"{"type":"Sequence","prepend_scheme":"always"}"#;
227 let parse = serde_json::from_str::<DecoderWrapper>(json);
228 match parse {
229 Err(err) => assert_eq!(format!("{err}"), "missing field `decoders`"),
230 _ => panic!("Expected error"),
231 }
232 }
233}