use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum DataType {
Float64,
Float32,
Int64,
Bool,
Utf8,
Bytes,
Json,
}
impl fmt::Display for DataType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Float64 => write!(f, "f64"),
Self::Float32 => write!(f, "f32"),
Self::Int64 => write!(f, "i64"),
Self::Bool => write!(f, "bool"),
Self::Utf8 => write!(f, "str"),
Self::Bytes => write!(f, "bytes"),
Self::Json => write!(f, "json"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Schema {
pub dtype: DataType,
pub shape: Option<Vec<Dimension>>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Dimension {
Fixed(usize),
Dynamic(String),
}
impl fmt::Display for Dimension {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Fixed(n) => write!(f, "{n}"),
Self::Dynamic(name) => write!(f, "{name}"),
}
}
}
impl Schema {
pub fn vector(dtype: DataType, len: usize) -> Self {
Self {
dtype,
shape: Some(vec![Dimension::Fixed(len)]),
}
}
pub fn matrix(dtype: DataType, rows: usize, cols: usize) -> Self {
Self {
dtype,
shape: Some(vec![Dimension::Fixed(rows), Dimension::Fixed(cols)]),
}
}
pub fn batched(dtype: DataType, feature_dims: &[usize]) -> Self {
let mut dims = vec![Dimension::Dynamic("batch".into())];
dims.extend(feature_dims.iter().map(|&d| Dimension::Fixed(d)));
Self {
dtype,
shape: Some(dims),
}
}
pub fn scalar(dtype: DataType) -> Self {
Self {
dtype,
shape: Some(vec![]),
}
}
pub fn json() -> Self {
Self {
dtype: DataType::Json,
shape: None,
}
}
pub fn bytes() -> Self {
Self {
dtype: DataType::Bytes,
shape: None,
}
}
pub fn dynamic(dtype: DataType) -> Self {
Self { dtype, shape: None }
}
pub fn is_compatible_with(&self, other: &Schema) -> bool {
if self.dtype != other.dtype {
return false;
}
match (&self.shape, &other.shape) {
(None, _) | (_, None) => true, (Some(a), Some(b)) => {
if a.len() != b.len() {
return false;
}
a.iter().zip(b.iter()).all(|(da, db)| match (da, db) {
(Dimension::Fixed(x), Dimension::Fixed(y)) => x == y,
_ => true, })
}
}
}
pub fn rank(&self) -> Option<usize> {
self.shape.as_ref().map(|s| s.len())
}
}
impl fmt::Display for Schema {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.dtype)?;
if let Some(shape) = &self.shape {
if shape.is_empty() {
write!(f, " (scalar)")?;
} else {
let dims: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
write!(f, "[{}]", dims.join(", "))?;
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn schema_display() {
assert_eq!(
Schema::scalar(DataType::Float64).to_string(),
"f64 (scalar)"
);
assert_eq!(
Schema::vector(DataType::Float64, 128).to_string(),
"f64[128]"
);
assert_eq!(
Schema::matrix(DataType::Float64, 100, 50).to_string(),
"f64[100, 50]"
);
assert_eq!(
Schema::batched(DataType::Float32, &[128]).to_string(),
"f32[batch, 128]"
);
assert_eq!(Schema::json().to_string(), "json");
}
#[test]
fn compatible_same_schema() {
let s = Schema::vector(DataType::Float64, 128);
assert!(s.is_compatible_with(&s));
}
#[test]
fn compatible_dynamic_with_fixed() {
let dynamic = Schema::batched(DataType::Float64, &[128]);
let fixed = Schema::matrix(DataType::Float64, 32, 128);
assert!(dynamic.is_compatible_with(&fixed));
assert!(fixed.is_compatible_with(&dynamic));
}
#[test]
fn compatible_unknown_shape() {
let unknown = Schema::dynamic(DataType::Float64);
let known = Schema::vector(DataType::Float64, 128);
assert!(unknown.is_compatible_with(&known));
assert!(known.is_compatible_with(&unknown));
}
#[test]
fn incompatible_different_dtype() {
let f64_schema = Schema::vector(DataType::Float64, 128);
let i64_schema = Schema::vector(DataType::Int64, 128);
assert!(!f64_schema.is_compatible_with(&i64_schema));
}
#[test]
fn incompatible_different_fixed_dims() {
let a = Schema::vector(DataType::Float64, 128);
let b = Schema::vector(DataType::Float64, 256);
assert!(!a.is_compatible_with(&b));
}
#[test]
fn incompatible_different_rank() {
let vec = Schema::vector(DataType::Float64, 128);
let mat = Schema::matrix(DataType::Float64, 128, 64);
assert!(!vec.is_compatible_with(&mat));
}
#[test]
fn json_compatible_with_json() {
assert!(Schema::json().is_compatible_with(&Schema::json()));
}
#[test]
fn json_incompatible_with_tensor() {
assert!(!Schema::json().is_compatible_with(&Schema::vector(DataType::Float64, 10)));
}
#[test]
fn serde_roundtrip() {
let schemas = vec![
Schema::scalar(DataType::Float64),
Schema::vector(DataType::Float32, 100),
Schema::batched(DataType::Float64, &[128, 64]),
Schema::json(),
Schema::dynamic(DataType::Int64),
];
for s in schemas {
let json = serde_json::to_string(&s).unwrap();
let deserialized: Schema = serde_json::from_str(&json).unwrap();
assert_eq!(s, deserialized);
}
}
#[test]
fn rank() {
assert_eq!(Schema::scalar(DataType::Float64).rank(), Some(0));
assert_eq!(Schema::vector(DataType::Float64, 10).rank(), Some(1));
assert_eq!(Schema::matrix(DataType::Float64, 10, 5).rank(), Some(2));
assert_eq!(Schema::json().rank(), None);
}
}