use super::dictionary::StringDictionary;
use super::varint::{decode_varint, encode_varint};
use crate::error::{InterpretError, TauqError};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum SchemaType {
Bool = 1,
Int = 2,
UInt = 3,
F32 = 4,
F64 = 5,
String = 6,
Bytes = 7,
Option = 8,
Seq = 9,
Map = 10,
SchemaRef = 11,
}
impl SchemaType {
pub fn from_u8(v: u8) -> Option<Self> {
match v {
1 => Some(SchemaType::Bool),
2 => Some(SchemaType::Int),
3 => Some(SchemaType::UInt),
4 => Some(SchemaType::F32),
5 => Some(SchemaType::F64),
6 => Some(SchemaType::String),
7 => Some(SchemaType::Bytes),
8 => Some(SchemaType::Option),
9 => Some(SchemaType::Seq),
10 => Some(SchemaType::Map),
11 => Some(SchemaType::SchemaRef),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct SchemaField {
pub name: String,
pub typ: SchemaType,
pub inner_type: Option<SchemaType>,
pub schema_ref: Option<u32>,
}
impl SchemaField {
pub fn new(name: impl Into<String>, typ: SchemaType) -> Self {
Self {
name: name.into(),
typ,
inner_type: None,
schema_ref: None,
}
}
pub fn optional(name: impl Into<String>, inner: SchemaType) -> Self {
Self {
name: name.into(),
typ: SchemaType::Option,
inner_type: Some(inner),
schema_ref: None,
}
}
pub fn seq(name: impl Into<String>, inner: SchemaType) -> Self {
Self {
name: name.into(),
typ: SchemaType::Seq,
inner_type: Some(inner),
schema_ref: None,
}
}
pub fn schema_ref(name: impl Into<String>, schema_idx: u32) -> Self {
Self {
name: name.into(),
typ: SchemaType::SchemaRef,
inner_type: None,
schema_ref: Some(schema_idx),
}
}
}
#[derive(Debug, Clone)]
pub struct Schema {
pub name: String,
pub fields: Vec<SchemaField>,
}
impl Schema {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
fields: Vec::new(),
}
}
pub fn field(mut self, field: SchemaField) -> Self {
self.fields.push(field);
self
}
pub fn add_field(&mut self, name: impl Into<String>, typ: SchemaType) {
self.fields.push(SchemaField::new(name, typ));
}
pub fn encode(&self, buf: &mut Vec<u8>, dict: &mut StringDictionary) {
let name_idx = dict.intern(&self.name);
encode_varint(name_idx as u64, buf);
encode_varint(self.fields.len() as u64, buf);
for field in &self.fields {
let field_idx = dict.intern(&field.name);
encode_varint(field_idx as u64, buf);
buf.push(field.typ as u8);
if field.typ == SchemaType::Option || field.typ == SchemaType::Seq {
buf.push(field.inner_type.unwrap_or(SchemaType::Int) as u8);
}
if field.typ == SchemaType::SchemaRef {
encode_varint(field.schema_ref.unwrap_or(0) as u64, buf);
}
}
}
pub fn decode(
bytes: &[u8],
dict: &super::dictionary::BorrowedDictionary,
) -> Result<(Self, usize), TauqError> {
let mut pos = 0;
let (name_idx, len) = decode_varint(bytes)?;
pos += len;
let name = dict
.get(name_idx as u32)
.ok_or_else(|| TauqError::Interpret(InterpretError::new("Invalid schema name index")))?
.to_string();
let (field_count, len) = decode_varint(&bytes[pos..])?;
pos += len;
if field_count > 10_000 {
return Err(TauqError::Interpret(InterpretError::new(format!(
"Schema field count {} exceeds maximum 10000",
field_count
))));
}
let mut fields = Vec::with_capacity(field_count as usize);
for _ in 0..field_count {
let (field_idx, len) = decode_varint(&bytes[pos..])?;
pos += len;
let field_name = dict
.get(field_idx as u32)
.ok_or_else(|| {
TauqError::Interpret(InterpretError::new("Invalid field name index"))
})?
.to_string();
let typ = SchemaType::from_u8(bytes[pos])
.ok_or_else(|| TauqError::Interpret(InterpretError::new("Invalid schema type")))?;
pos += 1;
let mut field = SchemaField::new(field_name, typ);
if typ == SchemaType::Option || typ == SchemaType::Seq {
field.inner_type = SchemaType::from_u8(bytes[pos]);
pos += 1;
}
if typ == SchemaType::SchemaRef {
let (ref_idx, len) = decode_varint(&bytes[pos..])?;
pos += len;
field.schema_ref = Some(ref_idx as u32);
}
fields.push(field);
}
Ok((Schema { name, fields }, pos))
}
}
#[derive(Debug, Default)]
pub struct SchemaRegistry {
schemas: Vec<Schema>,
}
impl SchemaRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, schema: Schema) -> u32 {
let idx = self.schemas.len() as u32;
self.schemas.push(schema);
idx
}
pub fn get(&self, idx: u32) -> Option<&Schema> {
self.schemas.get(idx as usize)
}
pub fn len(&self) -> usize {
self.schemas.len()
}
pub fn is_empty(&self) -> bool {
self.schemas.is_empty()
}
pub fn encode(&self, buf: &mut Vec<u8>, dict: &mut StringDictionary) {
encode_varint(self.schemas.len() as u64, buf);
for schema in &self.schemas {
schema.encode(buf, dict);
}
}
pub fn decode(
bytes: &[u8],
dict: &super::dictionary::BorrowedDictionary,
) -> Result<(Self, usize), TauqError> {
let mut pos = 0;
let (count, len) = decode_varint(bytes)?;
pos += len;
let mut registry = Self::new();
for _ in 0..count {
let (schema, len) = Schema::decode(&bytes[pos..], dict)?;
pos += len;
registry.register(schema);
}
Ok((registry, pos))
}
}
pub fn infer_schema_from_json(value: &serde_json::Value, name: &str) -> Option<Schema> {
match value {
serde_json::Value::Array(arr) => {
if let Some(serde_json::Value::Object(first)) = arr.first() {
let first_keys: Vec<&String> = first.keys().collect();
let all_same = arr.iter().all(|item| {
if let serde_json::Value::Object(obj) = item {
let keys: Vec<&String> = obj.keys().collect();
keys == first_keys
} else {
false
}
});
if all_same {
let mut schema = Schema::new(name);
for (key, value) in first {
let typ = json_value_to_schema_type(value);
schema.add_field(key, typ);
}
return Some(schema);
}
}
None
}
serde_json::Value::Object(obj) => {
let mut schema = Schema::new(name);
for (key, value) in obj {
let typ = json_value_to_schema_type(value);
schema.add_field(key, typ);
}
Some(schema)
}
_ => None,
}
}
fn json_value_to_schema_type(value: &serde_json::Value) -> SchemaType {
match value {
serde_json::Value::Null => SchemaType::Option,
serde_json::Value::Bool(_) => SchemaType::Bool,
serde_json::Value::Number(n) => {
if n.is_i64() {
SchemaType::Int
} else if n.is_u64() {
SchemaType::UInt
} else {
SchemaType::F64
}
}
serde_json::Value::String(_) => SchemaType::String,
serde_json::Value::Array(_) => SchemaType::Seq,
serde_json::Value::Object(_) => SchemaType::Map,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_schema_creation() {
let schema = Schema::new("Employee")
.field(SchemaField::new("id", SchemaType::UInt))
.field(SchemaField::new("name", SchemaType::String))
.field(SchemaField::new("age", SchemaType::UInt))
.field(SchemaField::optional("email", SchemaType::String));
assert_eq!(schema.name, "Employee");
assert_eq!(schema.fields.len(), 4);
assert_eq!(schema.fields[0].name, "id");
assert_eq!(schema.fields[3].typ, SchemaType::Option);
}
#[test]
fn test_schema_roundtrip() {
let schema = Schema::new("User")
.field(SchemaField::new("id", SchemaType::UInt))
.field(SchemaField::new("name", SchemaType::String));
let mut dict = StringDictionary::new();
let mut buf = Vec::new();
schema.encode(&mut buf, &mut dict);
let mut dict_buf = Vec::new();
dict.encode(&mut dict_buf);
let (borrowed_dict, _) =
super::super::dictionary::BorrowedDictionary::decode(&dict_buf).unwrap();
let (decoded, _) = Schema::decode(&buf, &borrowed_dict).unwrap();
assert_eq!(decoded.name, "User");
assert_eq!(decoded.fields.len(), 2);
assert_eq!(decoded.fields[0].name, "id");
assert_eq!(decoded.fields[1].name, "name");
}
#[test]
fn test_infer_schema_from_json() {
let json = serde_json::json!([
{"id": 1, "name": "Alice", "active": true},
{"id": 2, "name": "Bob", "active": false},
]);
let schema = infer_schema_from_json(&json, "Users").unwrap();
assert_eq!(schema.name, "Users");
assert_eq!(schema.fields.len(), 3);
}
}