use std::collections::HashMap;
use std::sync::Arc;
use apache_avro::Schema as AvroSchema;
use apache_avro::types::Value;
use bytes::Bytes;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use crate::error::{Result, SchemaRegError};
use crate::subject::SubjectNameStrategy;
use crate::traits::SchemaRegistryClient;
use crate::types::{SchemaId, SchemaReference, SchemaType};
use crate::wire::{decode_wire_format_bytes, encode_wire_format};
fn parse_avro_schema(schema_str: &str) -> Result<AvroSchema> {
AvroSchema::parse_str(schema_str)
.map_err(|e| SchemaRegError::config(format!("invalid Avro schema: {e}")))
}
fn schema_fullname(schema: &AvroSchema) -> Option<String> {
match schema {
AvroSchema::Record(rs) => Some(rs.name.fullname(rs.name.namespace.clone())),
AvroSchema::Enum(es) => Some(es.name.fullname(es.name.namespace.clone())),
AvroSchema::Fixed(fs) => Some(fs.name.fullname(fs.name.namespace.clone())),
_ => None,
}
}
struct EncoderEntry {
schema_id: SchemaId,
avro_schema: Arc<AvroSchema>,
}
pub struct AvroSchemaEncoder<C> {
registry: C,
schema_str: String,
avro_schema: Arc<AvroSchema>,
schema_fullname: Option<String>,
strategy: SubjectNameStrategy,
references: Vec<SchemaReference>,
cache: RwLock<HashMap<String, Arc<EncoderEntry>>>,
}
impl<C: SchemaRegistryClient> AvroSchemaEncoder<C> {
pub fn builder() -> AvroSchemaEncoderBuilder<C> {
AvroSchemaEncoderBuilder::new()
}
async fn resolve_subject(&self, subject: &str) -> Result<Arc<EncoderEntry>> {
if let Some(entry) = self.cache.read().get(subject) {
return Ok(Arc::clone(entry));
}
let schema_id = self
.registry
.register_schema(
subject,
&self.schema_str,
SchemaType::Avro,
&self.references,
)
.await?;
let entry = Arc::new(EncoderEntry {
schema_id,
avro_schema: Arc::clone(&self.avro_schema),
});
self.cache
.write()
.insert(subject.to_string(), Arc::clone(&entry));
Ok(entry)
}
pub async fn encode(&self, value: Value, topic: &str, is_key: bool) -> Result<Bytes> {
let subject = self
.strategy
.subject_name(topic, self.schema_fullname.as_deref(), is_key)?;
let entry = self.resolve_subject(&subject).await?;
let raw = apache_avro::to_avro_datum(&entry.avro_schema, value)
.map_err(|e| SchemaRegError::registry(format!("Avro serialization failed: {e}")))?;
Ok(encode_wire_format(entry.schema_id, &raw))
}
pub async fn encode_ser<T: Serialize>(
&self,
value: &T,
topic: &str,
is_key: bool,
) -> Result<Bytes> {
let av_value = apache_avro::to_value(value).map_err(|e| {
SchemaRegError::registry(format!("failed to convert value to Avro: {e}"))
})?;
self.encode(av_value, topic, is_key).await
}
}
pub struct AvroSchemaEncoderBuilder<C> {
registry: Option<C>,
schema: Option<String>,
strategy: SubjectNameStrategy,
references: Vec<SchemaReference>,
}
impl<C: SchemaRegistryClient> AvroSchemaEncoderBuilder<C> {
fn new() -> Self {
Self {
registry: None,
schema: None,
strategy: SubjectNameStrategy::TopicName,
references: Vec::new(),
}
}
pub fn registry(mut self, registry: C) -> Self {
self.registry = Some(registry);
self
}
pub fn schema(mut self, schema: impl Into<String>) -> Self {
self.schema = Some(schema.into());
self
}
pub fn strategy(mut self, strategy: SubjectNameStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn references(mut self, references: Vec<SchemaReference>) -> Self {
self.references = references;
self
}
pub fn build(self) -> Result<AvroSchemaEncoder<C>> {
let registry = self
.registry
.ok_or_else(|| SchemaRegError::config("AvroSchemaEncoder: registry must be set"))?;
let schema_str = self
.schema
.ok_or_else(|| SchemaRegError::config("AvroSchemaEncoder: schema must be set"))?;
let avro_schema = parse_avro_schema(&schema_str)?;
let fullname = schema_fullname(&avro_schema);
Ok(AvroSchemaEncoder {
registry,
schema_str,
avro_schema: Arc::new(avro_schema),
schema_fullname: fullname,
strategy: self.strategy,
references: self.references,
cache: RwLock::new(HashMap::new()),
})
}
}
pub struct AvroSchemaDecoder<C> {
registry: C,
schema_cache: RwLock<HashMap<SchemaId, Arc<AvroSchema>>>,
}
impl<C: SchemaRegistryClient> AvroSchemaDecoder<C> {
pub fn new(registry: C) -> Self {
Self {
registry,
schema_cache: RwLock::new(HashMap::new()),
}
}
async fn get_avro_schema(&self, id: SchemaId) -> Result<Arc<AvroSchema>> {
if let Some(schema) = self.schema_cache.read().get(&id) {
return Ok(Arc::clone(schema));
}
let registry_schema = self.registry.get_schema_by_id(id).await?;
let avro_schema = parse_avro_schema(®istry_schema.schema)?;
let avro_schema = Arc::new(avro_schema);
self.schema_cache
.write()
.insert(id, Arc::clone(&avro_schema));
Ok(avro_schema)
}
pub async fn decode(&self, data: Bytes) -> Result<Value> {
let (schema_id, payload) = decode_wire_format_bytes(&data)?;
let avro_schema = self.get_avro_schema(schema_id).await?;
let value = apache_avro::from_avro_datum(&avro_schema, &mut payload.as_ref(), None)
.map_err(|e| SchemaRegError::registry(format!("Avro deserialization failed: {e}")))?;
Ok(value)
}
pub async fn decode_de<T: for<'de> Deserialize<'de>>(&self, data: Bytes) -> Result<T> {
let value = self.decode(data).await?;
apache_avro::from_value::<T>(&value).map_err(|e| {
SchemaRegError::registry(format!(
"failed to deserialize Avro value into target type: {e}"
))
})
}
}