use std::any::Any;
use std::fmt;
use std::sync::OnceLock;
use buffa::Message;
use buffa::view::{MessageView, OwnedView};
use bytes::Bytes;
use serde::Serialize;
use serde::de::DeserializeOwned;
use crate::codec::{CodecFormat, decode_json, decode_proto, encode_json, encode_proto};
use crate::error::ConnectError;
pub trait AnyMessage: Send + Sync + 'static {
fn as_any(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
fn into_any(self: Box<Self>) -> Box<dyn Any>;
fn encode(&self, format: CodecFormat) -> Result<Bytes, ConnectError>;
fn type_name(&self) -> &'static str {
std::any::type_name::<Self>()
}
}
impl<T> AnyMessage for T
where
T: Message + Serialize + 'static,
{
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn into_any(self: Box<Self>) -> Box<dyn Any> {
self
}
fn encode(&self, format: CodecFormat) -> Result<Bytes, ConnectError> {
match format {
CodecFormat::Proto => encode_proto(self),
CodecFormat::Json => encode_json(self),
}
}
}
pub struct Payload {
bytes: Bytes,
format: CodecFormat,
decoded: OnceLock<Box<dyn AnyMessage>>,
replaced: Option<Box<dyn AnyMessage>>,
}
impl Payload {
pub fn new(bytes: Bytes, format: CodecFormat) -> Self {
Self {
bytes,
format,
decoded: OnceLock::new(),
replaced: None,
}
}
pub fn bytes(&self) -> &Bytes {
&self.bytes
}
pub fn format(&self) -> CodecFormat {
self.format
}
pub fn message<M>(&self) -> Result<&M, ConnectError>
where
M: Message + Serialize + DeserializeOwned + 'static,
{
if let Some(replaced) = &self.replaced {
return replaced.as_any().downcast_ref::<M>().ok_or_else(|| {
ConnectError::internal(format!(
"payload replacement is a {}, not a {}",
replaced.type_name(),
std::any::type_name::<M>()
))
});
}
if self.decoded.get().is_none() {
let m: M = match self.format {
CodecFormat::Proto => decode_proto(&self.bytes)?,
CodecFormat::Json => decode_json(&self.bytes)?,
};
let _ = self.decoded.set(Box::new(m));
}
let cached = self.decoded.get().expect("decoded cell populated above");
cached.as_any().downcast_ref::<M>().ok_or_else(|| {
ConnectError::internal(format!(
"payload was previously decoded as a {}, not a {}",
cached.type_name(),
std::any::type_name::<M>()
))
})
}
pub fn take_message<M>(self) -> Result<M, ConnectError>
where
M: Message + DeserializeOwned + 'static,
{
if let Some(replaced) = self.replaced {
let type_name = replaced.type_name();
return replaced
.into_any()
.downcast::<M>()
.map(|b| *b)
.map_err(|_| {
ConnectError::internal(format!(
"payload replacement is a {}, not a {}",
type_name,
std::any::type_name::<M>()
))
});
}
if let Some(cached) = self.decoded.into_inner() {
let type_name = cached.type_name();
return cached.into_any().downcast::<M>().map(|b| *b).map_err(|_| {
ConnectError::internal(format!(
"payload was previously decoded as a {}, not a {}",
type_name,
std::any::type_name::<M>()
))
});
}
match self.format {
CodecFormat::Proto => decode_proto(&self.bytes),
CodecFormat::Json => decode_json(&self.bytes),
}
}
pub fn view<V>(&self) -> Result<OwnedView<V>, ConnectError>
where
V: MessageView<'static>,
{
if let Some(replaced) = &self.replaced {
let bytes = replaced.encode(CodecFormat::Proto)?;
return OwnedView::decode(bytes).map_err(|e| {
ConnectError::internal(format!("failed to decode replacement as view: {e}"))
});
}
if self.format != CodecFormat::Proto {
return Err(ConnectError::internal(
"Payload::view requires a proto-encoded wire; use Payload::message for JSON",
));
}
OwnedView::decode(self.bytes.clone()).map_err(|e| {
ConnectError::invalid_argument(format!("failed to decode payload as view: {e}"))
})
}
pub fn set_message<M>(&mut self, message: M)
where
M: AnyMessage,
{
self.replaced = Some(Box::new(message));
if self.decoded.get().is_some() {
self.decoded = OnceLock::new();
}
}
pub fn encoded(&self) -> Result<Bytes, ConnectError> {
match &self.replaced {
Some(r) => r.encode(self.format),
None => Ok(self.bytes.clone()),
}
}
}
impl fmt::Debug for Payload {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Payload")
.field("len", &self.bytes.len())
.field("format", &self.format)
.field("decoded", &self.decoded.get().is_some())
.field("replaced", &self.replaced.is_some())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use buffa_types::google::protobuf::__buffa::view::StringValueView;
use buffa_types::google::protobuf::StringValue;
fn proto_payload(value: &str) -> Payload {
let msg = StringValue {
value: value.into(),
..Default::default()
};
Payload::new(encode_proto(&msg).unwrap(), CodecFormat::Proto)
}
#[test]
fn message_decodes_and_caches() {
let p = proto_payload("hello");
let m1: &StringValue = p.message().unwrap();
assert_eq!(m1.value, "hello");
let m2: &StringValue = p.message().unwrap();
assert!(std::ptr::eq(m1, m2), "second call should hit the cache");
}
#[test]
fn message_decodes_json() {
let bytes = encode_json(&StringValue {
value: "json".into(),
..Default::default()
})
.unwrap();
let p = Payload::new(bytes, CodecFormat::Json);
let m: &StringValue = p.message().unwrap();
assert_eq!(m.value, "json");
}
#[test]
fn view_zero_copy_proto() {
let p = proto_payload("zero copy");
let v = p.view::<StringValueView>().unwrap();
assert_eq!(v.value, "zero copy");
let value_ptr = v.value.as_ptr() as usize;
let bytes_range =
p.bytes().as_ptr() as usize..p.bytes().as_ptr() as usize + p.bytes().len();
assert!(
bytes_range.contains(&value_ptr),
"view should borrow from the payload's wire bytes"
);
}
#[test]
fn view_errors_on_json() {
let bytes = encode_json(&StringValue {
value: "x".into(),
..Default::default()
})
.unwrap();
let p = Payload::new(bytes, CodecFormat::Json);
let err = p.view::<StringValueView>().unwrap_err();
assert!(
err.message
.as_deref()
.unwrap_or_default()
.contains("requires a proto-encoded wire"),
"{err:?}"
);
}
#[test]
fn set_message_round_trips() {
let mut p = proto_payload("before");
p.set_message(StringValue {
value: "after".into(),
..Default::default()
});
let m: &StringValue = p.message().unwrap();
assert_eq!(m.value, "after");
let v = p.view::<StringValueView>().unwrap();
assert_eq!(v.value, "after");
let encoded = p.encoded().unwrap();
let rt: StringValue = decode_proto(&encoded).unwrap();
assert_eq!(rt.value, "after");
let orig: StringValue = decode_proto(p.bytes()).unwrap();
assert_eq!(orig.value, "before");
}
#[test]
fn set_message_round_trips_json_format() {
let bytes = encode_json(&StringValue {
value: "before".into(),
..Default::default()
})
.unwrap();
let mut p = Payload::new(bytes, CodecFormat::Json);
p.set_message(StringValue {
value: "after".into(),
..Default::default()
});
let encoded = p.encoded().unwrap();
let rt: StringValue = decode_json(&encoded).unwrap();
assert_eq!(rt.value, "after");
}
#[test]
fn encoded_without_replacement_returns_original() {
let p = proto_payload("x");
assert!(std::ptr::eq(
p.encoded().unwrap().as_ptr(),
p.bytes().as_ptr()
));
}
#[test]
fn message_wrong_type_errors() {
use buffa_types::google::protobuf::Int32Value;
let p = proto_payload("x");
let _: &StringValue = p.message().unwrap();
let err = p.message::<Int32Value>().unwrap_err();
let msg = err.message.as_deref().unwrap_or_default();
assert!(msg.contains("previously decoded as a"), "{err:?}");
assert!(msg.contains("StringValue"), "{err:?}");
assert!(msg.contains("Int32Value"), "{err:?}");
}
#[test]
fn message_decode_error_is_invalid_argument() {
use crate::ErrorCode;
let p = Payload::new(Bytes::from_static(&[0xff, 0xff, 0xff]), CodecFormat::Proto);
let err = p.message::<StringValue>().unwrap_err();
assert_eq!(err.code, ErrorCode::InvalidArgument, "{err:?}");
}
#[test]
fn message_replacement_wrong_type_errors() {
use buffa_types::google::protobuf::Int32Value;
let mut p = proto_payload("x");
p.set_message(Int32Value {
value: 7,
..Default::default()
});
let err = p.message::<StringValue>().unwrap_err();
let msg = err.message.as_deref().unwrap_or_default();
assert!(msg.contains("replacement is a"), "{err:?}");
assert!(msg.contains("Int32Value"), "{err:?}");
assert!(msg.contains("StringValue"), "{err:?}");
}
#[test]
fn view_replaced_json_format_payload() {
let bytes = encode_json(&StringValue {
value: "before".into(),
..Default::default()
})
.unwrap();
let mut p = Payload::new(bytes, CodecFormat::Json);
p.set_message(StringValue {
value: "after".into(),
..Default::default()
});
let v = p.view::<StringValueView>().unwrap();
assert_eq!(v.value, "after");
}
#[test]
fn set_message_twice_supersedes() {
let mut p = proto_payload("original");
p.set_message(StringValue {
value: "first".into(),
..Default::default()
});
p.set_message(StringValue {
value: "second".into(),
..Default::default()
});
let m: &StringValue = p.message().unwrap();
assert_eq!(m.value, "second");
}
#[test]
fn take_message_decodes_fresh_when_no_cache() {
let p = proto_payload("fresh");
let m: StringValue = p.take_message().unwrap();
assert_eq!(m.value, "fresh");
}
#[test]
fn take_message_reuses_cache() {
let p = proto_payload("cached");
let _ = p.message::<StringValue>().unwrap();
let m: StringValue = p.take_message().unwrap();
assert_eq!(m.value, "cached");
}
#[test]
fn take_message_returns_replacement() {
let mut p = Payload::new(Bytes::from_static(&[0xff, 0xff, 0xff]), CodecFormat::Proto);
p.set_message(StringValue {
value: "replaced".into(),
..Default::default()
});
let m: StringValue = p.take_message().unwrap();
assert_eq!(m.value, "replaced");
}
#[test]
fn take_message_wrong_cached_type_errors() {
use buffa_types::google::protobuf::Int32Value;
let p = proto_payload("x");
let _: &StringValue = p.message().unwrap();
let err = p.take_message::<Int32Value>().unwrap_err();
let msg = err.message.as_deref().unwrap_or_default();
assert!(msg.contains("previously decoded as a"), "{err:?}");
assert!(msg.contains("StringValue"), "{err:?}");
assert!(msg.contains("Int32Value"), "{err:?}");
}
#[test]
fn take_message_wrong_replacement_type_errors() {
use buffa_types::google::protobuf::Int32Value;
let mut p = proto_payload("x");
p.set_message(Int32Value {
value: 7,
..Default::default()
});
let err = p.take_message::<StringValue>().unwrap_err();
let msg = err.message.as_deref().unwrap_or_default();
assert!(msg.contains("replacement is a"), "{err:?}");
}
#[test]
fn take_message_decode_error_is_invalid_argument() {
use crate::ErrorCode;
let p = Payload::new(Bytes::from_static(&[0xff, 0xff, 0xff]), CodecFormat::Proto);
let err = p.take_message::<StringValue>().unwrap_err();
assert_eq!(err.code, ErrorCode::InvalidArgument, "{err:?}");
}
#[test]
fn payload_debug_redacts_body() {
let p = proto_payload("secret");
let dbg = format!("{p:?}");
assert!(!dbg.contains("secret"), "Debug must not leak body: {dbg}");
assert!(dbg.contains("Proto"), "{dbg}");
}
#[test]
fn payload_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Payload>();
assert_send_sync::<Box<dyn AnyMessage>>();
}
#[test]
fn message_concurrent_same_type() {
let p = proto_payload("race");
std::thread::scope(|s| {
let handles: Vec<_> = (0..16)
.map(|_| {
let p = &p;
s.spawn(move || p.message::<StringValue>().unwrap() as *const _ as usize)
})
.collect();
let addrs: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
assert!(
addrs.iter().all(|&a| a == addrs[0]),
"all callers should observe the same cached value"
);
});
}
}