use std::collections::HashMap;
use std::sync::Arc;
use bytes::Bytes;
use jsonschema::Validator;
use parking_lot::{Mutex, RwLock};
use serde::Serialize;
use serde::de::DeserializeOwned;
use serde_json::Value;
use tokio::sync::oneshot;
use crate::error::{Result, SchemaRegError};
use crate::subject::SubjectNameStrategy;
use crate::traits::SchemaRegistryClient;
use crate::types::{EncodeTarget, SchemaId, SchemaReference, SchemaType};
use crate::wire::{decode_wire_format_bytes, encode_wire_format};
type InFlightMap = Mutex<HashMap<String, Vec<oneshot::Sender<Result<Arc<EncoderEntry>>>>>>;
fn compile_schema(schema_str: &str) -> Result<Validator> {
let schema: Value = serde_json::from_str(schema_str)
.map_err(|e| SchemaRegError::config(format!("invalid JSON Schema (parse): {e}")))?;
jsonschema::validator_for(&schema)
.map_err(|e| SchemaRegError::config(format!("invalid JSON Schema (compile): {e}")))
}
fn validate(validator: &Validator, value: &Value) -> Result<()> {
let errors: Vec<String> = validator
.iter_errors(value)
.map(|e| e.to_string())
.collect();
if errors.is_empty() {
Ok(())
} else {
Err(SchemaRegError::wire_format(format!(
"JSON Schema validation failed: {}",
errors.join("; ")
)))
}
}
struct EncoderEntry {
schema_id: SchemaId,
validator: Arc<Validator>,
}
pub struct JsonSchemaEncoder<C> {
registry: C,
schema_str: String,
validator: Arc<Validator>,
record_name: Option<String>,
strategy: SubjectNameStrategy,
references: Vec<SchemaReference>,
validate_on_encode: bool,
cache: RwLock<HashMap<String, Arc<EncoderEntry>>>,
in_flight: InFlightMap,
}
impl<C: std::fmt::Debug> std::fmt::Debug for JsonSchemaEncoder<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JsonSchemaEncoder")
.field("registry", &self.registry)
.field("record_name", &self.record_name)
.field("strategy", &self.strategy)
.field("validate_on_encode", &self.validate_on_encode)
.finish_non_exhaustive()
}
}
impl<C: SchemaRegistryClient> JsonSchemaEncoder<C> {
pub fn builder() -> JsonSchemaEncoderBuilder<C> {
JsonSchemaEncoderBuilder::new()
}
#[must_use]
pub fn cached_schema_id(&self, subject: &str) -> Option<SchemaId> {
self.cache.read().get(subject).map(|e| e.schema_id)
}
async fn resolve_subject(&self, subject: &str) -> Result<Arc<EncoderEntry>> {
if let Some(entry) = self.cache.read().get(subject) {
return Ok(Arc::clone(entry));
}
let waiter_rx = {
let mut in_flight = self.in_flight.lock();
if let Some(entry) = self.cache.read().get(subject) {
return Ok(Arc::clone(entry));
}
if let Some(waiters) = in_flight.get_mut(subject) {
let (tx, rx) = oneshot::channel();
waiters.push(tx);
Some(rx)
} else {
in_flight.insert(subject.to_string(), Vec::new());
None
}
};
if let Some(rx) = waiter_rx {
return rx.await.map_err(|_| {
SchemaRegError::invalid_state(
"JSON schema entry resolution cancelled by the leader",
)
})?;
}
struct ResolveGuard<'a> {
in_flight: &'a InFlightMap,
subject: String,
done: bool,
}
impl Drop for ResolveGuard<'_> {
fn drop(&mut self) {
if !self.done {
let waiters = self
.in_flight
.lock()
.remove(&self.subject)
.unwrap_or_default();
for tx in waiters {
let _ = tx.send(Err(SchemaRegError::invalid_state(
"JSON schema entry resolution cancelled",
)));
}
}
}
}
let mut guard = ResolveGuard {
in_flight: &self.in_flight,
subject: subject.to_string(),
done: false,
};
let result = self
.registry
.register_schema(
subject,
&self.schema_str,
SchemaType::Json,
&self.references,
)
.await
.map(|schema_id| {
Arc::new(EncoderEntry {
schema_id,
validator: Arc::clone(&self.validator),
})
});
let waiters = self.in_flight.lock().remove(subject).unwrap_or_default();
match &result {
Ok(entry) => {
self.cache
.write()
.insert(subject.to_string(), Arc::clone(entry));
for tx in waiters {
let _ = tx.send(Ok(Arc::clone(entry)));
}
}
Err(e) => {
let cloned = e.clone();
for tx in waiters {
let _ = tx.send(Err(cloned.clone()));
}
}
}
guard.done = true;
result
}
pub async fn encode(&self, value: &Value, topic: &str, target: EncodeTarget) -> Result<Bytes> {
let subject = self
.strategy
.subject_name(topic, self.record_name.as_deref(), target)?;
let entry = self.resolve_subject(&subject).await?;
if self.validate_on_encode {
validate(&entry.validator, value)?;
}
let raw = serde_json::to_vec(value)
.map_err(|e| SchemaRegError::wire_format(format!("JSON serialization failed: {e}")))?;
Ok(encode_wire_format(entry.schema_id, &raw))
}
pub async fn encode_ser<T: Serialize>(
&self,
value: &T,
topic: &str,
target: EncodeTarget,
) -> Result<Bytes> {
let json_value = serde_json::to_value(value).map_err(|e| {
SchemaRegError::wire_format(format!("failed to convert value to JSON: {e}"))
})?;
self.encode(&json_value, topic, target).await
}
}
pub struct JsonSchemaEncoderBuilder<C> {
registry: Option<C>,
schema: Option<String>,
record_name: Option<String>,
strategy: SubjectNameStrategy,
references: Vec<SchemaReference>,
validate_on_encode: bool,
}
impl<C: SchemaRegistryClient> JsonSchemaEncoderBuilder<C> {
fn new() -> Self {
Self {
registry: None,
schema: None,
record_name: None,
strategy: SubjectNameStrategy::TopicName,
references: Vec::new(),
validate_on_encode: true,
}
}
pub fn registry(mut self, registry: C) -> Self {
self.registry = Some(registry);
self
}
pub fn schema(mut self, schema: impl Into<String>) -> Self {
self.schema = Some(schema.into());
self
}
pub fn record_name(mut self, name: impl Into<String>) -> Self {
self.record_name = Some(name.into());
self
}
pub fn strategy(mut self, strategy: SubjectNameStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn references(mut self, references: Vec<SchemaReference>) -> Self {
self.references = references;
self
}
pub fn validate_on_encode(mut self, validate: bool) -> Self {
self.validate_on_encode = validate;
self
}
pub fn build(self) -> Result<JsonSchemaEncoder<C>> {
let registry = self
.registry
.ok_or_else(|| SchemaRegError::config("JsonSchemaEncoder: registry must be set"))?;
let schema_str = self
.schema
.ok_or_else(|| SchemaRegError::config("JsonSchemaEncoder: schema must be set"))?;
let validator = Arc::new(compile_schema(&schema_str)?);
Ok(JsonSchemaEncoder {
registry,
schema_str,
validator,
record_name: self.record_name,
strategy: self.strategy,
references: self.references,
validate_on_encode: self.validate_on_encode,
cache: RwLock::new(HashMap::new()),
in_flight: Mutex::new(HashMap::new()),
})
}
}
pub struct JsonSchemaDecoder<C> {
registry: C,
validate_on_decode: bool,
schema_cache: RwLock<HashMap<SchemaId, Arc<Validator>>>,
}
impl<C: SchemaRegistryClient> JsonSchemaDecoder<C> {
pub fn new(registry: C) -> Self {
Self {
registry,
validate_on_decode: false,
schema_cache: RwLock::new(HashMap::new()),
}
}
pub fn with_validation(registry: C) -> Self {
Self {
registry,
validate_on_decode: true,
schema_cache: RwLock::new(HashMap::new()),
}
}
async fn get_validator(&self, id: SchemaId) -> Result<Arc<Validator>> {
if let Some(v) = self.schema_cache.read().get(&id) {
return Ok(Arc::clone(v));
}
let registry_schema = self.registry.get_schema_by_id(id).await?;
let validator = Arc::new(compile_schema(®istry_schema.schema)?);
self.schema_cache.write().insert(id, Arc::clone(&validator));
Ok(validator)
}
pub async fn decode(&self, data: Bytes) -> Result<Value> {
let (schema_id, payload) = decode_wire_format_bytes(&data)?;
let value: Value = serde_json::from_slice(&payload).map_err(|e| {
SchemaRegError::wire_format(format!("JSON deserialisation failed: {e}"))
})?;
if self.validate_on_decode {
let validator = self.get_validator(schema_id).await?;
validate(&validator, &value)?;
}
Ok(value)
}
pub async fn decode_de<T: DeserializeOwned>(&self, data: Bytes) -> Result<T> {
let value = self.decode(data).await?;
serde_json::from_value(value).map_err(|e| {
SchemaRegError::wire_format(format!(
"failed to deserialise JSON value into target type: {e}"
))
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::sync::{Arc as StdArc, Mutex as StdMutex};
use serde::{Deserialize, Serialize};
use serde_json::json;
use crate::types::{Schema, SchemaType, SchemaVersion};
#[derive(Clone, Debug)]
struct MockRegistry {
inner: StdArc<MockRegistryInner>,
}
struct MockRegistryInner {
schemas: StdMutex<HashMap<SchemaId, Schema>>,
next_id: StdMutex<u32>,
}
impl std::fmt::Debug for MockRegistryInner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MockRegistryInner").finish_non_exhaustive()
}
}
impl MockRegistry {
fn new() -> Self {
Self {
inner: StdArc::new(MockRegistryInner {
schemas: StdMutex::new(HashMap::new()),
next_id: StdMutex::new(1),
}),
}
}
}
impl SchemaRegistryClient for MockRegistry {
async fn get_schema_by_id(&self, id: SchemaId) -> crate::error::Result<StdArc<Schema>> {
self.inner
.schemas
.lock()
.unwrap()
.get(&id)
.map(|s| StdArc::new(s.clone()))
.ok_or_else(|| SchemaRegError::api(40403, format!("schema {id} not found")))
}
async fn get_latest_schema(&self, _subject: &str) -> crate::error::Result<StdArc<Schema>> {
Err(SchemaRegError::not_supported("not implemented"))
}
async fn get_schema_by_version(
&self,
_subject: &str,
_version: SchemaVersion,
) -> crate::error::Result<StdArc<Schema>> {
Err(SchemaRegError::not_supported("not implemented"))
}
async fn register_schema(
&self,
_subject: &str,
schema: &str,
schema_type: SchemaType,
_references: &[SchemaReference],
) -> crate::error::Result<SchemaId> {
let mut next_id = self.inner.next_id.lock().unwrap();
let id = SchemaId::from(*next_id);
*next_id += 1;
let schema_obj = Schema::new(id, schema_type, schema);
self.inner.schemas.lock().unwrap().insert(id, schema_obj);
Ok(id)
}
}
const ORDER_SCHEMA: &str = r#"{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"type": "object",
"properties": {
"id": { "type": "integer" },
"item": { "type": "string" },
"price": { "type": "number" }
},
"required": ["id", "item", "price"],
"additionalProperties": false
}"#;
#[tokio::test]
async fn encode_valid_value() {
let reg = MockRegistry::new();
let enc = JsonSchemaEncoder::builder()
.registry(reg)
.schema(ORDER_SCHEMA)
.build()
.unwrap();
let v = json!({"id": 1, "item": "Widget", "price": 9.99});
let framed = enc.encode(&v, "orders", EncodeTarget::Value).await.unwrap();
assert_eq!(framed[0], 0x00, "magic byte must be 0x00");
assert!(framed.len() > 5, "framed must include payload");
}
#[tokio::test]
async fn encode_invalid_value_rejected() {
let reg = MockRegistry::new();
let enc = JsonSchemaEncoder::builder()
.registry(reg)
.schema(ORDER_SCHEMA)
.build()
.unwrap();
let v = json!({"id": 1, "price": 9.99});
let err = enc
.encode(&v, "orders", EncodeTarget::Value)
.await
.unwrap_err();
assert!(err.is_wire_format_error(), "should be a wire format error");
assert!(
err.to_string().contains("validation"),
"should mention validation"
);
}
#[tokio::test]
async fn encode_no_validation() {
let reg = MockRegistry::new();
let enc = JsonSchemaEncoder::builder()
.registry(reg)
.schema(ORDER_SCHEMA)
.validate_on_encode(false)
.build()
.unwrap();
let v = json!({"id": "not-an-integer"});
let framed = enc.encode(&v, "orders", EncodeTarget::Value).await.unwrap();
assert_eq!(framed[0], 0x00);
}
#[tokio::test]
async fn encode_caches_schema_id() {
let reg = MockRegistry::new();
let enc = JsonSchemaEncoder::builder()
.registry(reg)
.schema(ORDER_SCHEMA)
.build()
.unwrap();
let v = json!({"id": 1, "item": "A", "price": 1.0});
let f1 = enc.encode(&v, "orders", EncodeTarget::Value).await.unwrap();
let f2 = enc.encode(&v, "orders", EncodeTarget::Value).await.unwrap();
assert_eq!(
&f1[1..5],
&f2[1..5],
"schema ID must be cached across calls"
);
}
#[tokio::test]
async fn encode_key_and_value_subjects() {
let reg = MockRegistry::new();
let enc = JsonSchemaEncoder::builder()
.registry(reg)
.schema(ORDER_SCHEMA)
.validate_on_encode(false)
.build()
.unwrap();
let v = json!({"id": 1, "item": "A", "price": 1.0});
let fv = enc.encode(&v, "orders", EncodeTarget::Value).await.unwrap();
let fk = enc.encode(&v, "orders", EncodeTarget::Key).await.unwrap();
assert_ne!(
&fv[1..5],
&fk[1..5],
"key and value must use different schema IDs"
);
assert!(enc.cached_schema_id("orders-value").is_some());
assert!(enc.cached_schema_id("orders-key").is_some());
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct Order {
id: i64,
item: String,
price: f64,
}
#[tokio::test]
async fn encode_ser_roundtrip() {
let reg = MockRegistry::new();
let enc = JsonSchemaEncoder::builder()
.registry(reg.clone())
.schema(ORDER_SCHEMA)
.build()
.unwrap();
let dec = JsonSchemaDecoder::new(reg);
let original = Order {
id: 42,
item: "Gadget".into(),
price: 19.99,
};
let framed = enc
.encode_ser(&original, "orders", EncodeTarget::Value)
.await
.unwrap();
let decoded: Order = dec.decode_de(framed).await.unwrap();
assert_eq!(original, decoded);
}
#[tokio::test]
async fn decode_valid_payload() {
let reg = MockRegistry::new();
let enc = JsonSchemaEncoder::builder()
.registry(reg.clone())
.schema(ORDER_SCHEMA)
.build()
.unwrap();
let dec = JsonSchemaDecoder::new(reg);
let v = json!({"id": 7, "item": "Sprocket", "price": 3.50});
let framed = enc.encode(&v, "orders", EncodeTarget::Value).await.unwrap();
let decoded = dec.decode(framed).await.unwrap();
assert_eq!(decoded, v);
}
#[tokio::test]
async fn decode_with_validation_valid() {
let reg = MockRegistry::new();
let enc = JsonSchemaEncoder::builder()
.registry(reg.clone())
.schema(ORDER_SCHEMA)
.validate_on_encode(false) .build()
.unwrap();
let dec = JsonSchemaDecoder::with_validation(reg);
let valid = json!({"id": 1, "item": "Valid", "price": 1.0});
let framed = enc
.encode(&valid, "orders", EncodeTarget::Value)
.await
.unwrap();
let result = dec.decode(framed).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn decode_with_validation_invalid() {
let reg = MockRegistry::new();
let enc = JsonSchemaEncoder::builder()
.registry(reg.clone())
.schema(ORDER_SCHEMA)
.validate_on_encode(false) .build()
.unwrap();
let dec = JsonSchemaDecoder::with_validation(reg);
let invalid = json!({"id": 1});
let framed = enc
.encode(&invalid, "orders", EncodeTarget::Value)
.await
.unwrap();
let err = dec.decode(framed).await.unwrap_err();
assert!(err.is_wire_format_error());
}
#[tokio::test]
async fn build_with_invalid_schema_returns_config_error() {
let reg = MockRegistry::new();
let result = JsonSchemaEncoder::builder()
.registry(reg)
.schema("not valid JSON")
.build();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.is_config_error(),
"should be a config error, got: {err}"
);
}
}