use std::fmt;
use std::sync::Arc;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::error::{Result, SchemaRegError};
use crate::http::{HttpClient, HttpClientConfig};
use crate::traits::SchemaRegistryClient;
use crate::types::{ArtifactId, Schema, SchemaId, SchemaReference, SchemaType, SchemaVersion};
const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_GROUP: &str = "default";
const API_PREFIX: &str = "/apis/registry/v3";
const HDR_ARTIFACT_TYPE: &str = "x-registry-artifacttype";
const HDR_GLOBAL_ID: &str = "x-registry-globalid";
const HDR_VERSION: &str = "x-registry-version";
const HDR_GROUP_ID: &str = "x-registry-groupid";
const HDR_ARTIFACT_ID: &str = "x-registry-artifactid";
#[derive(Deserialize)]
struct ApicurioError {
detail: Option<String>,
title: Option<String>,
message: Option<String>,
}
impl ApicurioError {
fn into_message(self) -> String {
self.detail
.or(self.message)
.or(self.title)
.unwrap_or_else(|| "unknown error".to_string())
}
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct CreateArtifactRequest<'a> {
artifact_id: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
artifact_type: Option<&'a str>,
first_version: CreateVersionWrapper<'a>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct CreateVersionWrapper<'a> {
content: VersionContentRequest<'a>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct VersionContentRequest<'a> {
content: &'a str,
content_type: &'a str,
#[serde(skip_serializing_if = "<[_]>::is_empty")]
references: Vec<ArtifactReferenceJson>,
}
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct ArtifactReferenceJson {
group_id: Option<String>,
artifact_id: String,
version: String,
name: String,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct CreateArtifactResponse {
version: CreateArtifactVersionInfo,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct CreateArtifactVersionInfo {
global_id: i64,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct ArtifactVersionList {
versions: Vec<ArtifactVersionSummary>,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct ArtifactVersionSummary {
version: Option<serde_json::Value>,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct ArtifactSearchResults {
artifacts: Vec<SearchedArtifact>,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct SearchedArtifact {
artifact_id: String,
group_id: Option<String>,
}
#[derive(Deserialize)]
struct CompatibilityTestResult {
compatible: bool,
}
enum ApicurioAuth {
None,
Basic {
username: zeroize::Zeroizing<String>,
password: zeroize::Zeroizing<String>,
},
Bearer {
token: zeroize::Zeroizing<String>,
},
}
pub struct ApicurioSchemaRegistry {
client: HttpClient,
base_url: String,
auth: ApicurioAuth,
}
impl ApicurioSchemaRegistry {
pub fn new(url: impl Into<String>) -> Result<Self> {
let url = normalize_url(url.into());
reject_embedded_credentials(&url)?;
if url.starts_with("http://") {
tracing::warn!(
url = %url,
"Apicurio Registry URL uses plain HTTP — credentials and schema data will be \
transmitted in cleartext; use HTTPS in production"
);
}
let client = HttpClient::with_webpki_roots(Some(DEFAULT_REQUEST_TIMEOUT))?;
Ok(Self {
client,
base_url: url,
auth: ApicurioAuth::None,
})
}
pub fn builder() -> ApicurioSchemaRegistryBuilder {
ApicurioSchemaRegistryBuilder::default()
}
fn api_url(&self, path: &str) -> String {
format!("{}{}{}", self.base_url, API_PREFIX, path)
}
fn auth_header(&self) -> Option<zeroize::Zeroizing<String>> {
match &self.auth {
ApicurioAuth::None => None,
ApicurioAuth::Basic { username, password } => {
use base64::Engine as _;
let creds =
zeroize::Zeroizing::new(format!("{}:{}", username.as_str(), password.as_str()));
let encoded = zeroize::Zeroizing::new(
base64::engine::general_purpose::STANDARD.encode(creds.as_bytes()),
);
Some(zeroize::Zeroizing::new(format!(
"Basic {}",
encoded.as_str()
)))
}
ApicurioAuth::Bearer { token } => Some(zeroize::Zeroizing::new(format!(
"Bearer {}",
token.as_str()
))),
}
}
fn auth_str(&self) -> Option<String> {
self.auth_header().map(|z| z.as_str().to_string())
}
async fn get_json<T: serde::de::DeserializeOwned>(&self, url: &str) -> Result<T> {
let auth = self.auth_str();
let resp = self
.client
.request(
"GET",
url,
&[("Accept", "application/json")],
None,
auth.as_deref(),
)
.await?;
self.handle_json_response(resp.status, resp.content_type.as_deref(), &resp.body)
}
async fn get_content(&self, url: &str) -> Result<crate::http::HttpResponse> {
let auth = self.auth_str();
let resp = self
.client
.request("GET", url, &[("Accept", "*/*")], None, auth.as_deref())
.await?;
if resp.status == 404 {
let msg = self.parse_error_body(&resp.body);
return Err(SchemaRegError::api(40401, msg));
}
if resp.status == 401 || resp.status == 403 {
let msg = self.parse_error_body(&resp.body);
return Err(SchemaRegError::auth(resp.status, msg));
}
if !(200..300).contains(&resp.status as &u16) {
let msg = self.parse_error_body(&resp.body);
return Err(SchemaRegError::http(resp.status, msg));
}
Ok(resp)
}
async fn post_json<T: serde::de::DeserializeOwned>(&self, url: &str, body: &[u8]) -> Result<T> {
let auth = self.auth_str();
let resp = self
.client
.request(
"POST",
url,
&[
("Accept", "application/json"),
("Content-Type", "application/json"),
],
Some(body),
auth.as_deref(),
)
.await?;
self.handle_json_response(resp.status, resp.content_type.as_deref(), &resp.body)
}
async fn delete_no_content(&self, url: &str) -> Result<()> {
let auth = self.auth_str();
let resp = self
.client
.request("DELETE", url, &[], None, auth.as_deref())
.await?;
if resp.status == 404 {
let msg = self.parse_error_body(&resp.body);
return Err(SchemaRegError::api(40401, msg));
}
if resp.status == 401 || resp.status == 403 {
let msg = self.parse_error_body(&resp.body);
return Err(SchemaRegError::auth(resp.status, msg));
}
if resp.status == 204 || (200..300).contains(&resp.status) {
return Ok(());
}
let msg = self.parse_error_body(&resp.body);
Err(SchemaRegError::http(resp.status, msg))
}
fn handle_json_response<T: serde::de::DeserializeOwned>(
&self,
status: u16,
_content_type: Option<&str>,
body: &[u8],
) -> Result<T> {
if (200..300).contains(&status) {
return serde_json::from_slice(body).map_err(|e| {
SchemaRegError::invalid_state(format!("failed to parse Apicurio response: {e}"))
});
}
let msg = self.parse_error_body(body);
match status {
401 | 403 => Err(SchemaRegError::auth(status, msg)),
404 => Err(SchemaRegError::api(40401, msg)),
409 => Err(SchemaRegError::api(40902, msg)),
_ => Err(SchemaRegError::http(status, msg)),
}
}
fn parse_error_body(&self, body: &[u8]) -> String {
serde_json::from_slice::<ApicurioError>(body)
.ok()
.map(ApicurioError::into_message)
.unwrap_or_else(|| {
let preview = String::from_utf8_lossy(body);
if preview.is_empty() {
"<empty>".to_string()
} else {
preview.chars().take(256).collect()
}
})
}
fn schema_from_content_response(
&self,
resp: crate::http::HttpResponse,
fallback_id: Option<SchemaId>,
fallback_subject: Option<&str>,
) -> Result<Arc<Schema>> {
let schema_type = resp
.headers
.get(HDR_ARTIFACT_TYPE)
.and_then(|s| s.parse::<SchemaType>().ok())
.unwrap_or(SchemaType::Avro);
let id = resp
.headers
.get(HDR_GLOBAL_ID)
.and_then(|s| s.parse::<i64>().ok())
.map(|v| SchemaId::from(v as u32))
.or(fallback_id)
.unwrap_or_else(|| SchemaId::from(0u32));
let version = resp
.headers
.get(HDR_VERSION)
.and_then(|s| s.parse::<i32>().ok())
.map(SchemaVersion::new);
let group_id = resp
.headers
.get(HDR_GROUP_ID)
.cloned()
.unwrap_or_else(|| DEFAULT_GROUP.to_string());
let artifact_id_header = resp.headers.get(HDR_ARTIFACT_ID).cloned();
let subject = artifact_id_header
.map(|a| format!("{group_id}/{a}"))
.or_else(|| fallback_subject.map(str::to_string));
let schema_str = String::from_utf8(resp.body.to_vec()).map_err(|e| {
SchemaRegError::wire_format(format!("invalid UTF-8 in Apicurio schema content: {e}"))
})?;
Ok(Arc::new(Schema {
id,
schema_type,
schema: schema_str.into(),
version,
subject,
references: Vec::new(),
}))
}
fn parse_version(v: &serde_json::Value) -> Option<SchemaVersion> {
match v {
serde_json::Value::String(s) => s.parse::<i32>().ok().map(SchemaVersion::new),
serde_json::Value::Number(n) => n.as_i64().map(|n| SchemaVersion::new(n as i32)),
_ => None,
}
}
fn to_reference_json(refs: &[SchemaReference]) -> Vec<ArtifactReferenceJson> {
refs.iter()
.map(|r| {
let artifact = ArtifactId::from_subject(&r.subject);
ArtifactReferenceJson {
group_id: Some(artifact.group),
artifact_id: artifact.artifact,
version: r.version.as_i32().to_string(),
name: r.name.clone(),
}
})
.collect()
}
pub async fn get_latest_schema_impl(&self, subject: &str) -> Result<Arc<Schema>> {
let artifact_id = ArtifactId::from_subject(subject);
let url = self.api_url(&format!(
"/groups/{}/artifacts/{}/versions/branch=latest/content?returnArtifactType=true",
percent_encode(&artifact_id.group),
percent_encode(&artifact_id.artifact),
));
let resp = self.get_content(&url).await?;
self.schema_from_content_response(resp, None, Some(subject))
}
pub async fn get_schema_by_version_impl(
&self,
subject: &str,
version: SchemaVersion,
) -> Result<Arc<Schema>> {
let artifact_id = ArtifactId::from_subject(subject);
let version_str = version.as_i32().to_string();
let url = self.api_url(&format!(
"/groups/{}/artifacts/{}/versions/{}/content?returnArtifactType=true",
percent_encode(&artifact_id.group),
percent_encode(&artifact_id.artifact),
percent_encode(&version_str),
));
let resp = self.get_content(&url).await?;
self.schema_from_content_response(resp, None, Some(subject))
}
pub async fn register_schema_impl(
&self,
subject: &str,
schema: &str,
schema_type: SchemaType,
references: &[SchemaReference],
) -> Result<SchemaId> {
let artifact_id = ArtifactId::from_subject(subject);
let url = self.api_url(&format!(
"/groups/{}/artifacts?ifExists=FIND_OR_CREATE_VERSION",
percent_encode(&artifact_id.group),
));
let refs = Self::to_reference_json(references);
let content_type = schema_content_type(schema_type);
let req = CreateArtifactRequest {
artifact_id: &artifact_id.artifact,
artifact_type: Some(schema_type.as_str()),
first_version: CreateVersionWrapper {
content: VersionContentRequest {
content: schema,
content_type,
references: refs,
},
},
};
let body = serde_json::to_vec(&req).map_err(|e| {
SchemaRegError::invalid_state(format!("failed to serialise Apicurio request: {e}"))
})?;
let result: CreateArtifactResponse = self.post_json(&url, &body).await?;
Ok(SchemaId::from(result.version.global_id as u32))
}
pub async fn check_compatibility_impl(
&self,
subject: &str,
schema: &str,
schema_type: SchemaType,
references: &[SchemaReference],
) -> Result<bool> {
let artifact_id = ArtifactId::from_subject(subject);
let url = self.api_url(&format!(
"/groups/{}/artifacts/{}/versions/branch=latest/compatibility",
percent_encode(&artifact_id.group),
percent_encode(&artifact_id.artifact),
));
let refs = Self::to_reference_json(references);
let content_type = schema_content_type(schema_type);
let req = serde_json::json!({
"content": {
"content": schema,
"contentType": content_type,
"references": refs,
}
});
let body = serde_json::to_vec(&req).map_err(|e| {
SchemaRegError::invalid_state(format!("failed to serialise compatibility request: {e}"))
})?;
let result: CompatibilityTestResult = self.post_json(&url, &body).await?;
Ok(result.compatible)
}
pub async fn delete_artifact(&self, subject: &str) -> Result<()> {
let artifact_id = ArtifactId::from_subject(subject);
let url = self.api_url(&format!(
"/groups/{}/artifacts/{}",
percent_encode(&artifact_id.group),
percent_encode(&artifact_id.artifact),
));
self.delete_no_content(&url).await
}
pub async fn list_subjects(&self, limit: usize) -> Result<Vec<String>> {
let url = self.api_url(&format!("/search/artifacts?limit={limit}"));
let results: ArtifactSearchResults = self.get_json(&url).await?;
Ok(results
.artifacts
.into_iter()
.map(|a| {
let group = a.group_id.unwrap_or_else(|| DEFAULT_GROUP.to_string());
format!("{group}/{}", a.artifact_id)
})
.collect())
}
pub async fn list_versions(&self, subject: &str) -> Result<Vec<SchemaVersion>> {
let artifact_id = ArtifactId::from_subject(subject);
let url = self.api_url(&format!(
"/groups/{}/artifacts/{}/versions?limit=500",
percent_encode(&artifact_id.group),
percent_encode(&artifact_id.artifact),
));
let list: ArtifactVersionList = self.get_json(&url).await?;
let mut versions = Vec::with_capacity(list.versions.len());
for v in list.versions {
if let Some(ver_val) = v.version
&& let Some(sv) = Self::parse_version(&ver_val)
{
versions.push(sv);
}
}
Ok(versions)
}
}
impl fmt::Debug for ApicurioSchemaRegistry {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let auth_desc = match &self.auth {
ApicurioAuth::None => "none",
ApicurioAuth::Basic { .. } => "basic(***)",
ApicurioAuth::Bearer { .. } => "bearer(***)",
};
f.debug_struct("ApicurioSchemaRegistry")
.field("base_url", &self.base_url)
.field("auth", &auth_desc)
.finish()
}
}
impl SchemaRegistryClient for ApicurioSchemaRegistry {
async fn get_schema_by_id(&self, id: SchemaId) -> Result<Arc<Schema>> {
let url = self.api_url(&format!(
"/ids/globalIds/{}?returnArtifactType=true",
id.as_u32()
));
let resp = self.get_content(&url).await?;
self.schema_from_content_response(resp, Some(id), None)
}
async fn get_latest_schema(&self, subject: &str) -> Result<Arc<Schema>> {
self.get_latest_schema_impl(subject).await
}
async fn get_schema_by_version(
&self,
subject: &str,
version: SchemaVersion,
) -> Result<Arc<Schema>> {
self.get_schema_by_version_impl(subject, version).await
}
async fn register_schema(
&self,
subject: &str,
schema: &str,
schema_type: SchemaType,
references: &[SchemaReference],
) -> Result<SchemaId> {
self.register_schema_impl(subject, schema, schema_type, references)
.await
}
fn check_compatibility<'a>(
&'a self,
subject: &'a str,
schema: &'a str,
schema_type: SchemaType,
references: &'a [SchemaReference],
) -> impl std::future::Future<Output = Result<bool>> + Send + 'a {
self.check_compatibility_impl(subject, schema, schema_type, references)
}
async fn delete_subject<'a>(
&'a self,
subject: &'a str,
_permanent: bool,
) -> Result<Vec<SchemaVersion>> {
self.delete_artifact(subject).await?;
Ok(Vec::new())
}
fn get_subjects(&self) -> impl std::future::Future<Output = Result<Vec<String>>> + Send + '_ {
self.list_subjects(500)
}
fn get_versions<'a>(
&'a self,
subject: &'a str,
) -> impl std::future::Future<Output = Result<Vec<SchemaVersion>>> + Send + 'a {
self.list_versions(subject)
}
}
pub struct ApicurioSchemaRegistryBuilder {
url: Option<String>,
auth: ApicurioAuth,
request_timeout: Option<Duration>,
root_certificates: Vec<reqwest::Certificate>,
identity: Option<reqwest::Identity>,
}
impl Default for ApicurioSchemaRegistryBuilder {
fn default() -> Self {
Self {
url: None,
auth: ApicurioAuth::None,
request_timeout: Some(DEFAULT_REQUEST_TIMEOUT),
root_certificates: Vec::new(),
identity: None,
}
}
}
impl ApicurioSchemaRegistryBuilder {
pub fn url(mut self, url: impl Into<String>) -> Self {
self.url = Some(url.into());
self
}
pub fn basic_auth(mut self, username: impl Into<String>, password: impl Into<String>) -> Self {
self.auth = ApicurioAuth::Basic {
username: zeroize::Zeroizing::new(username.into()),
password: zeroize::Zeroizing::new(password.into()),
};
self
}
pub fn bearer_token(mut self, token: impl Into<String>) -> Self {
self.auth = ApicurioAuth::Bearer {
token: zeroize::Zeroizing::new(token.into()),
};
self
}
pub fn request_timeout(mut self, timeout: Duration) -> Self {
self.request_timeout = Some(timeout);
self
}
pub fn add_root_certificate(mut self, cert: reqwest::Certificate) -> Self {
self.root_certificates.push(cert);
self
}
pub fn identity(mut self, identity: reqwest::Identity) -> Self {
self.identity = Some(identity);
self
}
pub fn build(self) -> Result<ApicurioSchemaRegistry> {
let raw_url = self.url.ok_or_else(|| {
SchemaRegError::config("ApicurioSchemaRegistryBuilder: URL is required")
})?;
let url = normalize_url(raw_url);
reject_embedded_credentials(&url)?;
if url.starts_with("http://") && !matches!(&self.auth, ApicurioAuth::None) {
tracing::warn!(
url = %url,
"Apicurio Registry URL uses plain HTTP with authentication — credentials will \
be sent in cleartext; use HTTPS in production"
);
}
let client = HttpClient::with_config(HttpClientConfig {
timeout: self.request_timeout,
root_certificates: self.root_certificates,
identity: self.identity,
})?;
Ok(ApicurioSchemaRegistry {
client,
base_url: url,
auth: self.auth,
})
}
}
fn schema_content_type(schema_type: SchemaType) -> &'static str {
match schema_type {
SchemaType::Avro | SchemaType::Json => "application/json",
SchemaType::Protobuf => "application/x-protobuf",
}
}
fn normalize_url(mut url: String) -> String {
let trimmed_len = url.trim_end_matches('/').len();
url.truncate(trimmed_len);
url
}
fn reject_embedded_credentials(url: &str) -> Result<()> {
let Some(scheme_end) = url.find("://") else {
return Ok(());
};
let authority_start = scheme_end + 3;
let authority = &url[authority_start..];
let authority_end = authority.find(['/', '?', '#']).unwrap_or(authority.len());
let authority_slice = &authority[..authority_end];
if authority_slice.contains('@') {
return Err(SchemaRegError::config(
"Apicurio Registry URL must not contain embedded credentials (user:pass@host); \
use ApicurioSchemaRegistryBuilder::basic_auth() instead",
));
}
Ok(())
}
static PATH_SEGMENT_ENCODE_SET: percent_encoding::AsciiSet = percent_encoding::CONTROLS
.add(b' ')
.add(b'"')
.add(b'#')
.add(b'<')
.add(b'>')
.add(b'?')
.add(b'`')
.add(b'{')
.add(b'}')
.add(b'/')
.add(b'%')
.add(b'[')
.add(b']')
.add(b'\\')
.add(b'^')
.add(b'@');
fn percent_encode(input: &str) -> String {
percent_encoding::utf8_percent_encode(input, &PATH_SEGMENT_ENCODE_SET).to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_artifact_id_subject_roundtrip() {
let id = ArtifactId::from_subject("default/orders-value");
assert_eq!(id.group, "default");
assert_eq!(id.artifact, "orders-value");
assert_eq!(id.to_subject(), "default/orders-value");
}
#[test]
fn test_artifact_id_bare_subject_uses_default_group() {
let id = ArtifactId::from_subject("orders-value");
assert_eq!(id.group, "default");
assert_eq!(id.artifact, "orders-value");
}
#[test]
fn test_artifact_id_custom_group() {
let id = ArtifactId::from_subject("production/payments-key");
assert_eq!(id.group, "production");
assert_eq!(id.artifact, "payments-key");
assert_eq!(id.to_subject(), "production/payments-key");
}
#[test]
fn test_schema_content_type() {
assert_eq!(schema_content_type(SchemaType::Avro), "application/json");
assert_eq!(schema_content_type(SchemaType::Json), "application/json");
assert_eq!(
schema_content_type(SchemaType::Protobuf),
"application/x-protobuf"
);
}
#[test]
fn test_debug_masks_credentials() {
let registry = ApicurioSchemaRegistry {
client: HttpClient::with_webpki_roots(None).unwrap(),
base_url: "http://localhost:8080".to_string(),
auth: ApicurioAuth::Bearer {
token: zeroize::Zeroizing::new("secret".to_string()),
},
};
let dbg = format!("{registry:?}");
assert!(dbg.contains("bearer(***)"));
assert!(!dbg.contains("secret"));
}
#[test]
fn test_normalize_url_strips_trailing_slash() {
assert_eq!(
normalize_url("http://localhost:8080/".to_string()),
"http://localhost:8080"
);
assert_eq!(
normalize_url("http://localhost:8080".to_string()),
"http://localhost:8080"
);
}
#[test]
fn test_reject_embedded_credentials() {
assert!(reject_embedded_credentials("http://user:pass@localhost:8080").is_err());
assert!(reject_embedded_credentials("http://localhost:8080").is_ok());
}
#[test]
fn test_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<ApicurioSchemaRegistry>();
}
}