use std::{
collections::{HashMap, HashSet},
fmt::{self, Debug, Formatter},
};
use arrow_array::RecordBatch;
use arrow_schema::{Field as ArrowField, Schema as ArrowSchema};
use deepsize::DeepSizeOf;
use lance_arrow::*;
use snafu::{location, Location};
use super::field::{Field, SchemaCompareOptions};
use crate::{Error, Result};
#[derive(Default, Debug, Clone, DeepSizeOf)]
pub struct Schema {
pub fields: Vec<Field>,
pub metadata: HashMap<String, String>,
}
struct SchemaFieldIterPreOrder<'a> {
field_stack: Vec<&'a Field>,
}
impl<'a> SchemaFieldIterPreOrder<'a> {
#[allow(dead_code)]
fn new(schema: &'a Schema) -> Self {
let mut field_stack = Vec::with_capacity(schema.fields.len() * 2);
for field in schema.fields.iter().rev() {
field_stack.push(field);
}
Self { field_stack }
}
}
impl<'a> Iterator for SchemaFieldIterPreOrder<'a> {
type Item = &'a Field;
fn next(&mut self) -> Option<Self::Item> {
if let Some(next_field) = self.field_stack.pop() {
for child in next_field.children.iter().rev() {
self.field_stack.push(child);
}
Some(next_field)
} else {
None
}
}
}
impl Schema {
pub fn compare_with_options(&self, expected: &Self, options: &SchemaCompareOptions) -> bool {
if self.fields.len() != expected.fields.len() {
false
} else {
self.fields
.iter()
.zip(&expected.fields)
.all(|(lhs, rhs)| lhs.compare_with_options(rhs, options))
&& (!options.compare_metadata || self.metadata == expected.metadata)
}
}
pub fn explain_difference(
&self,
expected: &Self,
options: &SchemaCompareOptions,
) -> Option<String> {
if self.fields.len() != expected.fields.len()
|| !self
.fields
.iter()
.zip(expected.fields.iter())
.all(|(field, expected)| field.name == expected.name)
{
let self_fields = self
.fields
.iter()
.map(|f| f.name.clone())
.collect::<HashSet<_>>();
let expected_fields = expected
.fields
.iter()
.map(|f| f.name.clone())
.collect::<HashSet<_>>();
let missing = expected_fields
.difference(&self_fields)
.cloned()
.collect::<Vec<_>>();
let unexpected = self_fields
.difference(&expected_fields)
.cloned()
.collect::<Vec<_>>();
if missing.is_empty() && unexpected.is_empty() {
Some(format!(
"fields in different order, expected: [{}], actual: [{}]",
expected
.fields
.iter()
.map(|f| f.name.clone())
.collect::<Vec<_>>()
.join(", "),
self.fields
.iter()
.map(|f| f.name.clone())
.collect::<Vec<_>>()
.join(", "),
))
} else {
Some(format!(
"fields did not match, missing=[{}], unexpected=[{}]",
missing.join(", "),
unexpected.join(", ")
))
}
} else {
let differences = self
.fields
.iter()
.zip(expected.fields.iter())
.flat_map(|(field, expected)| field.explain_difference(expected, options))
.collect::<Vec<_>>();
if differences.is_empty() {
if options.compare_metadata && self.metadata != expected.metadata {
Some("schema metadata did not match expected schema metadata".to_string())
} else {
None
}
} else {
Some(differences.join(", "))
}
}
}
pub fn has_dictionary_types(&self) -> bool {
self.fields.iter().any(|f| f.has_dictionary_types())
}
pub fn check_compatible(&self, expected: &Self, options: &SchemaCompareOptions) -> Result<()> {
if !self.compare_with_options(expected, options) {
let difference = self.explain_difference(expected, options);
Err(Error::SchemaMismatch {
difference: difference.unwrap_or("unknown reason".to_string()),
location: location!(),
})
} else {
Ok(())
}
}
pub fn to_compact_string(&self, indent: Indentation) -> String {
ArrowSchema::from(self).to_compact_string(indent)
}
pub fn project<T: AsRef<str>>(&self, columns: &[T]) -> Result<Self> {
let mut candidates: Vec<Field> = vec![];
for col in columns {
let split = col.as_ref().split('.').collect::<Vec<_>>();
let first = split[0];
if let Some(field) = self.field(first) {
let projected_field = field.project(&split[1..])?;
if let Some(candidate_field) = candidates.iter_mut().find(|f| f.name == first) {
candidate_field.merge(&projected_field)?;
} else {
candidates.push(projected_field)
}
} else {
return Err(Error::Schema {
message: format!("Column {} does not exist", col.as_ref()),
location: location!(),
});
}
}
Ok(Self {
fields: candidates,
metadata: self.metadata.clone(),
})
}
pub fn validate(&self) -> Result<()> {
let mut seen_names = HashSet::new();
for field in self.fields.iter() {
if field.name.contains('.') {
return Err(Error::Schema{message:format!(
"Top level field {} cannot contain `.`. Maybe you meant to create a struct field?",
field.name.clone()
), location: location!(),});
}
let column_path = self
.field_ancestry_by_id(field.id)
.unwrap()
.iter()
.map(|f| f.name.as_str())
.collect::<Vec<_>>()
.join(".");
if !seen_names.insert(column_path.clone()) {
return Err(Error::Schema {
message: format!(
"Duplicate field name \"{}\" in schema:\n {:#?}",
column_path, self
),
location: location!(),
});
}
}
let mut seen_ids = HashSet::new();
for field in self.fields_pre_order() {
if field.id < 0 {
return Err(Error::Schema {
message: format!("Field {} has a negative id {}", field.name, field.id),
location: location!(),
});
}
if !seen_ids.insert(field.id) {
return Err(Error::Schema {
message: format!("Duplicate field id {} in schema {:?}", field.id, self),
location: location!(),
});
}
}
Ok(())
}
pub fn intersection(&self, other: &Self) -> Result<Self> {
let mut candidates: Vec<Field> = vec![];
for field in other.fields.iter() {
if let Some(candidate_field) = self.field(&field.name) {
candidates.push(candidate_field.intersection(field)?);
}
}
Ok(Self {
fields: candidates,
metadata: self.metadata.clone(),
})
}
pub fn fields_pre_order(&self) -> impl Iterator<Item = &Field> {
SchemaFieldIterPreOrder::new(self)
}
pub fn project_by_ids(&self, column_ids: &[i32]) -> Self {
let filtered_fields = self
.fields
.iter()
.filter_map(|f| f.project_by_ids(column_ids))
.collect();
Self {
fields: filtered_fields,
metadata: self.metadata.clone(),
}
}
pub fn project_by_schema<S: TryInto<Self, Error = Error>>(
&self,
projection: S,
) -> Result<Self> {
let projection = projection.try_into()?;
let mut new_fields = vec![];
for field in projection.fields.iter() {
if let Some(self_field) = self.field(&field.name) {
new_fields.push(self_field.project_by_field(field)?);
} else {
return Err(Error::Schema {
message: format!("Field {} not found", field.name),
location: location!(),
});
}
}
Ok(Self {
fields: new_fields,
metadata: self.metadata.clone(),
})
}
pub fn exclude<T: TryInto<Self> + Debug>(&self, schema: T) -> Result<Self> {
let other = schema.try_into().map_err(|_| Error::Schema {
message: "The other schema is not compatible with this schema".to_string(),
location: location!(),
})?;
let mut fields = vec![];
for field in self.fields.iter() {
if let Some(other_field) = other.field(&field.name) {
if field.data_type().is_struct() {
if let Some(f) = field.exclude(other_field) {
fields.push(f)
}
}
} else {
fields.push(field.clone());
}
}
Ok(Self {
fields,
metadata: self.metadata.clone(),
})
}
pub fn field(&self, name: &str) -> Option<&Field> {
let split = name.split('.').collect::<Vec<_>>();
self.fields
.iter()
.find(|f| f.name == split[0])
.and_then(|c| c.sub_field(&split[1..]))
}
pub fn field_id(&self, column: &str) -> Result<i32> {
self.field(column)
.map(|f| f.id)
.ok_or_else(|| Error::Schema {
message: "Vector column not in schema".to_string(),
location: location!(),
})
}
pub fn top_level_field_ids(&self) -> Vec<i32> {
self.fields.iter().map(|f| f.id).collect()
}
pub fn field_ids(&self) -> Vec<i32> {
self.fields_pre_order().map(|f| f.id).collect()
}
pub fn field_by_id(&self, id: impl Into<i32>) -> Option<&Field> {
let id = id.into();
for field in self.fields.iter() {
if field.id == id {
return Some(field);
}
if let Some(grandchild) = field.field_by_id(id) {
return Some(grandchild);
}
}
None
}
pub fn field_ancestry_by_id(&self, id: i32) -> Option<Vec<&Field>> {
let mut to_visit = self.fields.iter().map(|f| vec![f]).collect::<Vec<_>>();
while let Some(path) = to_visit.pop() {
let field = path.last().unwrap();
if field.id == id {
return Some(path);
}
for child in field.children.iter() {
let mut new_path = path.clone();
new_path.push(child);
to_visit.push(new_path);
}
}
None
}
pub fn mut_field_by_id(&mut self, id: impl Into<i32>) -> Option<&mut Field> {
let id = id.into();
for field in self.fields.as_mut_slice() {
if field.id == id {
return Some(field);
}
if let Some(grandchild) = field.mut_field_by_id(id) {
return Some(grandchild);
}
}
None
}
pub fn max_field_id(&self) -> Option<i32> {
self.fields.iter().map(|f| f.max_id()).max()
}
pub fn set_dictionary(&mut self, batch: &RecordBatch) -> Result<()> {
for field in self.fields.as_mut_slice() {
let column = batch
.column_by_name(&field.name)
.ok_or_else(|| Error::Schema {
message: format!("column '{}' does not exist in the record batch", field.name),
location: location!(),
})?;
field.set_dictionary(column);
}
Ok(())
}
pub fn set_field_id(&mut self, max_existing_id: Option<i32>) {
let schema_max_id = self.max_field_id().unwrap_or(-1);
let max_existing_id = max_existing_id.unwrap_or(-1);
let mut current_id = schema_max_id.max(max_existing_id) + 1;
self.fields
.iter_mut()
.for_each(|f| f.set_id(-1, &mut current_id));
}
fn reset_id(&mut self) {
self.fields.iter_mut().for_each(|f| f.reset_id());
}
pub fn extend(&mut self, fields: &[ArrowField]) -> Result<()> {
let new_fields = fields
.iter()
.map(Field::try_from)
.collect::<Result<Vec<_>>>()?;
self.fields.extend(new_fields);
let field_names = self.fields.iter().map(|f| &f.name).collect::<HashSet<_>>();
if field_names.len() != self.fields.len() {
Err(Error::Internal {
message: format!(
"Attempt to add fields [{:?}] would lead to duplicate field names",
fields.iter().map(|f| f.name()).collect::<Vec<_>>()
),
location: location!(),
})
} else {
Ok(())
}
}
pub fn merge<S: TryInto<Self, Error = Error>>(&self, other: S) -> Result<Self> {
let mut other: Self = other.try_into()?;
other.reset_id();
let mut merged_fields: Vec<Field> = vec![];
for mut field in self.fields.iter().cloned() {
if let Some(other_field) = other.field(&field.name) {
field.merge(other_field)?;
}
merged_fields.push(field);
}
for field in other.fields.as_slice() {
if !merged_fields.iter().any(|f| f.name == field.name) {
merged_fields.push(field.clone());
}
}
let metadata = self
.metadata
.iter()
.chain(other.metadata.iter())
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
let schema = Self {
fields: merged_fields,
metadata,
};
Ok(schema)
}
}
impl PartialEq for Schema {
fn eq(&self, other: &Self) -> bool {
self.fields == other.fields
}
}
impl fmt::Display for Schema {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
for field in self.fields.iter() {
writeln!(f, "{field}")?
}
Ok(())
}
}
impl TryFrom<&ArrowSchema> for Schema {
type Error = Error;
fn try_from(schema: &ArrowSchema) -> Result<Self> {
let mut schema = Self {
fields: schema
.fields
.iter()
.map(|f| Field::try_from(f.as_ref()))
.collect::<Result<_>>()?,
metadata: schema.metadata.clone(),
};
schema.set_field_id(None);
Ok(schema)
}
}
impl From<&Schema> for ArrowSchema {
fn from(schema: &Schema) -> Self {
Self {
fields: schema.fields.iter().map(ArrowField::from).collect(),
metadata: schema.metadata.clone(),
}
}
}
impl TryFrom<&Self> for Schema {
type Error = Error;
fn try_from(schema: &Self) -> Result<Self> {
Ok(schema.clone())
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use arrow_schema::{
DataType, Field as ArrowField, Fields as ArrowFields, Schema as ArrowSchema,
};
#[test]
fn test_schema_projection() {
let arrow_schema = ArrowSchema::new(vec![
ArrowField::new("a", DataType::Int32, false),
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f2", DataType::Boolean, false),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
),
ArrowField::new("c", DataType::Float64, false),
]);
let schema = Schema::try_from(&arrow_schema).unwrap();
let projected = schema.project(&["b.f1", "b.f3", "c"]).unwrap();
let expected_arrow_schema = ArrowSchema::new(vec![
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
),
ArrowField::new("c", DataType::Float64, false),
]);
assert_eq!(ArrowSchema::from(&projected), expected_arrow_schema);
}
#[test]
fn test_schema_project_by_ids() {
let arrow_schema = ArrowSchema::new(vec![
ArrowField::new("a", DataType::Int32, false),
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f2", DataType::Boolean, false),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
),
ArrowField::new("c", DataType::Float64, false),
]);
let mut schema = Schema::try_from(&arrow_schema).unwrap();
schema.set_field_id(None);
let projected = schema.project_by_ids(&[2, 4, 5]);
let expected_arrow_schema = ArrowSchema::new(vec![
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
),
ArrowField::new("c", DataType::Float64, false),
]);
assert_eq!(ArrowSchema::from(&projected), expected_arrow_schema);
let projected = schema.project_by_ids(&[2]);
let expected_arrow_schema = ArrowSchema::new(vec![ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![ArrowField::new(
"f1",
DataType::Utf8,
true,
)])),
true,
)]);
assert_eq!(ArrowSchema::from(&projected), expected_arrow_schema);
let projected = schema.project_by_ids(&[1]);
let expected_arrow_schema = ArrowSchema::new(vec![ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f2", DataType::Boolean, false),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
)]);
assert_eq!(ArrowSchema::from(&projected), expected_arrow_schema);
}
#[test]
fn test_schema_project_by_schema() {
let arrow_schema = ArrowSchema::new(vec![
ArrowField::new("a", DataType::Int32, false),
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f2", DataType::Boolean, false),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
),
ArrowField::new("c", DataType::Float64, false),
ArrowField::new("s", DataType::Utf8, false),
ArrowField::new(
"l",
DataType::List(Arc::new(ArrowField::new("le", DataType::Int32, false))),
false,
),
ArrowField::new(
"fixed_l",
DataType::List(Arc::new(ArrowField::new("elem", DataType::Float32, false))),
false,
),
ArrowField::new(
"d",
DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
false,
),
]);
let schema = Schema::try_from(&arrow_schema).unwrap();
let projection = ArrowSchema::new(vec![
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![ArrowField::new(
"f1",
DataType::Utf8,
true,
)])),
true,
),
ArrowField::new("s", DataType::Utf8, false),
ArrowField::new(
"l",
DataType::List(Arc::new(ArrowField::new("le", DataType::Int32, false))),
false,
),
ArrowField::new(
"fixed_l",
DataType::List(Arc::new(ArrowField::new("elem", DataType::Float32, false))),
false,
),
ArrowField::new(
"d",
DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
false,
),
]);
let projected = schema.project_by_schema(&projection).unwrap();
assert_eq!(ArrowSchema::from(&projected), projection);
}
#[test]
fn test_get_nested_field() {
let arrow_schema = ArrowSchema::new(vec![ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f2", DataType::Boolean, false),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
)]);
let schema = Schema::try_from(&arrow_schema).unwrap();
let field = schema.field("b.f2").unwrap();
assert_eq!(field.data_type(), DataType::Boolean);
}
#[test]
fn test_exclude_fields() {
let arrow_schema = ArrowSchema::new(vec![
ArrowField::new("a", DataType::Int32, false),
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f2", DataType::Boolean, false),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
),
ArrowField::new("c", DataType::Float64, false),
]);
let schema = Schema::try_from(&arrow_schema).unwrap();
let projection = schema.project(&["a", "b.f2", "b.f3"]).unwrap();
let excluded = schema.exclude(&projection).unwrap();
let expected_arrow_schema = ArrowSchema::new(vec![
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![ArrowField::new(
"f1",
DataType::Utf8,
true,
)])),
true,
),
ArrowField::new("c", DataType::Float64, false),
]);
assert_eq!(ArrowSchema::from(&excluded), expected_arrow_schema);
}
#[test]
fn test_intersection() {
let arrow_schema = ArrowSchema::new(vec![
ArrowField::new("a", DataType::Int32, false),
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f2", DataType::Boolean, false),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
),
ArrowField::new("c", DataType::Float64, false),
]);
let schema = Schema::try_from(&arrow_schema).unwrap();
let arrow_schema = ArrowSchema::new(vec![
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f2", DataType::Boolean, false),
])),
true,
),
ArrowField::new("c", DataType::Float64, false),
ArrowField::new("d", DataType::Utf8, false),
]);
let other = Schema::try_from(&arrow_schema).unwrap();
let actual: ArrowSchema = (&schema.intersection(&other).unwrap()).into();
let expected = ArrowSchema::new(vec![
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f2", DataType::Boolean, false),
])),
true,
),
ArrowField::new("c", DataType::Float64, false),
]);
assert_eq!(actual, expected);
}
#[test]
fn test_merge_schemas_and_assign_field_ids() {
let arrow_schema = ArrowSchema::new(vec![
ArrowField::new("a", DataType::Int32, false),
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f2", DataType::Boolean, false),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
),
ArrowField::new("c", DataType::Float64, false),
]);
let schema = Schema::try_from(&arrow_schema).unwrap();
assert_eq!(schema.max_field_id(), Some(5));
let to_merged_arrow_schema = ArrowSchema::new(vec![
ArrowField::new("d", DataType::Int32, false),
ArrowField::new("e", DataType::Binary, false),
]);
let to_merged = Schema::try_from(&to_merged_arrow_schema).unwrap();
assert_eq!(to_merged.max_field_id(), Some(1));
let mut merged = schema.merge(&to_merged).unwrap();
assert_eq!(merged.max_field_id(), Some(5));
let field = merged.field("d").unwrap();
assert_eq!(field.id, -1);
let field = merged.field("e").unwrap();
assert_eq!(field.id, -1);
merged.set_field_id(Some(7));
let field = merged.field("d").unwrap();
assert_eq!(field.id, 8);
let field = merged.field("e").unwrap();
assert_eq!(field.id, 9);
assert_eq!(merged.max_field_id(), Some(9));
}
#[test]
fn test_merge_arrow_schema() {
let arrow_schema = ArrowSchema::new(vec![
ArrowField::new("a", DataType::Int32, false),
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f2", DataType::Boolean, false),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
),
ArrowField::new("c", DataType::Float64, false),
]);
let schema = Schema::try_from(&arrow_schema).unwrap();
assert_eq!(schema.max_field_id(), Some(5));
let to_merged_arrow_schema = ArrowSchema::new(vec![
ArrowField::new("d", DataType::Int32, false),
ArrowField::new("e", DataType::Binary, false),
]);
let mut merged = schema.merge(&to_merged_arrow_schema).unwrap();
merged.set_field_id(None);
assert_eq!(merged.max_field_id(), Some(7));
let field = merged.field("d").unwrap();
assert_eq!(field.id, 6);
let field = merged.field("e").unwrap();
assert_eq!(field.id, 7);
}
#[test]
fn test_merge_nested_field() {
let arrow_schema1 = ArrowSchema::new(vec![ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new(
"f1",
DataType::Struct(ArrowFields::from(vec![ArrowField::new(
"f11",
DataType::Utf8,
true,
)])),
true,
),
ArrowField::new("f2", DataType::Float32, false),
])),
true,
)]);
let schema1 = Schema::try_from(&arrow_schema1).unwrap();
let arrow_schema2 = ArrowSchema::new(vec![ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new(
"f1",
DataType::Struct(ArrowFields::from(vec![ArrowField::new(
"f22",
DataType::Utf8,
true,
)])),
true,
),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
)]);
let schema2 = Schema::try_from(&arrow_schema2).unwrap();
let expected_arrow_schema = ArrowSchema::new(vec![ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new(
"f1",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f11", DataType::Utf8, true),
ArrowField::new("f22", DataType::Utf8, true),
])),
true,
),
ArrowField::new("f2", DataType::Float32, false),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
)]);
let mut expected_schema = Schema::try_from(&expected_arrow_schema).unwrap();
expected_schema.fields[0]
.child_mut("f1")
.unwrap()
.child_mut("f22")
.unwrap()
.id = 4;
expected_schema.fields[0].child_mut("f2").unwrap().id = 3;
let mut result = schema1.merge(&schema2).unwrap();
result.set_field_id(None);
assert_eq!(result, expected_schema);
}
#[test]
fn test_field_by_id() {
let arrow_schema = ArrowSchema::new(vec![
ArrowField::new("a", DataType::Int32, false),
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f2", DataType::Boolean, false),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
),
ArrowField::new("c", DataType::Float64, false),
]);
let schema = Schema::try_from(&arrow_schema).unwrap();
let field = schema.field_by_id(1).unwrap();
assert_eq!(field.name, "b");
let field = schema.field_by_id(3).unwrap();
assert_eq!(field.name, "f2");
}
#[test]
fn test_explain_difference() {
let expected = ArrowSchema::new(vec![
ArrowField::new("a", DataType::Int32, false),
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f2", DataType::Boolean, false),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
),
ArrowField::new("c", DataType::Float64, false),
]);
let expected = Schema::try_from(&expected).unwrap();
let mismatched = ArrowSchema::new(vec![
ArrowField::new("a", DataType::Int32, false),
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
),
ArrowField::new("c", DataType::Float64, true),
]);
let mismatched = Schema::try_from(&mismatched).unwrap();
assert_eq!(mismatched.explain_difference(&expected, &SchemaCompareOptions::default()), Some("`b` had mismatched children, missing=[f2] unexpected=[], `c` should have nullable=false but nullable=true".to_string()));
}
}