#[cfg(feature = "schema-registry")]
mod client;
pub mod glue;
#[cfg(feature = "schema-registry")]
mod http;
#[cfg(feature = "schema-registry")]
#[cfg_attr(docsrs, doc(cfg(feature = "schema-registry")))]
pub use client::{ConfluentSchemaRegistry, ConfluentSchemaRegistryBuilder};
use self::glue::{GlueSchema, GlueSchemaRegistryClient, GlueSchemaVersionId};
use std::collections::{HashMap, HashSet, VecDeque};
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use bytes::{BufMut, Bytes, BytesMut};
use parking_lot::{Mutex, RwLock};
use tokio::sync::oneshot;
use crate::error::{KrafkaError, Result};
use tracing::debug;
trait ErasedSchemaRegistryClient: Send + Sync {
fn get_schema_by_id_erased<'a>(
&'a self,
id: SchemaId,
) -> Pin<Box<dyn Future<Output = Result<Schema>> + Send + 'a>>;
}
impl<T: SchemaRegistryClient> ErasedSchemaRegistryClient for T {
fn get_schema_by_id_erased<'a>(
&'a self,
id: SchemaId,
) -> Pin<Box<dyn Future<Output = Result<Schema>> + Send + 'a>> {
Box::pin(self.get_schema_by_id(id))
}
}
trait ErasedGlueSchemaRegistryClient: Send + Sync {
fn get_schema_by_version_id_erased<'a>(
&'a self,
id: GlueSchemaVersionId,
) -> Pin<Box<dyn Future<Output = Result<GlueSchema>> + Send + 'a>>;
}
impl<T: GlueSchemaRegistryClient> ErasedGlueSchemaRegistryClient for T {
fn get_schema_by_version_id_erased<'a>(
&'a self,
id: GlueSchemaVersionId,
) -> Pin<Box<dyn Future<Output = Result<GlueSchema>> + Send + 'a>> {
Box::pin(self.get_schema_by_version_id(id))
}
}
pub type SchemaId = u32;
pub type SchemaVersion = i32;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum SchemaType {
Avro,
Protobuf,
Json,
}
impl SchemaType {
pub fn as_str(&self) -> &'static str {
match self {
Self::Avro => "AVRO",
Self::Protobuf => "PROTOBUF",
Self::Json => "JSON",
}
}
}
impl fmt::Display for SchemaType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl std::str::FromStr for SchemaType {
type Err = KrafkaError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
if s.eq_ignore_ascii_case("AVRO") {
Ok(Self::Avro)
} else if s.eq_ignore_ascii_case("PROTOBUF") {
Ok(Self::Protobuf)
} else if s.eq_ignore_ascii_case("JSON") {
Ok(Self::Json)
} else {
Err(KrafkaError::schema_registry(format!(
"unknown schema type: '{s}'"
)))
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SchemaReference {
pub name: String,
pub subject: String,
pub version: SchemaVersion,
}
impl SchemaReference {
pub fn new(
name: impl Into<String>,
subject: impl Into<String>,
version: SchemaVersion,
) -> Self {
Self {
name: name.into(),
subject: subject.into(),
version,
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Schema {
pub id: SchemaId,
pub schema_type: SchemaType,
pub schema: String,
pub version: Option<SchemaVersion>,
pub subject: Option<String>,
pub references: Vec<SchemaReference>,
}
impl Schema {
pub fn new(id: SchemaId, schema_type: SchemaType, schema: impl Into<String>) -> Self {
Self {
id,
schema_type,
schema: schema.into(),
version: None,
subject: None,
references: Vec::new(),
}
}
pub fn with_subject(mut self, subject: impl Into<String>, version: SchemaVersion) -> Self {
self.subject = Some(subject.into());
self.version = Some(version);
self
}
pub fn with_references(mut self, references: Vec<SchemaReference>) -> Self {
self.references = references;
self
}
}
const MAGIC_BYTE: u8 = 0x00;
const HEADER_SIZE: usize = 5;
pub fn encode_wire_format(schema_id: SchemaId, payload: &[u8]) -> Bytes {
let mut buf = BytesMut::with_capacity(HEADER_SIZE + payload.len());
buf.put_u8(MAGIC_BYTE);
buf.put_u32(schema_id);
buf.put_slice(payload);
buf.freeze()
}
pub fn decode_wire_format(data: &[u8]) -> Result<(SchemaId, &[u8])> {
let schema_id = validate_wire_header(data)?;
Ok((schema_id, &data[HEADER_SIZE..]))
}
pub fn decode_wire_format_bytes(data: &Bytes) -> Result<(SchemaId, Bytes)> {
let schema_id = validate_wire_header(data)?;
Ok((schema_id, data.slice(HEADER_SIZE..)))
}
fn validate_wire_header(data: &[u8]) -> Result<SchemaId> {
if data.len() < HEADER_SIZE {
return Err(KrafkaError::serialization(format!(
"wire format data too short: expected at least {HEADER_SIZE} bytes, got {}",
data.len()
)));
}
if data[0] != MAGIC_BYTE {
return Err(KrafkaError::serialization(format!(
"invalid wire format magic byte: expected 0x{MAGIC_BYTE:02X}, got 0x{:02X}",
data[0]
)));
}
Ok(u32::from_be_bytes([data[1], data[2], data[3], data[4]]))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum DetectedWireFormat {
Confluent {
schema_id: SchemaId,
payload_offset: usize,
},
Glue {
version_id: GlueSchemaVersionId,
payload_offset: usize,
},
InvalidConfluent,
InvalidGlue,
Unknown,
}
pub fn detect_wire_format(data: &[u8]) -> DetectedWireFormat {
if data.is_empty() {
return DetectedWireFormat::Unknown;
}
match data[0] {
MAGIC_BYTE => {
if data.len() < HEADER_SIZE {
return DetectedWireFormat::InvalidConfluent;
}
let schema_id = u32::from_be_bytes([data[1], data[2], data[3], data[4]]);
DetectedWireFormat::Confluent {
schema_id,
payload_offset: HEADER_SIZE,
}
}
glue::GLUE_HEADER_VERSION_BYTE => {
if data.len() < glue::GLUE_HEADER_SIZE {
return DetectedWireFormat::InvalidGlue;
}
let compression = data[1];
if compression != glue::GLUE_COMPRESSION_NONE_BYTE
&& compression != glue::GLUE_COMPRESSION_ZLIB_BYTE
{
return DetectedWireFormat::InvalidGlue;
}
let mut version_bytes = [0u8; 16];
version_bytes.copy_from_slice(&data[2..glue::GLUE_HEADER_SIZE]);
DetectedWireFormat::Glue {
version_id: GlueSchemaVersionId::from_bytes(version_bytes),
payload_offset: glue::GLUE_HEADER_SIZE,
}
}
_ => DetectedWireFormat::Unknown,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum SchemaFormat {
Avro,
Json,
Protobuf,
Unknown,
}
impl From<SchemaType> for SchemaFormat {
fn from(value: SchemaType) -> Self {
match value {
SchemaType::Avro => Self::Avro,
SchemaType::Json => Self::Json,
SchemaType::Protobuf => Self::Protobuf,
}
}
}
impl From<glue::GlueDataFormat> for SchemaFormat {
fn from(value: glue::GlueDataFormat) -> Self {
match value {
glue::GlueDataFormat::Avro => Self::Avro,
glue::GlueDataFormat::Json => Self::Json,
glue::GlueDataFormat::Protobuf => Self::Protobuf,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum SchemaMetadata {
Confluent(Schema),
Glue(GlueSchema),
}
impl SchemaMetadata {
pub fn schema_format(&self) -> SchemaFormat {
match self {
Self::Confluent(schema) => schema.schema_type.into(),
Self::Glue(schema) => schema.data_format.into(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub struct DecodedMessage {
pub schema_format: SchemaFormat,
pub payload: Bytes,
pub schema_metadata: Option<SchemaMetadata>,
}
#[derive(Default, Clone, Copy)]
pub struct WireFormatDecoder<'a> {
confluent: Option<&'a dyn ErasedSchemaRegistryClient>,
glue: Option<&'a dyn ErasedGlueSchemaRegistryClient>,
}
impl fmt::Debug for WireFormatDecoder<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SchemaDecoder")
.field("has_confluent", &self.confluent.is_some())
.field("has_glue", &self.glue.is_some())
.finish()
}
}
impl<'a> WireFormatDecoder<'a> {
pub fn new() -> Self {
Self::default()
}
pub fn confluent(registry: &'a impl SchemaRegistryClient) -> Self {
Self::new().with_confluent(registry)
}
pub fn glue(registry: &'a impl GlueSchemaRegistryClient) -> Self {
Self::new().with_glue(registry)
}
pub fn with_confluent(mut self, registry: &'a impl SchemaRegistryClient) -> Self {
self.confluent = Some(registry);
self
}
pub fn with_glue(mut self, registry: &'a impl GlueSchemaRegistryClient) -> Self {
self.glue = Some(registry);
self
}
pub async fn decode(&self, data: Bytes) -> Result<DecodedMessage> {
match detect_wire_format(&data) {
DetectedWireFormat::Confluent { schema_id, .. } => {
let registry = self.confluent.ok_or_else(|| {
KrafkaError::config(
"schema decoder missing Confluent registry for Confluent-framed payload",
)
})?;
let (_, payload) = decode_wire_format_bytes(&data)?;
let schema = registry.get_schema_by_id_erased(schema_id).await?;
Ok(DecodedMessage {
schema_format: schema.schema_type.into(),
payload,
schema_metadata: Some(SchemaMetadata::Confluent(schema)),
})
}
DetectedWireFormat::Glue { version_id, .. } => {
let registry = self.glue.ok_or_else(|| {
KrafkaError::config(
"schema decoder missing Glue registry for Glue-framed payload",
)
})?;
let (_, payload) = glue::decode_glue_wire_format_bytes(&data)?;
let schema = registry.get_schema_by_version_id_erased(version_id).await?;
Ok(DecodedMessage {
schema_format: schema.data_format.into(),
payload,
schema_metadata: Some(SchemaMetadata::Glue(schema)),
})
}
DetectedWireFormat::InvalidConfluent
| DetectedWireFormat::InvalidGlue
| DetectedWireFormat::Unknown => Ok(DecodedMessage {
schema_format: SchemaFormat::Unknown,
payload: data,
schema_metadata: None,
}),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
#[non_exhaustive]
pub enum SubjectNameStrategy {
#[default]
TopicName,
RecordName,
TopicRecordName,
}
fn schema_lookup_cancelled_error(id: SchemaId) -> KrafkaError {
KrafkaError::invalid_state(format!(
"schema lookup cancelled before completion for id {id}"
))
}
impl SubjectNameStrategy {
pub fn subject_name(
&self,
topic: &str,
record_name: Option<&str>,
is_key: bool,
) -> Result<String> {
match self {
Self::TopicName => {
let suffix = if is_key { "key" } else { "value" };
Ok(format!("{topic}-{suffix}"))
}
Self::RecordName => {
let name = record_name.ok_or_else(|| {
KrafkaError::config("RecordName strategy requires a record name")
})?;
Ok(name.to_string())
}
Self::TopicRecordName => {
let name = record_name.ok_or_else(|| {
KrafkaError::config("TopicRecordName strategy requires a record name")
})?;
Ok(format!("{topic}-{name}"))
}
}
}
}
pub trait SchemaRegistryClient: Send + Sync {
fn get_schema_by_id(&self, id: SchemaId) -> impl Future<Output = Result<Schema>> + Send + '_;
fn get_latest_schema<'a>(
&'a self,
subject: &'a str,
) -> impl Future<Output = Result<Schema>> + Send + 'a;
fn get_schema_by_version<'a>(
&'a self,
subject: &'a str,
version: SchemaVersion,
) -> impl Future<Output = Result<Schema>> + Send + 'a;
fn register_schema<'a>(
&'a self,
subject: &'a str,
schema: &'a str,
schema_type: SchemaType,
references: &'a [SchemaReference],
) -> impl Future<Output = Result<SchemaId>> + Send + 'a;
fn check_compatibility<'a>(
&'a self,
_subject: &'a str,
_schema: &'a str,
_schema_type: SchemaType,
_references: &'a [SchemaReference],
) -> impl Future<Output = Result<bool>> + Send + 'a {
std::future::ready(Err(KrafkaError::schema_registry(
"check_compatibility: not implemented for this registry",
)))
}
fn delete_subject<'a>(
&'a self,
_subject: &'a str,
_permanent: bool,
) -> impl Future<Output = Result<Vec<SchemaVersion>>> + Send + 'a {
std::future::ready(Err(KrafkaError::schema_registry(
"delete_subject: not implemented for this registry",
)))
}
fn get_subjects(&self) -> impl Future<Output = Result<Vec<String>>> + Send + '_ {
std::future::ready(Err(KrafkaError::schema_registry(
"get_subjects: not implemented for this registry",
)))
}
fn get_versions<'a>(
&'a self,
_subject: &'a str,
) -> impl Future<Output = Result<Vec<SchemaVersion>>> + Send + 'a {
std::future::ready(Err(KrafkaError::schema_registry(
"get_versions: not implemented for this registry",
)))
}
}
pub trait AnySchemaCache: Send + Sync {
type Id: Copy + Send + Sync;
fn cache_len(&self) -> usize;
fn cache_is_empty(&self) -> bool;
fn clear_cache(&self);
fn invalidate(&self, id: Self::Id);
fn invalidate_all(&self);
fn warm_cache<'a>(
&'a self,
ids: &'a [Self::Id],
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
}
pub struct CachedSchemaRegistry<C> {
inner: C,
cache: RwLock<HashMap<SchemaId, Schema>>,
insertion_order: RwLock<VecDeque<SchemaId>>,
max_entries: Option<usize>,
in_flight_token: AtomicU64,
invalidation_generation: AtomicU64,
in_flight: Mutex<HashMap<SchemaId, SchemaInFlightEntry>>,
}
#[derive(Default)]
struct SchemaInFlightEntry {
token: u64,
waiters: Vec<oneshot::Sender<Result<Schema>>>,
}
const DEFAULT_MAX_CACHE_ENTRIES: usize = 1000;
impl<C: SchemaRegistryClient> CachedSchemaRegistry<C> {
pub fn new(inner: C) -> Self {
Self::with_max_entries(inner, DEFAULT_MAX_CACHE_ENTRIES)
}
pub fn with_capacity(inner: C, capacity: usize) -> Self {
Self {
inner,
cache: RwLock::new(HashMap::with_capacity(capacity)),
insertion_order: RwLock::new(VecDeque::with_capacity(capacity)),
max_entries: None,
in_flight_token: AtomicU64::new(0),
invalidation_generation: AtomicU64::new(0),
in_flight: Mutex::new(HashMap::new()),
}
}
pub fn with_max_entries(inner: C, max_entries: usize) -> Self {
let max_entries = max_entries.max(1);
Self {
inner,
cache: RwLock::new(HashMap::with_capacity(max_entries)),
insertion_order: RwLock::new(VecDeque::with_capacity(max_entries)),
max_entries: Some(max_entries),
in_flight_token: AtomicU64::new(0),
invalidation_generation: AtomicU64::new(0),
in_flight: Mutex::new(HashMap::new()),
}
}
pub fn inner(&self) -> &C {
&self.inner
}
pub fn cache_len(&self) -> usize {
self.cache.read().len()
}
pub fn cache_is_empty(&self) -> bool {
self.cache.read().is_empty()
}
fn clear_cache_storage(&self) {
self.cache.write().clear();
self.insertion_order.write().clear();
}
pub fn clear_cache(&self) {
self.invalidation_generation.fetch_add(1, Ordering::SeqCst);
let cancelled: Vec<_> = self.in_flight.lock().drain().collect();
self.clear_cache_storage();
for (id, entry) in cancelled {
for waiter in entry.waiters {
let _ = waiter.send(Err(schema_lookup_cancelled_error(id)));
}
}
}
pub fn invalidate(&self, schema_id: SchemaId) {
self.invalidation_generation.fetch_add(1, Ordering::SeqCst);
let waiters = self
.in_flight
.lock()
.remove(&schema_id)
.map(|entry| entry.waiters)
.unwrap_or_default();
self.cache.write().remove(&schema_id);
self.insertion_order
.write()
.retain(|cached_id| *cached_id != schema_id);
for waiter in waiters {
let _ = waiter.send(Err(schema_lookup_cancelled_error(schema_id)));
}
}
pub fn invalidate_all(&self) {
self.clear_cache();
}
pub async fn warm_cache(&self, schema_ids: &[SchemaId]) -> Result<()> {
let mut seen = HashSet::with_capacity(schema_ids.len());
for &id in schema_ids {
if !seen.insert(id) {
continue;
}
self.get_schema_by_id_impl(id).await?;
}
Ok(())
}
async fn get_schema_by_id_impl(&self, id: SchemaId) -> Result<Schema> {
if let Some(schema) = self.cache.read().get(&id) {
debug!(schema_id = id, "schema cache hit");
return Ok(schema.clone());
}
let (waiter_rx, leader_token) = {
let mut in_flight = self.in_flight.lock();
if let Some(schema) = self.cache.read().get(&id) {
return Ok(schema.clone());
}
if let Some(entry) = in_flight.get_mut(&id) {
let (tx, rx) = oneshot::channel();
entry.waiters.push(tx);
(Some(rx), None)
} else {
let token = self.in_flight_token.fetch_add(1, Ordering::SeqCst) + 1;
in_flight.insert(
id,
SchemaInFlightEntry {
token,
waiters: Vec::new(),
},
);
(None, Some(token))
}
};
if let Some(rx) = waiter_rx {
return rx.await.map_err(|_| schema_lookup_cancelled_error(id))?;
}
struct InFlightSchemaFetchGuard<'a> {
in_flight: &'a Mutex<HashMap<SchemaId, SchemaInFlightEntry>>,
id: SchemaId,
token: u64,
completed: bool,
}
impl Drop for InFlightSchemaFetchGuard<'_> {
fn drop(&mut self) {
if self.completed {
return;
}
let waiters = {
let mut in_flight = self.in_flight.lock();
if matches!(in_flight.get(&self.id), Some(entry) if entry.token == self.token) {
in_flight
.remove(&self.id)
.map(|entry| entry.waiters)
.unwrap_or_default()
} else {
Vec::new()
}
};
for waiter in waiters {
let _ = waiter.send(Err(schema_lookup_cancelled_error(self.id)));
}
}
}
let Some(leader_token) = leader_token else {
return Err(schema_lookup_cancelled_error(id));
};
let mut guard = InFlightSchemaFetchGuard {
in_flight: &self.in_flight,
id,
token: leader_token,
completed: false,
};
let result = self.inner.get_schema_by_id(id).await;
if let Ok(schema) = &result {
let should_insert = {
let in_flight = self.in_flight.lock();
matches!(in_flight.get(&id), Some(entry) if entry.token == leader_token)
};
if should_insert {
let mut cache = self.cache.write();
debug!(schema_id = id, "schema cache miss — fetched from registry");
if let Some(existing) = cache.get_mut(&id) {
*existing = schema.clone();
} else {
if let Some(max_entries) = self.max_entries {
let mut insertion_order = self.insertion_order.write();
if cache.len() >= max_entries
&& let Some(evicted) = insertion_order.pop_front()
{
cache.remove(&evicted);
}
insertion_order.push_back(id);
}
cache.insert(id, schema.clone());
}
} else {
debug!(
schema_id = id,
"schema fetch completed after invalidation; skipping cache insert"
);
}
}
let waiters = {
let mut in_flight = self.in_flight.lock();
if matches!(in_flight.get(&id), Some(entry) if entry.token == leader_token) {
in_flight
.remove(&id)
.map(|entry| entry.waiters)
.unwrap_or_default()
} else {
Vec::new()
}
};
for waiter in waiters {
let _ = waiter.send(result.clone());
}
guard.completed = true;
result
}
async fn get_latest_schema_impl(&self, subject: &str) -> Result<Schema> {
let observed_generation = self.invalidation_generation.load(Ordering::SeqCst);
let schema = self.inner.get_latest_schema(subject).await?;
self.insert_cache_entry_if_current(schema.id, schema.clone(), observed_generation);
Ok(schema)
}
async fn get_schema_by_version_impl(
&self,
subject: &str,
version: SchemaVersion,
) -> Result<Schema> {
let observed_generation = self.invalidation_generation.load(Ordering::SeqCst);
let schema = self.inner.get_schema_by_version(subject, version).await?;
self.insert_cache_entry_if_current(schema.id, schema.clone(), observed_generation);
Ok(schema)
}
async fn register_schema_impl(
&self,
subject: &str,
schema: &str,
schema_type: SchemaType,
references: &[SchemaReference],
) -> Result<SchemaId> {
self.inner
.register_schema(subject, schema, schema_type, references)
.await
}
pub async fn get_schema_by_id(&self, id: SchemaId) -> Result<Schema> {
self.get_schema_by_id_impl(id).await
}
pub async fn get_latest_schema(&self, subject: &str) -> Result<Schema> {
self.get_latest_schema_impl(subject).await
}
pub async fn get_schema_by_version(
&self,
subject: &str,
version: SchemaVersion,
) -> Result<Schema> {
self.get_schema_by_version_impl(subject, version).await
}
pub async fn register_schema(
&self,
subject: &str,
schema: &str,
schema_type: SchemaType,
references: &[SchemaReference],
) -> Result<SchemaId> {
self.register_schema_impl(subject, schema, schema_type, references)
.await
}
fn insert_cache_entry(&self, id: SchemaId, schema: Schema) {
let mut cache = self.cache.write();
if let Some(existing) = cache.get_mut(&id) {
*existing = schema;
return;
}
if let Some(max_entries) = self.max_entries {
let mut insertion_order = self.insertion_order.write();
if cache.len() >= max_entries
&& let Some(evicted) = insertion_order.pop_front()
{
cache.remove(&evicted);
}
insertion_order.push_back(id);
}
cache.insert(id, schema);
}
fn insert_cache_entry_if_current(
&self,
id: SchemaId,
schema: Schema,
observed_generation: u64,
) {
if self.invalidation_generation.load(Ordering::SeqCst) != observed_generation {
debug!(
schema_id = id,
"schema fetch completed after invalidation; skipping cache insert"
);
return;
}
self.insert_cache_entry(id, schema);
}
}
impl<C> fmt::Debug for CachedSchemaRegistry<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CachedSchemaRegistry")
.field("cache_len", &self.cache.read().len())
.field("max_entries", &self.max_entries)
.finish()
}
}
impl<C: SchemaRegistryClient> SchemaRegistryClient for CachedSchemaRegistry<C> {
async fn get_schema_by_id(&self, id: SchemaId) -> Result<Schema> {
self.get_schema_by_id_impl(id).await
}
async fn get_latest_schema(&self, subject: &str) -> Result<Schema> {
self.get_latest_schema_impl(subject).await
}
async fn get_schema_by_version(&self, subject: &str, version: SchemaVersion) -> Result<Schema> {
self.get_schema_by_version_impl(subject, version).await
}
async fn register_schema(
&self,
subject: &str,
schema: &str,
schema_type: SchemaType,
references: &[SchemaReference],
) -> Result<SchemaId> {
self.register_schema_impl(subject, schema, schema_type, references)
.await
}
fn check_compatibility<'a>(
&'a self,
subject: &'a str,
schema: &'a str,
schema_type: SchemaType,
references: &'a [SchemaReference],
) -> impl Future<Output = Result<bool>> + Send + 'a {
self.inner
.check_compatibility(subject, schema, schema_type, references)
}
fn delete_subject<'a>(
&'a self,
subject: &'a str,
permanent: bool,
) -> impl Future<Output = Result<Vec<SchemaVersion>>> + Send + 'a {
self.inner.delete_subject(subject, permanent)
}
fn get_subjects(&self) -> impl Future<Output = Result<Vec<String>>> + Send + '_ {
self.inner.get_subjects()
}
fn get_versions<'a>(
&'a self,
subject: &'a str,
) -> impl Future<Output = Result<Vec<SchemaVersion>>> + Send + 'a {
self.inner.get_versions(subject)
}
}
impl<C: SchemaRegistryClient> AnySchemaCache for CachedSchemaRegistry<C> {
type Id = SchemaId;
fn cache_len(&self) -> usize {
Self::cache_len(self)
}
fn cache_is_empty(&self) -> bool {
Self::cache_is_empty(self)
}
fn clear_cache(&self) {
Self::clear_cache(self)
}
fn invalidate(&self, id: Self::Id) {
Self::invalidate(self, id)
}
fn invalidate_all(&self) {
Self::invalidate_all(self)
}
fn warm_cache<'a>(
&'a self,
ids: &'a [Self::Id],
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
Box::pin(async move { Self::warm_cache(self, ids).await })
}
}
pub trait SchemaEncoder: Send + Sync {
fn encode(
&self,
payload: Bytes,
topic: &str,
record_name: Option<&str>,
is_key: bool,
) -> Pin<Box<dyn Future<Output = Result<Bytes>> + Send + '_>>;
}
#[cfg(feature = "schema-registry")]
#[cfg_attr(docsrs, doc(cfg(feature = "schema-registry")))]
pub struct ConfluentSchemaEncoder<C> {
registry: C,
schema: String,
schema_type: SchemaType,
strategy: SubjectNameStrategy,
references: Vec<SchemaReference>,
id_cache: parking_lot::RwLock<std::collections::HashMap<String, SchemaId>>,
}
#[cfg(feature = "schema-registry")]
impl<C: SchemaRegistryClient> ConfluentSchemaEncoder<C> {
pub fn builder() -> ConfluentSchemaEncoderBuilder<C> {
ConfluentSchemaEncoderBuilder::new()
}
async fn resolve_id(&self, subject: &str) -> Result<SchemaId> {
{
let cache = self.id_cache.read();
if let Some(&id) = cache.get(subject) {
return Ok(id);
}
}
let id = self
.registry
.register_schema(subject, &self.schema, self.schema_type, &self.references)
.await?;
self.id_cache.write().insert(subject.to_string(), id);
Ok(id)
}
}
#[cfg(feature = "schema-registry")]
impl<C: SchemaRegistryClient> fmt::Debug for ConfluentSchemaEncoder<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ConfluentSchemaEncoder")
.field("schema_type", &self.schema_type)
.field("strategy", &self.strategy)
.field("cached_subjects", &self.id_cache.read().len())
.finish()
}
}
#[cfg(feature = "schema-registry")]
impl<C: SchemaRegistryClient> SchemaEncoder for ConfluentSchemaEncoder<C> {
fn encode(
&self,
payload: Bytes,
topic: &str,
record_name: Option<&str>,
is_key: bool,
) -> Pin<Box<dyn Future<Output = Result<Bytes>> + Send + '_>> {
let topic = topic.to_string();
let record_name = record_name.map(str::to_string);
Box::pin(async move {
let subject = self
.strategy
.subject_name(&topic, record_name.as_deref(), is_key)?;
let id = self.resolve_id(&subject).await?;
Ok(encode_wire_format(id, &payload))
})
}
}
#[cfg(feature = "schema-registry")]
#[cfg_attr(docsrs, doc(cfg(feature = "schema-registry")))]
pub struct ConfluentSchemaEncoderBuilder<C> {
registry: Option<C>,
schema: Option<String>,
schema_type: SchemaType,
strategy: SubjectNameStrategy,
references: Vec<SchemaReference>,
}
#[cfg(feature = "schema-registry")]
impl<C: SchemaRegistryClient> ConfluentSchemaEncoderBuilder<C> {
fn new() -> Self {
Self {
registry: None,
schema: None,
schema_type: SchemaType::Avro,
strategy: SubjectNameStrategy::TopicName,
references: Vec::new(),
}
}
pub fn registry(mut self, registry: C) -> Self {
self.registry = Some(registry);
self
}
pub fn schema(mut self, schema: impl Into<String>, schema_type: SchemaType) -> Self {
self.schema = Some(schema.into());
self.schema_type = schema_type;
self
}
pub fn strategy(mut self, strategy: SubjectNameStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn references(mut self, references: Vec<SchemaReference>) -> Self {
self.references = references;
self
}
pub fn build(self) -> Result<ConfluentSchemaEncoder<C>> {
let registry = self
.registry
.ok_or_else(|| KrafkaError::config("ConfluentSchemaEncoder: registry must be set"))?;
let schema = self
.schema
.ok_or_else(|| KrafkaError::config("ConfluentSchemaEncoder: schema must be set"))?;
Ok(ConfluentSchemaEncoder {
registry,
schema,
schema_type: self.schema_type,
strategy: self.strategy,
references: self.references,
id_cache: parking_lot::RwLock::new(std::collections::HashMap::new()),
})
}
}
pub trait SchemaDecoder: Send + Sync {
fn decode(
&self,
payload: Bytes,
topic: &str,
is_key: bool,
) -> Pin<Box<dyn Future<Output = Result<Bytes>> + Send + '_>>;
}
#[cfg(feature = "schema-registry")]
#[cfg_attr(docsrs, doc(cfg(feature = "schema-registry")))]
#[derive(Debug, Clone, Copy, Default)]
pub struct ConfluentSchemaDecoder;
#[cfg(feature = "schema-registry")]
impl ConfluentSchemaDecoder {
pub fn new() -> Self {
Self
}
}
#[cfg(feature = "schema-registry")]
impl SchemaDecoder for ConfluentSchemaDecoder {
fn decode(
&self,
payload: Bytes,
_topic: &str,
_is_key: bool,
) -> Pin<Box<dyn Future<Output = Result<Bytes>> + Send + '_>> {
Box::pin(async move {
match detect_wire_format(&payload) {
DetectedWireFormat::Confluent { .. } => {
let (_, inner) = decode_wire_format_bytes(&payload)?;
Ok(inner)
}
_ => Ok(payload),
}
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use tokio::sync::{Notify, Semaphore};
fn ok<T, E: std::fmt::Display>(result: std::result::Result<T, E>) -> T {
match result {
Ok(value) => value,
Err(err) => unreachable!("expected Ok(..), got Err({err})"),
}
}
fn err<T, E: std::fmt::Display>(result: std::result::Result<T, E>) -> E {
match result {
Err(err) => err,
Ok(_) => unreachable!("expected Err(..), got Ok(..)"),
}
}
fn join_ok<T>(result: std::result::Result<T, tokio::task::JoinError>) -> T {
match result {
Ok(value) => value,
Err(err) => unreachable!("spawned task failed unexpectedly: {err}"),
}
}
#[test]
fn test_wire_format_roundtrip() {
let payload = b"hello world";
let encoded = encode_wire_format(42, payload);
let (id, decoded) = ok(decode_wire_format(&encoded));
assert_eq!(id, 42);
assert_eq!(decoded, payload);
}
#[test]
fn test_wire_format_empty_payload() {
let encoded = encode_wire_format(1, b"");
assert_eq!(encoded.len(), HEADER_SIZE);
let (id, payload) = ok(decode_wire_format(&encoded));
assert_eq!(id, 1);
assert!(payload.is_empty());
}
#[test]
fn test_wire_format_max_schema_id() {
let encoded = encode_wire_format(u32::MAX, b"data");
let (id, _) = ok(decode_wire_format(&encoded));
assert_eq!(id, u32::MAX);
}
#[test]
fn test_wire_format_header_bytes() {
let encoded = encode_wire_format(256, b"x");
assert_eq!(&encoded[..5], &[0x00, 0x00, 0x00, 0x01, 0x00]);
assert_eq!(&encoded[5..], b"x");
}
#[test]
fn test_wire_format_invalid_magic_byte() {
let data = [0x01, 0, 0, 0, 1, 0x42];
let result = decode_wire_format(&data);
assert!(result.is_err());
assert!(err(result).to_string().contains("magic byte"));
}
#[test]
fn test_wire_format_too_short() {
let result = decode_wire_format(&[0x00, 0, 0]);
assert!(result.is_err());
assert!(err(result).to_string().contains("too short"));
}
#[test]
fn test_wire_format_empty_data() {
let result = decode_wire_format(&[]);
assert!(result.is_err());
}
#[test]
fn test_detect_wire_format_confluent() {
let encoded = encode_wire_format(42, b"data");
let detected = detect_wire_format(&encoded);
assert_eq!(
detected,
DetectedWireFormat::Confluent {
schema_id: 42,
payload_offset: 5,
}
);
}
#[test]
fn test_detect_wire_format_glue() {
let version_id: GlueSchemaVersionId =
"550e8400-e29b-41d4-a716-446655440000".parse().unwrap();
let encoded = crate::schema_registry::glue::encode_glue_wire_format(
version_id,
b"data",
crate::schema_registry::glue::GlueCompression::None,
)
.unwrap();
let detected = detect_wire_format(&encoded);
assert_eq!(
detected,
DetectedWireFormat::Glue {
version_id,
payload_offset: 18,
}
);
}
#[test]
fn test_detect_wire_format_unknown() {
assert_eq!(detect_wire_format(&[]), DetectedWireFormat::Unknown);
assert_eq!(
detect_wire_format(&[0x99, 0x00, 0x00]),
DetectedWireFormat::Unknown
);
}
#[test]
fn test_detect_wire_format_confluent_schema_id_zero() {
assert_eq!(
detect_wire_format(&[MAGIC_BYTE, 0x00, 0x00, 0x00, 0x00, 0x41]),
DetectedWireFormat::Confluent {
schema_id: 0,
payload_offset: HEADER_SIZE,
}
);
}
#[test]
fn test_detect_wire_format_invalid_known_headers() {
assert_eq!(
detect_wire_format(&[MAGIC_BYTE, 0x01, 0x02]),
DetectedWireFormat::InvalidConfluent
);
assert_eq!(
detect_wire_format(&[
glue::GLUE_HEADER_VERSION_BYTE,
glue::GLUE_COMPRESSION_NONE_BYTE
]),
DetectedWireFormat::InvalidGlue
);
}
#[test]
fn test_detect_wire_format_glue_accepts_non_rfc_uuid_layout() {
let nil: GlueSchemaVersionId = "00000000-0000-0000-0000-000000000000".parse().unwrap();
let encoded = crate::schema_registry::glue::encode_glue_wire_format(
nil,
b"data",
crate::schema_registry::glue::GlueCompression::None,
)
.unwrap();
assert_eq!(
detect_wire_format(&encoded),
DetectedWireFormat::Glue {
version_id: nil,
payload_offset: glue::GLUE_HEADER_SIZE,
}
);
}
struct DecoderMockGlueRegistry;
impl glue::GlueSchemaRegistryClient for DecoderMockGlueRegistry {
async fn get_schema_by_version_id(
&self,
id: GlueSchemaVersionId,
) -> Result<glue::GlueSchema> {
Ok(glue::GlueSchema::new(
id,
glue::GlueDataFormat::Json,
r#"{"type":"object"}"#,
))
}
async fn register_schema(
&self,
_schema_name: &str,
_schema: &str,
_data_format: glue::GlueDataFormat,
) -> Result<GlueSchemaVersionId> {
Ok("550e8400-e29b-41d4-a716-446655440000"
.parse::<GlueSchemaVersionId>()
.unwrap())
}
}
#[tokio::test]
async fn test_schema_decoder_confluent() {
let registry = CachedSchemaRegistry::new(MockRegistry::new());
let decoder = WireFormatDecoder::confluent(®istry);
let encoded = encode_wire_format(7, b"payload");
let decoded = ok(decoder.decode(encoded).await);
assert_eq!(decoded.schema_format, SchemaFormat::Avro);
assert_eq!(&decoded.payload[..], b"payload");
match decoded.schema_metadata {
Some(SchemaMetadata::Confluent(schema)) => assert_eq!(schema.id, 7),
_ => unreachable!("expected confluent metadata"),
}
}
#[tokio::test]
async fn test_schema_decoder_glue() {
let registry = glue::CachedGlueSchemaRegistry::new(DecoderMockGlueRegistry);
let decoder = WireFormatDecoder::glue(®istry);
let version_id: GlueSchemaVersionId =
"550e8400-e29b-41d4-a716-446655440000".parse().unwrap();
let encoded =
glue::encode_glue_wire_format(version_id, b"payload", glue::GlueCompression::None)
.unwrap();
let decoded = ok(decoder.decode(encoded).await);
assert_eq!(decoded.schema_format, SchemaFormat::Json);
assert_eq!(&decoded.payload[..], b"payload");
match decoded.schema_metadata {
Some(SchemaMetadata::Glue(schema)) => assert_eq!(schema.schema_version_id, version_id),
_ => unreachable!("expected glue metadata"),
}
}
#[tokio::test]
async fn test_schema_decoder_unknown_passthrough() {
let decoder = WireFormatDecoder::new();
let decoded = ok(decoder.decode(Bytes::from_static(b"plain-data")).await);
assert_eq!(decoded.schema_format, SchemaFormat::Unknown);
assert_eq!(&decoded.payload[..], b"plain-data");
assert!(decoded.schema_metadata.is_none());
}
#[tokio::test]
async fn test_schema_decoder_missing_registry_is_error() {
let decoder = WireFormatDecoder::new();
let encoded = encode_wire_format(1, b"x");
let result = decoder.decode(encoded).await;
assert!(result.is_err());
assert!(
err(result)
.to_string()
.contains("missing Confluent registry")
);
}
#[tokio::test]
async fn test_schema_decoder_truncated_confluent_header_passthrough() {
let decoder = WireFormatDecoder::new();
let truncated = Bytes::from_static(&[MAGIC_BYTE, 0x00, 0x01]);
let decoded = ok(decoder.decode(truncated.clone()).await);
assert_eq!(decoded.schema_format, SchemaFormat::Unknown);
assert_eq!(decoded.payload, truncated);
assert!(decoded.schema_metadata.is_none());
}
#[tokio::test]
async fn test_schema_decoder_truncated_glue_header_passthrough() {
let decoder = WireFormatDecoder::new();
let truncated = Bytes::from_static(&[glue::GLUE_HEADER_VERSION_BYTE, 0x00, 0x01, 0x02]);
let decoded = ok(decoder.decode(truncated.clone()).await);
assert_eq!(decoded.schema_format, SchemaFormat::Unknown);
assert_eq!(decoded.payload, truncated);
assert!(decoded.schema_metadata.is_none());
}
#[test]
fn test_subject_default_is_topic_name() {
assert_eq!(
SubjectNameStrategy::default(),
SubjectNameStrategy::TopicName
);
}
#[test]
fn test_subject_topic_name_key() {
let s = ok(SubjectNameStrategy::TopicName.subject_name("orders", None, true));
assert_eq!(s, "orders-key");
}
#[test]
fn test_subject_topic_name_value() {
let s = ok(SubjectNameStrategy::TopicName.subject_name("orders", None, false));
assert_eq!(s, "orders-value");
}
#[test]
fn test_subject_record_name() {
let s = ok(SubjectNameStrategy::RecordName.subject_name(
"orders",
Some("com.example.Order"),
false,
));
assert_eq!(s, "com.example.Order");
}
#[test]
fn test_subject_record_name_missing() {
let result = SubjectNameStrategy::RecordName.subject_name("orders", None, false);
assert!(result.is_err());
}
#[test]
fn test_subject_topic_record_name() {
let s =
ok(SubjectNameStrategy::TopicRecordName.subject_name("orders", Some("Order"), true));
assert_eq!(s, "orders-Order");
}
#[test]
fn test_subject_topic_record_name_missing() {
let result = SubjectNameStrategy::TopicRecordName.subject_name("orders", None, true);
assert!(result.is_err());
}
#[test]
fn test_schema_type_display() {
assert_eq!(SchemaType::Avro.to_string(), "AVRO");
assert_eq!(SchemaType::Protobuf.to_string(), "PROTOBUF");
assert_eq!(SchemaType::Json.to_string(), "JSON");
}
#[test]
fn test_schema_type_from_str() {
assert_eq!(ok("AVRO".parse::<SchemaType>()), SchemaType::Avro);
assert_eq!(ok("PROTOBUF".parse::<SchemaType>()), SchemaType::Protobuf);
assert_eq!(ok("JSON".parse::<SchemaType>()), SchemaType::Json);
}
#[test]
fn test_schema_type_from_str_unknown() {
let result = "XML".parse::<SchemaType>();
assert!(result.is_err());
assert!(err(result).to_string().contains("XML"));
}
#[test]
fn test_schema_new() {
let s = Schema::new(1, SchemaType::Avro, r#"{"type":"string"}"#);
assert_eq!(s.id, 1);
assert_eq!(s.schema_type, SchemaType::Avro);
assert_eq!(s.schema, r#"{"type":"string"}"#);
assert_eq!(s.version, None);
assert_eq!(s.subject, None);
assert!(s.references.is_empty());
}
#[test]
fn test_schema_with_subject() {
let s = Schema::new(1, SchemaType::Avro, "{}").with_subject("my-topic-value", 3);
assert_eq!(s.subject, Some("my-topic-value".to_string()));
assert_eq!(s.version, Some(3));
}
#[test]
fn test_schema_with_references() {
let refs = vec![SchemaReference::new("Ref", "ref-subject", 1)];
let s = Schema::new(1, SchemaType::Avro, "{}").with_references(refs.clone());
assert_eq!(s.references, refs);
}
#[test]
fn test_schema_reference_new() {
let r = SchemaReference::new("com.example.Address", "address-value", 2);
assert_eq!(r.name, "com.example.Address");
assert_eq!(r.subject, "address-value");
assert_eq!(r.version, 2);
}
struct MockRegistry {
get_by_id_calls: AtomicU32,
}
impl MockRegistry {
fn new() -> Self {
Self {
get_by_id_calls: AtomicU32::new(0),
}
}
fn get_by_id_call_count(&self) -> u32 {
self.get_by_id_calls.load(Ordering::SeqCst)
}
}
impl SchemaRegistryClient for MockRegistry {
async fn get_schema_by_id(&self, id: SchemaId) -> Result<Schema> {
self.get_by_id_calls.fetch_add(1, Ordering::SeqCst);
Ok(Schema::new(id, SchemaType::Avro, r#"{"type":"string"}"#))
}
async fn get_latest_schema(&self, subject: &str) -> Result<Schema> {
Ok(Schema::new(100, SchemaType::Avro, r#"{"type":"string"}"#).with_subject(subject, 1))
}
async fn get_schema_by_version(
&self,
subject: &str,
version: SchemaVersion,
) -> Result<Schema> {
Ok(Schema::new(100, SchemaType::Avro, r#"{"type":"string"}"#)
.with_subject(subject, version))
}
async fn register_schema(
&self,
_subject: &str,
_schema: &str,
_schema_type: SchemaType,
_references: &[SchemaReference],
) -> Result<SchemaId> {
Ok(42)
}
}
struct BlockingMockRegistry {
get_by_id_calls: AtomicU32,
get_latest_calls: AtomicU32,
get_by_version_calls: AtomicU32,
started: Notify,
release: Semaphore,
waiting_calls: AtomicU32,
}
impl BlockingMockRegistry {
fn new() -> Self {
Self {
get_by_id_calls: AtomicU32::new(0),
get_latest_calls: AtomicU32::new(0),
get_by_version_calls: AtomicU32::new(0),
started: Notify::new(),
release: Semaphore::new(0),
waiting_calls: AtomicU32::new(0),
}
}
fn get_by_id_call_count(&self) -> u32 {
self.get_by_id_calls.load(Ordering::SeqCst)
}
fn get_latest_call_count(&self) -> u32 {
self.get_latest_calls.load(Ordering::SeqCst)
}
fn get_by_version_call_count(&self) -> u32 {
self.get_by_version_calls.load(Ordering::SeqCst)
}
async fn wait_started(&self) {
self.started.notified().await;
}
fn release(&self) {
let waiting = self.waiting_calls.swap(0, Ordering::SeqCst);
self.release.add_permits(waiting as usize);
}
}
impl SchemaRegistryClient for BlockingMockRegistry {
async fn get_schema_by_id(&self, id: SchemaId) -> Result<Schema> {
self.get_by_id_calls.fetch_add(1, Ordering::SeqCst);
self.started.notify_waiters();
self.waiting_calls.fetch_add(1, Ordering::SeqCst);
let _ = self
.release
.acquire()
.await
.expect("blocking registry release permit");
Ok(Schema::new(id, SchemaType::Avro, r#"{"type":"string"}"#))
}
async fn get_latest_schema(&self, subject: &str) -> Result<Schema> {
self.get_latest_calls.fetch_add(1, Ordering::SeqCst);
self.started.notify_waiters();
self.waiting_calls.fetch_add(1, Ordering::SeqCst);
let _ = self
.release
.acquire()
.await
.expect("blocking registry release permit");
Ok(Schema::new(100, SchemaType::Avro, r#"{"type":"string"}"#).with_subject(subject, 1))
}
async fn get_schema_by_version(
&self,
subject: &str,
version: SchemaVersion,
) -> Result<Schema> {
self.get_by_version_calls.fetch_add(1, Ordering::SeqCst);
self.started.notify_waiters();
self.waiting_calls.fetch_add(1, Ordering::SeqCst);
let _ = self
.release
.acquire()
.await
.expect("blocking registry release permit");
Ok(Schema::new(100, SchemaType::Avro, r#"{"type":"string"}"#)
.with_subject(subject, version))
}
async fn register_schema(
&self,
_subject: &str,
_schema: &str,
_schema_type: SchemaType,
_references: &[SchemaReference],
) -> Result<SchemaId> {
Ok(42)
}
}
#[tokio::test]
async fn test_cache_miss_then_hit() {
let mock = MockRegistry::new();
let cached = CachedSchemaRegistry::new(mock);
let s1 = ok(cached.get_schema_by_id(1).await);
assert_eq!(cached.inner().get_by_id_call_count(), 1);
assert_eq!(cached.cache_len(), 1);
let s2 = ok(cached.get_schema_by_id(1).await);
assert_eq!(cached.inner().get_by_id_call_count(), 1);
assert_eq!(s1, s2);
}
#[tokio::test]
async fn test_cache_different_ids() {
let mock = MockRegistry::new();
let cached = CachedSchemaRegistry::new(mock);
ok(cached.get_schema_by_id(1).await);
ok(cached.get_schema_by_id(2).await);
assert_eq!(cached.inner().get_by_id_call_count(), 2);
assert_eq!(cached.cache_len(), 2);
ok(cached.get_schema_by_id(1).await);
ok(cached.get_schema_by_id(2).await);
assert_eq!(cached.inner().get_by_id_call_count(), 2);
}
#[tokio::test]
async fn test_default_cache_is_bounded_and_populates_insertion_order() {
let mock = MockRegistry::new();
let cached = CachedSchemaRegistry::new(mock);
ok(cached.get_schema_by_id(1).await);
ok(cached.get_schema_by_id(2).await);
assert_eq!(cached.cache_len(), 2);
assert_eq!(cached.insertion_order.read().len(), 2);
cached.invalidate(1);
assert_eq!(cached.cache_len(), 1);
assert_eq!(cached.insertion_order.read().len(), 1); }
#[tokio::test]
async fn test_cache_clear() {
let mock = MockRegistry::new();
let cached = CachedSchemaRegistry::new(mock);
ok(cached.get_schema_by_id(1).await);
assert_eq!(cached.cache_len(), 1);
cached.clear_cache();
assert_eq!(cached.cache_len(), 0);
ok(cached.get_schema_by_id(1).await);
assert_eq!(cached.inner().get_by_id_call_count(), 2);
}
#[tokio::test]
async fn test_cache_invalidate_single_entry() {
let mock = MockRegistry::new();
let cached = CachedSchemaRegistry::new(mock);
ok(cached.get_schema_by_id(1).await);
ok(cached.get_schema_by_id(2).await);
assert_eq!(cached.cache_len(), 2);
cached.invalidate(1);
assert_eq!(cached.cache_len(), 1);
ok(cached.get_schema_by_id(2).await);
assert_eq!(cached.inner().get_by_id_call_count(), 2);
ok(cached.get_schema_by_id(1).await);
assert_eq!(cached.inner().get_by_id_call_count(), 3);
}
#[tokio::test]
async fn test_cache_invalidate_does_not_repopulate_from_inflight_fetch() {
let cached = Arc::new(CachedSchemaRegistry::new(BlockingMockRegistry::new()));
let first = {
let cached = cached.clone();
tokio::spawn(async move { ok(cached.get_schema_by_id(7).await) })
};
cached.inner().wait_started().await;
cached.invalidate(7);
{
let cached = cached.clone();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
cached.inner().release();
});
}
let _ = tokio::time::timeout(std::time::Duration::from_secs(5), first)
.await
.expect("in-flight fetch did not complete")
.expect("in-flight task failed");
assert_eq!(cached.cache_len(), 0);
let second = {
let cached = cached.clone();
tokio::spawn(async move { ok(cached.get_schema_by_id(7).await) })
};
cached.inner().wait_started().await;
cached.inner().release();
let _ = join_ok(second.await);
assert_eq!(cached.inner().get_by_id_call_count(), 2);
}
#[tokio::test]
async fn test_cache_clear_does_not_repopulate_from_inflight_fetch() {
let cached = Arc::new(CachedSchemaRegistry::new(BlockingMockRegistry::new()));
let first = {
let cached = cached.clone();
tokio::spawn(async move { ok(cached.get_schema_by_id(7).await) })
};
cached.inner().wait_started().await;
cached.clear_cache();
{
let cached = cached.clone();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
cached.inner().release();
});
}
let _ = tokio::time::timeout(std::time::Duration::from_secs(5), first)
.await
.expect("in-flight fetch did not complete")
.expect("in-flight task failed");
assert_eq!(cached.cache_len(), 0);
let second = {
let cached = cached.clone();
tokio::spawn(async move { ok(cached.get_schema_by_id(7).await) })
};
cached.inner().wait_started().await;
cached.inner().release();
let _ = join_ok(second.await);
assert_eq!(cached.inner().get_by_id_call_count(), 2);
}
#[tokio::test]
async fn test_invalidate_single_id_does_not_block_other_inflight_cache_insert() {
let cached = Arc::new(CachedSchemaRegistry::new(BlockingMockRegistry::new()));
let id_cancelled = 7;
let id_unrelated = 8;
let t1 = {
let cached = cached.clone();
tokio::spawn(async move { ok(cached.get_schema_by_id(id_cancelled).await) })
};
let t2 = {
let cached = cached.clone();
tokio::spawn(async move { ok(cached.get_schema_by_id(id_unrelated).await) })
};
tokio::time::timeout(std::time::Duration::from_secs(5), async {
while cached.inner().get_by_id_call_count() < 2 {
tokio::task::yield_now().await;
}
})
.await
.expect("both in-flight lookups did not start");
cached.invalidate(id_cancelled);
cached.inner().release();
let _ = join_ok(t1.await);
let _ = join_ok(t2.await);
let calls_after_inflight = cached.inner().get_by_id_call_count();
ok(cached.get_schema_by_id(id_unrelated).await);
assert_eq!(cached.inner().get_by_id_call_count(), calls_after_inflight);
let miss = {
let cached = cached.clone();
tokio::spawn(async move { ok(cached.get_schema_by_id(id_cancelled).await) })
};
cached.inner().wait_started().await;
cached.inner().release();
let _ = join_ok(miss.await);
assert_eq!(
cached.inner().get_by_id_call_count(),
calls_after_inflight + 1
);
}
#[tokio::test]
async fn test_cache_warm_cache_deduplicates_ids() {
let mock = MockRegistry::new();
let cached = CachedSchemaRegistry::new(mock);
ok(cached.warm_cache(&[1, 2, 1, 2, 3]).await);
assert_eq!(cached.inner().get_by_id_call_count(), 3);
assert_eq!(cached.cache_len(), 3);
ok(cached.get_schema_by_id(1).await);
ok(cached.get_schema_by_id(2).await);
ok(cached.get_schema_by_id(3).await);
assert_eq!(cached.inner().get_by_id_call_count(), 3);
}
#[tokio::test]
async fn test_cache_coalesces_concurrent_misses() {
let cached = Arc::new(CachedSchemaRegistry::new(BlockingMockRegistry::new()));
let first = {
let cached = cached.clone();
tokio::spawn(async move { ok(cached.get_schema_by_id(7).await) })
};
cached.inner().wait_started().await;
let second = {
let cached = cached.clone();
tokio::spawn(async move { ok(cached.get_schema_by_id(7).await) })
};
tokio::task::yield_now().await;
cached.inner().release();
let first_schema = join_ok(first.await);
let second_schema = join_ok(second.await);
assert_eq!(first_schema, second_schema);
assert_eq!(cached.inner().get_by_id_call_count(), 1);
}
#[tokio::test]
async fn test_cache_coalescer_cleans_up_when_leader_is_cancelled() {
let cached = Arc::new(CachedSchemaRegistry::new(BlockingMockRegistry::new()));
let first = {
let cached = cached.clone();
tokio::spawn(async move { ok(cached.get_schema_by_id(9).await) })
};
cached.inner().wait_started().await;
first.abort();
tokio::task::yield_now().await;
let second = {
let cached = cached.clone();
tokio::spawn(async move { ok(cached.get_schema_by_id(9).await) })
};
tokio::time::timeout(
std::time::Duration::from_secs(5),
cached.inner().wait_started(),
)
.await
.expect("second lookup did not reach inner registry");
cached.inner().release();
let schema = tokio::time::timeout(std::time::Duration::from_secs(5), second)
.await
.expect("second lookup timed out")
.expect("second task failed");
assert_eq!(schema.id, 9);
}
#[tokio::test]
async fn test_cache_get_latest_populates_id_cache() {
let mock = MockRegistry::new();
let cached = CachedSchemaRegistry::new(mock);
let schema = ok(cached.get_latest_schema("test-value").await);
assert_eq!(cached.cache_len(), 1);
let by_id = ok(cached.get_schema_by_id(schema.id).await);
assert_eq!(cached.inner().get_by_id_call_count(), 0);
assert_eq!(by_id.id, schema.id);
}
#[tokio::test]
async fn test_invalidate_drops_inflight_get_latest_cache_population() {
let cached = Arc::new(CachedSchemaRegistry::new(BlockingMockRegistry::new()));
let latest = {
let cached = cached.clone();
tokio::spawn(async move { ok(cached.get_latest_schema("test-value").await) })
};
tokio::time::timeout(std::time::Duration::from_secs(5), async {
while cached.inner().get_latest_call_count() < 1 {
tokio::task::yield_now().await;
}
})
.await
.expect("latest lookup did not start");
cached.invalidate(100);
cached.inner().release();
let _ = join_ok(latest.await);
assert_eq!(cached.cache_len(), 0);
ok(cached.get_schema_by_id(100).await);
assert_eq!(cached.inner().get_by_id_call_count(), 1);
}
#[tokio::test]
async fn test_cache_get_by_version_populates_id_cache() {
let mock = MockRegistry::new();
let cached = CachedSchemaRegistry::new(mock);
let schema = ok(cached.get_schema_by_version("test-value", 1).await);
assert_eq!(cached.cache_len(), 1);
let by_id = ok(cached.get_schema_by_id(schema.id).await);
assert_eq!(cached.inner().get_by_id_call_count(), 0);
assert_eq!(by_id.id, schema.id);
}
#[tokio::test]
async fn test_invalidate_drops_inflight_get_by_version_cache_population() {
let cached = Arc::new(CachedSchemaRegistry::new(BlockingMockRegistry::new()));
let by_version = {
let cached = cached.clone();
tokio::spawn(async move { ok(cached.get_schema_by_version("test-value", 1).await) })
};
tokio::time::timeout(std::time::Duration::from_secs(5), async {
while cached.inner().get_by_version_call_count() < 1 {
tokio::task::yield_now().await;
}
})
.await
.expect("version lookup did not start");
cached.invalidate(100);
cached.inner().release();
let _ = join_ok(by_version.await);
assert_eq!(cached.cache_len(), 0);
ok(cached.get_schema_by_id(100).await);
assert_eq!(cached.inner().get_by_id_call_count(), 1);
}
#[tokio::test]
async fn test_cache_register_forwards() {
let mock = MockRegistry::new();
let cached = CachedSchemaRegistry::new(mock);
let id = cached
.register_schema("test-value", "{}", SchemaType::Avro, &[])
.await;
let id = ok(id);
assert_eq!(id, 42);
}
#[test]
fn test_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Schema>();
assert_send_sync::<SchemaReference>();
assert_send_sync::<SchemaType>();
assert_send_sync::<SubjectNameStrategy>();
assert_send_sync::<CachedSchemaRegistry<MockRegistry>>();
}
#[test]
fn test_erased_object_safe() {
fn _assert_object_safe(_: &dyn ErasedSchemaRegistryClient) {}
}
#[test]
fn test_cached_debug() {
let cached = CachedSchemaRegistry::new(MockRegistry::new());
let debug = format!("{cached:?}");
assert!(debug.contains("cache_len"));
}
#[test]
fn test_wire_format_bytes_roundtrip() {
let payload = b"hello world";
let encoded = encode_wire_format(42, payload);
let (id, decoded) = ok(decode_wire_format_bytes(&encoded));
assert_eq!(id, 42);
assert_eq!(&decoded[..], payload);
}
#[test]
fn test_wire_format_bytes_empty_payload() {
let encoded = encode_wire_format(1, b"");
let (id, payload) = ok(decode_wire_format_bytes(&encoded));
assert_eq!(id, 1);
assert!(payload.is_empty());
}
#[test]
fn test_wire_format_bytes_invalid_magic() {
let data = Bytes::from_static(&[0x01, 0, 0, 0, 1, 0x42]);
let result = decode_wire_format_bytes(&data);
assert!(result.is_err());
assert!(err(result).to_string().contains("magic byte"));
}
#[test]
fn test_wire_format_bytes_too_short() {
let data = Bytes::from_static(&[0x00, 0, 0]);
let result = decode_wire_format_bytes(&data);
assert!(result.is_err());
assert!(err(result).to_string().contains("too short"));
}
#[test]
fn test_wire_format_bytes_zero_copy() {
let encoded = encode_wire_format(99, b"shared");
let (_, payload) = ok(decode_wire_format_bytes(&encoded));
assert_eq!(&payload[..], b"shared");
}
#[test]
fn test_schema_type_from_str_lowercase() {
assert_eq!(ok("avro".parse::<SchemaType>()), SchemaType::Avro);
assert_eq!(ok("protobuf".parse::<SchemaType>()), SchemaType::Protobuf);
assert_eq!(ok("json".parse::<SchemaType>()), SchemaType::Json);
}
#[test]
fn test_schema_type_from_str_mixed_case() {
assert_eq!(ok("Avro".parse::<SchemaType>()), SchemaType::Avro);
assert_eq!(ok("ProtobuF".parse::<SchemaType>()), SchemaType::Protobuf);
assert_eq!(ok("Json".parse::<SchemaType>()), SchemaType::Json);
}
#[tokio::test]
async fn test_cache_with_capacity() {
let mock = MockRegistry::new();
let cached = CachedSchemaRegistry::with_capacity(mock, 100);
assert_eq!(cached.cache_len(), 0);
ok(cached.get_schema_by_id(1).await);
assert_eq!(cached.cache_len(), 1);
}
#[tokio::test]
async fn test_cache_with_max_entries_evicts_oldest_entry() {
let mock = MockRegistry::new();
let cached = CachedSchemaRegistry::with_max_entries(mock, 1);
ok(cached.get_schema_by_id(1).await);
ok(cached.get_schema_by_id(2).await);
assert_eq!(cached.cache_len(), 1);
assert_eq!(cached.inner().get_by_id_call_count(), 2);
ok(cached.get_schema_by_id(1).await);
assert_eq!(cached.inner().get_by_id_call_count(), 3);
}
mod inherent_api_tests {
use crate::Result;
use crate::schema_registry::{
CachedSchemaRegistry, Schema, SchemaReference, SchemaType, SchemaVersion,
};
struct InherentMockRegistry;
impl crate::schema_registry::SchemaRegistryClient for InherentMockRegistry {
async fn get_schema_by_id(&self, id: u32) -> Result<Schema> {
Ok(Schema::new(id, SchemaType::Avro, r#"{"type":"string"}"#))
}
async fn get_latest_schema(&self, subject: &str) -> Result<Schema> {
let subject = subject.to_string();
Ok(Schema::new(7, SchemaType::Avro, r#"{"type":"string"}"#)
.with_subject(subject, 1))
}
async fn get_schema_by_version(
&self,
subject: &str,
version: SchemaVersion,
) -> Result<Schema> {
let subject = subject.to_string();
Ok(Schema::new(9, SchemaType::Avro, r#"{"type":"string"}"#)
.with_subject(subject, version))
}
async fn register_schema(
&self,
_subject: &str,
_schema: &str,
_schema_type: SchemaType,
_references: &[SchemaReference],
) -> Result<u32> {
Ok(42)
}
}
#[tokio::test]
async fn cached_schema_registry_methods_work_without_trait_import() {
let cached = CachedSchemaRegistry::new(InherentMockRegistry);
let by_id = cached.get_schema_by_id(1).await.unwrap();
assert_eq!(by_id.id, 1);
let latest = cached.get_latest_schema("orders-value").await.unwrap();
assert_eq!(latest.id, 7);
let by_version = cached
.get_schema_by_version("orders-value", 3)
.await
.unwrap();
assert_eq!(by_version.id, 9);
let registered = cached
.register_schema(
"orders-value",
r#"{"type":"string"}"#,
SchemaType::Avro,
&[],
)
.await
.unwrap();
assert_eq!(registered, 42);
}
}
#[tokio::test]
async fn test_any_schema_cache_trait_for_confluent_cache() {
let mock = MockRegistry::new();
let cached = CachedSchemaRegistry::new(mock);
let generic_cache: &dyn AnySchemaCache<Id = SchemaId> = &cached;
ok(generic_cache.warm_cache(&[11, 12, 11]).await);
assert_eq!(generic_cache.cache_len(), 2);
assert!(!generic_cache.cache_is_empty());
generic_cache.invalidate(11);
assert_eq!(generic_cache.cache_len(), 1);
generic_cache.invalidate_all();
assert!(generic_cache.cache_is_empty());
}
struct FixedIdEncoder {
key_id: SchemaId,
value_id: SchemaId,
}
impl SchemaEncoder for FixedIdEncoder {
fn encode(
&self,
payload: Bytes,
_topic: &str,
_record_name: Option<&str>,
is_key: bool,
) -> Pin<Box<dyn Future<Output = Result<Bytes>> + Send + '_>> {
let id = if is_key { self.key_id } else { self.value_id };
let out = encode_wire_format(id, &payload);
Box::pin(std::future::ready(Ok(out)))
}
}
#[cfg(feature = "schema-registry")]
#[tokio::test]
async fn test_confluent_schema_encoder_value_only() {
let encoder = ConfluentSchemaEncoder::builder()
.registry(MockRegistry::new())
.schema(r#"{"type":"string"}"#, SchemaType::Avro)
.build()
.unwrap();
let raw = Bytes::from_static(b"hello");
let framed = encoder
.encode(raw.clone(), "orders", None, false)
.await
.unwrap();
assert_eq!(framed.len(), 5 + raw.len());
let (id, payload) = decode_wire_format(&framed).unwrap();
assert_eq!(id, 42); assert_eq!(payload, &raw[..]);
}
#[cfg(feature = "schema-registry")]
#[tokio::test]
async fn test_confluent_schema_encoder_key_only() {
let encoder = ConfluentSchemaEncoder::builder()
.registry(MockRegistry::new())
.schema(r#"{"type":"string"}"#, SchemaType::Avro)
.build()
.unwrap();
let raw_key = Bytes::from_static(b"my-key");
let framed_key = encoder
.encode(raw_key.clone(), "orders", None, true)
.await
.unwrap();
assert_eq!(framed_key.len(), 5 + raw_key.len());
let (id, key_payload) = decode_wire_format(&framed_key).unwrap();
assert_eq!(id, 42);
assert_eq!(key_payload, &raw_key[..]);
}
#[cfg(feature = "schema-registry")]
#[tokio::test]
async fn test_confluent_schema_encoder_both() {
let encoder = ConfluentSchemaEncoder::builder()
.registry(MockRegistry::new())
.schema(r#"{"type":"string"}"#, SchemaType::Avro)
.build()
.unwrap();
let framed_key = encoder
.encode(Bytes::from_static(b"k"), "t", None, true)
.await
.unwrap();
let framed_val = encoder
.encode(Bytes::from_static(b"v"), "t", None, false)
.await
.unwrap();
let (kid, _) = decode_wire_format(&framed_key).unwrap();
let (vid, _) = decode_wire_format(&framed_val).unwrap();
assert_eq!(kid, 42);
assert_eq!(vid, 42);
}
#[cfg(feature = "schema-registry")]
#[tokio::test]
async fn test_confluent_schema_encoder_id_cached() {
let encoder = ConfluentSchemaEncoder::builder()
.registry(MockRegistry::new())
.schema(r#"{"type":"string"}"#, SchemaType::Avro)
.build()
.unwrap();
for _ in 0..5 {
let framed = encoder
.encode(Bytes::from_static(b"x"), "t", None, false)
.await
.unwrap();
let (id, _) = decode_wire_format(&framed).unwrap();
assert_eq!(id, 42);
}
assert_eq!(encoder.id_cache.read().len(), 1);
}
#[tokio::test]
async fn test_fixed_encoder_frames_value() {
let encoder = FixedIdEncoder {
key_id: 1,
value_id: 7,
};
let framed = encoder
.encode(Bytes::from_static(b"payload"), "t", None, false)
.await
.unwrap();
let (id, data) = decode_wire_format(&framed).unwrap();
assert_eq!(id, 7);
assert_eq!(data, b"payload");
}
#[tokio::test]
async fn test_encoder_encodes_empty_key() {
let encoder = FixedIdEncoder {
key_id: 3,
value_id: 9,
};
let framed_key = encoder.encode(Bytes::new(), "t", None, true).await.unwrap();
let (kid, k_payload) = decode_wire_format(&framed_key).unwrap();
assert_eq!(kid, 3);
assert_eq!(k_payload, b"");
let framed_val = encoder
.encode(Bytes::from_static(b"val"), "t", None, false)
.await
.unwrap();
let (vid, v_payload) = decode_wire_format(&framed_val).unwrap();
assert_eq!(vid, 9);
assert_eq!(v_payload, b"val");
}
#[cfg(feature = "schema-registry")]
#[tokio::test]
async fn test_confluent_schema_encoder_topic_name_strategy_subject() {
let encoder = ConfluentSchemaEncoder::builder()
.registry(MockRegistry::new())
.schema(r#"{"type":"string"}"#, SchemaType::Avro)
.strategy(SubjectNameStrategy::TopicName)
.build()
.unwrap();
let _ = encoder
.encode(Bytes::from_static(b"x"), "orders", None, false)
.await
.unwrap();
assert!(encoder.id_cache.read().contains_key("orders-value"));
}
#[cfg(feature = "schema-registry")]
#[test]
fn test_confluent_schema_encoder_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<ConfluentSchemaEncoder<MockRegistry>>();
}
#[cfg(feature = "schema-registry")]
#[test]
fn test_confluent_schema_encoder_is_object_safe() {
fn _takes_dyn(_: &dyn SchemaEncoder) {}
}
#[cfg(feature = "schema-registry")]
#[test]
fn test_confluent_schema_encoder_debug() {
let enc = ConfluentSchemaEncoder::builder()
.registry(MockRegistry::new())
.schema(r#"{"type":"string"}"#, SchemaType::Avro)
.build()
.unwrap();
let debug_str = format!("{enc:?}");
assert!(debug_str.contains("ConfluentSchemaEncoder"));
assert!(!debug_str.contains(r#"{"type":"string"}"#));
}
#[cfg(feature = "schema-registry")]
#[tokio::test]
async fn test_confluent_schema_decoder_strips_header() {
let dec = ConfluentSchemaDecoder::new();
let inner = b"avro data";
let framed = encode_wire_format(42, inner);
let decoded = dec.decode(framed, "orders", false).await.unwrap();
assert_eq!(&decoded[..], inner);
}
#[cfg(feature = "schema-registry")]
#[tokio::test]
async fn test_confluent_schema_decoder_key_strips_header() {
let dec = ConfluentSchemaDecoder::new();
let inner = b"key-bytes";
let framed = encode_wire_format(7, inner);
let decoded = dec.decode(framed, "orders", true).await.unwrap();
assert_eq!(&decoded[..], inner);
}
#[cfg(feature = "schema-registry")]
#[tokio::test]
async fn test_confluent_schema_decoder_non_confluent_passthrough() {
let dec = ConfluentSchemaDecoder::new();
let plain = Bytes::from_static(b"plain text");
let out = dec.decode(plain.clone(), "t", false).await.unwrap();
assert_eq!(out, plain);
}
#[cfg(feature = "schema-registry")]
#[tokio::test]
async fn test_confluent_schema_decoder_empty_payload_passthrough() {
let dec = ConfluentSchemaDecoder::new();
let empty = Bytes::new();
let out = dec.decode(empty.clone(), "t", false).await.unwrap();
assert_eq!(out, empty);
}
#[cfg(feature = "schema-registry")]
#[tokio::test]
async fn test_confluent_schema_decoder_zero_copy() {
let dec = ConfluentSchemaDecoder::new();
let inner = b"data";
let framed = encode_wire_format(1, inner);
let decoded = dec.decode(framed.clone(), "t", false).await.unwrap();
assert_eq!(&decoded[..], inner);
assert_eq!(decoded.len(), inner.len());
}
#[cfg(feature = "schema-registry")]
#[test]
fn test_confluent_schema_decoder_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<ConfluentSchemaDecoder>();
}
#[cfg(feature = "schema-registry")]
#[test]
fn test_confluent_schema_decoder_debug() {
let dec = ConfluentSchemaDecoder::new();
let s = format!("{dec:?}");
assert!(s.contains("ConfluentSchemaDecoder"));
}
#[cfg(feature = "schema-registry")]
#[test]
fn test_confluent_schema_decoder_is_object_safe() {
fn _takes_dyn(_: &dyn SchemaDecoder) {}
}
}