#[cfg(feature = "glue")]
mod client;
#[cfg(feature = "glue")]
pub use client::{AwsGlueSchemaRegistry, AwsGlueSchemaRegistryBuilder};
use std::borrow::Cow;
use std::collections::HashSet;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use bytes::{BufMut, Bytes, BytesMut};
#[cfg(feature = "glue")]
use flate2::Compression;
#[cfg(feature = "glue")]
use flate2::read::ZlibDecoder;
#[cfg(feature = "glue")]
use flate2::write::ZlibEncoder;
#[cfg(feature = "glue")]
use std::io::{Read, Write};
use crate::cache_inner::InMemoryCache;
use crate::error::{Result, SchemaRegError};
use crate::traits::AnySchemaCache;
pub(crate) const GLUE_HEADER_VERSION_BYTE: u8 = 0x03;
pub(crate) const GLUE_COMPRESSION_NONE_BYTE: u8 = 0x00;
pub(crate) const GLUE_COMPRESSION_ZLIB_BYTE: u8 = 0x05;
pub(crate) const GLUE_HEADER_SIZE: usize = 18;
const UUID_SIZE: usize = 16;
#[derive(Debug, Clone)]
pub struct WarmGlueCacheError {
pub failures: Vec<(GlueSchemaVersionId, SchemaRegError)>,
}
impl fmt::Display for WarmGlueCacheError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"warm_cache failed for {} Glue schema version ID(s):",
self.failures.len()
)?;
for (id, e) in &self.failures {
write!(f, " id {id}: {e};")?;
}
Ok(())
}
}
impl std::error::Error for WarmGlueCacheError {}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct GlueSchemaVersionId([u8; UUID_SIZE]);
impl GlueSchemaVersionId {
pub fn from_bytes(bytes: [u8; UUID_SIZE]) -> Self {
Self(bytes)
}
pub fn as_bytes(&self) -> &[u8; UUID_SIZE] {
&self.0
}
}
impl From<[u8; UUID_SIZE]> for GlueSchemaVersionId {
fn from(bytes: [u8; UUID_SIZE]) -> Self {
Self(bytes)
}
}
impl From<GlueSchemaVersionId> for [u8; UUID_SIZE] {
fn from(id: GlueSchemaVersionId) -> Self {
id.0
}
}
impl fmt::Display for GlueSchemaVersionId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let b = &self.0;
write!(
f,
"{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}",
b[0],
b[1],
b[2],
b[3],
b[4],
b[5],
b[6],
b[7],
b[8],
b[9],
b[10],
b[11],
b[12],
b[13],
b[14],
b[15]
)
}
}
impl fmt::Debug for GlueSchemaVersionId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "GlueSchemaVersionId({self})")
}
}
impl FromStr for GlueSchemaVersionId {
type Err = SchemaRegError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
let bytes = s.as_bytes();
if bytes.len() != 36 {
return Err(SchemaRegError::invalid_state(format!(
"invalid UUID: expected 36 characters, got {}",
bytes.len()
)));
}
if bytes[8] != b'-' || bytes[13] != b'-' || bytes[18] != b'-' || bytes[23] != b'-' {
return Err(SchemaRegError::invalid_state(
"invalid UUID format: expected dashes at positions 8, 13, 18, 23",
));
}
let hex_positions: [usize; UUID_SIZE] =
[0, 2, 4, 6, 9, 11, 14, 16, 19, 21, 24, 26, 28, 30, 32, 34];
let mut uuid_bytes = [0u8; UUID_SIZE];
for (i, &pos) in hex_positions.iter().enumerate() {
uuid_bytes[i] = parse_hex_byte(bytes[pos], bytes[pos + 1]).ok_or_else(|| {
SchemaRegError::invalid_state("invalid UUID: non-hexadecimal character")
})?;
}
Ok(Self(uuid_bytes))
}
}
fn parse_hex_byte(hi: u8, lo: u8) -> Option<u8> {
Some((hex_digit(hi)? << 4) | hex_digit(lo)?)
}
fn hex_digit(c: u8) -> Option<u8> {
match c {
b'0'..=b'9' => Some(c - b'0'),
b'a'..=b'f' => Some(c - b'a' + 10),
b'A'..=b'F' => Some(c - b'A' + 10),
_ => None,
}
}
fn glue_schema_lookup_cancelled_error(id: GlueSchemaVersionId) -> SchemaRegError {
SchemaRegError::invalid_state(format!(
"glue schema lookup cancelled before completion for id {id}"
))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
#[non_exhaustive]
pub enum GlueCompression {
#[default]
None,
Zlib,
}
impl GlueCompression {
pub fn as_str(&self) -> &'static str {
match self {
Self::None => "NONE",
Self::Zlib => "ZLIB",
}
}
}
impl fmt::Display for GlueCompression {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum GlueDataFormat {
Avro,
Json,
Protobuf,
}
impl GlueDataFormat {
pub fn as_str(&self) -> &'static str {
match self {
Self::Avro => "AVRO",
Self::Json => "JSON",
Self::Protobuf => "PROTOBUF",
}
}
}
impl fmt::Display for GlueDataFormat {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl FromStr for GlueDataFormat {
type Err = SchemaRegError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
if s.eq_ignore_ascii_case("AVRO") {
Ok(Self::Avro)
} else if s.eq_ignore_ascii_case("JSON") {
Ok(Self::Json)
} else if s.eq_ignore_ascii_case("PROTOBUF") {
Ok(Self::Protobuf)
} else {
Err(SchemaRegError::invalid_state(format!(
"unknown Glue data format: '{s}'"
)))
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GlueSchema {
pub schema_version_id: GlueSchemaVersionId,
pub data_format: GlueDataFormat,
pub schema_definition: Arc<str>,
pub schema_arn: Option<String>,
pub version_number: Option<i64>,
}
impl GlueSchema {
pub fn new(
schema_version_id: GlueSchemaVersionId,
data_format: GlueDataFormat,
schema_definition: impl Into<Arc<str>>,
) -> Self {
Self {
schema_version_id,
data_format,
schema_definition: schema_definition.into(),
schema_arn: None,
version_number: None,
}
}
pub fn with_metadata(mut self, schema_arn: impl Into<String>, version_number: i64) -> Self {
self.schema_arn = Some(schema_arn.into());
self.version_number = Some(version_number);
self
}
}
pub fn encode_glue_wire_format(
schema_version_id: GlueSchemaVersionId,
payload: &[u8],
compression: GlueCompression,
) -> Result<Bytes> {
let compressed;
let (compression_byte, payload_bytes): (u8, &[u8]) = match compression {
GlueCompression::None => (GLUE_COMPRESSION_NONE_BYTE, payload),
GlueCompression::Zlib => {
compressed = compress_zlib(payload)?;
(GLUE_COMPRESSION_ZLIB_BYTE, &compressed)
}
};
let mut buf = BytesMut::with_capacity(GLUE_HEADER_SIZE + payload_bytes.len());
buf.put_u8(GLUE_HEADER_VERSION_BYTE);
buf.put_u8(compression_byte);
buf.put_slice(schema_version_id.as_bytes());
buf.put_slice(payload_bytes);
Ok(buf.freeze())
}
pub fn decode_glue_wire_format(data: &[u8]) -> Result<(GlueSchemaVersionId, Vec<u8>)> {
let (schema_version_id, compression) = validate_glue_wire_header(data)?;
let raw = &data[GLUE_HEADER_SIZE..];
let payload = match compression {
GlueCompression::None => raw.to_vec(),
GlueCompression::Zlib => decompress_zlib(raw)?,
};
Ok((schema_version_id, payload))
}
pub fn decode_glue_wire_format_borrowed(
data: &[u8],
) -> Result<(GlueSchemaVersionId, Cow<'_, [u8]>)> {
let (schema_version_id, compression) = validate_glue_wire_header(data)?;
let raw = &data[GLUE_HEADER_SIZE..];
let payload = match compression {
GlueCompression::None => Cow::Borrowed(raw),
GlueCompression::Zlib => Cow::Owned(decompress_zlib(raw)?),
};
Ok((schema_version_id, payload))
}
pub fn decode_glue_wire_format_bytes(data: &Bytes) -> Result<(GlueSchemaVersionId, Bytes)> {
let (schema_version_id, compression) = validate_glue_wire_header(data)?;
let payload = match compression {
GlueCompression::None => data.slice(GLUE_HEADER_SIZE..),
GlueCompression::Zlib => {
let decompressed = decompress_zlib(&data[GLUE_HEADER_SIZE..])?;
Bytes::from(decompressed)
}
};
Ok((schema_version_id, payload))
}
pub(crate) fn validate_glue_wire_header(
data: &[u8],
) -> Result<(GlueSchemaVersionId, GlueCompression)> {
if data.len() < GLUE_HEADER_SIZE {
return Err(SchemaRegError::wire_format(format!(
"Glue wire format data too short: expected at least {GLUE_HEADER_SIZE} bytes, got {}",
data.len()
)));
}
if data[0] != GLUE_HEADER_VERSION_BYTE {
return Err(SchemaRegError::wire_format(format!(
"invalid Glue wire format header version byte: expected 0x{GLUE_HEADER_VERSION_BYTE:02X}, got 0x{:02X}",
data[0]
)));
}
let compression = match data[1] {
GLUE_COMPRESSION_NONE_BYTE => GlueCompression::None,
GLUE_COMPRESSION_ZLIB_BYTE => GlueCompression::Zlib,
other => {
return Err(SchemaRegError::wire_format(format!(
"unknown Glue wire format compression byte: 0x{other:02X}"
)));
}
};
let mut uuid_bytes = [0u8; UUID_SIZE];
uuid_bytes.copy_from_slice(&data[2..GLUE_HEADER_SIZE]);
Ok((GlueSchemaVersionId(uuid_bytes), compression))
}
#[cfg(feature = "glue")]
fn compress_zlib(data: &[u8]) -> Result<Vec<u8>> {
let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
encoder
.write_all(data)
.map_err(|e| SchemaRegError::wire_format(format!("ZLIB compression failed: {e}")))?;
encoder
.finish()
.map_err(|e| SchemaRegError::wire_format(format!("ZLIB compression failed: {e}")))
}
#[cfg(not(feature = "glue"))]
fn compress_zlib(_data: &[u8]) -> Result<Vec<u8>> {
Err(SchemaRegError::wire_format(
"Glue ZLIB compression requires the `glue` Cargo feature",
))
}
#[cfg(feature = "glue")]
const MAX_DECOMPRESSED_SIZE: usize = 128 * 1024 * 1024;
#[cfg(feature = "glue")]
pub(crate) fn decompress_zlib(data: &[u8]) -> Result<Vec<u8>> {
let decoder = ZlibDecoder::new(data);
let mut limited = decoder.take(MAX_DECOMPRESSED_SIZE as u64 + 1);
let mut decompressed = Vec::new();
limited
.read_to_end(&mut decompressed)
.map_err(|e| SchemaRegError::wire_format(format!("ZLIB decompression failed: {e}")))?;
if decompressed.len() > MAX_DECOMPRESSED_SIZE {
return Err(SchemaRegError::wire_format(format!(
"ZLIB decompressed size {} exceeds maximum {} bytes (possible decompression bomb)",
decompressed.len(),
MAX_DECOMPRESSED_SIZE
)));
}
Ok(decompressed)
}
#[cfg(not(feature = "glue"))]
pub(crate) fn decompress_zlib(_data: &[u8]) -> Result<Vec<u8>> {
Err(SchemaRegError::wire_format(
"Glue ZLIB decompression requires the `glue` Cargo feature",
))
}
pub trait GlueSchemaRegistryClient: Send + Sync {
fn get_schema_by_version_id(
&self,
id: GlueSchemaVersionId,
) -> impl Future<Output = Result<Arc<GlueSchema>>> + Send + '_;
fn register_schema<'a>(
&'a self,
schema_name: &'a str,
schema: &'a str,
data_format: GlueDataFormat,
) -> impl Future<Output = Result<GlueSchemaVersionId>> + Send + 'a;
}
impl<T: GlueSchemaRegistryClient + ?Sized> GlueSchemaRegistryClient for &T {
fn get_schema_by_version_id(
&self,
id: GlueSchemaVersionId,
) -> impl Future<Output = Result<Arc<GlueSchema>>> + Send + '_ {
T::get_schema_by_version_id(self, id)
}
fn register_schema<'a>(
&'a self,
schema_name: &'a str,
schema: &'a str,
data_format: GlueDataFormat,
) -> impl Future<Output = Result<GlueSchemaVersionId>> + Send + 'a {
T::register_schema(self, schema_name, schema, data_format)
}
}
impl<T: GlueSchemaRegistryClient + ?Sized> GlueSchemaRegistryClient for std::sync::Arc<T> {
fn get_schema_by_version_id(
&self,
id: GlueSchemaVersionId,
) -> impl Future<Output = Result<Arc<GlueSchema>>> + Send + '_ {
T::get_schema_by_version_id(self, id)
}
fn register_schema<'a>(
&'a self,
schema_name: &'a str,
schema: &'a str,
data_format: GlueDataFormat,
) -> impl Future<Output = Result<GlueSchemaVersionId>> + Send + 'a {
T::register_schema(self, schema_name, schema, data_format)
}
}
pub trait DynGlueSchemaRegistryClient: Send + Sync {
fn get_schema_by_version_id<'a>(
&'a self,
id: GlueSchemaVersionId,
) -> std::pin::Pin<Box<dyn Future<Output = Result<Arc<GlueSchema>>> + Send + 'a>>;
fn register_schema<'a>(
&'a self,
schema_name: &'a str,
schema: &'a str,
data_format: GlueDataFormat,
) -> std::pin::Pin<Box<dyn Future<Output = Result<GlueSchemaVersionId>> + Send + 'a>>;
}
impl<T: GlueSchemaRegistryClient> DynGlueSchemaRegistryClient for T {
fn get_schema_by_version_id<'a>(
&'a self,
id: GlueSchemaVersionId,
) -> std::pin::Pin<Box<dyn Future<Output = Result<Arc<GlueSchema>>> + Send + 'a>> {
Box::pin(GlueSchemaRegistryClient::get_schema_by_version_id(self, id))
}
fn register_schema<'a>(
&'a self,
schema_name: &'a str,
schema: &'a str,
data_format: GlueDataFormat,
) -> std::pin::Pin<Box<dyn Future<Output = Result<GlueSchemaVersionId>> + Send + 'a>> {
Box::pin(GlueSchemaRegistryClient::register_schema(
self,
schema_name,
schema,
data_format,
))
}
}
pub struct CachedGlueSchemaRegistry<C> {
inner: C,
cache: InMemoryCache<GlueSchemaVersionId, GlueSchema>,
}
pub const DEFAULT_MAX_GLUE_CACHE_ENTRIES: usize = 1000;
impl<C: GlueSchemaRegistryClient> CachedGlueSchemaRegistry<C> {
pub fn new(inner: C) -> Self {
Self::with_max_entries(inner, DEFAULT_MAX_GLUE_CACHE_ENTRIES)
}
pub fn with_max_entries(inner: C, max_entries: usize) -> Self {
let max_entries = max_entries.max(1);
Self {
inner,
cache: InMemoryCache::new(Some(max_entries), glue_schema_lookup_cancelled_error),
}
}
pub fn inner(&self) -> &C {
&self.inner
}
pub fn cache_len(&self) -> usize {
self.cache.len()
}
pub fn cache_is_empty(&self) -> bool {
self.cache.is_empty()
}
pub fn clear_cache(&self) {
self.cache.clear();
}
pub fn invalidate(&self, version_id: GlueSchemaVersionId) {
self.cache.invalidate(version_id);
}
pub fn invalidate_all(&self) {
self.cache.clear();
}
pub async fn warm_cache(
&self,
version_ids: &[GlueSchemaVersionId],
) -> std::result::Result<(), WarmGlueCacheError> {
const WARM_CONCURRENCY: usize = 16;
let unique: HashSet<GlueSchemaVersionId> = version_ids.iter().copied().collect();
if unique.is_empty() {
return Ok(());
}
let ids: Vec<GlueSchemaVersionId> = unique.into_iter().collect();
let mut failures: Vec<(GlueSchemaVersionId, SchemaRegError)> = Vec::new();
for chunk in ids.chunks(WARM_CONCURRENCY) {
let futs = chunk.iter().map(|&id| async move {
(
id,
self.cache
.get_or_fetch(id, || self.inner.get_schema_by_version_id(id))
.await,
)
});
let results = futures::future::join_all(futs).await;
for (id, result) in results {
if let Err(e) = result {
failures.push((id, e));
}
}
}
if failures.is_empty() {
Ok(())
} else {
Err(WarmGlueCacheError { failures })
}
}
pub async fn get_schema_by_version_id(
&self,
id: GlueSchemaVersionId,
) -> Result<Arc<GlueSchema>> {
self.cache
.get_or_fetch(id, || self.inner.get_schema_by_version_id(id))
.await
}
pub async fn register_schema(
&self,
schema_name: &str,
schema: &str,
data_format: GlueDataFormat,
) -> Result<GlueSchemaVersionId> {
self.inner
.register_schema(schema_name, schema, data_format)
.await
}
}
impl<C> fmt::Debug for CachedGlueSchemaRegistry<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CachedGlueSchemaRegistry")
.field("cache_len", &self.cache.len())
.field("cache", &self.cache)
.finish()
}
}
impl<C: GlueSchemaRegistryClient> GlueSchemaRegistryClient for CachedGlueSchemaRegistry<C> {
async fn get_schema_by_version_id(&self, id: GlueSchemaVersionId) -> Result<Arc<GlueSchema>> {
self.get_schema_by_version_id(id).await
}
async fn register_schema(
&self,
schema_name: &str,
schema: &str,
data_format: GlueDataFormat,
) -> Result<GlueSchemaVersionId> {
self.register_schema(schema_name, schema, data_format).await
}
}
impl<C: GlueSchemaRegistryClient> AnySchemaCache for CachedGlueSchemaRegistry<C> {
type Id = GlueSchemaVersionId;
fn cache_len(&self) -> usize {
Self::cache_len(self)
}
fn cache_is_empty(&self) -> bool {
Self::cache_is_empty(self)
}
fn clear_cache(&self) {
Self::clear_cache(self)
}
fn invalidate(&self, id: Self::Id) {
Self::invalidate(self, id)
}
fn invalidate_all(&self) {
Self::invalidate_all(self)
}
fn warm_cache<'a>(
&'a self,
ids: &'a [Self::Id],
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
Box::pin(async move {
Self::warm_cache(self, ids)
.await
.map_err(|e| SchemaRegError::invalid_state(e.to_string()))
})
}
}