use std::fmt;
use std::time::Duration;
use reqwest::header::{ACCEPT, CONTENT_TYPE};
use serde::{Deserialize, Serialize};
use super::{Schema, SchemaId, SchemaReference, SchemaRegistryClient, SchemaType, SchemaVersion};
use crate::error::{KrafkaError, Result};
const SCHEMA_REGISTRY_CONTENT_TYPE: &str = "application/vnd.schemaregistry.v1+json";
const ERROR_BODY_PREVIEW_LIMIT: usize = 512;
#[derive(Deserialize)]
struct SchemaByIdResponse {
schema: String,
#[serde(rename = "schemaType", default = "default_avro_type")]
schema_type: String,
references: Option<Vec<ReferenceJson>>,
}
#[derive(Deserialize)]
struct SchemaBySubjectResponse {
id: SchemaId,
schema: String,
version: SchemaVersion,
subject: String,
#[serde(rename = "schemaType", default = "default_avro_type")]
schema_type: String,
references: Option<Vec<ReferenceJson>>,
}
#[derive(Deserialize)]
struct RegisterSchemaResponse {
id: SchemaId,
}
#[derive(Deserialize)]
struct CompatibilityResponse {
is_compatible: bool,
}
#[derive(Deserialize)]
struct ErrorResponse {
error_code: i32,
message: String,
}
#[derive(Serialize, Deserialize)]
struct ReferenceJson {
name: String,
subject: String,
version: SchemaVersion,
}
#[derive(Serialize)]
struct RegisterSchemaRequest<'a> {
schema: &'a str,
#[serde(rename = "schemaType")]
schema_type: &'a str,
#[serde(skip_serializing_if = "Vec::is_empty")]
references: Vec<ReferenceJson>,
}
fn default_avro_type() -> String {
"AVRO".to_string()
}
fn sanitized_error_body_preview(body: &str) -> String {
if body.is_empty() {
return "<empty>".to_string();
}
let mut preview = String::new();
let mut truncated = false;
for ch in body.chars() {
let replacement = match ch {
'\n' => "\\n".to_string(),
'\r' => "\\r".to_string(),
'\t' => "\\t".to_string(),
ch if ch.is_control() => "?".to_string(),
ch => ch.to_string(),
};
if preview.len() + replacement.len() > ERROR_BODY_PREVIEW_LIMIT {
truncated = true;
break;
}
preview.push_str(&replacement);
}
if truncated {
preview.push_str("...[truncated]");
}
preview
}
#[derive(Default)]
enum RegistryAuth {
#[default]
None,
Basic {
username: zeroize::Zeroizing<String>,
password: zeroize::Zeroizing<String>,
},
Bearer {
token: zeroize::Zeroizing<String>,
},
}
pub struct ConfluentSchemaRegistry {
client: reqwest::Client,
base_url: String,
auth: RegistryAuth,
}
impl ConfluentSchemaRegistry {
pub fn new(url: impl Into<String>) -> Result<Self> {
let url = normalize_url(url.into());
reject_embedded_credentials(&url)?;
Ok(Self {
client: reqwest::Client::new(),
base_url: url,
auth: RegistryAuth::None,
})
}
pub fn builder() -> ConfluentSchemaRegistryBuilder {
ConfluentSchemaRegistryBuilder::default()
}
pub async fn check_compatibility(
&self,
subject: &str,
schema: &str,
schema_type: SchemaType,
references: &[SchemaReference],
) -> Result<bool> {
let url = format!(
"{}/compatibility/subjects/{}/versions/latest",
self.base_url,
percent_encode(subject)
);
let body = RegisterSchemaRequest {
schema,
schema_type: schema_type.as_str(),
references: Self::to_reference_json(references),
};
let result: CompatibilityResponse = self
.send_request(
self.client
.post(&url)
.header(CONTENT_TYPE, SCHEMA_REGISTRY_CONTENT_TYPE)
.json(&body),
)
.await?;
Ok(result.is_compatible)
}
pub async fn get_subjects(&self) -> Result<Vec<String>> {
let url = format!("{}/subjects", self.base_url);
self.send_request(
self.client
.get(&url)
.header(ACCEPT, SCHEMA_REGISTRY_CONTENT_TYPE),
)
.await
}
pub async fn get_versions(&self, subject: &str) -> Result<Vec<SchemaVersion>> {
let url = format!(
"{}/subjects/{}/versions",
self.base_url,
percent_encode(subject)
);
self.send_request(
self.client
.get(&url)
.header(ACCEPT, SCHEMA_REGISTRY_CONTENT_TYPE),
)
.await
}
pub async fn delete_subject(
&self,
subject: &str,
permanent: bool,
) -> Result<Vec<SchemaVersion>> {
let mut url = format!("{}/subjects/{}", self.base_url, percent_encode(subject));
if permanent {
url.push_str("?permanent=true");
}
self.send_request(
self.client
.delete(&url)
.header(ACCEPT, SCHEMA_REGISTRY_CONTENT_TYPE),
)
.await
}
fn apply_auth(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
match &self.auth {
RegistryAuth::None => builder,
RegistryAuth::Basic { username, password } => {
builder.basic_auth(username.as_str(), Some(password.as_str()))
}
RegistryAuth::Bearer { token } => builder.bearer_auth(token.as_str()),
}
}
async fn handle_response<T: serde::de::DeserializeOwned>(
response: reqwest::Response,
) -> Result<T> {
let status = response.status();
if status.is_success() {
response.json::<T>().await.map_err(|e| {
KrafkaError::schema_registry_with_source("failed to parse response", e)
})
} else {
let body = response.text().await.unwrap_or_default();
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
Err(KrafkaError::schema_registry(format!(
"{} (error code {})",
err.message, err.error_code
)))
} else {
let body = sanitized_error_body_preview(&body);
Err(KrafkaError::schema_registry(format!(
"HTTP {status}: {body}"
)))
}
}
}
async fn send_request<T: serde::de::DeserializeOwned>(
&self,
request: reqwest::RequestBuilder,
) -> Result<T> {
let response = self
.apply_auth(request)
.send()
.await
.map_err(|e| KrafkaError::schema_registry_with_source("request failed", e))?;
Self::handle_response(response).await
}
fn to_reference_json(refs: &[SchemaReference]) -> Vec<ReferenceJson> {
refs.iter()
.map(|r| ReferenceJson {
name: r.name.clone(),
subject: r.subject.clone(),
version: r.version,
})
.collect()
}
fn parse_references(refs: Option<Vec<ReferenceJson>>) -> Vec<SchemaReference> {
refs.unwrap_or_default()
.into_iter()
.map(|r| SchemaReference {
name: r.name,
subject: r.subject,
version: r.version,
})
.collect()
}
fn schema_from_subject_response(body: SchemaBySubjectResponse) -> Result<Schema> {
let schema_type: SchemaType = body.schema_type.parse()?;
Ok(Schema {
id: body.id,
schema_type,
schema: body.schema,
version: Some(body.version),
subject: Some(body.subject),
references: Self::parse_references(body.references),
})
}
}
impl fmt::Debug for ConfluentSchemaRegistry {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let auth_desc = match &self.auth {
RegistryAuth::None => "none",
RegistryAuth::Basic { .. } => "basic(***)",
RegistryAuth::Bearer { .. } => "bearer(***)",
};
f.debug_struct("ConfluentSchemaRegistry")
.field("base_url", &self.base_url)
.field("auth", &auth_desc)
.finish()
}
}
impl SchemaRegistryClient for ConfluentSchemaRegistry {
async fn get_schema_by_id(&self, id: SchemaId) -> Result<Schema> {
let url = format!("{}/schemas/ids/{id}", self.base_url);
let body: SchemaByIdResponse = self
.send_request(
self.client
.get(&url)
.header(ACCEPT, SCHEMA_REGISTRY_CONTENT_TYPE),
)
.await?;
let schema_type: SchemaType = body.schema_type.parse()?;
Ok(Schema {
id,
schema_type,
schema: body.schema,
version: None,
subject: None,
references: Self::parse_references(body.references),
})
}
async fn get_latest_schema(&self, subject: &str) -> Result<Schema> {
let url = format!(
"{}/subjects/{}/versions/latest",
self.base_url,
percent_encode(subject)
);
let body: SchemaBySubjectResponse = self
.send_request(
self.client
.get(&url)
.header(ACCEPT, SCHEMA_REGISTRY_CONTENT_TYPE),
)
.await?;
Self::schema_from_subject_response(body)
}
async fn get_schema_by_version(&self, subject: &str, version: SchemaVersion) -> Result<Schema> {
let url = format!(
"{}/subjects/{}/versions/{version}",
self.base_url,
percent_encode(subject)
);
let body: SchemaBySubjectResponse = self
.send_request(
self.client
.get(&url)
.header(ACCEPT, SCHEMA_REGISTRY_CONTENT_TYPE),
)
.await?;
Self::schema_from_subject_response(body)
}
async fn register_schema(
&self,
subject: &str,
schema: &str,
schema_type: SchemaType,
references: &[SchemaReference],
) -> Result<SchemaId> {
let refs = Self::to_reference_json(references);
let url = format!(
"{}/subjects/{}/versions",
self.base_url,
percent_encode(subject)
);
let body = RegisterSchemaRequest {
schema,
schema_type: schema_type.as_str(),
references: refs,
};
let result: RegisterSchemaResponse = self
.send_request(
self.client
.post(&url)
.header(CONTENT_TYPE, SCHEMA_REGISTRY_CONTENT_TYPE)
.json(&body),
)
.await?;
Ok(result.id)
}
async fn check_compatibility(
&self,
subject: &str,
schema: &str,
schema_type: SchemaType,
references: &[SchemaReference],
) -> Result<bool> {
ConfluentSchemaRegistry::check_compatibility(self, subject, schema, schema_type, references)
.await
}
async fn delete_subject(&self, subject: &str, permanent: bool) -> Result<Vec<SchemaVersion>> {
ConfluentSchemaRegistry::delete_subject(self, subject, permanent).await
}
async fn get_subjects(&self) -> Result<Vec<String>> {
ConfluentSchemaRegistry::get_subjects(self).await
}
async fn get_versions(&self, subject: &str) -> Result<Vec<SchemaVersion>> {
ConfluentSchemaRegistry::get_versions(self, subject).await
}
}
fn normalize_url(mut url: String) -> String {
let trimmed_len = url.trim_end_matches('/').len();
url.truncate(trimmed_len);
url
}
fn reject_embedded_credentials(url: &str) -> Result<()> {
let Some(scheme_end) = url.find("://") else {
return Ok(());
};
let authority_start = scheme_end + 3;
let authority = &url[authority_start..];
let authority_end = authority.find(['/', '?', '#']).unwrap_or(authority.len());
let authority_slice = &authority[..authority_end];
if authority_slice.contains('@') {
return Err(KrafkaError::config(
"schema registry URL must not contain embedded credentials (user:pass@host); \
use ConfluentSchemaRegistryBuilder::basic_auth() instead",
));
}
Ok(())
}
#[cfg(test)]
fn masked_userinfo_indicator(_userinfo: &str) -> &'static str {
"<***@>"
}
#[cfg(test)]
fn sanitize_url(url: String) -> String {
normalize_url(url)
}
fn percent_encode(input: &str) -> String {
let mut encoded = String::with_capacity(input.len());
for c in input.chars() {
match c {
'%' => encoded.push_str("%25"),
'/' => encoded.push_str("%2F"),
' ' => encoded.push_str("%20"),
'#' => encoded.push_str("%23"),
'?' => encoded.push_str("%3F"),
_ => encoded.push(c),
}
}
encoded
}
pub struct ConfluentSchemaRegistryBuilder {
url: Option<String>,
auth: RegistryAuth,
request_timeout: Option<Duration>,
}
impl Default for ConfluentSchemaRegistryBuilder {
fn default() -> Self {
Self {
url: None,
auth: RegistryAuth::None,
request_timeout: Some(Duration::from_secs(30)),
}
}
}
impl ConfluentSchemaRegistryBuilder {
pub fn url(mut self, url: impl Into<String>) -> Self {
self.url = Some(url.into());
self
}
pub fn basic_auth(mut self, username: impl Into<String>, password: impl Into<String>) -> Self {
self.auth = RegistryAuth::Basic {
username: zeroize::Zeroizing::new(username.into()),
password: zeroize::Zeroizing::new(password.into()),
};
self
}
pub fn bearer_token(mut self, token: impl Into<String>) -> Self {
self.auth = RegistryAuth::Bearer {
token: zeroize::Zeroizing::new(token.into()),
};
self
}
pub fn request_timeout(mut self, timeout: Duration) -> Self {
self.request_timeout = Some(timeout);
self
}
pub fn clear_request_timeout(mut self) -> Self {
self.request_timeout = None;
self
}
pub fn build(self) -> Result<ConfluentSchemaRegistry> {
let url = self
.url
.ok_or_else(|| KrafkaError::config("schema registry URL is required"))?;
reject_embedded_credentials(&url)?;
if matches!(
self.auth,
RegistryAuth::Basic { .. } | RegistryAuth::Bearer { .. }
) && url.starts_with("http://")
{
return Err(KrafkaError::config(
"schema registry auth requires HTTPS — credentials would be sent in cleartext over HTTP",
));
}
let mut http_builder = reqwest::Client::builder();
if let Some(timeout) = self.request_timeout {
http_builder = http_builder.timeout(timeout);
}
let client = http_builder.build().map_err(|e| {
KrafkaError::schema_registry(format!("failed to build HTTP client: {e}"))
})?;
Ok(ConfluentSchemaRegistry {
client,
base_url: normalize_url(url),
auth: self.auth,
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
use tokio::net::TcpListener;
#[test]
fn test_builder_missing_url() {
let result = ConfluentSchemaRegistryBuilder::default().build();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("URL"));
}
#[test]
fn test_builder_with_url() {
let client = ConfluentSchemaRegistryBuilder::default()
.url("http://localhost:8081")
.build()
.unwrap();
assert_eq!(client.base_url, "http://localhost:8081");
}
#[test]
fn test_builder_with_request_timeout() {
let builder = ConfluentSchemaRegistryBuilder::default()
.url("http://localhost:8081")
.request_timeout(Duration::from_secs(2));
assert_eq!(builder.request_timeout, Some(Duration::from_secs(2)));
let client = builder.build().unwrap();
assert_eq!(client.base_url, "http://localhost:8081");
}
#[test]
fn test_builder_clear_request_timeout() {
let builder = ConfluentSchemaRegistryBuilder::default()
.url("http://localhost:8081")
.request_timeout(Duration::from_secs(2))
.clear_request_timeout();
assert_eq!(builder.request_timeout, None);
let client = builder.build().unwrap();
assert_eq!(client.base_url, "http://localhost:8081");
}
#[tokio::test]
async fn test_builder_request_timeout_applies_to_built_client() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (_socket, _) = listener.accept().await.unwrap();
tokio::time::sleep(Duration::from_millis(500)).await;
});
let client = ConfluentSchemaRegistryBuilder::default()
.url(format!("http://{addr}"))
.request_timeout(Duration::from_millis(30))
.build()
.unwrap();
let timed = tokio::time::timeout(Duration::from_secs(2), client.get_schema_by_id(1))
.await
.expect("request_timeout should complete the request with an error");
let err = timed.unwrap_err();
assert!(err.to_string().contains("request failed"));
server.abort();
}
#[tokio::test]
async fn test_builder_clear_request_timeout_removes_client_deadline() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (_socket, _) = listener.accept().await.unwrap();
tokio::time::sleep(Duration::from_millis(500)).await;
});
let client = ConfluentSchemaRegistryBuilder::default()
.url(format!("http://{addr}"))
.request_timeout(Duration::from_millis(20))
.clear_request_timeout()
.build()
.unwrap();
let result =
tokio::time::timeout(Duration::from_millis(150), client.get_schema_by_id(1)).await;
assert!(
result.is_err(),
"request unexpectedly completed with cleared timeout"
);
server.abort();
}
#[test]
fn test_builder_strips_trailing_slash() {
let client = ConfluentSchemaRegistryBuilder::default()
.url("http://localhost:8081/")
.build()
.unwrap();
assert_eq!(client.base_url, "http://localhost:8081");
}
#[test]
fn test_new_strips_trailing_slash() {
let client = ConfluentSchemaRegistry::new("http://localhost:8081/").unwrap();
assert_eq!(client.base_url, "http://localhost:8081");
}
#[test]
fn test_debug_redacts_basic_auth() {
let client = ConfluentSchemaRegistryBuilder::default()
.url("https://localhost:8081")
.basic_auth("admin", "s3cret")
.build()
.unwrap();
let debug = format!("{client:?}");
assert!(debug.contains("basic(***)"));
assert!(!debug.contains("s3cret"));
assert!(!debug.contains("admin"));
}
#[test]
fn test_debug_redacts_bearer_token() {
let client = ConfluentSchemaRegistryBuilder::default()
.url("https://localhost:8081")
.bearer_token("my-secret-token")
.build()
.unwrap();
let debug = format!("{client:?}");
assert!(debug.contains("bearer(***)"));
assert!(!debug.contains("my-secret-token"));
}
#[test]
fn test_builder_rejects_bearer_token_over_http() {
let err = ConfluentSchemaRegistryBuilder::default()
.url("http://localhost:8081")
.bearer_token("my-secret-token")
.build()
.unwrap_err();
assert!(
err.to_string().contains("requires HTTPS"),
"expected HTTPS auth guard, got: {err}"
);
}
#[test]
fn test_debug_no_auth() {
let client = ConfluentSchemaRegistry::new("http://localhost:8081").unwrap();
let debug = format!("{client:?}");
assert!(debug.contains("none"));
}
#[test]
fn test_normalize_url_no_userinfo() {
let url = normalize_url("https://registry.example.com:8081".to_string());
assert_eq!(url, "https://registry.example.com:8081");
}
#[test]
fn test_normalize_url_no_scheme() {
let url = normalize_url("localhost:8081".to_string());
assert_eq!(url, "localhost:8081");
}
#[test]
fn test_normalize_url_strips_trailing_slashes() {
let url = normalize_url("https://registry.example.com:8081/".to_string());
assert_eq!(url, "https://registry.example.com:8081");
}
#[test]
fn test_reject_embedded_credentials_errors_on_user_pass() {
let err =
reject_embedded_credentials("https://admin:s3cret@registry.example.com:8081/path")
.unwrap_err();
assert!(err.to_string().contains("embedded credentials"));
}
#[test]
fn test_reject_embedded_credentials_errors_on_user_only() {
let err = reject_embedded_credentials("https://admin@registry.example.com").unwrap_err();
assert!(err.to_string().contains("embedded credentials"));
}
#[test]
fn test_reject_embedded_credentials_ok_no_userinfo() {
assert!(reject_embedded_credentials("https://registry.example.com:8081").is_ok());
}
#[test]
fn test_reject_embedded_credentials_ok_no_scheme() {
assert!(reject_embedded_credentials("localhost:8081").is_ok());
}
#[test]
fn test_new_rejects_embedded_credentials() {
let err = ConfluentSchemaRegistry::new("https://user:pass@host:8081").unwrap_err();
assert!(err.to_string().contains("embedded credentials"));
}
#[test]
fn test_new_accepts_clean_url() {
let client = ConfluentSchemaRegistry::new("https://host:8081").unwrap();
assert_eq!(client.base_url, "https://host:8081");
}
#[test]
fn test_builder_rejects_embedded_credentials() {
let err = ConfluentSchemaRegistryBuilder::default()
.url("https://user:pass@host:8081/")
.build()
.unwrap_err();
assert!(err.to_string().contains("embedded credentials"));
}
#[test]
fn test_masked_userinfo_indicator_never_reveals_userinfo() {
assert_eq!(masked_userinfo_indicator("admin:s3cret"), "<***@>");
assert_eq!(masked_userinfo_indicator("admin"), "<***@>");
assert_eq!(masked_userinfo_indicator("opaque-token"), "<***@>");
assert_eq!(masked_userinfo_indicator(":s3cret"), "<***@>");
}
#[test]
fn test_sanitize_url_strips_trailing_slashes() {
let url = sanitize_url("https://registry.example.com:8081/".to_string());
assert_eq!(url, "https://registry.example.com:8081");
}
#[test]
fn test_percent_encode() {
assert_eq!(percent_encode("simple"), "simple");
assert_eq!(percent_encode("has/slash"), "has%2Fslash");
assert_eq!(percent_encode("has space"), "has%20space");
assert_eq!(percent_encode("100%"), "100%25");
assert_eq!(percent_encode("a?b#c"), "a%3Fb%23c");
}
#[test]
fn test_sanitized_error_body_preview_caps_and_escapes() {
let body = format!("line1\nline2\r\t{}", "x".repeat(ERROR_BODY_PREVIEW_LIMIT));
let preview = sanitized_error_body_preview(&body);
assert!(preview.contains("line1\\nline2\\r\\t"));
assert!(preview.ends_with("...[truncated]"));
assert!(preview.len() <= ERROR_BODY_PREVIEW_LIMIT + "...[truncated]".len());
}
#[test]
fn test_sanitized_error_body_preview_handles_empty_body() {
assert_eq!(sanitized_error_body_preview(""), "<empty>");
}
#[test]
fn test_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<ConfluentSchemaRegistry>();
assert_send_sync::<ConfluentSchemaRegistryBuilder>();
}
}