use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Schema {
String,
Int,
UInt,
Float,
Bool,
Null,
Array(Box<Schema>),
Object(ObjectSchema),
Union(Vec<Schema>),
Any,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ObjectSchema {
pub fields: Vec<Field>,
pub allow_additional_fields: bool,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Field {
pub name: String,
pub schema: Schema,
pub required: bool,
pub aliases: Vec<String>,
pub default: Option<serde_json::Value>,
pub description: Option<String>,
#[serde(default)]
pub stream_annotation: StreamAnnotation,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum StreamAnnotation {
#[default]
Normal,
NotNull,
Done,
}
impl Schema {
pub fn object(fields: Vec<(String, Schema, bool)>) -> Self {
Schema::Object(ObjectSchema {
fields: fields
.into_iter()
.map(|(name, schema, required)| Field {
name,
schema,
required,
aliases: Vec::new(),
default: None,
description: None,
stream_annotation: StreamAnnotation::Normal,
})
.collect(),
allow_additional_fields: false,
})
}
pub fn array(element_schema: Schema) -> Self {
Schema::Array(Box::new(element_schema))
}
pub fn union(variants: Vec<Schema>) -> Self {
Schema::Union(variants)
}
pub fn is_primitive(&self) -> bool {
matches!(
self,
Schema::String
| Schema::Int
| Schema::UInt
| Schema::Float
| Schema::Bool
| Schema::Null
)
}
pub fn is_nullable(&self) -> bool {
match self {
Schema::Null => true,
Schema::Union(variants) => variants.iter().any(|v| v.is_nullable()),
_ => false,
}
}
pub fn type_name(&self) -> &'static str {
match self {
Schema::String => "string",
Schema::Int => "int",
Schema::UInt => "uint",
Schema::Float => "float",
Schema::Bool => "bool",
Schema::Null => "null",
Schema::Array(_) => "array",
Schema::Object(_) => "object",
Schema::Union(_) => "union",
Schema::Any => "any",
}
}
}
impl Field {
pub fn required(name: impl Into<String>, schema: Schema) -> Self {
Field {
name: name.into(),
schema,
required: true,
aliases: Vec::new(),
default: None,
description: None,
stream_annotation: StreamAnnotation::Normal,
}
}
pub fn optional(name: impl Into<String>, schema: Schema) -> Self {
Field {
name: name.into(),
schema,
required: false,
aliases: Vec::new(),
default: None,
description: None,
stream_annotation: StreamAnnotation::Normal,
}
}
pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
self.aliases.push(alias.into());
self
}
pub fn with_default(mut self, default: serde_json::Value) -> Self {
self.default = Some(default);
self
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn with_stream_annotation(mut self, annotation: StreamAnnotation) -> Self {
self.stream_annotation = annotation;
self
}
}
impl ObjectSchema {
pub fn new(fields: Vec<Field>) -> Self {
ObjectSchema {
fields,
allow_additional_fields: false,
}
}
pub fn allow_additional(mut self) -> Self {
self.allow_additional_fields = true;
self
}
pub fn get_field(&self, name: &str) -> Option<&Field> {
self.fields.iter().find(|f| f.name == name)
}
pub fn all_field_names(&self) -> Vec<String> {
let mut names = Vec::new();
for field in &self.fields {
names.push(field.name.clone());
names.extend(field.aliases.clone());
}
names
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_schemas() {
assert!(Schema::String.is_primitive());
assert!(Schema::Int.is_primitive());
assert!(!Schema::array(Schema::String).is_primitive());
}
#[test]
fn test_object_schema_creation() {
let schema = Schema::object(vec![
("name".into(), Schema::String, true),
("age".into(), Schema::Int, false),
]);
if let Schema::Object(obj) = schema {
assert_eq!(obj.fields.len(), 2);
assert_eq!(obj.fields[0].name, "name");
assert!(obj.fields[0].required);
assert_eq!(obj.fields[1].name, "age");
assert!(!obj.fields[1].required);
} else {
panic!("Expected Object schema");
}
}
#[test]
fn test_field_builder() {
let field = Field::required("username", Schema::String)
.with_alias("user_name")
.with_description("The user's login name");
assert_eq!(field.name, "username");
assert!(field.required);
assert_eq!(field.aliases, vec!["user_name"]);
assert!(field.description.is_some());
}
#[test]
fn test_nullable_schema() {
assert!(Schema::Null.is_nullable());
assert!(!Schema::String.is_nullable());
assert!(Schema::union(vec![Schema::String, Schema::Null]).is_nullable());
}
#[test]
fn test_type_names() {
assert_eq!(Schema::String.type_name(), "string");
assert_eq!(Schema::Int.type_name(), "int");
assert_eq!(Schema::array(Schema::Bool).type_name(), "array");
assert_eq!(
Schema::object(vec![("x".into(), Schema::Float, true)]).type_name(),
"object"
);
}
}