use std::any::TypeId;
use std::collections::HashMap;
use std::collections::HashSet;
use anyhow::anyhow;
use anyhow::bail;
use anyhow::ensure;
use anyhow::Result;
use fbthrift::binary_protocol;
use fbthrift::compact_protocol;
use fbthrift::simplejson_protocol;
use fbthrift::GetUri;
use protocol::StandardProtocol;
use crate::universal_name::ensure_registered;
use crate::universal_name::get_universal_hash;
use crate::universal_name::get_universal_hash_prefix_sha_256;
use crate::universal_name::UniversalHashAlgorithm;
const UNIVERSAL_HASH_PREFIX_SHA_256_LEN: i8 = 16;
pub struct AnyRegistry {
uri_to_typeid: HashMap<&'static str, TypeId>,
typeid_to_uri: HashMap<TypeId, &'static str>,
alg_to_hashes: HashMap<UniversalHashAlgorithm, HashSet<Vec<u8>>>,
}
pub trait SerializeRef = binary_protocol::SerializeRef
+ compact_protocol::SerializeRef
+ simplejson_protocol::SerializeRef;
pub trait DeserializeSlice = binary_protocol::DeserializeSlice
+ compact_protocol::DeserializeSlice
+ simplejson_protocol::DeserializeSlice;
impl AnyRegistry {
pub fn new() -> Self {
let mut alg_to_hashes = HashMap::new();
alg_to_hashes.insert(UniversalHashAlgorithm::Sha2_256, HashSet::new());
Self {
uri_to_typeid: HashMap::new(),
typeid_to_uri: HashMap::new(),
alg_to_hashes,
}
}
pub fn register_type<T: 'static + GetUri>(&mut self) -> Result<bool> {
let uri = T::uri();
let type_id = TypeId::of::<T>();
if self.uri_to_typeid.contains_key(uri) || self.typeid_to_uri.contains_key(&type_id) {
return Ok(false);
}
self.uri_to_typeid.insert(uri, type_id);
self.typeid_to_uri.insert(type_id, uri);
for (alg, hashes) in self.alg_to_hashes.iter_mut() {
let hash = get_universal_hash(*alg, uri)?;
hashes.insert(hash);
}
Ok(true)
}
pub fn has_type<T: 'static + GetUri>(&self, obj: &any::Any) -> Result<bool> {
let type_uri = T::uri();
let type_hash_prefix_sha2_256 =
get_universal_hash_prefix_sha_256(type_uri, UNIVERSAL_HASH_PREFIX_SHA_256_LEN)?;
if let Some(obj_hash_prefix_sha2_256) = &obj.typeHashPrefixSha2_256 {
Ok(type_hash_prefix_sha2_256 == *obj_hash_prefix_sha2_256)
} else if let Some(obj_uri) = obj.r#type.as_ref() {
Ok(type_uri == obj_uri)
} else {
bail!("No type information found");
}
}
pub fn store<T: 'static + SerializeRef>(
&mut self,
obj: &T,
protocol: StandardProtocol,
) -> Result<any::Any> {
let type_id = TypeId::of::<T>();
let uri = self
.typeid_to_uri
.get(&type_id)
.ok_or_else(|| anyhow!("Type {:?} not registered", type_id))?;
let hash_prefix =
get_universal_hash_prefix_sha_256(uri, UNIVERSAL_HASH_PREFIX_SHA_256_LEN)?;
Ok(any::Any {
r#type: Some(uri.to_string()),
typeHashPrefixSha2_256: Some(hash_prefix),
protocol: Some(protocol),
data: serialize(obj, protocol)?,
..Default::default()
})
}
pub fn load<T: DeserializeSlice>(&self, obj: &any::Any) -> Result<T> {
if let Some(type_hash_prefix_sha2_256) = &obj.typeHashPrefixSha2_256 {
let hashes = self
.alg_to_hashes
.get(&UniversalHashAlgorithm::Sha2_256)
.ok_or_else(|| anyhow!("Hash algorithm Sha2_256 not supported"))?;
ensure_registered(hashes, type_hash_prefix_sha2_256)?;
} else if let Some(uri) = obj.r#type.as_ref() {
ensure!(
self.uri_to_typeid.contains_key(uri.as_str()),
"Type '{}' not registered",
uri
);
} else {
bail!("No type information found");
}
deserialize(&obj.data, obj.protocol.unwrap_or(StandardProtocol::Compact))
}
}
fn serialize<T: SerializeRef>(obj: &T, protocol: StandardProtocol) -> Result<Vec<u8>> {
match protocol {
StandardProtocol::Binary => Ok(binary_protocol::serialize(obj).to_vec()),
StandardProtocol::Compact => Ok(compact_protocol::serialize(obj).to_vec()),
StandardProtocol::SimpleJson => Ok(simplejson_protocol::serialize(obj).to_vec()),
unsupported => Err(anyhow!("Unsupported protocol: {:?}", unsupported)),
}
}
fn deserialize<T: DeserializeSlice>(data: &[u8], protocol: StandardProtocol) -> Result<T> {
match protocol {
StandardProtocol::Binary => binary_protocol::deserialize(data),
StandardProtocol::Compact => compact_protocol::deserialize(data),
StandardProtocol::SimpleJson => simplejson_protocol::deserialize(data),
unsupported => Err(anyhow!("Unsupported protocol: {:?}", unsupported)),
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use maplit::btreemap;
use testset::struct_map_string_i32;
use super::*;
fn get_test_object() -> struct_map_string_i32 {
struct_map_string_i32 {
field_1: btreemap! {"Answer to the Ultimate Question of Life, the Universe, and Everything.".to_owned() => 42},
..Default::default()
}
}
fn test_serialize_round_trip_for_protocol(protocol: StandardProtocol) -> Result<()> {
let original = get_test_object();
let serialized = serialize(&original, protocol)?;
let deserialized = deserialize(&serialized, protocol)?;
assert_eq!(original, deserialized);
Ok(())
}
#[test]
fn test_round_trip() -> Result<()> {
let protocols = vec![
StandardProtocol::Binary,
StandardProtocol::Compact,
StandardProtocol::SimpleJson,
];
for protocol in protocols {
test_serialize_round_trip_for_protocol(protocol)?;
}
Ok(())
}
#[test]
fn test_unsupported_protocol() {
let test_object = get_test_object();
assert!(serialize(&test_object, StandardProtocol::Custom).is_err());
assert!(serialize(&test_object, StandardProtocol::Json).is_err());
assert!(deserialize::<struct_map_string_i32>(b"", StandardProtocol::Custom).is_err());
assert!(deserialize::<struct_map_string_i32>(b"", StandardProtocol::Json).is_err());
}
}