use arrow::datatypes::{FieldRef, SchemaBuilder, SchemaRef};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct TableSchema {
file_schema: SchemaRef,
table_partition_cols: Arc<Vec<FieldRef>>,
table_schema: SchemaRef,
}
impl TableSchema {
pub fn new(file_schema: SchemaRef, table_partition_cols: Vec<FieldRef>) -> Self {
let mut builder = SchemaBuilder::from(file_schema.as_ref());
builder.extend(table_partition_cols.iter().cloned());
Self {
file_schema,
table_partition_cols: Arc::new(table_partition_cols),
table_schema: Arc::new(builder.finish()),
}
}
pub fn from_file_schema(file_schema: SchemaRef) -> Self {
Self::new(file_schema, vec![])
}
pub fn with_table_partition_cols(mut self, partition_cols: Vec<FieldRef>) -> Self {
if self.table_partition_cols.is_empty() {
self.table_partition_cols = Arc::new(partition_cols);
} else {
let table_partition_cols = Arc::get_mut(&mut self.table_partition_cols).expect(
"Expected to be the sole owner of table_partition_cols since this function accepts mut self",
);
table_partition_cols.extend(partition_cols);
}
let mut builder = SchemaBuilder::from(self.file_schema.as_ref());
builder.extend(self.table_partition_cols.iter().cloned());
self.table_schema = Arc::new(builder.finish());
self
}
pub fn file_schema(&self) -> &SchemaRef {
&self.file_schema
}
pub fn table_partition_cols(&self) -> &Vec<FieldRef> {
&self.table_partition_cols
}
pub fn table_schema(&self) -> &SchemaRef {
&self.table_schema
}
}
impl From<SchemaRef> for TableSchema {
fn from(schema: SchemaRef) -> Self {
Self::from_file_schema(schema)
}
}
#[cfg(test)]
mod tests {
use super::TableSchema;
use arrow::datatypes::{DataType, Field, Schema};
use std::sync::Arc;
#[test]
fn test_table_schema_creation() {
let file_schema = Arc::new(Schema::new(vec![
Field::new("user_id", DataType::Int64, false),
Field::new("amount", DataType::Float64, false),
]));
let partition_cols = vec![
Arc::new(Field::new("date", DataType::Utf8, false)),
Arc::new(Field::new("region", DataType::Utf8, false)),
];
let table_schema = TableSchema::new(file_schema.clone(), partition_cols.clone());
assert_eq!(table_schema.file_schema().as_ref(), file_schema.as_ref());
assert_eq!(table_schema.table_partition_cols().len(), 2);
assert_eq!(table_schema.table_partition_cols()[0], partition_cols[0]);
assert_eq!(table_schema.table_partition_cols()[1], partition_cols[1]);
let expected_fields = vec![
Field::new("user_id", DataType::Int64, false),
Field::new("amount", DataType::Float64, false),
Field::new("date", DataType::Utf8, false),
Field::new("region", DataType::Utf8, false),
];
let expected_schema = Schema::new(expected_fields);
assert_eq!(table_schema.table_schema().as_ref(), &expected_schema);
}
#[test]
fn test_add_multiple_partition_columns() {
let file_schema =
Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
let initial_partition_cols =
vec![Arc::new(Field::new("country", DataType::Utf8, false))];
let table_schema = TableSchema::new(file_schema.clone(), initial_partition_cols);
let additional_partition_cols = vec![
Arc::new(Field::new("city", DataType::Utf8, false)),
Arc::new(Field::new("year", DataType::Int32, false)),
];
let updated_table_schema =
table_schema.with_table_partition_cols(additional_partition_cols);
assert_eq!(
updated_table_schema.file_schema().as_ref(),
file_schema.as_ref()
);
assert_eq!(updated_table_schema.table_partition_cols().len(), 3);
assert_eq!(
updated_table_schema.table_partition_cols()[0].name(),
"country"
);
assert_eq!(
updated_table_schema.table_partition_cols()[1].name(),
"city"
);
assert_eq!(
updated_table_schema.table_partition_cols()[2].name(),
"year"
);
let expected_fields = vec![
Field::new("id", DataType::Int32, false),
Field::new("country", DataType::Utf8, false),
Field::new("city", DataType::Utf8, false),
Field::new("year", DataType::Int32, false),
];
let expected_schema = Schema::new(expected_fields);
assert_eq!(
updated_table_schema.table_schema().as_ref(),
&expected_schema
);
}
}