use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use bytes::Bytes;
use ember_protocol::Frame;
use prost_reflect::{
DescriptorPool, DynamicMessage, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage,
};
use thiserror::Error;
const MAX_DESCRIPTOR_BYTES: usize = 10 * 1024 * 1024;
const MAX_FIELD_PATH_DEPTH: usize = 16;
const MAX_PROTO_VALUE_BYTES: usize = 64 * 1024 * 1024;
const MAX_SCHEMAS: usize = 1024;
#[derive(Debug, Error)]
pub enum SchemaError {
#[error("invalid descriptor: {0}")]
InvalidDescriptor(String),
#[error("unknown message type: {0}")]
UnknownMessageType(String),
#[error("validation failed: {0}")]
ValidationFailed(String),
#[error("schema already registered: {0}")]
AlreadyExists(String),
#[error("field not found: {0}")]
FieldNotFound(String),
#[error("descriptor too large: {0} bytes (max {1})")]
DescriptorTooLarge(usize, usize),
#[error("field path too deep: {0} segments (max {1})")]
PathTooDeep(usize, usize),
#[error("proto value too large: {0} bytes (max {1})")]
ValueTooLarge(usize, usize),
#[error("schema limit reached: {0} schemas (max {1})")]
TooManySchemas(usize, usize),
}
struct RegisteredSchema {
descriptor_bytes: Bytes,
#[cfg_attr(not(test), allow(dead_code))]
pool: DescriptorPool,
message_types: Vec<String>,
}
pub struct SchemaRegistry {
schemas: HashMap<String, RegisteredSchema>,
message_cache: HashMap<String, MessageDescriptor>,
}
pub type SharedSchemaRegistry = Arc<RwLock<SchemaRegistry>>;
impl SchemaRegistry {
pub fn new() -> Self {
Self {
schemas: HashMap::new(),
message_cache: HashMap::new(),
}
}
pub fn shared() -> SharedSchemaRegistry {
Arc::new(RwLock::new(Self::new()))
}
pub fn register(
&mut self,
name: String,
descriptor_bytes: Bytes,
) -> Result<Vec<String>, SchemaError> {
if self.schemas.contains_key(&name) {
return Err(SchemaError::AlreadyExists(name));
}
if self.schemas.len() >= MAX_SCHEMAS {
return Err(SchemaError::TooManySchemas(self.schemas.len(), MAX_SCHEMAS));
}
if descriptor_bytes.len() > MAX_DESCRIPTOR_BYTES {
return Err(SchemaError::DescriptorTooLarge(
descriptor_bytes.len(),
MAX_DESCRIPTOR_BYTES,
));
}
let pool = DescriptorPool::decode(descriptor_bytes.as_ref())
.map_err(|e| SchemaError::InvalidDescriptor(e.to_string()))?;
let message_types: Vec<String> = pool
.all_messages()
.map(|m| m.full_name().to_owned())
.collect();
if message_types.is_empty() {
return Err(SchemaError::InvalidDescriptor(
"no message types found in descriptor".into(),
));
}
for desc in pool.all_messages() {
self.message_cache.insert(desc.full_name().to_owned(), desc);
}
self.schemas.insert(
name,
RegisteredSchema {
descriptor_bytes,
pool,
message_types: message_types.clone(),
},
);
Ok(message_types)
}
pub fn validate(&self, message_type: &str, data: &[u8]) -> Result<(), SchemaError> {
if data.len() > MAX_PROTO_VALUE_BYTES {
return Err(SchemaError::ValueTooLarge(
data.len(),
MAX_PROTO_VALUE_BYTES,
));
}
let descriptor = self.find_message(message_type)?;
DynamicMessage::decode(descriptor, data)
.map_err(|e| SchemaError::ValidationFailed(e.to_string()))?;
Ok(())
}
pub fn schema_names(&self) -> Vec<String> {
let mut names: Vec<String> = self.schemas.keys().cloned().collect();
names.sort();
names
}
pub fn describe(&self, name: &str) -> Option<Vec<String>> {
self.schemas.get(name).map(|s| s.message_types.clone())
}
pub fn iter_schemas(&self) -> impl Iterator<Item = (&str, &Bytes)> {
self.schemas
.iter()
.map(|(name, schema)| (name.as_str(), &schema.descriptor_bytes))
}
pub fn restore(&mut self, name: String, descriptor_bytes: Bytes) {
if self.schemas.contains_key(&name) {
return;
}
let pool = match DescriptorPool::decode(descriptor_bytes.as_ref()) {
Ok(p) => p,
Err(e) => {
tracing::warn!(schema = %name, "failed to restore schema: {e}");
return;
}
};
let message_types: Vec<String> = pool
.all_messages()
.map(|m| m.full_name().to_owned())
.collect();
for desc in pool.all_messages() {
self.message_cache.insert(desc.full_name().to_owned(), desc);
}
self.schemas.insert(
name,
RegisteredSchema {
descriptor_bytes,
pool,
message_types,
},
);
}
pub fn get_field(
&self,
type_name: &str,
data: &[u8],
field_path: &str,
) -> Result<Frame, SchemaError> {
let descriptor = self.find_message(type_name)?;
let msg = DynamicMessage::decode(descriptor, data)
.map_err(|e| SchemaError::ValidationFailed(e.to_string()))?;
let (value, field_desc) = resolve_field_path(&msg, field_path)?;
value_to_frame(&value, &field_desc)
}
pub fn set_field(
&self,
type_name: &str,
data: &[u8],
field_path: &str,
raw_value: &str,
) -> Result<Bytes, SchemaError> {
let descriptor = self.find_message(type_name)?;
let mut msg = DynamicMessage::decode(descriptor, data)
.map_err(|e| SchemaError::ValidationFailed(e.to_string()))?;
let (parent, leaf_name, leaf_desc) = resolve_field_path_mut(&mut msg, field_path)?;
let parsed = parse_field_value(raw_value, &leaf_desc)?;
parent.set_field_by_name(&leaf_name, parsed);
let mut buf = Vec::new();
use prost_reflect::prost::Message;
msg.encode(&mut buf)
.map_err(|e| SchemaError::ValidationFailed(format!("re-encode failed: {e}")))?;
Ok(Bytes::from(buf))
}
pub fn clear_field(
&self,
type_name: &str,
data: &[u8],
field_path: &str,
) -> Result<Bytes, SchemaError> {
let descriptor = self.find_message(type_name)?;
let mut msg = DynamicMessage::decode(descriptor, data)
.map_err(|e| SchemaError::ValidationFailed(e.to_string()))?;
let (parent, leaf_name, _leaf_desc) = resolve_field_path_mut(&mut msg, field_path)?;
parent.clear_field_by_name(&leaf_name);
let mut buf = Vec::new();
use prost_reflect::prost::Message;
msg.encode(&mut buf)
.map_err(|e| SchemaError::ValidationFailed(format!("re-encode failed: {e}")))?;
Ok(Bytes::from(buf))
}
fn find_message(&self, message_type: &str) -> Result<MessageDescriptor, SchemaError> {
self.message_cache
.get(message_type)
.cloned()
.ok_or_else(|| SchemaError::UnknownMessageType(message_type.to_owned()))
}
}
fn resolve_field_path(
msg: &DynamicMessage,
path: &str,
) -> Result<(prost_reflect::Value, FieldDescriptor), SchemaError> {
if path.is_empty() {
return Err(SchemaError::FieldNotFound("empty field path".into()));
}
let segments: Vec<&str> = path.split('.').collect();
for seg in &segments {
if seg.is_empty() {
return Err(SchemaError::FieldNotFound(format!(
"invalid field path '{path}': empty segment"
)));
}
}
if segments.len() > MAX_FIELD_PATH_DEPTH {
return Err(SchemaError::PathTooDeep(
segments.len(),
MAX_FIELD_PATH_DEPTH,
));
}
if segments.len() == 1 {
let field_desc = msg
.descriptor()
.get_field_by_name(segments[0])
.ok_or_else(|| SchemaError::FieldNotFound(segments[0].to_string()))?;
let value = msg.get_field(&field_desc).into_owned();
return Ok((value, field_desc));
}
let mut current_msg = msg.clone();
for (i, segment) in segments.iter().enumerate() {
let field_desc = current_msg
.descriptor()
.get_field_by_name(segment)
.ok_or_else(|| SchemaError::FieldNotFound(segment.to_string()))?;
let value = current_msg.get_field(&field_desc).into_owned();
if i == segments.len() - 1 {
return Ok((value, field_desc));
}
match value {
prost_reflect::Value::Message(nested) => {
current_msg = nested;
}
_ => {
return Err(SchemaError::FieldNotFound(format!(
"'{segment}' is not a message field, cannot traverse further"
)));
}
}
}
Err(SchemaError::FieldNotFound(
"failed to resolve field path".into(),
))
}
fn value_to_frame(
value: &prost_reflect::Value,
field_desc: &FieldDescriptor,
) -> Result<Frame, SchemaError> {
if field_desc.is_list() || field_desc.is_map() {
return Err(SchemaError::ValidationFailed(
"use PROTO.GET for repeated/map fields".into(),
));
}
match value {
prost_reflect::Value::String(s) => Ok(Frame::Bulk(Bytes::from(s.clone()))),
prost_reflect::Value::Bytes(b) => Ok(Frame::Bulk(b.clone())),
prost_reflect::Value::I32(n) => Ok(Frame::Integer(i64::from(*n))),
prost_reflect::Value::I64(n) => Ok(Frame::Integer(*n)),
prost_reflect::Value::U32(n) => Ok(Frame::Integer(i64::from(*n))),
prost_reflect::Value::U64(n) => {
match i64::try_from(*n) {
Ok(i) => Ok(Frame::Integer(i)),
Err(_) => Ok(Frame::Bulk(Bytes::from(n.to_string()))),
}
}
prost_reflect::Value::F32(n) => Ok(Frame::Bulk(Bytes::from(format!("{n}")))),
prost_reflect::Value::F64(n) => Ok(Frame::Bulk(Bytes::from(format!("{n}")))),
prost_reflect::Value::Bool(b) => Ok(Frame::Integer(if *b { 1 } else { 0 })),
prost_reflect::Value::EnumNumber(n) => {
if let Kind::Enum(enum_desc) = field_desc.kind() {
if let Some(val) = enum_desc.get_value(*n) {
return Ok(Frame::Bulk(Bytes::from(val.name().to_owned())));
}
}
Ok(Frame::Integer(i64::from(*n)))
}
prost_reflect::Value::Message(_) => Err(SchemaError::ValidationFailed(
"use PROTO.GET for nested message fields".into(),
)),
prost_reflect::Value::List(_) => Err(SchemaError::ValidationFailed(
"use PROTO.GET for repeated fields".into(),
)),
prost_reflect::Value::Map(_) => Err(SchemaError::ValidationFailed(
"use PROTO.GET for map fields".into(),
)),
}
}
fn resolve_field_path_mut<'a>(
msg: &'a mut DynamicMessage,
path: &str,
) -> Result<(&'a mut DynamicMessage, String, FieldDescriptor), SchemaError> {
if path.is_empty() {
return Err(SchemaError::FieldNotFound("empty field path".into()));
}
let segments: Vec<&str> = path.split('.').collect();
for seg in &segments {
if seg.is_empty() {
return Err(SchemaError::FieldNotFound(format!(
"invalid field path '{path}': empty segment"
)));
}
}
if segments.len() > MAX_FIELD_PATH_DEPTH {
return Err(SchemaError::PathTooDeep(
segments.len(),
MAX_FIELD_PATH_DEPTH,
));
}
if segments.len() == 1 {
let field_desc = msg
.descriptor()
.get_field_by_name(segments[0])
.ok_or_else(|| SchemaError::FieldNotFound(segments[0].to_string()))?;
return Ok((msg, segments[0].to_string(), field_desc));
}
let mut current = msg;
for segment in &segments[..segments.len() - 1] {
let field_desc = current
.descriptor()
.get_field_by_name(segment)
.ok_or_else(|| SchemaError::FieldNotFound(segment.to_string()))?;
if !matches!(field_desc.kind(), Kind::Message(_)) {
return Err(SchemaError::FieldNotFound(format!(
"'{segment}' is not a message field, cannot traverse further"
)));
}
if !current.has_field_by_name(segment) {
let Kind::Message(nested_desc) = field_desc.kind() else {
return Err(SchemaError::FieldNotFound(format!(
"'{segment}' is not a message field"
)));
};
current.set_field_by_name(
segment,
prost_reflect::Value::Message(DynamicMessage::new(nested_desc)),
);
}
let val = current.get_field_by_name_mut(segment).ok_or_else(|| {
SchemaError::FieldNotFound(format!("failed to get mutable reference to '{segment}'"))
})?;
current = match val {
prost_reflect::Value::Message(ref mut nested) => nested,
_ => {
return Err(SchemaError::FieldNotFound(format!(
"'{segment}' is not a message field"
)));
}
};
}
let leaf = segments
.last()
.ok_or_else(|| SchemaError::FieldNotFound("failed to resolve field path".into()))?;
let leaf_desc = current
.descriptor()
.get_field_by_name(leaf)
.ok_or_else(|| SchemaError::FieldNotFound(leaf.to_string()))?;
Ok((current, leaf.to_string(), leaf_desc))
}
fn parse_field_value(
raw: &str,
field_desc: &FieldDescriptor,
) -> Result<prost_reflect::Value, SchemaError> {
if field_desc.is_list() || field_desc.is_map() {
return Err(SchemaError::ValidationFailed(
"use PROTO.SET for repeated/map fields".into(),
));
}
match field_desc.kind() {
Kind::String => Ok(prost_reflect::Value::String(raw.to_owned())),
Kind::Bytes => Ok(prost_reflect::Value::Bytes(Bytes::from(raw.to_owned()))),
Kind::Bool => match raw {
"true" | "1" => Ok(prost_reflect::Value::Bool(true)),
"false" | "0" => Ok(prost_reflect::Value::Bool(false)),
_ => Err(SchemaError::ValidationFailed(format!(
"invalid bool value: '{raw}' (expected true/false/1/0)"
))),
},
Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => {
let n: i32 = raw
.parse()
.map_err(|e| SchemaError::ValidationFailed(format!("invalid int32 value: {e}")))?;
Ok(prost_reflect::Value::I32(n))
}
Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => {
let n: i64 = raw
.parse()
.map_err(|e| SchemaError::ValidationFailed(format!("invalid int64 value: {e}")))?;
Ok(prost_reflect::Value::I64(n))
}
Kind::Uint32 | Kind::Fixed32 => {
let n: u32 = raw
.parse()
.map_err(|e| SchemaError::ValidationFailed(format!("invalid uint32 value: {e}")))?;
Ok(prost_reflect::Value::U32(n))
}
Kind::Uint64 | Kind::Fixed64 => {
let n: u64 = raw
.parse()
.map_err(|e| SchemaError::ValidationFailed(format!("invalid uint64 value: {e}")))?;
Ok(prost_reflect::Value::U64(n))
}
Kind::Float => {
let n: f32 = raw
.parse()
.map_err(|e| SchemaError::ValidationFailed(format!("invalid float value: {e}")))?;
Ok(prost_reflect::Value::F32(n))
}
Kind::Double => {
let n: f64 = raw
.parse()
.map_err(|e| SchemaError::ValidationFailed(format!("invalid double value: {e}")))?;
Ok(prost_reflect::Value::F64(n))
}
Kind::Enum(enum_desc) => {
if let Some(val) = enum_desc.get_value_by_name(raw) {
return Ok(prost_reflect::Value::EnumNumber(val.number()));
}
let n: i32 = raw.parse().map_err(|_| {
SchemaError::ValidationFailed(format!(
"invalid enum value: '{raw}' (not a valid name or number)"
))
})?;
Ok(prost_reflect::Value::EnumNumber(n))
}
Kind::Message(_) => Err(SchemaError::ValidationFailed(
"use PROTO.SET for nested message fields".into(),
)),
}
}
impl std::fmt::Debug for SchemaRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SchemaRegistry")
.field("schema_count", &self.schemas.len())
.field("cached_messages", &self.message_cache.len())
.finish()
}
}
impl Default for SchemaRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_descriptor(package: &str, message_name: &str, field_name: &str) -> Bytes {
use prost_reflect::prost::Message;
use prost_reflect::prost_types::{
DescriptorProto, FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet,
};
let fds = FileDescriptorSet {
file: vec![FileDescriptorProto {
name: Some(format!("{package}.proto")),
package: Some(package.to_owned()),
message_type: vec![DescriptorProto {
name: Some(message_name.to_owned()),
field: vec![FieldDescriptorProto {
name: Some(field_name.to_owned()),
number: Some(1),
r#type: Some(9), label: Some(1), ..Default::default()
}],
..Default::default()
}],
..Default::default()
}],
};
let mut buf = Vec::new();
fds.encode(&mut buf).expect("encode descriptor");
Bytes::from(buf)
}
#[test]
fn register_and_describe() {
let mut registry = SchemaRegistry::new();
let desc = make_descriptor("test", "User", "name");
let types = registry.register("users".into(), desc).unwrap();
assert_eq!(types, vec!["test.User"]);
let described = registry.describe("users").unwrap();
assert_eq!(described, vec!["test.User"]);
}
#[test]
fn double_registration_fails() {
let mut registry = SchemaRegistry::new();
let desc = make_descriptor("test", "User", "name");
registry.register("users".into(), desc.clone()).unwrap();
let err = registry.register("users".into(), desc).unwrap_err();
assert!(matches!(err, SchemaError::AlreadyExists(_)));
}
#[test]
fn invalid_descriptor_fails() {
let mut registry = SchemaRegistry::new();
let err = registry
.register("bad".into(), Bytes::from("not a protobuf"))
.unwrap_err();
assert!(matches!(err, SchemaError::InvalidDescriptor(_)));
}
#[test]
fn validate_valid_message() {
let mut registry = SchemaRegistry::new();
let desc = make_descriptor("test", "User", "name");
registry.register("users".into(), desc).unwrap();
let pool = ®istry.schemas["users"].pool;
let msg_desc = pool.get_message_by_name("test.User").unwrap();
let mut msg = DynamicMessage::new(msg_desc);
msg.set_field_by_name("name", prost_reflect::Value::String("alice".into()));
let mut buf = Vec::new();
use prost_reflect::prost::Message;
msg.encode(&mut buf).unwrap();
registry.validate("test.User", &buf).unwrap();
}
#[test]
fn validate_unknown_type_fails() {
let registry = SchemaRegistry::new();
let err = registry.validate("no.Such.Type", &[]).unwrap_err();
assert!(matches!(err, SchemaError::UnknownMessageType(_)));
}
#[test]
fn schema_names_sorted() {
let mut registry = SchemaRegistry::new();
registry
.register("z-schema".into(), make_descriptor("z", "Z", "val"))
.unwrap();
registry
.register("a-schema".into(), make_descriptor("a", "A", "val"))
.unwrap();
let names = registry.schema_names();
assert_eq!(names, vec!["a-schema", "z-schema"]);
}
#[test]
fn describe_unknown_returns_none() {
let registry = SchemaRegistry::new();
assert!(registry.describe("nope").is_none());
}
#[test]
fn restore_is_idempotent() {
let mut registry = SchemaRegistry::new();
let desc = make_descriptor("test", "User", "name");
registry.restore("users".into(), desc.clone());
registry.restore("users".into(), desc);
assert_eq!(registry.schema_names(), vec!["users"]);
}
#[test]
fn iter_schemas_returns_all() {
let mut registry = SchemaRegistry::new();
let desc1 = make_descriptor("a", "A", "val");
let desc2 = make_descriptor("b", "B", "val");
registry.register("alpha".into(), desc1).unwrap();
registry.register("beta".into(), desc2).unwrap();
let mut pairs: Vec<_> = registry
.iter_schemas()
.map(|(name, _)| name.to_owned())
.collect();
pairs.sort();
assert_eq!(pairs, vec!["alpha", "beta"]);
}
fn encode_user(registry: &SchemaRegistry, name: &str) -> Vec<u8> {
let pool = ®istry.schemas["users"].pool;
let msg_desc = pool.get_message_by_name("test.User").unwrap();
let mut msg = DynamicMessage::new(msg_desc);
msg.set_field_by_name("name", prost_reflect::Value::String(name.into()));
let mut buf = Vec::new();
use prost_reflect::prost::Message;
msg.encode(&mut buf).unwrap();
buf
}
#[test]
fn get_field_string() {
let mut registry = SchemaRegistry::new();
let desc = make_descriptor("test", "User", "name");
registry.register("users".into(), desc).unwrap();
let data = encode_user(®istry, "alice");
let frame = registry.get_field("test.User", &data, "name").unwrap();
assert_eq!(frame, Frame::Bulk(Bytes::from("alice")));
}
#[test]
fn get_field_default_value() {
let mut registry = SchemaRegistry::new();
let desc = make_descriptor("test", "User", "name");
registry.register("users".into(), desc).unwrap();
let pool = ®istry.schemas["users"].pool;
let msg_desc = pool.get_message_by_name("test.User").unwrap();
let msg = DynamicMessage::new(msg_desc);
let mut buf = Vec::new();
use prost_reflect::prost::Message;
msg.encode(&mut buf).unwrap();
let frame = registry.get_field("test.User", &buf, "name").unwrap();
assert_eq!(frame, Frame::Bulk(Bytes::from("")));
}
#[test]
fn get_field_int() {
use prost_reflect::prost_types::{
DescriptorProto, FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet,
};
let fds = FileDescriptorSet {
file: vec![FileDescriptorProto {
name: Some("test.proto".into()),
package: Some("test".into()),
message_type: vec![DescriptorProto {
name: Some("Counter".into()),
field: vec![FieldDescriptorProto {
name: Some("count".into()),
number: Some(1),
r#type: Some(5), label: Some(1),
..Default::default()
}],
..Default::default()
}],
..Default::default()
}],
};
let mut desc_buf = Vec::new();
use prost_reflect::prost::Message;
fds.encode(&mut desc_buf).unwrap();
let desc = Bytes::from(desc_buf);
let mut registry = SchemaRegistry::new();
registry.register("counters".into(), desc.clone()).unwrap();
let pool = ®istry.schemas["counters"].pool;
let msg_desc = pool.get_message_by_name("test.Counter").unwrap();
let mut msg = DynamicMessage::new(msg_desc);
msg.set_field_by_name("count", prost_reflect::Value::I32(42));
let mut buf = Vec::new();
msg.encode(&mut buf).unwrap();
let frame = registry.get_field("test.Counter", &buf, "count").unwrap();
assert_eq!(frame, Frame::Integer(42));
}
#[test]
fn get_field_bool() {
use prost_reflect::prost_types::{
DescriptorProto, FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet,
};
let fds = FileDescriptorSet {
file: vec![FileDescriptorProto {
name: Some("test.proto".into()),
package: Some("test".into()),
message_type: vec![DescriptorProto {
name: Some("Flag".into()),
field: vec![FieldDescriptorProto {
name: Some("active".into()),
number: Some(1),
r#type: Some(8), label: Some(1),
..Default::default()
}],
..Default::default()
}],
..Default::default()
}],
};
let mut desc_buf = Vec::new();
use prost_reflect::prost::Message;
fds.encode(&mut desc_buf).unwrap();
let desc = Bytes::from(desc_buf);
let mut registry = SchemaRegistry::new();
registry.register("flags".into(), desc).unwrap();
let pool = ®istry.schemas["flags"].pool;
let msg_desc = pool.get_message_by_name("test.Flag").unwrap();
let mut msg = DynamicMessage::new(msg_desc);
msg.set_field_by_name("active", prost_reflect::Value::Bool(true));
let mut buf = Vec::new();
msg.encode(&mut buf).unwrap();
let frame = registry.get_field("test.Flag", &buf, "active").unwrap();
assert_eq!(frame, Frame::Integer(1));
}
fn make_nested_descriptor() -> Bytes {
use prost_reflect::prost_types::{
DescriptorProto, FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet,
};
let fds = FileDescriptorSet {
file: vec![FileDescriptorProto {
name: Some("test.proto".into()),
package: Some("test".into()),
message_type: vec![
DescriptorProto {
name: Some("Inner".into()),
field: vec![FieldDescriptorProto {
name: Some("value".into()),
number: Some(1),
r#type: Some(9), label: Some(1),
..Default::default()
}],
..Default::default()
},
DescriptorProto {
name: Some("Outer".into()),
field: vec![FieldDescriptorProto {
name: Some("inner".into()),
number: Some(1),
r#type: Some(11), label: Some(1),
type_name: Some(".test.Inner".into()),
..Default::default()
}],
..Default::default()
},
],
..Default::default()
}],
};
let mut buf = Vec::new();
use prost_reflect::prost::Message;
fds.encode(&mut buf).unwrap();
Bytes::from(buf)
}
#[test]
fn get_field_nested_path() {
let desc = make_nested_descriptor();
let mut registry = SchemaRegistry::new();
registry.register("nested".into(), desc).unwrap();
let pool = ®istry.schemas["nested"].pool;
let outer_desc = pool.get_message_by_name("test.Outer").unwrap();
let inner_desc = pool.get_message_by_name("test.Inner").unwrap();
let mut inner = DynamicMessage::new(inner_desc);
inner.set_field_by_name("value", prost_reflect::Value::String("hello".into()));
let mut outer = DynamicMessage::new(outer_desc);
outer.set_field_by_name("inner", prost_reflect::Value::Message(inner));
let mut buf = Vec::new();
use prost_reflect::prost::Message;
outer.encode(&mut buf).unwrap();
let frame = registry
.get_field("test.Outer", &buf, "inner.value")
.unwrap();
assert_eq!(frame, Frame::Bulk(Bytes::from("hello")));
}
#[test]
fn get_field_nonexistent() {
let mut registry = SchemaRegistry::new();
let desc = make_descriptor("test", "User", "name");
registry.register("users".into(), desc).unwrap();
let data = encode_user(®istry, "alice");
let err = registry
.get_field("test.User", &data, "nonexistent")
.unwrap_err();
assert!(matches!(err, SchemaError::FieldNotFound(_)));
}
#[test]
fn get_field_empty_path() {
let mut registry = SchemaRegistry::new();
let desc = make_descriptor("test", "User", "name");
registry.register("users".into(), desc).unwrap();
let data = encode_user(®istry, "alice");
let err = registry.get_field("test.User", &data, "").unwrap_err();
assert!(matches!(err, SchemaError::FieldNotFound(_)));
}
fn make_multi_field_descriptor() -> Bytes {
use prost_reflect::prost_types::{
DescriptorProto, FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet,
};
let fds = FileDescriptorSet {
file: vec![FileDescriptorProto {
name: Some("test.proto".into()),
package: Some("test".into()),
message_type: vec![DescriptorProto {
name: Some("Profile".into()),
field: vec![
FieldDescriptorProto {
name: Some("name".into()),
number: Some(1),
r#type: Some(9), label: Some(1),
..Default::default()
},
FieldDescriptorProto {
name: Some("age".into()),
number: Some(2),
r#type: Some(5), label: Some(1),
..Default::default()
},
FieldDescriptorProto {
name: Some("active".into()),
number: Some(3),
r#type: Some(8), label: Some(1),
..Default::default()
},
],
..Default::default()
}],
..Default::default()
}],
};
let mut buf = Vec::new();
use prost_reflect::prost::Message;
fds.encode(&mut buf).unwrap();
Bytes::from(buf)
}
fn encode_profile(registry: &SchemaRegistry, name: &str, age: i32, active: bool) -> Vec<u8> {
let pool = ®istry.schemas["profiles"].pool;
let msg_desc = pool.get_message_by_name("test.Profile").unwrap();
let mut msg = DynamicMessage::new(msg_desc);
msg.set_field_by_name("name", prost_reflect::Value::String(name.into()));
msg.set_field_by_name("age", prost_reflect::Value::I32(age));
msg.set_field_by_name("active", prost_reflect::Value::Bool(active));
let mut buf = Vec::new();
use prost_reflect::prost::Message;
msg.encode(&mut buf).unwrap();
buf
}
#[test]
fn set_field_string() {
let desc = make_multi_field_descriptor();
let mut registry = SchemaRegistry::new();
registry.register("profiles".into(), desc).unwrap();
let data = encode_profile(®istry, "alice", 25, true);
let new_data = registry
.set_field("test.Profile", &data, "name", "bob")
.unwrap();
let frame = registry
.get_field("test.Profile", &new_data, "name")
.unwrap();
assert_eq!(frame, Frame::Bulk(Bytes::from("bob")));
let frame = registry
.get_field("test.Profile", &new_data, "age")
.unwrap();
assert_eq!(frame, Frame::Integer(25));
}
#[test]
fn set_field_int32() {
let desc = make_multi_field_descriptor();
let mut registry = SchemaRegistry::new();
registry.register("profiles".into(), desc).unwrap();
let data = encode_profile(®istry, "alice", 25, true);
let new_data = registry
.set_field("test.Profile", &data, "age", "30")
.unwrap();
let frame = registry
.get_field("test.Profile", &new_data, "age")
.unwrap();
assert_eq!(frame, Frame::Integer(30));
}
#[test]
fn set_field_bool() {
let desc = make_multi_field_descriptor();
let mut registry = SchemaRegistry::new();
registry.register("profiles".into(), desc).unwrap();
let data = encode_profile(®istry, "alice", 25, true);
let new_data = registry
.set_field("test.Profile", &data, "active", "false")
.unwrap();
let frame = registry
.get_field("test.Profile", &new_data, "active")
.unwrap();
assert_eq!(frame, Frame::Integer(0));
}
#[test]
fn set_field_invalid_int_value() {
let desc = make_multi_field_descriptor();
let mut registry = SchemaRegistry::new();
registry.register("profiles".into(), desc).unwrap();
let data = encode_profile(®istry, "alice", 25, true);
let err = registry
.set_field("test.Profile", &data, "age", "not_a_number")
.unwrap_err();
assert!(matches!(err, SchemaError::ValidationFailed(_)));
}
#[test]
fn set_field_nonexistent() {
let desc = make_multi_field_descriptor();
let mut registry = SchemaRegistry::new();
registry.register("profiles".into(), desc).unwrap();
let data = encode_profile(®istry, "alice", 25, true);
let err = registry
.set_field("test.Profile", &data, "nonexistent", "value")
.unwrap_err();
assert!(matches!(err, SchemaError::FieldNotFound(_)));
}
#[test]
fn clear_field_resets_to_default() {
let desc = make_multi_field_descriptor();
let mut registry = SchemaRegistry::new();
registry.register("profiles".into(), desc).unwrap();
let data = encode_profile(®istry, "alice", 25, true);
let new_data = registry.clear_field("test.Profile", &data, "name").unwrap();
let frame = registry
.get_field("test.Profile", &new_data, "name")
.unwrap();
assert_eq!(frame, Frame::Bulk(Bytes::from("")));
let frame = registry
.get_field("test.Profile", &new_data, "age")
.unwrap();
assert_eq!(frame, Frame::Integer(25));
}
#[test]
fn descriptor_size_limit_exceeded() {
let mut registry = SchemaRegistry::new();
let oversized = Bytes::from(vec![0u8; MAX_DESCRIPTOR_BYTES + 1]);
let err = registry.register("huge".into(), oversized).unwrap_err();
assert!(matches!(err, SchemaError::DescriptorTooLarge(_, _)));
}
#[test]
fn field_path_depth_limit_exceeded() {
let desc = make_nested_descriptor();
let mut registry = SchemaRegistry::new();
registry.register("nested".into(), desc).unwrap();
let pool = ®istry.schemas["nested"].pool;
let outer_desc = pool.get_message_by_name("test.Outer").unwrap();
let msg = DynamicMessage::new(outer_desc);
let mut buf = Vec::new();
use prost_reflect::prost::Message;
msg.encode(&mut buf).unwrap();
let deep_path = (0..17)
.map(|i| format!("f{i}"))
.collect::<Vec<_>>()
.join(".");
let err = registry
.get_field("test.Outer", &buf, &deep_path)
.unwrap_err();
assert!(matches!(err, SchemaError::PathTooDeep(17, 16)));
let err = registry
.set_field("test.Outer", &buf, &deep_path, "val")
.unwrap_err();
assert!(matches!(err, SchemaError::PathTooDeep(17, 16)));
let err = registry
.clear_field("test.Outer", &buf, &deep_path)
.unwrap_err();
assert!(matches!(err, SchemaError::PathTooDeep(17, 16)));
}
#[test]
fn double_dot_path_returns_error() {
let mut registry = SchemaRegistry::new();
let desc = make_descriptor("test", "User", "name");
registry.register("users".into(), desc).unwrap();
let data = encode_user(®istry, "alice");
let err = registry.get_field("test.User", &data, "a..b").unwrap_err();
match err {
SchemaError::FieldNotFound(msg) => assert!(msg.contains("empty segment"), "{msg}"),
other => panic!("expected FieldNotFound, got {other:?}"),
}
}
#[test]
fn trailing_dot_path_returns_error() {
let mut registry = SchemaRegistry::new();
let desc = make_descriptor("test", "User", "name");
registry.register("users".into(), desc).unwrap();
let data = encode_user(®istry, "alice");
let err = registry.get_field("test.User", &data, "name.").unwrap_err();
match err {
SchemaError::FieldNotFound(msg) => assert!(msg.contains("empty segment"), "{msg}"),
other => panic!("expected FieldNotFound, got {other:?}"),
}
}
#[test]
fn set_field_nested_path() {
let desc = make_nested_descriptor();
let mut registry = SchemaRegistry::new();
registry.register("nested".into(), desc).unwrap();
let pool = ®istry.schemas["nested"].pool;
let outer_desc = pool.get_message_by_name("test.Outer").unwrap();
let inner_desc = pool.get_message_by_name("test.Inner").unwrap();
let mut inner = DynamicMessage::new(inner_desc);
inner.set_field_by_name("value", prost_reflect::Value::String("hello".into()));
let mut outer = DynamicMessage::new(outer_desc);
outer.set_field_by_name("inner", prost_reflect::Value::Message(inner));
let mut buf = Vec::new();
use prost_reflect::prost::Message;
outer.encode(&mut buf).unwrap();
let new_data = registry
.set_field("test.Outer", &buf, "inner.value", "world")
.unwrap();
let frame = registry
.get_field("test.Outer", &new_data, "inner.value")
.unwrap();
assert_eq!(frame, Frame::Bulk(Bytes::from("world")));
}
#[test]
fn clear_field_nested_path() {
let desc = make_nested_descriptor();
let mut registry = SchemaRegistry::new();
registry.register("nested".into(), desc).unwrap();
let pool = ®istry.schemas["nested"].pool;
let outer_desc = pool.get_message_by_name("test.Outer").unwrap();
let inner_desc = pool.get_message_by_name("test.Inner").unwrap();
let mut inner = DynamicMessage::new(inner_desc);
inner.set_field_by_name("value", prost_reflect::Value::String("hello".into()));
let mut outer = DynamicMessage::new(outer_desc);
outer.set_field_by_name("inner", prost_reflect::Value::Message(inner));
let mut buf = Vec::new();
use prost_reflect::prost::Message;
outer.encode(&mut buf).unwrap();
let new_data = registry
.clear_field("test.Outer", &buf, "inner.value")
.unwrap();
let frame = registry
.get_field("test.Outer", &new_data, "inner.value")
.unwrap();
assert_eq!(frame, Frame::Bulk(Bytes::from("")));
}
#[test]
fn set_field_nested_creates_intermediate() {
let desc = make_nested_descriptor();
let mut registry = SchemaRegistry::new();
registry.register("nested".into(), desc).unwrap();
let pool = ®istry.schemas["nested"].pool;
let outer_desc = pool.get_message_by_name("test.Outer").unwrap();
let outer = DynamicMessage::new(outer_desc);
let mut buf = Vec::new();
use prost_reflect::prost::Message;
outer.encode(&mut buf).unwrap();
let new_data = registry
.set_field("test.Outer", &buf, "inner.value", "auto")
.unwrap();
let frame = registry
.get_field("test.Outer", &new_data, "inner.value")
.unwrap();
assert_eq!(frame, Frame::Bulk(Bytes::from("auto")));
}
fn make_uint64_descriptor() -> Bytes {
use prost_reflect::prost_types::{
DescriptorProto, FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet,
};
let fds = FileDescriptorSet {
file: vec![FileDescriptorProto {
name: Some("test.proto".into()),
package: Some("test".into()),
message_type: vec![DescriptorProto {
name: Some("BigNum".into()),
field: vec![FieldDescriptorProto {
name: Some("val".into()),
number: Some(1),
r#type: Some(4), label: Some(1),
..Default::default()
}],
..Default::default()
}],
..Default::default()
}],
};
let mut buf = Vec::new();
use prost_reflect::prost::Message;
fds.encode(&mut buf).unwrap();
Bytes::from(buf)
}
#[test]
fn u64_overflow_returns_bulk_string() {
let desc = make_uint64_descriptor();
let mut registry = SchemaRegistry::new();
registry.register("bignums".into(), desc).unwrap();
let pool = ®istry.schemas["bignums"].pool;
let msg_desc = pool.get_message_by_name("test.BigNum").unwrap();
let mut msg = DynamicMessage::new(msg_desc);
msg.set_field_by_name("val", prost_reflect::Value::U64(u64::MAX));
let mut buf = Vec::new();
use prost_reflect::prost::Message;
msg.encode(&mut buf).unwrap();
let frame = registry.get_field("test.BigNum", &buf, "val").unwrap();
assert_eq!(frame, Frame::Bulk(Bytes::from("18446744073709551615")));
}
#[test]
fn u64_fits_in_i64_returns_integer() {
let desc = make_uint64_descriptor();
let mut registry = SchemaRegistry::new();
registry.register("bignums".into(), desc).unwrap();
let pool = ®istry.schemas["bignums"].pool;
let msg_desc = pool.get_message_by_name("test.BigNum").unwrap();
let mut msg = DynamicMessage::new(msg_desc);
msg.set_field_by_name("val", prost_reflect::Value::U64(42));
let mut buf = Vec::new();
use prost_reflect::prost::Message;
msg.encode(&mut buf).unwrap();
let frame = registry.get_field("test.BigNum", &buf, "val").unwrap();
assert_eq!(frame, Frame::Integer(42));
}
#[test]
fn value_too_large_rejected() {
let mut registry = SchemaRegistry::new();
let desc = make_descriptor("test", "User", "name");
registry.register("users".into(), desc).unwrap();
let oversized = vec![0u8; MAX_PROTO_VALUE_BYTES + 1];
let err = registry.validate("test.User", &oversized).unwrap_err();
assert!(matches!(err, SchemaError::ValueTooLarge(_, _)));
}
#[test]
fn value_at_limit_allowed() {
let mut registry = SchemaRegistry::new();
let desc = make_descriptor("test", "User", "name");
registry.register("users".into(), desc).unwrap();
let at_limit = vec![0u8; MAX_PROTO_VALUE_BYTES];
let err = registry.validate("test.User", &at_limit).unwrap_err();
assert!(
!matches!(err, SchemaError::ValueTooLarge(_, _)),
"expected validation error, not size limit"
);
}
#[test]
fn schema_count_limit() {
let mut registry = SchemaRegistry::new();
for i in 0..MAX_SCHEMAS {
let desc = make_descriptor(&format!("pkg{i}"), &format!("Msg{i}"), "val");
registry
.register(format!("schema-{i}"), desc)
.unwrap_or_else(|e| panic!("failed to register schema {i}: {e}"));
}
let desc = make_descriptor("overflow", "Overflow", "val");
let err = registry.register("overflow".into(), desc).unwrap_err();
assert!(matches!(err, SchemaError::TooManySchemas(_, _)));
}
}