use core::marker::PhantomData;
use crate::unknown_fields::UnknownFields;
#[derive(Debug)]
pub struct Extension<C: ExtensionCodec> {
number: u32,
extendee: &'static str,
default: Option<fn() -> C::Value>,
_codec: PhantomData<fn() -> C>,
}
impl<C: ExtensionCodec> Extension<C> {
pub const fn new(number: u32, extendee: &'static str) -> Self {
Self {
number,
extendee,
default: None,
_codec: PhantomData,
}
}
pub const fn with_default(
number: u32,
extendee: &'static str,
default: fn() -> C::Value,
) -> Self {
Self {
number,
extendee,
default: Some(default),
_codec: PhantomData,
}
}
pub const fn number(&self) -> u32 {
self.number
}
pub const fn extendee(&self) -> &'static str {
self.extendee
}
}
#[track_caller]
#[inline]
fn assert_extendee<C: ExtensionCodec>(ext: &Extension<C>, expected: &'static str) {
assert_eq!(
ext.extendee, expected,
"extension at field {} extends `{}`, not `{}`",
ext.number, ext.extendee, expected
);
}
impl<C: ExtensionCodec> Clone for Extension<C> {
fn clone(&self) -> Self {
*self
}
}
impl<C: ExtensionCodec> Copy for Extension<C> {}
pub trait ExtensionCodec {
type Value;
type Output;
fn decode(number: u32, fields: &UnknownFields) -> Self::Output;
fn encode(number: u32, value: Self::Value, fields: &mut UnknownFields);
}
pub trait ExtensionSet {
const PROTO_FQN: &'static str;
fn unknown_fields(&self) -> &UnknownFields;
fn unknown_fields_mut(&mut self) -> &mut UnknownFields;
#[track_caller]
fn extension<C: ExtensionCodec>(&self, ext: &Extension<C>) -> C::Output {
assert_extendee(ext, Self::PROTO_FQN);
C::decode(ext.number, self.unknown_fields())
}
#[track_caller]
fn set_extension<C: ExtensionCodec>(&mut self, ext: &Extension<C>, value: C::Value) {
assert_extendee(ext, Self::PROTO_FQN);
self.unknown_fields_mut().retain(|f| f.number != ext.number);
C::encode(ext.number, value, self.unknown_fields_mut());
}
fn has_extension<C: ExtensionCodec>(&self, ext: &Extension<C>) -> bool {
if ext.extendee != Self::PROTO_FQN {
return false;
}
self.unknown_fields().iter().any(|f| f.number == ext.number)
}
#[track_caller]
fn clear_extension<C: ExtensionCodec>(&mut self, ext: &Extension<C>) {
assert_extendee(ext, Self::PROTO_FQN);
self.unknown_fields_mut().retain(|f| f.number != ext.number);
}
#[must_use]
#[track_caller]
fn extension_or_default<C>(&self, ext: &Extension<C>) -> C::Value
where
C: ExtensionCodec<Output = Option<<C as ExtensionCodec>::Value>>,
C::Value: Default,
{
self.extension(ext)
.or_else(|| ext.default.map(|f| f()))
.unwrap_or_default()
}
}
pub mod codecs {
use core::marker::PhantomData;
use alloc::string::String;
use alloc::vec::Vec;
use bytes::Buf;
use crate::encoding::{decode_varint, encode_varint};
use crate::message::Message;
use crate::types::{
zigzag_decode_i32, zigzag_decode_i64, zigzag_encode_i32, zigzag_encode_i64,
};
use crate::unknown_fields::{UnknownField, UnknownFieldData, UnknownFields};
use super::ExtensionCodec;
pub trait SingularCodec {
type Value;
fn decode_one(data: &UnknownFieldData) -> Option<Self::Value>;
fn decode_packed(bytes: &[u8], out: &mut Vec<Self::Value>);
fn encode_one(value: &Self::Value) -> UnknownFieldData;
}
pub trait PackableCodec: SingularCodec {
fn encode_packed(value: &Self::Value, buf: &mut Vec<u8>);
}
macro_rules! varint_codec {
($name:ident, $ty:ty, |$d:ident| $decode:expr, |$e:ident| $encode:expr $(,)?) => {
#[doc = concat!("Codec for the `", stringify!($name), "` proto scalar type.")]
pub struct $name;
impl ExtensionCodec for $name {
type Value = $ty;
type Output = Option<$ty>;
fn decode(number: u32, fields: &UnknownFields) -> Option<$ty> {
fields
.iter()
.rev()
.filter(|f| f.number == number)
.find_map(|f| Self::decode_one(&f.data))
}
fn encode(number: u32, value: $ty, fields: &mut UnknownFields) {
fields.push(UnknownField {
number,
data: Self::encode_one(&value),
});
}
}
impl SingularCodec for $name {
type Value = $ty;
fn decode_one(data: &UnknownFieldData) -> Option<$ty> {
match *data {
UnknownFieldData::Varint($d) => Some($decode),
_ => None,
}
}
fn decode_packed(bytes: &[u8], out: &mut Vec<$ty>) {
let mut buf = bytes;
while buf.has_remaining() {
match decode_varint(&mut buf) {
Ok($d) => out.push($decode),
Err(_) => return,
}
}
}
fn encode_one($e: &$ty) -> UnknownFieldData {
let $e = *$e;
UnknownFieldData::Varint($encode)
}
}
impl PackableCodec for $name {
fn encode_packed($e: &$ty, buf: &mut Vec<u8>) {
let $e = *$e;
encode_varint($encode, buf);
}
}
};
}
varint_codec!(Int32, i32, |v| v as i32, |v| v as i64 as u64);
varint_codec!(Int64, i64, |v| v as i64, |v| v as u64);
varint_codec!(Uint32, u32, |v| v as u32, |v| v as u64);
varint_codec!(Uint64, u64, |v| v, |v| v);
varint_codec!(
Sint32,
i32,
|v| zigzag_decode_i32(v as u32),
|v| zigzag_encode_i32(v) as u64
);
varint_codec!(
Sint64,
i64,
|v| zigzag_decode_i64(v),
|v| zigzag_encode_i64(v)
);
varint_codec!(Bool, bool, |v| v != 0, |v| v as u64);
pub struct EnumI32;
impl ExtensionCodec for EnumI32 {
type Value = i32;
type Output = Option<i32>;
fn decode(number: u32, fields: &UnknownFields) -> Option<i32> {
Int32::decode(number, fields)
}
fn encode(number: u32, value: i32, fields: &mut UnknownFields) {
Int32::encode(number, value, fields)
}
}
impl SingularCodec for EnumI32 {
type Value = i32;
fn decode_one(data: &UnknownFieldData) -> Option<i32> {
Int32::decode_one(data)
}
fn decode_packed(bytes: &[u8], out: &mut Vec<i32>) {
Int32::decode_packed(bytes, out)
}
fn encode_one(value: &i32) -> UnknownFieldData {
Int32::encode_one(value)
}
}
impl PackableCodec for EnumI32 {
fn encode_packed(value: &i32, buf: &mut Vec<u8>) {
Int32::encode_packed(value, buf)
}
}
macro_rules! fixed32_codec {
($name:ident, $ty:ty, |$d:ident| $decode:expr, |$e:ident| $encode:expr $(,)?) => {
#[doc = concat!("Codec for the `", stringify!($name), "` proto scalar type.")]
pub struct $name;
impl ExtensionCodec for $name {
type Value = $ty;
type Output = Option<$ty>;
fn decode(number: u32, fields: &UnknownFields) -> Option<$ty> {
fields
.iter()
.rev()
.filter(|f| f.number == number)
.find_map(|f| Self::decode_one(&f.data))
}
fn encode(number: u32, value: $ty, fields: &mut UnknownFields) {
fields.push(UnknownField {
number,
data: Self::encode_one(&value),
});
}
}
impl SingularCodec for $name {
type Value = $ty;
fn decode_one(data: &UnknownFieldData) -> Option<$ty> {
match *data {
UnknownFieldData::Fixed32($d) => Some($decode),
_ => None,
}
}
fn decode_packed(bytes: &[u8], out: &mut Vec<$ty>) {
let mut buf = bytes;
while buf.remaining() >= 4 {
let $d = buf.get_u32_le();
out.push($decode);
}
}
fn encode_one($e: &$ty) -> UnknownFieldData {
let $e = *$e;
UnknownFieldData::Fixed32($encode)
}
}
impl PackableCodec for $name {
fn encode_packed($e: &$ty, buf: &mut Vec<u8>) {
use bytes::BufMut;
let $e = *$e;
buf.put_u32_le($encode);
}
}
};
}
macro_rules! fixed64_codec {
($name:ident, $ty:ty, |$d:ident| $decode:expr, |$e:ident| $encode:expr $(,)?) => {
#[doc = concat!("Codec for the `", stringify!($name), "` proto scalar type.")]
pub struct $name;
impl ExtensionCodec for $name {
type Value = $ty;
type Output = Option<$ty>;
fn decode(number: u32, fields: &UnknownFields) -> Option<$ty> {
fields
.iter()
.rev()
.filter(|f| f.number == number)
.find_map(|f| Self::decode_one(&f.data))
}
fn encode(number: u32, value: $ty, fields: &mut UnknownFields) {
fields.push(UnknownField {
number,
data: Self::encode_one(&value),
});
}
}
impl SingularCodec for $name {
type Value = $ty;
fn decode_one(data: &UnknownFieldData) -> Option<$ty> {
match *data {
UnknownFieldData::Fixed64($d) => Some($decode),
_ => None,
}
}
fn decode_packed(bytes: &[u8], out: &mut Vec<$ty>) {
let mut buf = bytes;
while buf.remaining() >= 8 {
let $d = buf.get_u64_le();
out.push($decode);
}
}
fn encode_one($e: &$ty) -> UnknownFieldData {
let $e = *$e;
UnknownFieldData::Fixed64($encode)
}
}
impl PackableCodec for $name {
fn encode_packed($e: &$ty, buf: &mut Vec<u8>) {
use bytes::BufMut;
let $e = *$e;
buf.put_u64_le($encode);
}
}
};
}
fixed32_codec!(Fixed32, u32, |v| v, |v| v);
fixed32_codec!(Sfixed32, i32, |v| v as i32, |v| v as u32);
fixed32_codec!(Float, f32, |v| f32::from_bits(v), |v| v.to_bits());
fixed64_codec!(Fixed64, u64, |v| v, |v| v);
fixed64_codec!(Sfixed64, i64, |v| v as i64, |v| v as u64);
fixed64_codec!(Double, f64, |v| f64::from_bits(v), |v| v.to_bits());
pub struct StringCodec;
impl ExtensionCodec for StringCodec {
type Value = String;
type Output = Option<String>;
fn decode(number: u32, fields: &UnknownFields) -> Option<String> {
fields
.iter()
.rev()
.filter(|f| f.number == number)
.find_map(|f| Self::decode_one(&f.data))
}
fn encode(number: u32, value: String, fields: &mut UnknownFields) {
fields.push(UnknownField {
number,
data: UnknownFieldData::LengthDelimited(value.into_bytes()),
});
}
}
impl SingularCodec for StringCodec {
type Value = String;
fn decode_one(data: &UnknownFieldData) -> Option<String> {
match data {
UnknownFieldData::LengthDelimited(bytes) => String::from_utf8(bytes.clone()).ok(),
_ => None,
}
}
fn decode_packed(_bytes: &[u8], _out: &mut Vec<String>) {}
fn encode_one(value: &String) -> UnknownFieldData {
UnknownFieldData::LengthDelimited(value.clone().into_bytes())
}
}
pub struct BytesCodec;
impl ExtensionCodec for BytesCodec {
type Value = Vec<u8>;
type Output = Option<Vec<u8>>;
fn decode(number: u32, fields: &UnknownFields) -> Option<Vec<u8>> {
fields
.iter()
.rev()
.filter(|f| f.number == number)
.find_map(|f| Self::decode_one(&f.data))
}
fn encode(number: u32, value: Vec<u8>, fields: &mut UnknownFields) {
fields.push(UnknownField {
number,
data: UnknownFieldData::LengthDelimited(value),
});
}
}
impl SingularCodec for BytesCodec {
type Value = Vec<u8>;
fn decode_one(data: &UnknownFieldData) -> Option<Vec<u8>> {
match data {
UnknownFieldData::LengthDelimited(bytes) => Some(bytes.clone()),
_ => None,
}
}
fn decode_packed(_bytes: &[u8], _out: &mut Vec<Vec<u8>>) {}
fn encode_one(value: &Vec<u8>) -> UnknownFieldData {
UnknownFieldData::LengthDelimited(value.clone())
}
}
pub struct MessageCodec<M>(PhantomData<fn() -> M>);
impl<M: Message + Default> ExtensionCodec for MessageCodec<M> {
type Value = M;
type Output = Option<M>;
fn decode(number: u32, fields: &UnknownFields) -> Option<M> {
let mut msg: Option<M> = None;
for f in fields.iter().filter(|f| f.number == number) {
if let UnknownFieldData::LengthDelimited(bytes) = &f.data {
let m = msg.get_or_insert_with(M::default);
if m.merge_from_slice(bytes).is_err() {
return None;
}
}
}
msg
}
fn encode(number: u32, value: M, fields: &mut UnknownFields) {
fields.push(UnknownField {
number,
data: UnknownFieldData::LengthDelimited(value.encode_to_vec()),
});
}
}
impl<M: Message + Default> SingularCodec for MessageCodec<M> {
type Value = M;
fn decode_one(data: &UnknownFieldData) -> Option<M> {
match data {
UnknownFieldData::LengthDelimited(bytes) => {
let mut m = M::default();
m.merge_from_slice(bytes).ok()?;
Some(m)
}
_ => None,
}
}
fn decode_packed(_bytes: &[u8], _out: &mut Vec<M>) {}
fn encode_one(value: &M) -> UnknownFieldData {
UnknownFieldData::LengthDelimited(value.encode_to_vec())
}
}
pub struct GroupCodec<M>(PhantomData<fn() -> M>);
impl<M: Message + Default> ExtensionCodec for GroupCodec<M> {
type Value = M;
type Output = Option<M>;
fn decode(number: u32, fields: &UnknownFields) -> Option<M> {
let mut msg: Option<M> = None;
for f in fields.iter().filter(|f| f.number == number) {
if let UnknownFieldData::Group(inner) = &f.data {
let m = msg.get_or_insert_with(M::default);
let mut buf = Vec::with_capacity(inner.encoded_len());
inner.write_to(&mut buf);
if m.merge_from_slice(&buf).is_err() {
return None;
}
}
}
msg
}
fn encode(number: u32, value: M, fields: &mut UnknownFields) {
let bytes = value.encode_to_vec();
let inner = UnknownFields::decode_from_slice(&bytes)
.expect("BUG: re-decoding freshly-encoded message bytes failed");
fields.push(UnknownField {
number,
data: UnknownFieldData::Group(inner),
});
}
}
impl<M: Message + Default> SingularCodec for GroupCodec<M> {
type Value = M;
fn decode_one(data: &UnknownFieldData) -> Option<M> {
match data {
UnknownFieldData::Group(inner) => {
let mut buf = Vec::with_capacity(inner.encoded_len());
inner.write_to(&mut buf);
let mut m = M::default();
m.merge_from_slice(&buf).ok()?;
Some(m)
}
_ => None,
}
}
fn decode_packed(_bytes: &[u8], _out: &mut Vec<M>) {}
fn encode_one(value: &M) -> UnknownFieldData {
let bytes = value.encode_to_vec();
let inner = UnknownFields::decode_from_slice(&bytes)
.expect("BUG: re-decoding freshly-encoded message bytes failed");
UnknownFieldData::Group(inner)
}
}
pub struct Repeated<C>(PhantomData<fn() -> C>);
impl<C: SingularCodec> ExtensionCodec for Repeated<C> {
type Value = Vec<C::Value>;
type Output = Vec<C::Value>;
fn decode(number: u32, fields: &UnknownFields) -> Vec<C::Value> {
decode_repeated::<C>(number, fields)
}
fn encode(number: u32, value: Vec<C::Value>, fields: &mut UnknownFields) {
for v in &value {
fields.push(UnknownField {
number,
data: C::encode_one(v),
});
}
}
}
pub struct PackedRepeated<C>(PhantomData<fn() -> C>);
impl<C: PackableCodec> ExtensionCodec for PackedRepeated<C> {
type Value = Vec<C::Value>;
type Output = Vec<C::Value>;
fn decode(number: u32, fields: &UnknownFields) -> Vec<C::Value> {
decode_repeated::<C>(number, fields)
}
fn encode(number: u32, value: Vec<C::Value>, fields: &mut UnknownFields) {
if value.is_empty() {
return;
}
let mut buf = Vec::new();
for v in &value {
C::encode_packed(v, &mut buf);
}
fields.push(UnknownField {
number,
data: UnknownFieldData::LengthDelimited(buf),
});
}
}
fn decode_repeated<C: SingularCodec>(number: u32, fields: &UnknownFields) -> Vec<C::Value> {
let mut out = Vec::new();
for f in fields.iter().filter(|f| f.number == number) {
if let Some(v) = C::decode_one(&f.data) {
out.push(v);
} else if let UnknownFieldData::LengthDelimited(bytes) = &f.data {
C::decode_packed(bytes, &mut out);
}
}
out
}
}
#[cfg(test)]
mod tests {
use super::codecs::*;
use super::*;
use crate::unknown_fields::{UnknownField, UnknownFieldData};
use alloc::string::{String, ToString};
use alloc::{vec, vec::Vec};
const CARRIER: &str = "test.Carrier";
#[derive(Default)]
struct Carrier {
unknown: UnknownFields,
}
impl ExtensionSet for Carrier {
const PROTO_FQN: &'static str = CARRIER;
fn unknown_fields(&self) -> &UnknownFields {
&self.unknown
}
fn unknown_fields_mut(&mut self) -> &mut UnknownFields {
&mut self.unknown
}
}
fn varint(number: u32, v: u64) -> UnknownField {
UnknownField {
number,
data: UnknownFieldData::Varint(v),
}
}
fn fixed32(number: u32, v: u32) -> UnknownField {
UnknownField {
number,
data: UnknownFieldData::Fixed32(v),
}
}
fn fixed64(number: u32, v: u64) -> UnknownField {
UnknownField {
number,
data: UnknownFieldData::Fixed64(v),
}
}
fn ld(number: u32, data: Vec<u8>) -> UnknownField {
UnknownField {
number,
data: UnknownFieldData::LengthDelimited(data),
}
}
fn group(number: u32, inner: UnknownFields) -> UnknownField {
UnknownField {
number,
data: UnknownFieldData::Group(inner),
}
}
#[test]
fn extension_const_fn() {
const EXT: Extension<Int32> = Extension::new(50001, CARRIER);
assert_eq!(EXT.number(), 50001);
assert_eq!(EXT.extendee(), CARRIER);
let copy = EXT; assert_eq!(copy.number(), 50001);
}
#[test]
fn extension_with_default_const_fn() {
const fn seven() -> i32 {
7
}
const E: Extension<Int32> = Extension::with_default(1, CARRIER, seven);
assert_eq!(E.number(), 1);
assert_eq!(E.extendee(), CARRIER);
let copy = E; assert_eq!(copy.number(), 1);
}
const WRONG: Extension<Int32> = Extension::new(1, "other.Message");
#[test]
#[should_panic(expected = "extends `other.Message`, not `test.Carrier`")]
fn extension_panics_on_extendee_mismatch() {
let c = Carrier::default();
let _ = c.extension(&WRONG);
}
#[test]
#[should_panic(expected = "extends `other.Message`, not `test.Carrier`")]
fn set_extension_panics_on_extendee_mismatch() {
let mut c = Carrier::default();
c.set_extension(&WRONG, 42);
}
#[test]
#[should_panic(expected = "extends `other.Message`, not `test.Carrier`")]
fn clear_extension_panics_on_extendee_mismatch() {
let mut c = Carrier::default();
c.clear_extension(&WRONG);
}
#[test]
#[should_panic(expected = "extends `other.Message`, not `test.Carrier`")]
fn extension_or_default_panics_on_extendee_mismatch() {
let c = Carrier::default();
let _ = c.extension_or_default(&WRONG);
}
#[test]
fn has_extension_returns_false_on_extendee_mismatch() {
let mut c = Carrier::default();
c.unknown.push(varint(1, 42));
assert!(!c.has_extension(&WRONG));
const RIGHT: Extension<Int32> = Extension::new(1, CARRIER);
assert!(c.has_extension(&RIGHT));
}
#[test]
fn extension_or_default_returns_declared_default_when_absent() {
const fn seven() -> i32 {
7
}
const E: Extension<Int32> = Extension::with_default(1, CARRIER, seven);
let c = Carrier::default();
assert_eq!(c.extension(&E), None);
assert!(!c.has_extension(&E));
assert_eq!(c.extension_or_default(&E), 7);
}
#[test]
fn extension_or_default_set_value_wins() {
const fn seven() -> i32 {
7
}
const E: Extension<Int32> = Extension::with_default(1, CARRIER, seven);
let mut c = Carrier::default();
c.set_extension(&E, 99);
assert_eq!(c.extension_or_default(&E), 99);
assert_eq!(c.extension(&E), Some(99));
}
#[test]
fn extension_or_default_falls_back_to_type_default() {
const E: Extension<Int32> = Extension::new(1, CARRIER);
let c = Carrier::default();
assert_eq!(c.extension_or_default(&E), 0);
assert_eq!(c.extension(&E), None);
}
#[test]
fn extension_or_default_string_allocates_per_call() {
fn hello() -> String {
String::from("hello")
}
const E: Extension<StringCodec> = Extension::with_default(1, CARRIER, hello);
let c = Carrier::default();
assert_eq!(c.extension_or_default(&E), "hello");
let a = c.extension_or_default(&E);
let b = c.extension_or_default(&E);
assert_eq!(a, b);
assert_ne!(a.as_ptr(), b.as_ptr());
}
#[test]
fn extension_or_default_bytes() {
fn blob() -> Vec<u8> {
alloc::vec![0xDE, 0xAD]
}
const E: Extension<BytesCodec> = Extension::with_default(1, CARRIER, blob);
let c = Carrier::default();
assert_eq!(c.extension_or_default(&E), alloc::vec![0xDE, 0xAD]);
}
#[test]
fn extension_or_default_zero_is_present_not_default() {
const fn seven() -> i32 {
7
}
const E: Extension<Int32> = Extension::with_default(1, CARRIER, seven);
let mut c = Carrier::default();
c.set_extension(&E, 0);
assert_eq!(c.extension_or_default(&E), 0);
}
#[test]
fn singular_scalar_roundtrip() {
#[rustfmt::skip]
let int32_cases: &[(i32, u64)] = &[
(0, 0),
(42, 42),
(-7, (-7_i64) as u64),
];
for &(v, wire) in int32_cases {
let mut c = Carrier::default();
const E: Extension<Int32> = Extension::new(1, CARRIER);
c.set_extension(&E, v);
assert_eq!(
c.unknown.iter().next().unwrap().data,
UnknownFieldData::Varint(wire),
"int32 {v}"
);
assert_eq!(c.extension(&E), Some(v), "int32 {v}");
}
#[rustfmt::skip]
let sint32_cases: &[(i32, u64)] = &[
(0, 0),
(-1, 1),
(1, 2),
(-2, 3),
];
for &(v, wire) in sint32_cases {
let mut c = Carrier::default();
const E: Extension<Sint32> = Extension::new(1, CARRIER);
c.set_extension(&E, v);
assert_eq!(
c.unknown.iter().next().unwrap().data,
UnknownFieldData::Varint(wire),
"sint32 {v}"
);
assert_eq!(c.extension(&E), Some(v), "sint32 {v}");
}
let mut c = Carrier::default();
const SF32: Extension<Sfixed32> = Extension::new(1, CARRIER);
c.set_extension(&SF32, -1);
assert_eq!(
c.unknown.iter().next().unwrap().data,
UnknownFieldData::Fixed32(u32::MAX)
);
assert_eq!(c.extension(&SF32), Some(-1));
const B: Extension<Bool> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(varint(1, 7));
assert_eq!(c.extension(&B), Some(true));
let mut c = Carrier::default();
c.set_extension(&B, true);
assert_eq!(
c.unknown.iter().next().unwrap().data,
UnknownFieldData::Varint(1)
);
const F: Extension<Float> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.set_extension(&F, 1.5_f32);
assert_eq!(c.extension(&F), Some(1.5_f32));
const D: Extension<Double> = Extension::new(2, CARRIER);
c.set_extension(&D, -0.25_f64);
assert_eq!(c.extension(&D), Some(-0.25_f64));
const U: Extension<Uint32> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(varint(1, 0x1_0000_002A)); assert_eq!(c.extension(&U), Some(0x0000_002A));
}
#[test]
fn singular_scalar_last_wins() {
const E: Extension<Int32> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(varint(1, 5));
c.unknown.push(varint(1, 7));
assert_eq!(c.extension(&E), Some(7));
}
#[test]
fn singular_scalar_wrong_wire_type_skipped() {
const E: Extension<Int32> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(varint(1, 5));
c.unknown.push(ld(1, vec![0x00])); assert_eq!(c.extension(&E), Some(5));
let mut c = Carrier::default();
c.unknown.push(ld(1, vec![0x00]));
assert_eq!(c.extension(&E), None);
}
#[test]
fn singular_absent_returns_none() {
const E: Extension<Int32> = Extension::new(1, CARRIER);
let c = Carrier::default();
assert_eq!(c.extension(&E), None);
let mut c = Carrier::default();
c.unknown.push(varint(2, 5));
assert_eq!(c.extension(&E), None);
}
#[test]
fn explicit_presence_with_zero_value() {
const E: Extension<Int32> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
assert!(!c.has_extension(&E));
c.set_extension(&E, 0);
assert!(c.has_extension(&E));
assert_eq!(c.extension(&E), Some(0));
}
#[test]
fn set_clears_prior_occurrences() {
const E: Extension<Int32> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.set_extension(&E, 5);
c.set_extension(&E, 7);
assert_eq!(c.unknown.iter().filter(|f| f.number == 1).count(), 1);
assert_eq!(c.extension(&E), Some(7));
}
#[test]
fn set_preserves_other_fields() {
const E: Extension<Int32> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(varint(2, 99));
c.set_extension(&E, 5);
assert_eq!(c.unknown.len(), 2);
assert!(c.unknown.iter().any(|f| f.number == 2));
}
#[test]
fn clear_extension() {
const E: Extension<Int32> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.set_extension(&E, 5);
c.unknown.push(varint(2, 99));
c.clear_extension(&E);
assert!(!c.has_extension(&E));
assert_eq!(c.unknown.len(), 1); }
#[test]
fn string_roundtrip() {
const E: Extension<StringCodec> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.set_extension(&E, "hello".to_string());
assert_eq!(c.extension(&E), Some("hello".to_string()));
}
#[test]
fn string_invalid_utf8_is_none() {
const E: Extension<StringCodec> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(ld(1, vec![0xFF, 0xFE]));
assert_eq!(c.extension(&E), None);
}
#[test]
fn bytes_roundtrip() {
const E: Extension<BytesCodec> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.set_extension(&E, vec![0xDE, 0xAD, 0xBE, 0xEF]);
assert_eq!(c.extension(&E), Some(vec![0xDE, 0xAD, 0xBE, 0xEF]));
}
#[test]
fn repeated_scalar_unpacked() {
const E: Extension<Repeated<Int32>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(varint(1, 1));
c.unknown.push(varint(1, 2));
c.unknown.push(varint(1, 3));
assert_eq!(c.extension(&E), vec![1, 2, 3]);
}
#[test]
fn repeated_scalar_packed() {
const E: Extension<Repeated<Int32>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(ld(1, vec![0x01, 0x02, 0xAC, 0x02]));
assert_eq!(c.extension(&E), vec![1, 2, 300]);
}
#[test]
fn repeated_scalar_mixed_packed_unpacked() {
const E: Extension<Repeated<Int32>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(varint(1, 1));
c.unknown.push(ld(1, vec![0x02, 0x03])); c.unknown.push(varint(1, 4));
assert_eq!(c.extension(&E), vec![1, 2, 3, 4]);
}
#[test]
fn repeated_scalar_fixed32_packed() {
const E: Extension<Repeated<Fixed32>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(ld(1, vec![1, 0, 0, 0, 2, 0, 0, 0]));
assert_eq!(c.extension(&E), vec![1_u32, 2_u32]);
}
#[test]
fn repeated_scalar_wrong_wire_type_skipped() {
const E: Extension<Repeated<Int32>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(varint(1, 1));
c.unknown.push(fixed32(1, 0xDEAD)); c.unknown.push(varint(1, 2));
assert_eq!(c.extension(&E), vec![1, 2]);
}
#[test]
fn repeated_scalar_set_roundtrip() {
const E: Extension<Repeated<Sint32>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.set_extension(&E, vec![-1, 0, 1, -100]);
assert_eq!(c.extension(&E), vec![-1, 0, 1, -100]);
assert_eq!(c.unknown.len(), 4);
}
#[test]
fn repeated_string() {
const E: Extension<Repeated<StringCodec>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.set_extension(&E, vec!["a".to_string(), "b".to_string()]);
assert_eq!(c.extension(&E), vec!["a".to_string(), "b".to_string()]);
}
#[test]
fn repeated_string_malformed_element_skipped() {
const E: Extension<Repeated<StringCodec>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(ld(1, b"ok".to_vec()));
c.unknown.push(ld(1, vec![0xFF])); c.unknown.push(ld(1, b"also ok".to_vec()));
assert_eq!(
c.extension(&E),
vec!["ok".to_string(), "also ok".to_string()]
);
}
#[test]
fn repeated_empty_set_not_present() {
const E: Extension<Repeated<Int32>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.set_extension(&E, vec![]);
assert!(!c.has_extension(&E));
assert_eq!(c.extension(&E), Vec::<i32>::new());
}
#[test]
fn packed_repeated_set_twice_one_record() {
const E: Extension<PackedRepeated<Int32>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.set_extension(&E, vec![1, 2]);
c.set_extension(&E, vec![3, 4, 5]);
assert_eq!(c.unknown.len(), 1);
assert_eq!(c.extension(&E), vec![3, 4, 5]);
}
#[test]
fn packed_repeated_encode_one_record() {
const E: Extension<PackedRepeated<Int32>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.set_extension(&E, vec![1, 2, 300]);
assert_eq!(c.unknown.len(), 1);
assert_eq!(
c.unknown.iter().next().unwrap().data,
UnknownFieldData::LengthDelimited(vec![0x01, 0x02, 0xAC, 0x02])
);
assert_eq!(c.extension(&E), vec![1, 2, 300]);
}
#[test]
fn packed_repeated_decode_accepts_unpacked() {
const E: Extension<PackedRepeated<Int32>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(varint(1, 1));
c.unknown.push(varint(1, 2));
assert_eq!(c.extension(&E), vec![1, 2]);
}
#[test]
fn packed_repeated_empty_not_present() {
const E: Extension<PackedRepeated<Int32>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.set_extension(&E, vec![]);
assert!(!c.has_extension(&E));
}
#[test]
fn packed_repeated_fixed64_roundtrip() {
const E: Extension<PackedRepeated<Sfixed64>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.set_extension(&E, vec![-1_i64, 0, i64::MAX]);
assert_eq!(c.unknown.len(), 1);
assert_eq!(c.extension(&E), vec![-1_i64, 0, i64::MAX]);
}
#[derive(Clone, Default, PartialEq, Debug)]
struct TestMsg {
a: i32,
b: i32,
unknown: UnknownFields,
}
unsafe impl crate::DefaultInstance for TestMsg {
fn default_instance() -> &'static Self {
static INST: crate::__private::OnceBox<TestMsg> = crate::__private::OnceBox::new();
INST.get_or_init(|| alloc::boxed::Box::new(TestMsg::default()))
}
}
impl crate::Message for TestMsg {
fn compute_size(&self) -> u32 {
let mut n = 0;
if self.a != 0 {
n += 1 + crate::encoding::varint_len(self.a as i64 as u64);
}
if self.b != 0 {
n += 1 + crate::encoding::varint_len(self.b as i64 as u64);
}
n += self.unknown.encoded_len();
n as u32
}
fn write_to(&self, buf: &mut impl bytes::BufMut) {
use crate::encoding::encode_varint;
if self.a != 0 {
encode_varint(1 << 3, buf);
encode_varint(self.a as i64 as u64, buf);
}
if self.b != 0 {
encode_varint(2 << 3, buf);
encode_varint(self.b as i64 as u64, buf);
}
self.unknown.write_to(buf);
}
fn merge_field(
&mut self,
tag: crate::encoding::Tag,
buf: &mut impl bytes::Buf,
_depth: u32,
) -> Result<(), crate::DecodeError> {
match tag.field_number() {
1 => self.a = crate::types::decode_int32(buf)?,
2 => self.b = crate::types::decode_int32(buf)?,
_ => crate::encoding::skip_field(tag, buf)?,
}
Ok(())
}
fn cached_size(&self) -> u32 {
self.compute_size()
}
fn clear(&mut self) {
*self = Self::default();
}
}
#[test]
fn message_single_record() {
const E: Extension<MessageCodec<TestMsg>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(ld(1, vec![0x08, 0x05]));
let got = c.extension(&E).expect("decoded");
assert_eq!(got.a, 5);
assert_eq!(got.b, 0);
}
#[test]
fn message_split_records_merge() {
const E: Extension<MessageCodec<TestMsg>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(ld(1, vec![0x08, 0x05])); c.unknown.push(ld(1, vec![0x10, 0x07])); let got = c.extension(&E).expect("decoded");
assert_eq!(got.a, 5);
assert_eq!(got.b, 7);
}
#[test]
fn message_absent_is_none() {
const E: Extension<MessageCodec<TestMsg>> = Extension::new(1, CARRIER);
let c = Carrier::default();
assert!(c.extension(&E).is_none());
}
#[test]
fn message_malformed_is_none() {
const E: Extension<MessageCodec<TestMsg>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(ld(1, vec![0x08, 0x80]));
assert!(c.extension(&E).is_none());
}
#[test]
fn message_set_get_roundtrip() {
const E: Extension<MessageCodec<TestMsg>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.set_extension(
&E,
TestMsg {
a: 3,
b: -1,
unknown: UnknownFields::new(),
},
);
let got = c.extension(&E).expect("decoded");
assert_eq!(got.a, 3);
assert_eq!(got.b, -1);
}
#[test]
fn message_unknown_fields_survive_roundtrip() {
const E: Extension<MessageCodec<TestMsg>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(ld(1, vec![0x08, 0x05, 0x98, 0x06, 0x2A]));
let got = c.extension(&E).expect("decoded");
assert_eq!(got.a, 5);
let mut inner_unk = UnknownFields::new();
inner_unk.push(varint(99, 42));
let msg = TestMsg {
a: 5,
b: 0,
unknown: inner_unk,
};
let mut c = Carrier::default();
c.set_extension(&E, msg);
use crate::Message;
let mut roundtrip = TestMsg::default();
if let UnknownFieldData::LengthDelimited(bytes) = &c.unknown.iter().next().unwrap().data {
assert!(bytes.len() > 2, "payload includes unknown-field bytes");
assert!(bytes.windows(3).any(|w| w == [0x98, 0x06, 0x2A]));
roundtrip.merge_from_slice(bytes).unwrap();
assert_eq!(roundtrip.a, 5);
} else {
panic!("expected LengthDelimited");
}
}
#[test]
fn repeated_message() {
const E: Extension<Repeated<MessageCodec<TestMsg>>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(ld(1, vec![0x08, 0x01])); c.unknown.push(ld(1, vec![0x08, 0x02])); let got = c.extension(&E);
assert_eq!(got.len(), 2);
assert_eq!(got[0].a, 1);
assert_eq!(got[1].a, 2);
}
#[test]
fn repeated_message_malformed_element_skipped() {
const E: Extension<Repeated<MessageCodec<TestMsg>>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(ld(1, vec![0x08, 0x01])); c.unknown.push(ld(1, vec![0x08, 0x80])); c.unknown.push(ld(1, vec![0x08, 0x03])); let got = c.extension(&E);
assert_eq!(got.len(), 2);
assert_eq!(got[0].a, 1);
assert_eq!(got[1].a, 3);
}
fn inner_a(v: u64) -> UnknownFields {
let mut inner = UnknownFields::new();
inner.push(varint(1, v));
inner
}
#[test]
fn group_single_record() {
const E: Extension<GroupCodec<TestMsg>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(group(1, inner_a(5)));
let got = c.extension(&E).expect("decoded");
assert_eq!(got.a, 5);
assert_eq!(got.b, 0);
}
#[test]
fn group_split_records_merge() {
const E: Extension<GroupCodec<TestMsg>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(group(1, inner_a(5)));
let mut inner_b = UnknownFields::new();
inner_b.push(varint(2, 7));
c.unknown.push(group(1, inner_b));
let got = c.extension(&E).expect("decoded");
assert_eq!(got.a, 5);
assert_eq!(got.b, 7);
}
#[test]
fn group_absent_is_none() {
const E: Extension<GroupCodec<TestMsg>> = Extension::new(1, CARRIER);
let c = Carrier::default();
assert!(c.extension(&E).is_none());
}
#[test]
fn group_wrong_wire_type_is_none() {
const E: Extension<GroupCodec<TestMsg>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(ld(1, vec![0x08, 0x05]));
assert!(c.extension(&E).is_none());
}
#[test]
fn group_set_get_roundtrip() {
const E: Extension<GroupCodec<TestMsg>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.set_extension(
&E,
TestMsg {
a: 3,
b: -1,
unknown: UnknownFields::new(),
},
);
match &c.unknown.iter().next().unwrap().data {
UnknownFieldData::Group(_) => {}
other => panic!("expected Group, got {other:?}"),
}
let got = c.extension(&E).expect("decoded");
assert_eq!(got.a, 3);
assert_eq!(got.b, -1);
}
#[test]
fn group_set_empty_message() {
const E: Extension<GroupCodec<TestMsg>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.set_extension(&E, TestMsg::default());
assert!(c.has_extension(&E));
let got = c.extension(&E).expect("decoded");
assert_eq!(got, TestMsg::default());
}
#[test]
fn repeated_group() {
const E: Extension<Repeated<GroupCodec<TestMsg>>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(group(1, inner_a(1)));
c.unknown.push(group(1, inner_a(2)));
let got = c.extension(&E);
assert_eq!(got.len(), 2);
assert_eq!(got[0].a, 1);
assert_eq!(got[1].a, 2);
}
#[test]
fn repeated_group_set_get_roundtrip() {
const E: Extension<Repeated<GroupCodec<TestMsg>>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
let msgs = vec![
TestMsg {
a: 10,
b: 0,
unknown: UnknownFields::new(),
},
TestMsg {
a: 0,
b: 20,
unknown: UnknownFields::new(),
},
];
c.set_extension(&E, msgs.clone());
assert_eq!(c.unknown.len(), 2);
for f in c.unknown.iter() {
assert!(matches!(f.data, UnknownFieldData::Group(_)));
}
assert_eq!(c.extension(&E), msgs);
}
#[test]
fn scalar_roundtrip_table() {
macro_rules! rt {
($codec:ty, $v:expr) => {{
const E: Extension<$codec> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.set_extension(&E, $v);
assert_eq!(c.extension(&E), Some($v), stringify!($codec));
}};
}
rt!(Int64, -1_i64);
rt!(Uint64, u64::MAX);
rt!(Sint64, i64::MIN);
rt!(Fixed32, u32::MAX);
rt!(Fixed64, u64::MAX);
rt!(Sfixed64, -1_i64);
rt!(EnumI32, 42);
}
#[test]
fn fixed64_packed_decode_partial_tail_ignored() {
const E: Extension<Repeated<Fixed64>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
let mut payload = 42_u64.to_le_bytes().to_vec();
payload.extend_from_slice(&[0x01, 0x02, 0x03]);
c.unknown.push(ld(1, payload));
assert_eq!(c.extension(&E), vec![42_u64]);
}
#[test]
fn varint_packed_decode_malformed_tail_stops() {
const E: Extension<Repeated<Int32>> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(ld(1, vec![0x01, 0x80]));
assert_eq!(c.extension(&E), vec![1]);
}
#[test]
fn fixed64_wrong_wire_type() {
const E: Extension<Fixed64> = Extension::new(1, CARRIER);
let mut c = Carrier::default();
c.unknown.push(fixed64(1, 0xDEADBEEF));
assert_eq!(c.extension(&E), Some(0xDEADBEEF_u64));
let mut c = Carrier::default();
c.unknown.push(fixed32(1, 0)); assert_eq!(c.extension(&E), None);
}
}