use super::{super::OrderedVocabIter, convert_merges_to_hashmap, BpeBuilder, Pair, BPE};
use ahash::AHashMap;
use serde::{
de::{Error, MapAccess, Visitor},
ser::SerializeStruct,
Deserialize, Deserializer, Serialize, Serializer,
};
impl Serialize for BPE {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut model = serializer.serialize_struct("BPE", 8)?;
model.serialize_field("type", "BPE")?;
model.serialize_field("dropout", &self.dropout)?;
model.serialize_field("unk_token", &self.unk_token)?;
model.serialize_field("continuing_subword_prefix", &self.continuing_subword_prefix)?;
model.serialize_field("end_of_word_suffix", &self.end_of_word_suffix)?;
model.serialize_field("fuse_unk", &self.fuse_unk)?;
model.serialize_field("byte_fallback", &self.byte_fallback)?;
model.serialize_field("ignore_merges", &self.ignore_merges)?;
let mut merges: Vec<(&Pair, &u32)> = self
.merges
.iter()
.map(|(pair, (rank, _))| (pair, rank))
.collect();
merges.sort_unstable_by_key(|k| *k.1);
let merges = merges
.into_iter()
.map(|(pair, _)| (self.vocab_r[&pair.0].clone(), self.vocab_r[&pair.1].clone()))
.collect::<Vec<_>>();
let ordered_vocab = OrderedVocabIter::new(&self.vocab_r);
model.serialize_field("vocab", &ordered_vocab)?;
model.serialize_field("merges", &merges)?;
model.end()
}
}
impl<'de> Deserialize<'de> for BPE {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_struct(
"BPE",
&[
"type",
"dropout",
"unk_token",
"continuing_subword_prefix",
"end_of_word_suffix",
"fuse_unk",
"byte_fallback",
"ignore_merges",
"vocab",
"merges",
],
BPEVisitor,
)
}
}
struct BPEVisitor;
impl<'de> Visitor<'de> for BPEVisitor {
type Value = BPE;
fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(fmt, "struct BPE")
}
fn visit_map<V>(self, mut map: V) -> std::result::Result<Self::Value, V::Error>
where
V: MapAccess<'de>,
{
let mut builder = BpeBuilder::new();
let mut vocab: Option<AHashMap<String, u32>> = None;
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum MergeType {
Tuple(Vec<(String, String)>),
Legacy(Vec<String>),
}
let mut merges: Option<MergeType> = None;
while let Some(key) = map.next_key::<String>()? {
match key.as_ref() {
"dropout" => {
if let Some(dropout) = map.next_value()? {
builder = builder.dropout(dropout);
}
}
"unk_token" => {
if let Some(unk) = map.next_value()? {
builder = builder.unk_token(unk);
}
}
"continuing_subword_prefix" => {
if let Some(prefix) = map.next_value()? {
builder = builder.continuing_subword_prefix(prefix);
}
}
"end_of_word_suffix" => {
if let Some(suffix) = map.next_value()? {
builder = builder.end_of_word_suffix(suffix);
}
}
"fuse_unk" => {
if let Some(suffix) = map.next_value()? {
builder = builder.fuse_unk(suffix);
}
}
"byte_fallback" => {
if let Some(suffix) = map.next_value()? {
builder = builder.byte_fallback(suffix);
}
}
"ignore_merges" => {
if let Some(suffix) = map.next_value()? {
builder = builder.ignore_merges(suffix);
}
}
"vocab" => vocab = Some(map.next_value()?),
"merges" => merges = Some(map.next_value()?),
"type" => match map.next_value()? {
"BPE" => {}
u => {
return Err(serde::de::Error::invalid_value(
serde::de::Unexpected::Str(u),
&"BPE",
))
}
},
_ => {}
}
}
if let (Some(vocab), Some(merges)) = (vocab, merges) {
let merges = match merges {
MergeType::Tuple(merges) => merges,
MergeType::Legacy(merges) => {
convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(Error::custom)?
}
};
builder = builder.vocab_and_merges(vocab, merges);
Ok(builder.build().map_err(Error::custom)?)
} else {
Err(Error::custom("Missing vocab/merges"))
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::models::bpe::Vocab;
#[test]
fn test_serialization() {
let vocab: Vocab = [
("<unk>".into(), 0),
("a".into(), 1),
("b".into(), 2),
("ab".into(), 3),
]
.iter()
.cloned()
.collect();
let bpe = BpeBuilder::default()
.vocab_and_merges(vocab, vec![("a".to_string(), "b".to_string())])
.unk_token("<unk>".to_string())
.ignore_merges(true)
.build()
.unwrap();
let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#;
let legacy = serde_json::from_str(legacy).unwrap();
assert_eq!(bpe, legacy);
let data = serde_json::to_string(&bpe).unwrap();
assert_eq!(
data,
r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":[["a","b"]]}"#
);
let reconstructed = serde_json::from_str(&data).unwrap();
assert_eq!(bpe, reconstructed);
let vocab: Vocab = [
("<unk>".into(), 0),
("a".into(), 1),
("b c d".into(), 2),
("ab c d".into(), 3),
]
.iter()
.cloned()
.collect();
let bpe = BpeBuilder::default()
.vocab_and_merges(vocab, vec![("a".to_string(), "b c d".to_string())])
.unk_token("<unk>".to_string())
.ignore_merges(true)
.build()
.unwrap();
let data = serde_json::to_string(&bpe).unwrap();
assert_eq!(
data,
r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b c d":2,"ab c d":3},"merges":[["a","b c d"]]}"#
);
let reconstructed = serde_json::from_str(&data).unwrap();
assert_eq!(bpe, reconstructed);
}
#[test]
fn test_serialization_ignore_merges() {
let vocab: Vocab = [("<unk>".into(), 0), ("a".into(), 1), ("b".into(), 2)]
.iter()
.cloned()
.collect();
let mut bpe = BpeBuilder::default()
.vocab_and_merges(vocab, vec![])
.unk_token("<unk>".to_string())
.ignore_merges(true)
.build()
.unwrap();
let bpe_string = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2},"merges":[]}"#;
assert_eq!(serde_json::from_str::<BPE>(bpe_string).unwrap(), bpe);
bpe.ignore_merges = false;
let bpe_string = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"vocab":{"<unk>":0,"a":1,"b":2},"merges":[]}"#;
assert_eq!(serde_json::from_str::<BPE>(bpe_string).unwrap(), bpe);
}
}