use super::SchemaId;
use super::enum_support::{EnumInfo, EnumVariantInfo};
use super::field_types::{FieldDef, FieldType, semantic_to_field_type};
use arrow_schema::{DataType, Schema as ArrowSchema};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
#[inline]
fn allocate_current_id() -> SchemaId {
super::current_registry().allocate_id()
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TypeSchema {
pub id: SchemaId,
pub name: String,
pub fields: Vec<FieldDef>,
pub(crate) field_map: HashMap<String, usize>,
pub data_size: usize,
pub component_types: Option<Vec<String>>,
pub(crate) field_sources: HashMap<String, String>,
pub enum_info: Option<EnumInfo>,
#[serde(skip)]
pub content_hash: Option<[u8; 32]>,
}
impl TypeSchema {
pub fn field_kind(&self, idx: usize) -> Option<shape_value::NativeKind> {
self.fields
.get(idx)
.and_then(|f| f.field_type.to_native_kind().ok())
}
pub fn new(name: impl Into<String>, field_defs: Vec<(String, FieldType)>) -> Self {
Self::with_id(allocate_current_id(), name, field_defs)
}
pub fn with_id(
id: SchemaId,
name: impl Into<String>,
field_defs: Vec<(String, FieldType)>,
) -> Self {
let name = name.into();
let mut fields = Vec::with_capacity(field_defs.len());
let mut field_map = HashMap::with_capacity(field_defs.len());
let mut offset = 0;
for (index, (field_name, field_type)) in field_defs.into_iter().enumerate() {
let alignment = field_type.alignment();
offset = (offset + alignment - 1) & !(alignment - 1);
let field = FieldDef::new(&field_name, field_type.clone(), offset, index as u16);
field_map.insert(field_name, index);
offset += field_type.size();
fields.push(field);
}
let data_size = (offset + 7) & !7;
Self {
id,
name,
fields,
field_map,
data_size,
component_types: None,
field_sources: HashMap::new(),
enum_info: None,
content_hash: None,
}
}
pub fn get_field(&self, name: &str) -> Option<&FieldDef> {
self.field_map.get(name).map(|&idx| &self.fields[idx])
}
pub fn field_offset(&self, name: &str) -> Option<usize> {
self.get_field(name).map(|f| f.offset)
}
pub fn field_index(&self, name: &str) -> Option<u16> {
self.get_field(name).map(|f| f.index)
}
pub fn field_by_index(&self, index: u16) -> Option<&FieldDef> {
self.fields.get(index as usize)
}
pub fn field_count(&self) -> usize {
self.fields.len()
}
pub fn has_field(&self, name: &str) -> bool {
self.field_map.contains_key(name)
}
pub fn field_names(&self) -> impl Iterator<Item = &str> {
self.fields.iter().map(|f| f.name.as_str())
}
pub fn is_enum(&self) -> bool {
self.enum_info.is_some()
}
pub fn get_enum_info(&self) -> Option<&EnumInfo> {
self.enum_info.as_ref()
}
pub fn variant_id(&self, variant_name: &str) -> Option<u16> {
self.enum_info.as_ref()?.variant_id(variant_name)
}
pub fn new_enum(name: impl Into<String>, variants: Vec<EnumVariantInfo>) -> Self {
Self::new_enum_with_id(allocate_current_id(), name, variants)
}
pub fn new_enum_with_id(
id: SchemaId,
name: impl Into<String>,
variants: Vec<EnumVariantInfo>,
) -> Self {
let name = name.into();
let enum_info = EnumInfo::new(variants);
let max_payload = enum_info.max_payload_fields();
let mut fields = Vec::with_capacity(1 + max_payload as usize);
let mut field_map = HashMap::with_capacity(1 + max_payload as usize);
fields.push(FieldDef::new("__variant", FieldType::I64, 0, 0));
field_map.insert("__variant".to_string(), 0);
for i in 0..max_payload {
let field_name = format!("__payload_{}", i);
let offset = 8 + (i as usize * 8);
fields.push(FieldDef::new(&field_name, FieldType::Any, offset, i + 1));
field_map.insert(field_name, i as usize + 1);
}
let data_size = 8 + (max_payload as usize * 8);
Self {
id,
name,
fields,
field_map,
data_size,
component_types: None,
field_sources: HashMap::new(),
enum_info: Some(enum_info),
content_hash: None,
}
}
pub fn compute_content_hash(&self) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(b"name:");
hasher.update(self.name.as_bytes());
let mut sorted_fields: Vec<&FieldDef> = self.fields.iter().collect();
sorted_fields.sort_by(|a, b| a.name.cmp(&b.name));
hasher.update(b"|fields:");
for field in &sorted_fields {
hasher.update(b"(");
hasher.update(field.name.as_bytes());
hasher.update(b":");
hasher.update(field.field_type.to_string().as_bytes());
hasher.update(b")");
}
if let Some(enum_info) = &self.enum_info {
let mut sorted_variants: Vec<&super::enum_support::EnumVariantInfo> =
enum_info.variants.iter().collect();
sorted_variants.sort_by(|a, b| a.name.cmp(&b.name));
hasher.update(b"|variants:");
for variant in &sorted_variants {
hasher.update(b"(");
hasher.update(variant.name.as_bytes());
hasher.update(b":");
hasher.update(variant.payload_fields.to_string().as_bytes());
hasher.update(b")");
}
}
let result = hasher.finalize();
let mut hash = [0u8; 32];
hash.copy_from_slice(&result);
hash
}
pub fn content_hash(&mut self) -> [u8; 32] {
if let Some(hash) = self.content_hash {
return hash;
}
let hash = self.compute_content_hash();
self.content_hash = Some(hash);
hash
}
pub fn bind_to_arrow_schema(
&self,
arrow_schema: &ArrowSchema,
) -> Result<TypeBinding, TypeBindingError> {
let mut field_to_column = Vec::with_capacity(self.fields.len());
for field in &self.fields {
if field.name.starts_with("__") {
field_to_column.push(0); continue;
}
let col_name = field.wire_name();
let col_idx =
arrow_schema
.index_of(col_name)
.map_err(|_| TypeBindingError::MissingColumn {
field_name: col_name.to_string(),
type_name: self.name.clone(),
})?;
let arrow_field = &arrow_schema.fields()[col_idx];
if !is_compatible(&field.field_type, arrow_field.data_type()) {
return Err(TypeBindingError::TypeMismatch {
field_name: field.name.clone(),
expected: format!("{:?}", field.field_type),
actual: format!("{:?}", arrow_field.data_type()),
});
}
field_to_column.push(col_idx);
}
Ok(TypeBinding {
schema_name: self.name.clone(),
field_to_column,
})
}
pub fn from_canonical(canonical: &crate::type_system::environment::CanonicalType) -> Self {
let id = allocate_current_id();
let name = canonical.name.clone();
let mut fields = Vec::with_capacity(canonical.fields.len());
let mut field_map = HashMap::with_capacity(canonical.fields.len());
for (index, cf) in canonical.fields.iter().enumerate() {
let field_type = semantic_to_field_type(&cf.field_type, cf.optional);
let field = FieldDef::new(&cf.name, field_type, cf.offset, index as u16);
field_map.insert(cf.name.clone(), index);
fields.push(field);
}
Self {
id,
name,
fields,
field_map,
data_size: canonical.data_size,
component_types: None,
field_sources: HashMap::new(),
enum_info: None,
content_hash: None,
}
}
}
#[derive(Debug, Clone)]
pub struct TypeBinding {
pub schema_name: String,
pub field_to_column: Vec<usize>,
}
impl TypeBinding {
pub fn column_index(&self, field_index: usize) -> Option<usize> {
self.field_to_column.get(field_index).copied()
}
}
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum TypeBindingError {
#[error("Type '{type_name}' requires column '{field_name}' which is not in the DataTable")]
MissingColumn {
field_name: String,
type_name: String,
},
#[error("Column '{field_name}' has type {actual} but expected {expected}")]
TypeMismatch {
field_name: String,
expected: String,
actual: String,
},
}
fn is_compatible(field_type: &FieldType, arrow_type: &DataType) -> bool {
match (field_type, arrow_type) {
(FieldType::F64, DataType::Float64) => true,
(FieldType::F64, DataType::Float32) => true, (FieldType::F64, DataType::Int64) => true, (FieldType::I64, DataType::Int64) => true,
(FieldType::I64, DataType::Int32) => true, (FieldType::Bool, DataType::Boolean) => true,
(FieldType::String, DataType::Utf8) => true,
(FieldType::String, DataType::LargeUtf8) => true,
(FieldType::Timestamp, DataType::Timestamp(_, _)) => true,
(FieldType::Timestamp, DataType::Int64) => true, (FieldType::Decimal, DataType::Float64) => true, (FieldType::Decimal, DataType::Int64) => true, (FieldType::Any, _) => true, _ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_type_schema_creation() {
let schema = TypeSchema::new(
"TestType",
vec![
("a".to_string(), FieldType::F64),
("b".to_string(), FieldType::I64),
("c".to_string(), FieldType::String),
],
);
assert_eq!(schema.name, "TestType");
assert_eq!(schema.field_count(), 3);
assert_eq!(schema.data_size, 24); }
#[test]
fn test_field_offsets() {
let schema = TypeSchema::new(
"OffsetTest",
vec![
("first".to_string(), FieldType::F64),
("second".to_string(), FieldType::I64),
("third".to_string(), FieldType::Bool),
],
);
assert_eq!(schema.field_offset("first"), Some(0));
assert_eq!(schema.field_offset("second"), Some(8));
assert_eq!(schema.field_offset("third"), Some(16));
assert_eq!(schema.field_offset("nonexistent"), None);
}
#[test]
fn test_field_index() {
let schema = TypeSchema::new(
"IndexTest",
vec![
("a".to_string(), FieldType::F64),
("b".to_string(), FieldType::F64),
("c".to_string(), FieldType::F64),
],
);
assert_eq!(schema.field_index("a"), Some(0));
assert_eq!(schema.field_index("b"), Some(1));
assert_eq!(schema.field_index("c"), Some(2));
}
#[test]
fn test_unique_schema_ids() {
let schema1 = TypeSchema::new("Type1", vec![]);
let schema2 = TypeSchema::new("Type2", vec![]);
let schema3 = TypeSchema::new("Type3", vec![]);
assert_ne!(schema1.id, schema2.id);
assert_ne!(schema2.id, schema3.id);
assert_ne!(schema1.id, schema3.id);
}
#[test]
fn test_enum_schema_creation() {
let schema = TypeSchema::new_enum(
"Option",
vec![
EnumVariantInfo::new("Some", 0, 1),
EnumVariantInfo::new("None", 1, 0),
],
);
assert_eq!(schema.name, "Option");
assert!(schema.is_enum());
let enum_info = schema.get_enum_info().unwrap();
assert_eq!(enum_info.variants.len(), 2);
assert_eq!(enum_info.variant_id("Some"), Some(0));
assert_eq!(enum_info.variant_id("None"), Some(1));
assert_eq!(enum_info.max_payload_fields(), 1);
}
#[test]
fn test_enum_schema_layout() {
let schema = TypeSchema::new_enum(
"Result",
vec![
EnumVariantInfo::new("Ok", 0, 1),
EnumVariantInfo::new("Err", 1, 1),
],
);
assert_eq!(schema.data_size, 16);
assert_eq!(schema.field_count(), 2);
assert_eq!(schema.field_offset("__variant"), Some(0));
assert_eq!(schema.field_offset("__payload_0"), Some(8));
}
#[test]
fn test_enum_schema_multiple_payloads() {
let schema = TypeSchema::new_enum(
"Shape",
vec![
EnumVariantInfo::new("Circle", 0, 1), EnumVariantInfo::new("Rectangle", 1, 2), EnumVariantInfo::new("Point", 2, 0), ],
);
assert_eq!(schema.data_size, 24);
assert_eq!(schema.field_count(), 3);
assert_eq!(schema.field_offset("__variant"), Some(0));
assert_eq!(schema.field_offset("__payload_0"), Some(8));
assert_eq!(schema.field_offset("__payload_1"), Some(16));
}
#[test]
fn test_enum_variant_lookup() {
let schema = TypeSchema::new_enum(
"Status",
vec![
EnumVariantInfo::new("Pending", 0, 0),
EnumVariantInfo::new("Running", 1, 1),
EnumVariantInfo::new("Complete", 2, 1),
EnumVariantInfo::new("Failed", 3, 1),
],
);
let enum_info = schema.get_enum_info().unwrap();
let running = enum_info.variant_by_id(1).unwrap();
assert_eq!(running.name, "Running");
assert_eq!(running.payload_fields, 1);
let complete = enum_info.variant_by_name("Complete").unwrap();
assert_eq!(complete.id, 2);
assert!(enum_info.variant_by_id(99).is_none());
assert!(enum_info.variant_by_name("Unknown").is_none());
}
#[test]
fn test_bind_to_arrow_schema_success() {
use arrow_schema::{Field, Schema as ArrowSchema};
let type_schema = TypeSchema::new(
"Candle",
vec![
("open".to_string(), FieldType::F64),
("close".to_string(), FieldType::F64),
("volume".to_string(), FieldType::I64),
],
);
let arrow_schema = ArrowSchema::new(vec![
Field::new("date", DataType::Utf8, false),
Field::new("open", DataType::Float64, false),
Field::new("close", DataType::Float64, false),
Field::new("volume", DataType::Int64, false),
]);
let binding = type_schema.bind_to_arrow_schema(&arrow_schema).unwrap();
assert_eq!(binding.schema_name, "Candle");
assert_eq!(binding.column_index(0), Some(1));
assert_eq!(binding.column_index(1), Some(2));
assert_eq!(binding.column_index(2), Some(3));
}
#[test]
fn test_bind_missing_column() {
use arrow_schema::{Field, Schema as ArrowSchema};
let type_schema = TypeSchema::new(
"Candle",
vec![
("open".to_string(), FieldType::F64),
("missing_field".to_string(), FieldType::F64),
],
);
let arrow_schema = ArrowSchema::new(vec![Field::new("open", DataType::Float64, false)]);
let err = type_schema.bind_to_arrow_schema(&arrow_schema).unwrap_err();
assert!(matches!(err, TypeBindingError::MissingColumn { .. }));
}
#[test]
fn test_bind_type_mismatch() {
use arrow_schema::{Field, Schema as ArrowSchema};
let type_schema = TypeSchema::new("Test", vec![("name".to_string(), FieldType::F64)]);
let arrow_schema = ArrowSchema::new(vec![
Field::new("name", DataType::Utf8, false), ]);
let err = type_schema.bind_to_arrow_schema(&arrow_schema).unwrap_err();
assert!(matches!(err, TypeBindingError::TypeMismatch { .. }));
}
#[test]
fn test_bind_compatible_types() {
use arrow_schema::{Field, Schema as ArrowSchema, TimeUnit};
let type_schema = TypeSchema::new(
"Wide",
vec![
("f32_as_f64".to_string(), FieldType::F64),
("i32_as_i64".to_string(), FieldType::I64),
("ts".to_string(), FieldType::Timestamp),
("any_field".to_string(), FieldType::Any),
],
);
let arrow_schema = ArrowSchema::new(vec![
Field::new("f32_as_f64", DataType::Float32, false),
Field::new("i32_as_i64", DataType::Int32, false),
Field::new(
"ts",
DataType::Timestamp(TimeUnit::Microsecond, None),
false,
),
Field::new("any_field", DataType::Boolean, false),
]);
let binding = type_schema.bind_to_arrow_schema(&arrow_schema);
assert!(binding.is_ok());
}
}