tokenizers/decoders/
mod.rs

1pub mod bpe;
2pub mod byte_fallback;
3pub mod ctc;
4pub mod fuse;
5pub mod sequence;
6pub mod strip;
7pub mod wordpiece;
8
9// Re-export these as decoders
10pub use super::pre_tokenizers::byte_level;
11pub use super::pre_tokenizers::metaspace;
12
13use serde::{Deserialize, Deserializer, Serialize};
14
15use crate::decoders::bpe::BPEDecoder;
16use crate::decoders::byte_fallback::ByteFallback;
17use crate::decoders::ctc::CTC;
18use crate::decoders::fuse::Fuse;
19use crate::decoders::sequence::Sequence;
20use crate::decoders::strip::Strip;
21use crate::decoders::wordpiece::WordPiece;
22use crate::normalizers::replace::Replace;
23use crate::pre_tokenizers::byte_level::ByteLevel;
24use crate::pre_tokenizers::metaspace::Metaspace;
25use crate::{Decoder, Result};
26
27#[derive(Serialize, Clone, Debug)]
28#[serde(untagged)]
29pub enum DecoderWrapper {
30    BPE(BPEDecoder),
31    ByteLevel(ByteLevel),
32    WordPiece(WordPiece),
33    Metaspace(Metaspace),
34    CTC(CTC),
35    Sequence(Sequence),
36    Replace(Replace),
37    Fuse(Fuse),
38    Strip(Strip),
39    ByteFallback(ByteFallback),
40}
41
42impl<'de> Deserialize<'de> for DecoderWrapper {
43    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
44    where
45        D: Deserializer<'de>,
46    {
47        #[derive(Deserialize)]
48        pub struct Tagged {
49            #[serde(rename = "type")]
50            variant: EnumType,
51            #[serde(flatten)]
52            rest: serde_json::Value,
53        }
54        #[derive(Serialize, Deserialize)]
55        pub enum EnumType {
56            BPEDecoder,
57            ByteLevel,
58            WordPiece,
59            Metaspace,
60            CTC,
61            Sequence,
62            Replace,
63            Fuse,
64            Strip,
65            ByteFallback,
66        }
67
68        #[derive(Deserialize)]
69        #[serde(untagged)]
70        pub enum DecoderHelper {
71            Tagged(Tagged),
72            Legacy(serde_json::Value),
73        }
74
75        #[derive(Deserialize)]
76        #[serde(untagged)]
77        pub enum DecoderUntagged {
78            BPE(BPEDecoder),
79            ByteLevel(ByteLevel),
80            WordPiece(WordPiece),
81            Metaspace(Metaspace),
82            CTC(CTC),
83            Sequence(Sequence),
84            Replace(Replace),
85            Fuse(Fuse),
86            Strip(Strip),
87            ByteFallback(ByteFallback),
88        }
89
90        let helper = DecoderHelper::deserialize(deserializer).expect("Helper");
91        Ok(match helper {
92            DecoderHelper::Tagged(model) => {
93                let mut values: serde_json::Map<String, serde_json::Value> =
94                    serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?;
95                values.insert(
96                    "type".to_string(),
97                    serde_json::to_value(&model.variant).map_err(serde::de::Error::custom)?,
98                );
99                let values = serde_json::Value::Object(values);
100                match model.variant {
101                    EnumType::BPEDecoder => DecoderWrapper::BPE(
102                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
103                    ),
104                    EnumType::ByteLevel => DecoderWrapper::ByteLevel(
105                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
106                    ),
107                    EnumType::WordPiece => DecoderWrapper::WordPiece(
108                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
109                    ),
110                    EnumType::Metaspace => DecoderWrapper::Metaspace(
111                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
112                    ),
113                    EnumType::CTC => DecoderWrapper::CTC(
114                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
115                    ),
116                    EnumType::Sequence => DecoderWrapper::Sequence(
117                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
118                    ),
119                    EnumType::Replace => DecoderWrapper::Replace(
120                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
121                    ),
122                    EnumType::Fuse => DecoderWrapper::Fuse(
123                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
124                    ),
125                    EnumType::Strip => DecoderWrapper::Strip(
126                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
127                    ),
128                    EnumType::ByteFallback => DecoderWrapper::ByteFallback(
129                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
130                    ),
131                }
132            }
133            DecoderHelper::Legacy(value) => {
134                let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
135                match untagged {
136                    DecoderUntagged::BPE(dec) => DecoderWrapper::BPE(dec),
137                    DecoderUntagged::ByteLevel(dec) => DecoderWrapper::ByteLevel(dec),
138                    DecoderUntagged::WordPiece(dec) => DecoderWrapper::WordPiece(dec),
139                    DecoderUntagged::Metaspace(dec) => DecoderWrapper::Metaspace(dec),
140                    DecoderUntagged::CTC(dec) => DecoderWrapper::CTC(dec),
141                    DecoderUntagged::Sequence(dec) => DecoderWrapper::Sequence(dec),
142                    DecoderUntagged::Replace(dec) => DecoderWrapper::Replace(dec),
143                    DecoderUntagged::Fuse(dec) => DecoderWrapper::Fuse(dec),
144                    DecoderUntagged::Strip(dec) => DecoderWrapper::Strip(dec),
145                    DecoderUntagged::ByteFallback(dec) => DecoderWrapper::ByteFallback(dec),
146                }
147            }
148        })
149    }
150}
151
152impl Decoder for DecoderWrapper {
153    fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
154        match self {
155            Self::BPE(bpe) => bpe.decode_chain(tokens),
156            Self::ByteLevel(bl) => bl.decode_chain(tokens),
157            Self::Metaspace(ms) => ms.decode_chain(tokens),
158            Self::WordPiece(wp) => wp.decode_chain(tokens),
159            Self::CTC(ctc) => ctc.decode_chain(tokens),
160            Self::Sequence(seq) => seq.decode_chain(tokens),
161            Self::Replace(seq) => seq.decode_chain(tokens),
162            Self::ByteFallback(bf) => bf.decode_chain(tokens),
163            Self::Strip(bf) => bf.decode_chain(tokens),
164            Self::Fuse(bf) => bf.decode_chain(tokens),
165        }
166    }
167}
168
169impl_enum_from!(BPEDecoder, DecoderWrapper, BPE);
170impl_enum_from!(ByteLevel, DecoderWrapper, ByteLevel);
171impl_enum_from!(ByteFallback, DecoderWrapper, ByteFallback);
172impl_enum_from!(Fuse, DecoderWrapper, Fuse);
173impl_enum_from!(Strip, DecoderWrapper, Strip);
174impl_enum_from!(Metaspace, DecoderWrapper, Metaspace);
175impl_enum_from!(WordPiece, DecoderWrapper, WordPiece);
176impl_enum_from!(CTC, DecoderWrapper, CTC);
177impl_enum_from!(Sequence, DecoderWrapper, Sequence);
178impl_enum_from!(Replace, DecoderWrapper, Replace);
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn decoder_serialization() {
186        let oldjson = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true,"prepend_scheme":"always"}]}"#;
187        let olddecoder: DecoderWrapper = serde_json::from_str(oldjson).unwrap();
188        let oldserialized = serde_json::to_string(&olddecoder).unwrap();
189        let json = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","prepend_scheme":"always","split":true}]}"#;
190        assert_eq!(oldserialized, json);
191
192        let decoder: DecoderWrapper = serde_json::from_str(json).unwrap();
193        let serialized = serde_json::to_string(&decoder).unwrap();
194        assert_eq!(serialized, json);
195    }
196    #[test]
197    fn decoder_serialization_other_no_arg() {
198        let json = r#"{"type":"Sequence","decoders":[{"type":"Fuse"},{"type":"Metaspace","replacement":"▁","prepend_scheme":"always","split":true}]}"#;
199        let decoder: DecoderWrapper = serde_json::from_str(json).unwrap();
200        let serialized = serde_json::to_string(&decoder).unwrap();
201        assert_eq!(serialized, json);
202    }
203
204    #[test]
205    fn decoder_serialization_no_decode() {
206        let json = r#"{"type":"Sequence","decoders":[{},{"type":"Metaspace","replacement":"▁","prepend_scheme":"always"}]}"#;
207        let parse = serde_json::from_str::<DecoderWrapper>(json);
208        match parse {
209            Err(err) => assert_eq!(
210                format!("{err}"),
211                "data did not match any variant of untagged enum DecoderUntagged"
212            ),
213            _ => panic!("Expected error"),
214        }
215
216        let json = r#"{"replacement":"▁","prepend_scheme":"always"}"#;
217        let parse = serde_json::from_str::<DecoderWrapper>(json);
218        match parse {
219            Err(err) => assert_eq!(
220                format!("{err}"),
221                "data did not match any variant of untagged enum DecoderUntagged"
222            ),
223            _ => panic!("Expected error"),
224        }
225
226        let json = r#"{"type":"Sequence","prepend_scheme":"always"}"#;
227        let parse = serde_json::from_str::<DecoderWrapper>(json);
228        match parse {
229            Err(err) => assert_eq!(format!("{err}"), "missing field `decoders`"),
230            _ => panic!("Expected error"),
231        }
232    }
233}