use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct SchemaVersion {
pub major: u32,
pub minor: u32,
pub patch: u32,
}
impl SchemaVersion {
pub fn new(major: u32, minor: u32, patch: u32) -> Self {
SchemaVersion {
major,
minor,
patch,
}
}
pub fn initial() -> Self {
SchemaVersion::new(1, 0, 0)
}
pub fn is_compatible_with(&self, other: &SchemaVersion) -> bool {
self.major == other.major
}
pub fn is_greater_than(&self, other: &SchemaVersion) -> bool {
if self.major != other.major {
return self.major > other.major;
}
if self.minor != other.minor {
return self.minor > other.minor;
}
self.patch > other.patch
}
pub fn next_major(&self) -> Self {
SchemaVersion::new(self.major + 1, 0, 0)
}
pub fn next_minor(&self) -> Self {
SchemaVersion::new(self.major, self.minor + 1, 0)
}
pub fn next_patch(&self) -> Self {
SchemaVersion::new(self.major, self.minor, self.patch + 1)
}
}
impl Default for SchemaVersion {
fn default() -> Self {
SchemaVersion::initial()
}
}
impl fmt::Display for SchemaVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}.{}.{}", self.major, self.minor, self.patch)
}
}
impl PartialOrd for SchemaVersion {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SchemaVersion {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match self.major.cmp(&other.major) {
std::cmp::Ordering::Equal => match self.minor.cmp(&other.minor) {
std::cmp::Ordering::Equal => self.patch.cmp(&other.patch),
ord => ord,
},
ord => ord,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum SchemaDataType {
Int64,
Float64,
Boolean,
String,
DateTime,
Categorical {
categories: Vec<std::string::String>,
},
List {
element_type: Box<SchemaDataType>,
},
}
impl SchemaDataType {
pub fn type_name(&self) -> &str {
match self {
SchemaDataType::Int64 => "Int64",
SchemaDataType::Float64 => "Float64",
SchemaDataType::Boolean => "Boolean",
SchemaDataType::String => "String",
SchemaDataType::DateTime => "DateTime",
SchemaDataType::Categorical { .. } => "Categorical",
SchemaDataType::List { .. } => "List",
}
}
pub fn is_numeric(&self) -> bool {
matches!(self, SchemaDataType::Int64 | SchemaDataType::Float64)
}
pub fn can_cast_to(&self, target: &SchemaDataType) -> bool {
match (self, target) {
(a, b) if a == b => true,
(SchemaDataType::Int64, SchemaDataType::Float64) => true,
(SchemaDataType::Int64, SchemaDataType::String) => true,
(SchemaDataType::Float64, SchemaDataType::String) => true,
(SchemaDataType::Boolean, SchemaDataType::String) => true,
(SchemaDataType::String, SchemaDataType::Int64) => true,
(SchemaDataType::String, SchemaDataType::Float64) => true,
(SchemaDataType::String, SchemaDataType::Boolean) => true,
(SchemaDataType::String, SchemaDataType::DateTime) => true,
(SchemaDataType::Categorical { .. }, SchemaDataType::String) => true,
(SchemaDataType::String, SchemaDataType::Categorical { .. }) => true,
_ => false,
}
}
}
impl fmt::Display for SchemaDataType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SchemaDataType::List { element_type } => write!(f, "List<{}>", element_type),
SchemaDataType::Categorical { categories } => {
write!(f, "Categorical({} categories)", categories.len())
}
other => write!(f, "{}", other.type_name()),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum DefaultValue {
Int(i64),
Float(f64),
Bool(bool),
Str(std::string::String),
Null,
CurrentTimestamp,
}
impl fmt::Display for DefaultValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DefaultValue::Int(v) => write!(f, "{}", v),
DefaultValue::Float(v) => write!(f, "{}", v),
DefaultValue::Bool(v) => write!(f, "{}", v),
DefaultValue::Str(v) => write!(f, "'{}'", v),
DefaultValue::Null => write!(f, "NULL"),
DefaultValue::CurrentTimestamp => write!(f, "CURRENT_TIMESTAMP"),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum SchemaConstraint {
NotNull(std::string::String),
Unique(Vec<std::string::String>),
Range {
col: std::string::String,
min: Option<f64>,
max: Option<f64>,
},
Regex {
col: std::string::String,
pattern: std::string::String,
},
ForeignKey {
col: std::string::String,
ref_schema: std::string::String,
ref_col: std::string::String,
},
Enum {
col: std::string::String,
values: Vec<std::string::String>,
},
}
impl SchemaConstraint {
pub fn constraint_type(&self) -> &str {
match self {
SchemaConstraint::NotNull(_) => "NotNull",
SchemaConstraint::Unique(_) => "Unique",
SchemaConstraint::Range { .. } => "Range",
SchemaConstraint::Regex { .. } => "Regex",
SchemaConstraint::ForeignKey { .. } => "ForeignKey",
SchemaConstraint::Enum { .. } => "Enum",
}
}
pub fn affected_columns(&self) -> Vec<&str> {
match self {
SchemaConstraint::NotNull(col) => vec![col.as_str()],
SchemaConstraint::Unique(cols) => cols.iter().map(|s| s.as_str()).collect(),
SchemaConstraint::Range { col, .. } => vec![col.as_str()],
SchemaConstraint::Regex { col, .. } => vec![col.as_str()],
SchemaConstraint::ForeignKey { col, .. } => vec![col.as_str()],
SchemaConstraint::Enum { col, .. } => vec![col.as_str()],
}
}
pub fn generate_id(&self) -> std::string::String {
match self {
SchemaConstraint::NotNull(col) => format!("notnull_{}", col),
SchemaConstraint::Unique(cols) => format!("unique_{}", cols.join("_")),
SchemaConstraint::Range { col, min, max } => {
format!("range_{}_{:?}_{:?}", col, min, max)
}
SchemaConstraint::Regex { col, pattern } => format!("regex_{}_{}", col, pattern),
SchemaConstraint::ForeignKey {
col,
ref_schema,
ref_col,
} => {
format!("fk_{}_{}_{}", col, ref_schema, ref_col)
}
SchemaConstraint::Enum { col, .. } => format!("enum_{}", col),
}
}
}
impl fmt::Display for SchemaConstraint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SchemaConstraint::NotNull(col) => write!(f, "NOT NULL({})", col),
SchemaConstraint::Unique(cols) => write!(f, "UNIQUE({})", cols.join(", ")),
SchemaConstraint::Range { col, min, max } => {
write!(f, "RANGE({}: {:?}..{:?})", col, min, max)
}
SchemaConstraint::Regex { col, pattern } => {
write!(f, "REGEX({}: {})", col, pattern)
}
SchemaConstraint::ForeignKey {
col,
ref_schema,
ref_col,
} => {
write!(f, "FK({} -> {}.{})", col, ref_schema, ref_col)
}
SchemaConstraint::Enum { col, values } => {
write!(f, "ENUM({}: {:?})", col, values)
}
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ColumnSchema {
pub name: std::string::String,
pub data_type: SchemaDataType,
pub nullable: bool,
pub default_value: Option<DefaultValue>,
pub description: Option<std::string::String>,
pub tags: Vec<std::string::String>,
}
impl ColumnSchema {
pub fn new(name: impl Into<std::string::String>, data_type: SchemaDataType) -> Self {
ColumnSchema {
name: name.into(),
data_type,
nullable: true,
default_value: None,
description: None,
tags: Vec::new(),
}
}
pub fn with_nullable(mut self, nullable: bool) -> Self {
self.nullable = nullable;
self
}
pub fn with_default(mut self, default: DefaultValue) -> Self {
self.default_value = Some(default);
self
}
pub fn with_description(mut self, description: impl Into<std::string::String>) -> Self {
self.description = Some(description.into());
self
}
pub fn with_tag(mut self, tag: impl Into<std::string::String>) -> Self {
self.tags.push(tag.into());
self
}
}
impl fmt::Display for ColumnSchema {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}: {}{}",
self.name,
self.data_type,
if self.nullable {
" (nullable)"
} else {
" (not null)"
}
)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DataFrameSchema {
pub name: std::string::String,
pub version: SchemaVersion,
pub columns: Vec<ColumnSchema>,
pub constraints: Vec<SchemaConstraint>,
pub metadata: HashMap<std::string::String, std::string::String>,
}
impl DataFrameSchema {
pub fn new(name: impl Into<std::string::String>, version: SchemaVersion) -> Self {
DataFrameSchema {
name: name.into(),
version,
columns: Vec::new(),
constraints: Vec::new(),
metadata: HashMap::new(),
}
}
pub fn with_column(mut self, column: ColumnSchema) -> Self {
self.columns.push(column);
self
}
pub fn with_constraint(mut self, constraint: SchemaConstraint) -> Self {
self.constraints.push(constraint);
self
}
pub fn with_metadata(
mut self,
key: impl Into<std::string::String>,
value: impl Into<std::string::String>,
) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn get_column(&self, name: &str) -> Option<&ColumnSchema> {
self.columns.iter().find(|c| c.name == name)
}
pub fn column_position(&self, name: &str) -> Option<usize> {
self.columns.iter().position(|c| c.name == name)
}
pub fn has_column(&self, name: &str) -> bool {
self.columns.iter().any(|c| c.name == name)
}
pub fn column_names(&self) -> Vec<&str> {
self.columns.iter().map(|c| c.name.as_str()).collect()
}
pub fn constraints_for_column(&self, col_name: &str) -> Vec<&SchemaConstraint> {
self.constraints
.iter()
.filter(|c| c.affected_columns().contains(&col_name))
.collect()
}
}
impl fmt::Display for DataFrameSchema {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Schema: {} v{}", self.name, self.version)?;
writeln!(f, "Columns ({}):", self.columns.len())?;
for col in &self.columns {
writeln!(f, " - {}", col)?;
}
if !self.constraints.is_empty() {
writeln!(f, "Constraints ({}):", self.constraints.len())?;
for constraint in &self.constraints {
writeln!(f, " - {}", constraint)?;
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_schema_version_ordering() {
let v1 = SchemaVersion::new(1, 0, 0);
let v2 = SchemaVersion::new(1, 1, 0);
let v3 = SchemaVersion::new(2, 0, 0);
assert!(v1 < v2);
assert!(v2 < v3);
assert!(v1 < v3);
}
#[test]
fn test_schema_version_compatibility() {
let v1 = SchemaVersion::new(1, 0, 0);
let v2 = SchemaVersion::new(1, 5, 3);
let v3 = SchemaVersion::new(2, 0, 0);
assert!(v1.is_compatible_with(&v2));
assert!(!v1.is_compatible_with(&v3));
}
#[test]
fn test_schema_version_display() {
let v = SchemaVersion::new(1, 2, 3);
assert_eq!(v.to_string(), "1.2.3");
}
#[test]
fn test_column_schema_builder() {
let col = ColumnSchema::new("age", SchemaDataType::Int64)
.with_nullable(false)
.with_default(DefaultValue::Int(0))
.with_description("User age in years")
.with_tag("pii");
assert_eq!(col.name, "age");
assert!(!col.nullable);
assert!(col.default_value.is_some());
assert_eq!(col.tags, vec!["pii"]);
}
#[test]
fn test_dataframe_schema_builder() {
let schema = DataFrameSchema::new("users", SchemaVersion::initial())
.with_column(ColumnSchema::new("id", SchemaDataType::Int64).with_nullable(false))
.with_column(ColumnSchema::new("name", SchemaDataType::String))
.with_constraint(SchemaConstraint::NotNull("id".to_string()));
assert_eq!(schema.columns.len(), 2);
assert_eq!(schema.constraints.len(), 1);
assert!(schema.has_column("id"));
assert!(schema.has_column("name"));
assert!(!schema.has_column("email"));
}
#[test]
fn test_type_casting() {
assert!(SchemaDataType::Int64.can_cast_to(&SchemaDataType::Float64));
assert!(SchemaDataType::Int64.can_cast_to(&SchemaDataType::String));
assert!(!SchemaDataType::Boolean.can_cast_to(&SchemaDataType::Int64));
}
}