use std::collections::HashMap;
use arrow::datatypes::{DataType, Field, Schema};
use indexmap::IndexMap;
use serde::{Deserialize, Serialize};
use crate::core::MurrError;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "lowercase")]
pub enum DType {
Utf8,
Float32,
Float64,
}
impl DType {
pub fn size(&self) -> usize {
match self {
DType::Utf8 => 4,
DType::Float32 => 4,
DType::Float64 => 8,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(deny_unknown_fields)]
pub struct ColumnSchema {
pub dtype: DType,
#[serde(default = "ColumnSchema::default_nullable")]
pub nullable: bool,
}
impl ColumnSchema {
pub fn default_nullable() -> bool {
true
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct TableSchema {
pub key: String,
pub columns: IndexMap<String, ColumnSchema>,
}
impl From<&DType> for DataType {
fn from(dtype: &DType) -> Self {
match dtype {
DType::Utf8 => DataType::Utf8,
DType::Float32 => DataType::Float32,
DType::Float64 => DataType::Float64,
}
}
}
impl TryFrom<&DataType> for DType {
type Error = MurrError;
fn try_from(dt: &DataType) -> Result<Self, Self::Error> {
match dt {
DataType::Float32 => Ok(DType::Float32),
DataType::Float64 => Ok(DType::Float64),
DataType::Utf8 => Ok(DType::Utf8),
other => Err(MurrError::SegmentError(format!(
"unsupported dtype {other:?}"
))),
}
}
}
impl From<&TableSchema> for Schema {
fn from(schema: &TableSchema) -> Self {
let fields: Vec<Field> = schema
.columns
.iter()
.map(|(name, config)| Field::new(name, DataType::from(&config.dtype), config.nullable))
.collect();
let metadata = HashMap::from([("key".to_string(), schema.key.clone())]);
Schema::new_with_metadata(fields, metadata)
}
}