use std::collections::HashSet;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::cache_inner::InMemoryCache;
use crate::error::{Result, SchemaRegError};
use crate::traits::{AnySchemaCache, SchemaRegistryClient};
use crate::types::{
CompatibilityLevel, Schema, SchemaId, SchemaReference, SchemaType, SchemaVersion,
};
fn schema_lookup_cancelled_error(id: SchemaId) -> SchemaRegError {
SchemaRegError::invalid_state(format!(
"schema lookup cancelled before completion for id {id}"
))
}
#[derive(Debug, Clone)]
pub struct WarmCacheError {
pub failures: Vec<(SchemaId, SchemaRegError)>,
}
impl fmt::Display for WarmCacheError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"warm_cache failed for {} schema ID(s):",
self.failures.len()
)?;
for (id, e) in &self.failures {
write!(f, " id {id}: {e};")?;
}
Ok(())
}
}
impl std::error::Error for WarmCacheError {}
pub struct CachedSchemaRegistry<C> {
inner: C,
cache: InMemoryCache<SchemaId, Schema>,
}
pub const DEFAULT_MAX_CACHE_ENTRIES: usize = 1000;
impl<C: SchemaRegistryClient> CachedSchemaRegistry<C> {
pub fn new(inner: C) -> Self {
Self::with_max_entries(inner, DEFAULT_MAX_CACHE_ENTRIES)
}
pub fn with_max_entries(inner: C, max_entries: usize) -> Self {
let max_entries = max_entries.max(1);
Self {
inner,
cache: InMemoryCache::new(Some(max_entries), schema_lookup_cancelled_error),
}
}
pub fn inner(&self) -> &C {
&self.inner
}
pub fn cache_len(&self) -> usize {
self.cache.len()
}
pub fn cache_is_empty(&self) -> bool {
self.cache.is_empty()
}
pub fn clear_cache(&self) {
self.cache.clear();
}
pub fn invalidate(&self, schema_id: impl Into<SchemaId>) {
self.cache.invalidate(schema_id.into());
}
pub fn invalidate_all(&self) {
self.cache.clear();
}
pub async fn warm_cache(
&self,
schema_ids: impl IntoIterator<Item = impl Into<SchemaId>>,
) -> std::result::Result<(), WarmCacheError> {
const WARM_CONCURRENCY: usize = 16;
let unique: HashSet<SchemaId> = schema_ids.into_iter().map(Into::into).collect();
if unique.is_empty() {
return Ok(());
}
let ids: Vec<SchemaId> = unique.into_iter().collect();
let mut failures: Vec<(SchemaId, SchemaRegError)> = Vec::new();
for chunk in ids.chunks(WARM_CONCURRENCY) {
let futs = chunk.iter().map(|&id| async move {
(
id,
self.cache
.get_or_fetch(id, || self.inner.get_schema_by_id(id))
.await,
)
});
let results = futures::future::join_all(futs).await;
for (id, result) in results {
if let Err(e) = result {
failures.push((id, e));
}
}
}
if failures.is_empty() {
Ok(())
} else {
Err(WarmCacheError { failures })
}
}
pub fn invalidate_subject(&self, subject: &str) {
let ids: Vec<SchemaId> = self
.cache
.keys_matching(|s: &Schema| s.subject.as_deref() == Some(subject));
for id in ids {
self.cache.invalidate(id);
}
}
pub async fn get_schema_by_id(&self, id: SchemaId) -> Result<Arc<Schema>> {
self.cache
.get_or_fetch(id, || self.inner.get_schema_by_id(id))
.await
}
pub async fn get_latest_schema(&self, subject: &str) -> Result<Arc<Schema>> {
let generation = self.cache.generation();
let schema = self.inner.get_latest_schema(subject).await?;
self.cache
.insert_if_current(schema.id, Arc::clone(&schema), generation);
Ok(schema)
}
pub async fn get_schema_by_version(
&self,
subject: &str,
version: SchemaVersion,
) -> Result<Arc<Schema>> {
let generation = self.cache.generation();
let schema = self.inner.get_schema_by_version(subject, version).await?;
self.cache
.insert_if_current(schema.id, Arc::clone(&schema), generation);
Ok(schema)
}
pub async fn register_schema(
&self,
subject: &str,
schema: &str,
schema_type: SchemaType,
references: &[SchemaReference],
) -> Result<SchemaId> {
self.inner
.register_schema(subject, schema, schema_type, references)
.await
}
}
impl<C> fmt::Debug for CachedSchemaRegistry<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CachedSchemaRegistry")
.field("cache_len", &self.cache.len())
.field("cache", &self.cache)
.finish()
}
}
impl<C: SchemaRegistryClient> SchemaRegistryClient for CachedSchemaRegistry<C> {
async fn get_schema_by_id(&self, id: SchemaId) -> Result<Arc<Schema>> {
self.get_schema_by_id(id).await
}
async fn get_latest_schema(&self, subject: &str) -> Result<Arc<Schema>> {
self.get_latest_schema(subject).await
}
async fn get_schema_by_version(
&self,
subject: &str,
version: SchemaVersion,
) -> Result<Arc<Schema>> {
self.get_schema_by_version(subject, version).await
}
async fn register_schema(
&self,
subject: &str,
schema: &str,
schema_type: SchemaType,
references: &[SchemaReference],
) -> Result<SchemaId> {
self.register_schema(subject, schema, schema_type, references)
.await
}
fn check_compatibility<'a>(
&'a self,
subject: &'a str,
schema: &'a str,
schema_type: SchemaType,
references: &'a [SchemaReference],
) -> impl Future<Output = Result<bool>> + Send + 'a {
self.inner
.check_compatibility(subject, schema, schema_type, references)
}
fn delete_subject<'a>(
&'a self,
subject: &'a str,
permanent: bool,
) -> impl Future<Output = Result<Vec<SchemaVersion>>> + Send + 'a {
self.inner.delete_subject(subject, permanent)
}
fn get_subjects(&self) -> impl Future<Output = Result<Vec<String>>> + Send + '_ {
self.inner.get_subjects()
}
fn get_versions<'a>(
&'a self,
subject: &'a str,
) -> impl Future<Output = Result<Vec<SchemaVersion>>> + Send + 'a {
self.inner.get_versions(subject)
}
fn health_check(&self) -> impl Future<Output = Result<()>> + Send + '_ {
self.inner.health_check()
}
fn set_compatibility<'a>(
&'a self,
subject: &'a str,
level: CompatibilityLevel,
) -> impl Future<Output = Result<()>> + Send + 'a {
self.inner.set_compatibility(subject, level)
}
fn get_compatibility<'a>(
&'a self,
subject: &'a str,
) -> impl Future<Output = Result<CompatibilityLevel>> + Send + 'a {
self.inner.get_compatibility(subject)
}
}
impl<C: SchemaRegistryClient> AnySchemaCache for CachedSchemaRegistry<C> {
type Id = SchemaId;
fn cache_len(&self) -> usize {
Self::cache_len(self)
}
fn cache_is_empty(&self) -> bool {
Self::cache_is_empty(self)
}
fn clear_cache(&self) {
Self::clear_cache(self)
}
fn invalidate(&self, id: Self::Id) {
Self::invalidate(self, id)
}
fn invalidate_all(&self) {
Self::invalidate_all(self)
}
fn warm_cache<'a>(
&'a self,
ids: &'a [Self::Id],
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
Box::pin(async move {
Self::warm_cache(self, ids.iter().copied())
.await
.map_err(|e| SchemaRegError::invalid_state(e.to_string()))
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering as AtomicOrdering};
use tokio::sync::{Notify, Semaphore};
fn ok<T, E: std::fmt::Display>(result: std::result::Result<T, E>) -> T {
match result {
Ok(v) => v,
Err(e) => unreachable!("expected Ok(..), got Err({e})"),
}
}
fn join_ok<T>(result: std::result::Result<T, tokio::task::JoinError>) -> T {
match result {
Ok(v) => v,
Err(e) => unreachable!("spawned task failed: {e}"),
}
}
struct MockRegistry {
get_by_id_calls: AtomicU32,
}
impl MockRegistry {
fn new() -> Self {
Self {
get_by_id_calls: AtomicU32::new(0),
}
}
fn get_by_id_call_count(&self) -> u32 {
self.get_by_id_calls.load(AtomicOrdering::SeqCst)
}
}
impl SchemaRegistryClient for MockRegistry {
async fn get_schema_by_id(&self, id: SchemaId) -> Result<Arc<Schema>> {
self.get_by_id_calls.fetch_add(1, AtomicOrdering::SeqCst);
Ok(Arc::new(Schema::new(
id,
crate::types::SchemaType::Avro,
r#"{"type":"string"}"#,
)))
}
async fn get_latest_schema(&self, subject: &str) -> Result<Arc<Schema>> {
Ok(Arc::new(
Schema::new(
SchemaId::from(100u32),
crate::types::SchemaType::Avro,
r#"{"type":"string"}"#,
)
.with_subject(subject, 1i32),
))
}
async fn get_schema_by_version(
&self,
subject: &str,
version: SchemaVersion,
) -> Result<Arc<Schema>> {
Ok(Arc::new(
Schema::new(
SchemaId::from(100u32),
crate::types::SchemaType::Avro,
r#"{"type":"string"}"#,
)
.with_subject(subject, version),
))
}
async fn register_schema(
&self,
_subject: &str,
_schema: &str,
_schema_type: SchemaType,
_references: &[SchemaReference],
) -> Result<SchemaId> {
Ok(SchemaId::from(42u32))
}
}
struct BlockingMockRegistry {
get_by_id_calls: AtomicU32,
get_latest_calls: AtomicU32,
get_by_version_calls: AtomicU32,
started: Notify,
release: Semaphore,
waiting_calls: AtomicU32,
}
impl BlockingMockRegistry {
fn new() -> Self {
Self {
get_by_id_calls: AtomicU32::new(0),
get_latest_calls: AtomicU32::new(0),
get_by_version_calls: AtomicU32::new(0),
started: Notify::new(),
release: Semaphore::new(0),
waiting_calls: AtomicU32::new(0),
}
}
fn get_by_id_call_count(&self) -> u32 {
self.get_by_id_calls.load(AtomicOrdering::SeqCst)
}
fn get_latest_call_count(&self) -> u32 {
self.get_latest_calls.load(AtomicOrdering::SeqCst)
}
fn get_by_version_call_count(&self) -> u32 {
self.get_by_version_calls.load(AtomicOrdering::SeqCst)
}
async fn wait_started(&self) {
self.started.notified().await;
}
fn release(&self) {
let waiting = self.waiting_calls.swap(0, AtomicOrdering::SeqCst);
self.release.add_permits(waiting as usize);
}
}
impl SchemaRegistryClient for BlockingMockRegistry {
async fn get_schema_by_id(&self, id: SchemaId) -> Result<Arc<Schema>> {
self.get_by_id_calls.fetch_add(1, AtomicOrdering::SeqCst);
self.started.notify_waiters();
self.waiting_calls.fetch_add(1, AtomicOrdering::SeqCst);
let _ = self
.release
.acquire()
.await
.expect("blocking registry release permit");
Ok(Arc::new(Schema::new(
id,
crate::types::SchemaType::Avro,
r#"{"type":"string"}"#,
)))
}
async fn get_latest_schema(&self, subject: &str) -> Result<Arc<Schema>> {
self.get_latest_calls.fetch_add(1, AtomicOrdering::SeqCst);
self.started.notify_waiters();
self.waiting_calls.fetch_add(1, AtomicOrdering::SeqCst);
let _ = self
.release
.acquire()
.await
.expect("blocking registry release permit");
Ok(Arc::new(
Schema::new(
SchemaId::from(100u32),
crate::types::SchemaType::Avro,
r#"{"type":"string"}"#,
)
.with_subject(subject, 1i32),
))
}
async fn get_schema_by_version(
&self,
subject: &str,
version: SchemaVersion,
) -> Result<Arc<Schema>> {
self.get_by_version_calls
.fetch_add(1, AtomicOrdering::SeqCst);
self.started.notify_waiters();
self.waiting_calls.fetch_add(1, AtomicOrdering::SeqCst);
let _ = self
.release
.acquire()
.await
.expect("blocking registry release permit");
Ok(Arc::new(
Schema::new(
SchemaId::from(100u32),
crate::types::SchemaType::Avro,
r#"{"type":"string"}"#,
)
.with_subject(subject, version),
))
}
async fn register_schema(
&self,
_subject: &str,
_schema: &str,
_schema_type: SchemaType,
_references: &[SchemaReference],
) -> Result<SchemaId> {
Ok(SchemaId::from(42u32))
}
}
#[tokio::test]
async fn test_cache_miss_then_hit() {
let cached = CachedSchemaRegistry::new(MockRegistry::new());
let s1 = ok(cached.get_schema_by_id(SchemaId::from(1u32)).await);
assert_eq!(cached.inner().get_by_id_call_count(), 1);
assert_eq!(cached.cache_len(), 1);
let s2 = ok(cached.get_schema_by_id(SchemaId::from(1u32)).await);
assert_eq!(cached.inner().get_by_id_call_count(), 1);
assert_eq!(s1, s2);
}
#[tokio::test]
async fn test_cache_different_ids() {
let cached = CachedSchemaRegistry::new(MockRegistry::new());
ok(cached.get_schema_by_id(SchemaId::from(1u32)).await);
ok(cached.get_schema_by_id(SchemaId::from(2u32)).await);
assert_eq!(cached.inner().get_by_id_call_count(), 2);
assert_eq!(cached.cache_len(), 2);
ok(cached.get_schema_by_id(SchemaId::from(1u32)).await);
ok(cached.get_schema_by_id(SchemaId::from(2u32)).await);
assert_eq!(cached.inner().get_by_id_call_count(), 2);
}
#[tokio::test]
async fn test_cache_clear() {
let cached = CachedSchemaRegistry::new(MockRegistry::new());
ok(cached.get_schema_by_id(SchemaId::from(1u32)).await);
assert_eq!(cached.cache_len(), 1);
cached.clear_cache();
assert_eq!(cached.cache_len(), 0);
ok(cached.get_schema_by_id(SchemaId::from(1u32)).await);
assert_eq!(cached.inner().get_by_id_call_count(), 2);
}
#[tokio::test]
async fn test_cache_invalidate_single_entry() {
let cached = CachedSchemaRegistry::new(MockRegistry::new());
ok(cached.get_schema_by_id(SchemaId::from(1u32)).await);
ok(cached.get_schema_by_id(SchemaId::from(2u32)).await);
cached.invalidate(1u32);
assert_eq!(cached.cache_len(), 1);
ok(cached.get_schema_by_id(SchemaId::from(2u32)).await);
assert_eq!(cached.inner().get_by_id_call_count(), 2);
ok(cached.get_schema_by_id(SchemaId::from(1u32)).await);
assert_eq!(cached.inner().get_by_id_call_count(), 3);
}
#[tokio::test]
async fn test_cache_warm_cache_deduplicates_ids() {
let cached = CachedSchemaRegistry::new(MockRegistry::new());
ok(cached.warm_cache([1u32, 2u32, 1u32, 2u32, 3u32]).await);
assert_eq!(cached.inner().get_by_id_call_count(), 3);
assert_eq!(cached.cache_len(), 3);
ok(cached.get_schema_by_id(SchemaId::from(1u32)).await);
ok(cached.get_schema_by_id(SchemaId::from(2u32)).await);
ok(cached.get_schema_by_id(SchemaId::from(3u32)).await);
assert_eq!(cached.inner().get_by_id_call_count(), 3);
}
#[tokio::test]
async fn test_cache_coalesces_concurrent_misses() {
let cached = Arc::new(CachedSchemaRegistry::new(BlockingMockRegistry::new()));
let first = {
let c = cached.clone();
tokio::spawn(async move { ok(c.get_schema_by_id(SchemaId::from(7u32)).await) })
};
cached.inner().wait_started().await;
let second = {
let c = cached.clone();
tokio::spawn(async move { ok(c.get_schema_by_id(SchemaId::from(7u32)).await) })
};
tokio::task::yield_now().await;
cached.inner().release();
let s1 = join_ok(first.await);
let s2 = join_ok(second.await);
assert_eq!(s1, s2);
assert_eq!(cached.inner().get_by_id_call_count(), 1);
}
#[tokio::test]
async fn test_cache_coalescer_cleans_up_when_leader_is_cancelled() {
let cached = Arc::new(CachedSchemaRegistry::new(BlockingMockRegistry::new()));
let first = {
let c = cached.clone();
tokio::spawn(async move { ok(c.get_schema_by_id(SchemaId::from(9u32)).await) })
};
cached.inner().wait_started().await;
first.abort();
tokio::task::yield_now().await;
let second = {
let c = cached.clone();
tokio::spawn(async move { ok(c.get_schema_by_id(SchemaId::from(9u32)).await) })
};
tokio::time::timeout(
std::time::Duration::from_secs(5),
cached.inner().wait_started(),
)
.await
.expect("second lookup did not reach inner registry");
cached.inner().release();
let schema = tokio::time::timeout(std::time::Duration::from_secs(5), second)
.await
.expect("second lookup timed out")
.expect("second task failed");
assert_eq!(schema.id, 9u32);
}
#[tokio::test]
async fn test_cache_get_latest_populates_id_cache() {
let cached = CachedSchemaRegistry::new(MockRegistry::new());
let schema = ok(cached.get_latest_schema("test-value").await);
assert_eq!(cached.cache_len(), 1);
let by_id = ok(cached.get_schema_by_id(schema.id).await);
assert_eq!(cached.inner().get_by_id_call_count(), 0);
assert_eq!(by_id.id, schema.id);
}
#[tokio::test]
async fn test_cache_get_by_version_populates_id_cache() {
let cached = CachedSchemaRegistry::new(MockRegistry::new());
let schema = ok(cached
.get_schema_by_version("test-value", SchemaVersion::new(1))
.await);
assert_eq!(cached.cache_len(), 1);
let by_id = ok(cached.get_schema_by_id(schema.id).await);
assert_eq!(cached.inner().get_by_id_call_count(), 0);
assert_eq!(by_id.id, schema.id);
}
#[tokio::test]
async fn test_cache_register_forwards() {
let cached = CachedSchemaRegistry::new(MockRegistry::new());
let id = ok(cached
.register_schema("test-value", "{}", crate::types::SchemaType::Avro, &[])
.await);
assert_eq!(id, SchemaId::from(42u32));
}
#[tokio::test]
async fn test_cache_with_max_entries_evicts_oldest_entry() {
let cached = CachedSchemaRegistry::with_max_entries(MockRegistry::new(), 1);
ok(cached.get_schema_by_id(SchemaId::from(1u32)).await);
ok(cached.get_schema_by_id(SchemaId::from(2u32)).await);
assert_eq!(cached.cache_len(), 1);
assert_eq!(cached.inner().get_by_id_call_count(), 2);
ok(cached.get_schema_by_id(SchemaId::from(1u32)).await);
assert_eq!(cached.inner().get_by_id_call_count(), 3);
}
#[tokio::test]
async fn test_cache_with_max_entries() {
let cached = CachedSchemaRegistry::with_max_entries(MockRegistry::new(), 100);
assert_eq!(cached.cache_len(), 0);
ok(cached.get_schema_by_id(SchemaId::from(1u32)).await);
assert_eq!(cached.cache_len(), 1);
}
#[tokio::test]
async fn test_any_schema_cache_trait() {
let cached = CachedSchemaRegistry::new(MockRegistry::new());
let generic: &dyn AnySchemaCache<Id = SchemaId> = &cached;
ok(generic
.warm_cache(&[
SchemaId::from(11u32),
SchemaId::from(12u32),
SchemaId::from(11u32),
])
.await);
assert_eq!(generic.cache_len(), 2);
assert!(!generic.cache_is_empty());
generic.invalidate(SchemaId::from(11u32));
assert_eq!(generic.cache_len(), 1);
generic.invalidate_all();
assert!(generic.cache_is_empty());
}
#[tokio::test]
async fn test_invalidate_does_not_repopulate_from_inflight_fetch() {
let cached = Arc::new(CachedSchemaRegistry::new(BlockingMockRegistry::new()));
let first = {
let c = cached.clone();
tokio::spawn(async move { ok(c.get_schema_by_id(SchemaId::from(7u32)).await) })
};
cached.inner().wait_started().await;
cached.invalidate(7u32);
{
let c = cached.clone();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
c.inner().release();
});
}
let _ = tokio::time::timeout(std::time::Duration::from_secs(5), first)
.await
.expect("in-flight fetch timed out")
.expect("in-flight task failed");
assert_eq!(cached.cache_len(), 0);
let second = {
let c = cached.clone();
tokio::spawn(async move { ok(c.get_schema_by_id(SchemaId::from(7u32)).await) })
};
cached.inner().wait_started().await;
cached.inner().release();
let _ = join_ok(second.await);
assert_eq!(cached.inner().get_by_id_call_count(), 2);
}
#[tokio::test]
async fn test_invalidate_drops_inflight_get_latest_cache_population() {
let cached = Arc::new(CachedSchemaRegistry::new(BlockingMockRegistry::new()));
let latest = {
let c = cached.clone();
tokio::spawn(async move { ok(c.get_latest_schema("test-value").await) })
};
tokio::time::timeout(std::time::Duration::from_secs(5), async {
while cached.inner().get_latest_call_count() < 1 {
tokio::task::yield_now().await;
}
})
.await
.expect("latest lookup did not start");
cached.invalidate(100u32);
cached.inner().release();
let _ = join_ok(latest.await);
assert_eq!(cached.cache_len(), 0);
ok(cached.get_schema_by_id(SchemaId::from(100u32)).await);
assert_eq!(cached.inner().get_by_id_call_count(), 1);
}
#[tokio::test]
async fn test_invalidate_drops_inflight_get_by_version_cache_population() {
let cached = Arc::new(CachedSchemaRegistry::new(BlockingMockRegistry::new()));
let by_version = {
let c = cached.clone();
tokio::spawn(async move {
ok(c.get_schema_by_version("test-value", SchemaVersion::new(1))
.await)
})
};
tokio::time::timeout(std::time::Duration::from_secs(5), async {
while cached.inner().get_by_version_call_count() < 1 {
tokio::task::yield_now().await;
}
})
.await
.expect("version lookup did not start");
cached.invalidate(100u32);
cached.inner().release();
let _ = join_ok(by_version.await);
assert_eq!(cached.cache_len(), 0);
ok(cached.get_schema_by_id(SchemaId::from(100u32)).await);
assert_eq!(cached.inner().get_by_id_call_count(), 1);
}
#[test]
fn test_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<CachedSchemaRegistry<MockRegistry>>();
}
#[test]
fn test_debug() {
let cached = CachedSchemaRegistry::new(MockRegistry::new());
let dbg = format!("{cached:?}");
assert!(dbg.contains("cache_len"));
}
}