pub mod model;
use reqwest::Client;
use crate::error::SchemaSerdeError;
use crate::subject::SchemaKind;
use model::{
RegisterResponse, SchemaByIdResponse, SchemaPayload, SchemaReference, SubjectVersionResponse,
};
const CONTENT_TYPE: &str = "application/vnd.schemaregistry.v1+json";
#[derive(Debug, Clone)]
pub struct RegistryClient {
base_url: String,
http: Client,
}
impl RegistryClient {
#[must_use]
pub fn new(base_url: impl Into<String>) -> Self {
Self {
base_url: base_url.into().trim_end_matches('/').to_string(),
http: Client::new(),
}
}
pub async fn register(
&self,
subject: &str,
kind: SchemaKind,
schema: &str,
) -> Result<u32, SchemaSerdeError> {
let url = format!("{}/subjects/{subject}/versions", self.base_url);
let body = SchemaPayload {
schema,
schema_type: kind.wire_name(),
references: &[] as &[SchemaReference],
};
let resp: RegisterResponse = self.post_json(&url, &body).await?;
Ok(resp.id)
}
pub async fn lookup(
&self,
subject: &str,
kind: SchemaKind,
schema: &str,
) -> Result<u32, SchemaSerdeError> {
let url = format!("{}/subjects/{subject}", self.base_url);
let body = SchemaPayload {
schema,
schema_type: kind.wire_name(),
references: &[] as &[SchemaReference],
};
let resp: SubjectVersionResponse = self.post_json(&url, &body).await?;
Ok(resp.id)
}
pub async fn latest_id(&self, subject: &str) -> Result<u32, SchemaSerdeError> {
let url = format!("{}/subjects/{subject}/versions/latest", self.base_url);
let resp: SubjectVersionResponse = self.get_json(&url).await?;
Ok(resp.id)
}
pub async fn schema_by_id(&self, id: u32) -> Result<String, SchemaSerdeError> {
let url = format!("{}/schemas/ids/{id}", self.base_url);
let resp: SchemaByIdResponse = self.get_json(&url).await?;
Ok(resp.schema)
}
async fn post_json<B: serde::Serialize, R: serde::de::DeserializeOwned>(
&self,
url: &str,
body: &B,
) -> Result<R, SchemaSerdeError> {
let resp = self
.http
.post(url)
.header("Content-Type", CONTENT_TYPE)
.json(body)
.send()
.await
.map_err(|e| SchemaSerdeError::Registry(e.to_string()))?;
Self::parse(resp).await
}
async fn get_json<R: serde::de::DeserializeOwned>(
&self,
url: &str,
) -> Result<R, SchemaSerdeError> {
let resp = self
.http
.get(url)
.header("Accept", CONTENT_TYPE)
.send()
.await
.map_err(|e| SchemaSerdeError::Registry(e.to_string()))?;
Self::parse(resp).await
}
async fn parse<R: serde::de::DeserializeOwned>(
resp: reqwest::Response,
) -> Result<R, SchemaSerdeError> {
let status = resp.status();
let text = resp
.text()
.await
.map_err(|e| SchemaSerdeError::Registry(e.to_string()))?;
if !status.is_success() {
return Err(SchemaSerdeError::RegistryStatus {
status: status.as_u16(),
body: text,
});
}
serde_json::from_str(&text).map_err(|e| SchemaSerdeError::Registry(e.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::model::SchemaPayload;
use super::*;
use assert2::check;
#[test]
fn base_url_trims_trailing_slash() {
let c = RegistryClient::new("http://localhost:8081/");
check!(c.base_url == "http://localhost:8081");
}
#[test]
fn payload_omits_avro_type_and_empty_refs() {
let p = SchemaPayload {
schema: "\"string\"",
schema_type: SchemaKind::Avro.wire_name(),
references: &[],
};
let j = serde_json::to_string(&p).unwrap();
check!(j == r#"{"schema":"\"string\""}"#);
}
#[test]
fn payload_includes_protobuf_type() {
let p = SchemaPayload {
schema: "syntax = \"proto3\";",
schema_type: SchemaKind::Protobuf.wire_name(),
references: &[],
};
let j = serde_json::to_string(&p).unwrap();
check!(j.contains(r#""schemaType":"PROTOBUF""#));
}
}