use std::collections::HashMap;
use std::sync::Arc;
use apache_avro::Schema as AvroSchema;
use apache_avro::types::Value;
use bytes::Bytes;
use parking_lot::{Mutex, RwLock};
use serde::{Deserialize, Serialize};
use tokio::sync::oneshot;
use crate::error::{Result, SchemaRegError};
use crate::subject::SubjectNameStrategy;
use crate::traits::SchemaRegistryClient;
use crate::types::{EncodeTarget, SchemaId, SchemaReference, SchemaType};
use crate::wire::{decode_wire_format_bytes, encode_wire_format};
type InFlightMap = Mutex<HashMap<String, Vec<oneshot::Sender<Result<Arc<EncoderEntry>>>>>>;
fn parse_avro_schema(schema_str: &str) -> Result<AvroSchema> {
AvroSchema::parse_str(schema_str)
.map_err(|e| SchemaRegError::config(format!("invalid Avro schema: {e}")))
}
fn schema_fullname(schema: &AvroSchema) -> Option<String> {
match schema {
AvroSchema::Record(rs) => Some(rs.name.fullname(rs.name.namespace.clone())),
AvroSchema::Enum(es) => Some(es.name.fullname(es.name.namespace.clone())),
AvroSchema::Fixed(fs) => Some(fs.name.fullname(fs.name.namespace.clone())),
_ => None,
}
}
struct EncoderEntry {
schema_id: SchemaId,
avro_schema: Arc<AvroSchema>,
}
pub struct AvroSchemaEncoder<C> {
registry: C,
schema_str: String,
avro_schema: Arc<AvroSchema>,
schema_fullname: Option<String>,
strategy: SubjectNameStrategy,
references: Vec<SchemaReference>,
cache: RwLock<HashMap<String, Arc<EncoderEntry>>>,
in_flight: InFlightMap,
}
impl<C: SchemaRegistryClient> AvroSchemaEncoder<C> {
pub fn builder() -> AvroSchemaEncoderBuilder<C> {
AvroSchemaEncoderBuilder::new()
}
async fn resolve_subject(&self, subject: &str) -> Result<Arc<EncoderEntry>> {
if let Some(entry) = self.cache.read().get(subject) {
return Ok(Arc::clone(entry));
}
let waiter_rx = {
let mut in_flight = self.in_flight.lock();
if let Some(entry) = self.cache.read().get(subject) {
return Ok(Arc::clone(entry));
}
if let Some(waiters) = in_flight.get_mut(subject) {
let (tx, rx) = oneshot::channel();
waiters.push(tx);
Some(rx)
} else {
in_flight.insert(subject.to_string(), Vec::new());
None
}
};
if let Some(rx) = waiter_rx {
return rx.await.map_err(|_| {
SchemaRegError::invalid_state("schema entry resolution cancelled by the leader")
})?;
}
struct ResolveGuard<'a> {
in_flight: &'a InFlightMap,
subject: &'a str,
done: bool,
}
impl Drop for ResolveGuard<'_> {
fn drop(&mut self) {
if !self.done {
let waiters = self
.in_flight
.lock()
.remove(self.subject)
.unwrap_or_default();
for tx in waiters {
let _ = tx.send(Err(SchemaRegError::invalid_state(
"Avro schema entry resolution cancelled",
)));
}
}
}
}
let mut guard = ResolveGuard {
in_flight: &self.in_flight,
subject,
done: false,
};
let result = self
.registry
.register_schema(
subject,
&self.schema_str,
SchemaType::Avro,
&self.references,
)
.await
.map(|schema_id| {
Arc::new(EncoderEntry {
schema_id,
avro_schema: Arc::clone(&self.avro_schema),
})
});
let waiters = self.in_flight.lock().remove(subject).unwrap_or_default();
match &result {
Ok(entry) => {
self.cache
.write()
.insert(subject.to_string(), Arc::clone(entry));
for tx in waiters {
let _ = tx.send(Ok(Arc::clone(entry)));
}
}
Err(e) => {
let cloned = e.clone();
for tx in waiters {
let _ = tx.send(Err(cloned.clone()));
}
}
}
guard.done = true;
result
}
pub async fn encode(&self, value: Value, topic: &str, target: EncodeTarget) -> Result<Bytes> {
let subject = self
.strategy
.subject_name(topic, self.schema_fullname.as_deref(), target)?;
let entry = self.resolve_subject(&subject).await?;
let raw = apache_avro::to_avro_datum(&entry.avro_schema, value)
.map_err(|e| SchemaRegError::wire_format(format!("Avro serialization failed: {e}")))?;
Ok(encode_wire_format(entry.schema_id, &raw))
}
pub async fn encode_ser<T: Serialize>(
&self,
value: &T,
topic: &str,
target: EncodeTarget,
) -> Result<Bytes> {
let av_value = apache_avro::to_value(value).map_err(|e| {
SchemaRegError::wire_format(format!("failed to convert value to Avro: {e}"))
})?;
self.encode(av_value, topic, target).await
}
}
pub struct AvroSchemaEncoderBuilder<C> {
registry: Option<C>,
schema: Option<String>,
strategy: SubjectNameStrategy,
references: Vec<SchemaReference>,
}
impl<C: SchemaRegistryClient> AvroSchemaEncoderBuilder<C> {
fn new() -> Self {
Self {
registry: None,
schema: None,
strategy: SubjectNameStrategy::TopicName,
references: Vec::new(),
}
}
pub fn registry(mut self, registry: C) -> Self {
self.registry = Some(registry);
self
}
pub fn schema(mut self, schema: impl Into<String>) -> Self {
self.schema = Some(schema.into());
self
}
pub fn strategy(mut self, strategy: SubjectNameStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn references(mut self, references: Vec<SchemaReference>) -> Self {
self.references = references;
self
}
pub fn build(self) -> Result<AvroSchemaEncoder<C>> {
let registry = self
.registry
.ok_or_else(|| SchemaRegError::config("AvroSchemaEncoder: registry must be set"))?;
let schema_str = self
.schema
.ok_or_else(|| SchemaRegError::config("AvroSchemaEncoder: schema must be set"))?;
let avro_schema = parse_avro_schema(&schema_str)?;
let fullname = schema_fullname(&avro_schema);
Ok(AvroSchemaEncoder {
registry,
schema_str,
avro_schema: Arc::new(avro_schema),
schema_fullname: fullname,
strategy: self.strategy,
references: self.references,
cache: RwLock::new(HashMap::new()),
in_flight: Mutex::new(HashMap::new()),
})
}
}
pub struct AvroSchemaDecoder<C> {
registry: C,
schema_cache: RwLock<HashMap<SchemaId, Arc<AvroSchema>>>,
}
impl<C: SchemaRegistryClient> AvroSchemaDecoder<C> {
pub fn new(registry: C) -> Self {
Self {
registry,
schema_cache: RwLock::new(HashMap::new()),
}
}
async fn get_avro_schema(&self, id: SchemaId) -> Result<Arc<AvroSchema>> {
if let Some(schema) = self.schema_cache.read().get(&id) {
return Ok(Arc::clone(schema));
}
let registry_schema = self.registry.get_schema_by_id(id).await?;
let avro_schema = parse_avro_schema(®istry_schema.schema)?;
let avro_schema = Arc::new(avro_schema);
self.schema_cache
.write()
.insert(id, Arc::clone(&avro_schema));
Ok(avro_schema)
}
pub async fn decode(&self, data: Bytes) -> Result<Value> {
let (schema_id, payload) = decode_wire_format_bytes(&data)?;
let avro_schema = self.get_avro_schema(schema_id).await?;
let value = apache_avro::from_avro_datum(&avro_schema, &mut payload.as_ref(), None)
.map_err(|e| {
SchemaRegError::wire_format(format!("Avro deserialization failed: {e}"))
})?;
Ok(value)
}
pub async fn decode_de<T: for<'de> Deserialize<'de>>(&self, data: Bytes) -> Result<T> {
let value = self.decode(data).await?;
apache_avro::from_value::<T>(&value).map_err(|e| {
SchemaRegError::wire_format(format!(
"failed to deserialize Avro value into target type: {e}"
))
})
}
}