use alloc::boxed::Box;
#[cfg(feature = "json")]
pub use crate::any_registry::JsonAnyEntry;
#[cfg(feature = "json")]
pub use crate::extension_registry::JsonExtEntry;
#[cfg(feature = "json")]
#[deprecated(since = "0.3.0", note = "renamed to JsonAnyEntry")]
pub type AnyTypeEntry = JsonAnyEntry;
#[cfg(feature = "json")]
#[deprecated(since = "0.3.0", note = "renamed to JsonExtEntry")]
pub type ExtensionRegistryEntry = JsonExtEntry;
#[cfg(feature = "text")]
pub struct TextAnyEntry {
pub type_url: &'static str,
pub text_encode: fn(&[u8], &mut crate::text::TextEncoder<'_>) -> core::fmt::Result,
pub text_merge: AnyTextMergeFn,
}
#[cfg(feature = "text")]
impl core::fmt::Debug for TextAnyEntry {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("TextAnyEntry")
.field("type_url", &self.type_url)
.finish_non_exhaustive()
}
}
#[cfg(feature = "text")]
pub type AnyTextMergeFn =
fn(&mut crate::text::TextDecoder<'_>) -> Result<alloc::vec::Vec<u8>, crate::text::ParseError>;
#[cfg(feature = "text")]
pub struct TextExtEntry {
pub number: u32,
pub full_name: &'static str,
pub extendee: &'static str,
pub text_encode: fn(
u32,
&crate::unknown_fields::UnknownFields,
&mut crate::text::TextEncoder<'_>,
) -> core::fmt::Result,
pub text_merge: ExtTextMergeFn,
}
#[cfg(feature = "text")]
impl core::fmt::Debug for TextExtEntry {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("TextExtEntry")
.field("number", &self.number)
.field("full_name", &self.full_name)
.field("extendee", &self.extendee)
.finish_non_exhaustive()
}
}
#[cfg(feature = "text")]
pub type ExtTextMergeFn =
fn(
&mut crate::text::TextDecoder<'_>,
u32,
)
-> Result<alloc::vec::Vec<crate::unknown_fields::UnknownField>, crate::text::ParseError>;
#[cfg(feature = "text")]
#[derive(Default)]
struct TextAnyMap {
entries: hashbrown::HashMap<alloc::string::String, TextAnyEntry>,
}
#[cfg(feature = "text")]
impl TextAnyMap {
fn lookup(&self, type_url: &str) -> Option<&TextAnyEntry> {
self.entries.get(type_url)
}
}
#[cfg(feature = "text")]
#[derive(Default)]
struct TextExtMap {
by_number: hashbrown::HashMap<(alloc::string::String, u32), TextExtEntry>,
by_name: hashbrown::HashMap<alloc::string::String, (alloc::string::String, u32)>,
}
#[cfg(feature = "text")]
impl TextExtMap {
fn by_number(&self, extendee: &str, number: u32) -> Option<&TextExtEntry> {
use alloc::borrow::ToOwned;
self.by_number.get(&(extendee.to_owned(), number))
}
fn by_name(&self, full_name: &str) -> Option<&TextExtEntry> {
let key = self.by_name.get(full_name)?;
self.by_number.get(key)
}
}
#[derive(Default)]
pub struct TypeRegistry {
#[cfg(feature = "json")]
json_any: crate::any_registry::AnyRegistry,
#[cfg(feature = "json")]
json_ext: crate::extension_registry::ExtensionRegistry,
#[cfg(feature = "text")]
text_any: TextAnyMap,
#[cfg(feature = "text")]
text_ext: TextExtMap,
}
impl TypeRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[cfg(feature = "json")]
pub fn register_json_any(&mut self, entry: JsonAnyEntry) {
self.json_any.register(entry);
}
#[cfg(feature = "json")]
pub fn register_json_ext(&mut self, entry: JsonExtEntry) {
self.json_ext.register(entry);
}
#[cfg(feature = "text")]
pub fn register_text_any(&mut self, entry: TextAnyEntry) {
use alloc::borrow::ToOwned;
self.text_any
.entries
.insert(entry.type_url.to_owned(), entry);
}
#[cfg(feature = "text")]
pub fn register_text_ext(&mut self, entry: TextExtEntry) {
use alloc::borrow::ToOwned;
let key = (entry.extendee.to_owned(), entry.number);
self.text_ext
.by_name
.insert(entry.full_name.to_owned(), key.clone());
self.text_ext.by_number.insert(key, entry);
}
#[cfg(feature = "json")]
pub fn json_any_by_url(&self, type_url: &str) -> Option<&JsonAnyEntry> {
self.json_any.lookup(type_url)
}
#[cfg(feature = "json")]
pub fn json_ext_by_number(&self, extendee: &str, number: u32) -> Option<&JsonExtEntry> {
self.json_ext.by_number(extendee, number)
}
#[cfg(feature = "json")]
pub fn json_ext_by_name(&self, full_name: &str) -> Option<&JsonExtEntry> {
self.json_ext.by_name(full_name)
}
#[cfg(feature = "text")]
pub fn text_any_by_url(&self, type_url: &str) -> Option<&TextAnyEntry> {
self.text_any.lookup(type_url)
}
#[cfg(feature = "text")]
pub fn text_ext_by_number(&self, extendee: &str, number: u32) -> Option<&TextExtEntry> {
self.text_ext.by_number(extendee, number)
}
#[cfg(feature = "text")]
pub fn text_ext_by_name(&self, full_name: &str) -> Option<&TextExtEntry> {
self.text_ext.by_name(full_name)
}
}
#[deprecated(since = "0.3.0", note = "renamed to TypeRegistry")]
pub type JsonRegistry = TypeRegistry;
impl core::fmt::Debug for TypeRegistry {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("TypeRegistry").finish_non_exhaustive()
}
}
#[cfg(feature = "text")]
static TEXT_ANY: core::sync::atomic::AtomicPtr<TextAnyMap> =
core::sync::atomic::AtomicPtr::new(core::ptr::null_mut());
#[cfg(feature = "text")]
static TEXT_EXT: core::sync::atomic::AtomicPtr<TextExtMap> =
core::sync::atomic::AtomicPtr::new(core::ptr::null_mut());
pub fn set_type_registry(reg: TypeRegistry) {
let TypeRegistry {
#[cfg(feature = "json")]
json_any,
#[cfg(feature = "json")]
json_ext,
#[cfg(feature = "text")]
text_any,
#[cfg(feature = "text")]
text_ext,
} = reg;
#[cfg(feature = "json")]
#[allow(deprecated)]
{
crate::any_registry::set_any_registry(Box::new(json_any));
crate::extension_registry::set_extension_registry(Box::new(json_ext));
}
#[cfg(feature = "text")]
{
use core::sync::atomic::Ordering;
TEXT_ANY.swap(Box::into_raw(Box::new(text_any)), Ordering::Release);
TEXT_EXT.swap(Box::into_raw(Box::new(text_ext)), Ordering::Release);
}
}
#[deprecated(since = "0.3.0", note = "renamed to set_type_registry")]
pub fn set_json_registry(reg: TypeRegistry) {
set_type_registry(reg);
}
#[cfg(feature = "text")]
pub(crate) fn global_text_any(type_url: &str) -> Option<&'static TextAnyEntry> {
use core::sync::atomic::Ordering;
let ptr = TEXT_ANY.load(Ordering::Acquire);
if ptr.is_null() {
return None;
}
unsafe { &*ptr }.lookup(type_url)
}
#[cfg(feature = "text")]
pub(crate) fn global_text_ext_by_number(
extendee: &str,
number: u32,
) -> Option<&'static TextExtEntry> {
use core::sync::atomic::Ordering;
let ptr = TEXT_EXT.load(Ordering::Acquire);
if ptr.is_null() {
return None;
}
unsafe { &*ptr }.by_number(extendee, number)
}
#[cfg(feature = "text")]
pub(crate) fn global_text_ext_by_name(full_name: &str) -> Option<&'static TextExtEntry> {
use core::sync::atomic::Ordering;
let ptr = TEXT_EXT.load(Ordering::Acquire);
if ptr.is_null() {
return None;
}
unsafe { &*ptr }.by_name(full_name)
}
#[cfg(feature = "text")]
#[doc(hidden)]
pub fn clear_text_registry() {
use core::sync::atomic::Ordering;
TEXT_ANY.swap(core::ptr::null_mut(), Ordering::Release);
TEXT_EXT.swap(core::ptr::null_mut(), Ordering::Release);
}
#[cfg(feature = "json")]
pub fn any_to_json<M>(bytes: &[u8]) -> Result<serde_json::Value, alloc::string::String>
where
M: crate::Message + serde::Serialize,
{
use alloc::format;
let m = M::decode_from_slice(bytes).map_err(|e| format!("{e}"))?;
serde_json::to_value(&m).map_err(|e| format!("{e}"))
}
#[cfg(feature = "json")]
pub fn any_from_json<M>(v: serde_json::Value) -> Result<alloc::vec::Vec<u8>, alloc::string::String>
where
M: crate::Message + for<'de> serde::Deserialize<'de>,
{
use alloc::format;
let m: M = serde_json::from_value(v).map_err(|e| format!("{e}"))?;
Ok(m.encode_to_vec())
}
#[cfg(feature = "text")]
pub fn any_encode_text<M>(bytes: &[u8], enc: &mut crate::text::TextEncoder<'_>) -> core::fmt::Result
where
M: crate::Message + crate::text::TextFormat + Default,
{
let decoded = M::decode_from_slice(bytes);
debug_assert!(
decoded.is_ok(),
"any_encode_text: corrupt Any.value bytes: {:?}",
decoded.as_ref().err()
);
let m = decoded.unwrap_or_default();
enc.write_map_entry(|enc| m.encode_text(enc))
}
#[cfg(feature = "text")]
pub fn any_merge_text<M>(
dec: &mut crate::text::TextDecoder<'_>,
) -> Result<alloc::vec::Vec<u8>, crate::text::ParseError>
where
M: crate::Message + crate::text::TextFormat + Default,
{
let mut m = M::default();
dec.merge_message(&mut m)?;
Ok(m.encode_to_vec())
}
#[cfg(feature = "text")]
pub fn message_encode_text<M>(
n: u32,
f: &crate::unknown_fields::UnknownFields,
enc: &mut crate::text::TextEncoder<'_>,
) -> core::fmt::Result
where
M: crate::Message + crate::text::TextFormat + Default,
{
use crate::extension::codecs::MessageCodec;
use crate::extension::ExtensionCodec;
let m = MessageCodec::<M>::decode(n, f).unwrap_or_default();
enc.write_map_entry(|enc| m.encode_text(enc))
}
#[cfg(feature = "text")]
pub fn message_merge_text<M>(
dec: &mut crate::text::TextDecoder<'_>,
n: u32,
) -> Result<alloc::vec::Vec<crate::unknown_fields::UnknownField>, crate::text::ParseError>
where
M: crate::Message + crate::text::TextFormat + Default,
{
use crate::unknown_fields::{UnknownField, UnknownFieldData};
let mut m = M::default();
dec.merge_message(&mut m)?;
Ok(alloc::vec![UnknownField {
number: n,
data: UnknownFieldData::LengthDelimited(m.encode_to_vec()),
}])
}
#[cfg(feature = "text")]
pub fn group_encode_text<M>(
n: u32,
f: &crate::unknown_fields::UnknownFields,
enc: &mut crate::text::TextEncoder<'_>,
) -> core::fmt::Result
where
M: crate::Message + crate::text::TextFormat + Default,
{
use crate::extension::codecs::GroupCodec;
use crate::extension::ExtensionCodec;
let m = GroupCodec::<M>::decode(n, f).unwrap_or_default();
enc.write_map_entry(|enc| m.encode_text(enc))
}
#[cfg(feature = "text")]
pub fn group_merge_text<M>(
dec: &mut crate::text::TextDecoder<'_>,
n: u32,
) -> Result<alloc::vec::Vec<crate::unknown_fields::UnknownField>, crate::text::ParseError>
where
M: crate::Message + crate::text::TextFormat + Default,
{
use crate::unknown_fields::{UnknownField, UnknownFieldData, UnknownFields};
let mut m = M::default();
dec.merge_message(&mut m)?;
let bytes = m.encode_to_vec();
let inner = UnknownFields::decode_from_slice(&bytes).map_err(|_| {
crate::text::ParseError::new(
0,
0,
crate::text::ParseErrorKind::Internal(
"re-decoding freshly-encoded group message bytes failed",
),
)
})?;
Ok(alloc::vec![UnknownField {
number: n,
data: UnknownFieldData::Group(inner),
}])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_is_empty() {
let reg = TypeRegistry::default();
#[cfg(feature = "json")]
assert!(reg.json_any_by_url("anything").is_none());
#[cfg(feature = "text")]
assert!(reg.text_any_by_url("anything").is_none());
#[cfg(not(any(feature = "json", feature = "text")))]
let _ = reg;
}
#[test]
fn debug_impl() {
let reg = TypeRegistry::new();
let s = alloc::format!("{reg:?}");
assert!(s.contains("TypeRegistry"), "{s}");
}
#[cfg(feature = "json")]
mod json {
use super::*;
use crate::any_registry::with_any_registry;
use crate::extension_registry::{extension_registry, helpers};
fn dummy_to_json(_: &[u8]) -> Result<serde_json::Value, alloc::string::String> {
Ok(serde_json::json!({"ok": true}))
}
fn dummy_from_json(
_: serde_json::Value,
) -> Result<alloc::vec::Vec<u8>, alloc::string::String> {
Ok(alloc::vec![1, 2, 3])
}
fn any_entry(url: &'static str, wkt: bool) -> JsonAnyEntry {
JsonAnyEntry {
type_url: url,
to_json: dummy_to_json,
from_json: dummy_from_json,
is_wkt: wkt,
}
}
fn ext_entry(num: u32, name: &'static str, ext: &'static str) -> JsonExtEntry {
JsonExtEntry {
number: num,
full_name: name,
extendee: ext,
to_json: helpers::int32_to_json,
from_json: helpers::int32_from_json,
}
}
#[test]
fn register_json_any_and_ext_independently() {
let mut reg = TypeRegistry::new();
reg.register_json_any(any_entry("type.googleapis.com/test.Foo", false));
reg.register_json_ext(ext_entry(100, "test.ext", "test.Foo"));
assert!(reg
.json_any_by_url("type.googleapis.com/test.Foo")
.is_some());
assert!(reg.json_ext_by_number("test.Foo", 100).is_some());
assert!(reg.json_ext_by_name("test.ext").is_some());
}
static GLOBAL_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
#[test]
fn set_type_registry_installs_json_halves() {
let _g = GLOBAL_LOCK.lock().unwrap();
let mut reg = TypeRegistry::new();
reg.register_json_any(any_entry("type.googleapis.com/test.Unified", false));
reg.register_json_ext(ext_entry(200, "test.unified_ext", "test.Unified"));
set_type_registry(reg);
with_any_registry(|r| {
let r = r.expect("any registry installed");
assert!(r.lookup("type.googleapis.com/test.Unified").is_some());
});
let ext = extension_registry().expect("extension registry installed");
assert_eq!(ext.by_name("test.unified_ext").map(|e| e.number), Some(200));
crate::any_registry::clear_any_registry();
}
#[test]
#[allow(deprecated)]
fn deprecated_aliases_compile() {
let _: AnyTypeEntry = any_entry("x", false);
let _: ExtensionRegistryEntry = ext_entry(1, "a", "B");
let _: JsonRegistry = TypeRegistry::new();
let _g = GLOBAL_LOCK.lock().unwrap();
set_json_registry(TypeRegistry::new());
crate::any_registry::clear_any_registry();
}
}
#[cfg(feature = "text")]
mod text {
use super::*;
use crate::unknown_fields::{UnknownField, UnknownFieldData, UnknownFields};
#[derive(Default, Clone, PartialEq, Debug)]
struct Inner {
n: i32,
}
unsafe impl crate::DefaultInstance for Inner {
fn default_instance() -> &'static Self {
static I: crate::__private::OnceBox<Inner> = crate::__private::OnceBox::new();
I.get_or_init(|| alloc::boxed::Box::new(Inner::default()))
}
}
impl crate::Message for Inner {
fn compute_size(&self) -> u32 {
if self.n != 0 {
1 + crate::encoding::varint_len(self.n as i64 as u64) as u32
} else {
0
}
}
fn write_to(&self, buf: &mut impl bytes::BufMut) {
if self.n != 0 {
crate::encoding::Tag::new(1, crate::encoding::WireType::Varint).encode(buf);
crate::encoding::encode_varint(self.n as i64 as u64, buf);
}
}
fn merge_field(
&mut self,
tag: crate::encoding::Tag,
buf: &mut impl bytes::Buf,
_: u32,
) -> Result<(), crate::DecodeError> {
if tag.field_number() == 1 && tag.wire_type() == crate::encoding::WireType::Varint {
self.n = crate::encoding::decode_varint(buf)? as i32;
Ok(())
} else {
crate::encoding::skip_field(tag, buf)
}
}
fn cached_size(&self) -> u32 {
self.compute_size()
}
fn clear(&mut self) {
self.n = 0;
}
}
impl crate::text::TextFormat for Inner {
fn encode_text(&self, enc: &mut crate::text::TextEncoder<'_>) -> core::fmt::Result {
if self.n != 0 {
enc.write_field_name("n")?;
enc.write_i32(self.n)?;
}
Ok(())
}
fn merge_text(
&mut self,
dec: &mut crate::text::TextDecoder<'_>,
) -> Result<(), crate::text::ParseError> {
while let Some(name) = dec.read_field_name()? {
match name {
"n" => self.n = dec.read_i32()?,
_ => dec.skip_value()?,
}
}
Ok(())
}
}
fn fields_from(records: alloc::vec::Vec<UnknownField>) -> UnknownFields {
let mut f = UnknownFields::new();
for r in records {
f.push(r);
}
f
}
#[test]
fn message_text_roundtrip() {
let mut dec = crate::text::TextDecoder::new("f { n: 7 }");
dec.read_field_name().unwrap();
let records = message_merge_text::<Inner>(&mut dec, 50).unwrap();
assert_eq!(records.len(), 1);
assert_eq!(records[0].number, 50);
let UnknownFieldData::LengthDelimited(ref bytes) = records[0].data else {
panic!("expected LengthDelimited, got {:?}", records[0].data);
};
assert_eq!(*bytes, alloc::vec![0x08, 0x07]);
let fields = fields_from(records);
let mut s = alloc::string::String::new();
let mut enc = crate::text::TextEncoder::new(&mut s);
enc.write_extension_name("pkg.ext").unwrap();
message_encode_text::<Inner>(50, &fields, &mut enc).unwrap();
assert_eq!(s, "[pkg.ext] {n: 7}");
}
#[test]
fn group_text_roundtrip() {
let mut dec = crate::text::TextDecoder::new("f { n: 7 }");
dec.read_field_name().unwrap();
let records = group_merge_text::<Inner>(&mut dec, 121).unwrap();
assert_eq!(records.len(), 1);
assert_eq!(records[0].number, 121);
let UnknownFieldData::Group(ref inner) = records[0].data else {
panic!("expected Group, got {:?}", records[0].data);
};
let inner_vec: alloc::vec::Vec<_> = inner.iter().collect();
assert_eq!(inner_vec.len(), 1);
assert_eq!(inner_vec[0].number, 1);
assert_eq!(inner_vec[0].data, UnknownFieldData::Varint(7));
let fields = fields_from(records);
let mut s = alloc::string::String::new();
let mut enc = crate::text::TextEncoder::new(&mut s);
enc.write_extension_name("pkg.groupfield").unwrap();
group_encode_text::<Inner>(121, &fields, &mut enc).unwrap();
assert_eq!(s, "[pkg.groupfield] {n: 7}");
}
#[test]
fn helpers_satisfy_fn_pointer_signature() {
let _: fn(&[u8], &mut crate::text::TextEncoder<'_>) -> core::fmt::Result =
any_encode_text::<Inner>;
let _: AnyTextMergeFn = any_merge_text::<Inner>;
let _: fn(u32, &UnknownFields, &mut crate::text::TextEncoder<'_>) -> core::fmt::Result =
message_encode_text::<Inner>;
let _: ExtTextMergeFn = message_merge_text::<Inner>;
let _: fn(u32, &UnknownFields, &mut crate::text::TextEncoder<'_>) -> core::fmt::Result =
group_encode_text::<Inner>;
let _: ExtTextMergeFn = group_merge_text::<Inner>;
}
#[test]
fn register_and_lookup_text_any() {
let mut reg = TypeRegistry::new();
reg.register_text_any(TextAnyEntry {
type_url: "type.example.com/Inner",
text_encode: any_encode_text::<Inner>,
text_merge: any_merge_text::<Inner>,
});
assert!(reg.text_any_by_url("type.example.com/Inner").is_some());
assert!(reg.text_any_by_url("type.example.com/Missing").is_none());
}
#[test]
fn register_and_lookup_text_ext_both_axes() {
let mut reg = TypeRegistry::new();
reg.register_text_ext(TextExtEntry {
number: 50,
full_name: "pkg.inner_ext",
extendee: "pkg.Carrier",
text_encode: message_encode_text::<Inner>,
text_merge: message_merge_text::<Inner>,
});
let e = reg.text_ext_by_number("pkg.Carrier", 50).unwrap();
assert_eq!(e.full_name, "pkg.inner_ext");
assert!(reg.text_ext_by_number("pkg.Carrier", 99).is_none());
assert!(reg.text_ext_by_number("other.Msg", 50).is_none());
assert_eq!(reg.text_ext_by_name("pkg.inner_ext").unwrap().number, 50);
assert!(reg.text_ext_by_name("pkg.missing").is_none());
}
static GLOBAL_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
#[test]
fn set_type_registry_installs_text_halves() {
let _g = GLOBAL_LOCK.lock().unwrap();
let mut reg = TypeRegistry::new();
reg.register_text_any(TextAnyEntry {
type_url: "type.example.com/Global",
text_encode: any_encode_text::<Inner>,
text_merge: any_merge_text::<Inner>,
});
reg.register_text_ext(TextExtEntry {
number: 77,
full_name: "pkg.global_ext",
extendee: "pkg.Msg",
text_encode: message_encode_text::<Inner>,
text_merge: message_merge_text::<Inner>,
});
set_type_registry(reg);
assert!(global_text_any("type.example.com/Global").is_some());
assert!(global_text_any("type.example.com/Absent").is_none());
assert_eq!(
global_text_ext_by_name("pkg.global_ext").map(|e| e.number),
Some(77)
);
assert!(global_text_ext_by_number("pkg.Msg", 77).is_some());
clear_text_registry();
assert!(global_text_any("type.example.com/Global").is_none());
}
}
}