use arrow_schema::DataType;
use crate::{
Diagnostic, DiagnosticCode, DiagnosticSet, FieldRef, MssqlType, MssqlTypeLength, Result,
SchemaMapping,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) enum VariableWidthArrowToMssql {
Utf8ToNVarChar { length: MssqlTypeLength },
LargeUtf8ToNVarChar { length: MssqlTypeLength },
BinaryToVarBinary { length: MssqlTypeLength },
LargeBinaryToVarBinary { length: MssqlTypeLength },
}
impl VariableWidthArrowToMssql {
pub(crate) fn classify(mapping: &SchemaMapping, row_index: usize) -> Result<Self> {
let classification = match (mapping.arrow().data_type(), mapping.mssql().ty()) {
(DataType::Utf8, MssqlType::NVarChar(length)) => {
Self::Utf8ToNVarChar { length: *length }
}
(DataType::LargeUtf8, MssqlType::NVarChar(length)) => {
Self::LargeUtf8ToNVarChar { length: *length }
}
(DataType::Binary, MssqlType::VarBinary(length)) => {
Self::BinaryToVarBinary { length: *length }
}
(DataType::LargeBinary, MssqlType::VarBinary(length)) => {
Self::LargeBinaryToVarBinary { length: *length }
}
_ => {
return Err(value_conversion_error(row_mapping_diagnostic(
mapping,
row_index,
DiagnosticCode::ValueConversionUnsupported,
format!(
"variable-width conversion from Arrow {} to SQL Server {} is not supported",
mapping.arrow().data_type(),
mapping.mssql().ty().to_sql()
),
)));
}
};
Ok(classification)
}
}
fn row_mapping_diagnostic(
mapping: &SchemaMapping,
row_index: usize,
code: DiagnosticCode,
message: impl Into<String>,
) -> Diagnostic {
Diagnostic::error(code, message)
.with_field(FieldRef::new(
mapping.arrow().index(),
mapping.arrow().name(),
))
.with_row(row_index)
}
fn value_conversion_error(diagnostic: Diagnostic) -> crate::Error {
crate::Error::ValueConversion {
diagnostics: DiagnosticSet::from(vec![diagnostic]),
}
}
#[cfg(test)]
mod tests {
use arrow_schema::DataType;
use crate::{
ArrowFieldRef, DiagnosticCode, Identifier, MssqlColumn, MssqlType, MssqlTypeLength,
SchemaMapping,
};
use super::VariableWidthArrowToMssql;
#[test]
fn classifies_variable_width_mappings() {
let cases = [
(
DataType::Utf8,
MssqlType::NVarChar(MssqlTypeLength::Max),
VariableWidthArrowToMssql::Utf8ToNVarChar {
length: MssqlTypeLength::Max,
},
),
(
DataType::Utf8,
MssqlType::NVarChar(MssqlTypeLength::Bounded(32)),
VariableWidthArrowToMssql::Utf8ToNVarChar {
length: MssqlTypeLength::Bounded(32),
},
),
(
DataType::LargeUtf8,
MssqlType::NVarChar(MssqlTypeLength::Max),
VariableWidthArrowToMssql::LargeUtf8ToNVarChar {
length: MssqlTypeLength::Max,
},
),
(
DataType::Binary,
MssqlType::VarBinary(MssqlTypeLength::Max),
VariableWidthArrowToMssql::BinaryToVarBinary {
length: MssqlTypeLength::Max,
},
),
(
DataType::Binary,
MssqlType::VarBinary(MssqlTypeLength::Bounded(16)),
VariableWidthArrowToMssql::BinaryToVarBinary {
length: MssqlTypeLength::Bounded(16),
},
),
(
DataType::LargeBinary,
MssqlType::VarBinary(MssqlTypeLength::Max),
VariableWidthArrowToMssql::LargeBinaryToVarBinary {
length: MssqlTypeLength::Max,
},
),
];
for (index, (arrow_type, mssql_type, expected)) in cases.into_iter().enumerate() {
let mapping = mapping(index, "value", arrow_type, mssql_type);
assert_eq!(
VariableWidthArrowToMssql::classify(&mapping, index).unwrap(),
expected
);
}
}
#[test]
fn classifies_large_variable_width_mappings_for_scalar_reuse() {
let cases = [
(
DataType::LargeUtf8,
MssqlType::NVarChar(MssqlTypeLength::Max),
"large_text",
VariableWidthArrowToMssql::LargeUtf8ToNVarChar {
length: MssqlTypeLength::Max,
},
),
(
DataType::LargeBinary,
MssqlType::VarBinary(MssqlTypeLength::Max),
"large_bytes",
VariableWidthArrowToMssql::LargeBinaryToVarBinary {
length: MssqlTypeLength::Max,
},
),
];
for (index, (arrow_type, mssql_type, name, expected)) in cases.into_iter().enumerate() {
let mapping = mapping(index, name, arrow_type, mssql_type);
let classification = VariableWidthArrowToMssql::classify(&mapping, 7).unwrap();
assert_eq!(classification, expected);
assert!(matches!(
classification,
VariableWidthArrowToMssql::LargeUtf8ToNVarChar { .. }
| VariableWidthArrowToMssql::LargeBinaryToVarBinary { .. }
));
}
}
#[test]
fn classifier_rejects_mismatched_variable_width_pairs() {
let mapping = mapping(
3,
"text",
DataType::Utf8,
MssqlType::VarBinary(MssqlTypeLength::Max),
);
let err = VariableWidthArrowToMssql::classify(&mapping, 9).unwrap_err();
assert_single_diagnostic(
err,
DiagnosticCode::ValueConversionUnsupported,
Some(9),
Some((3, "text")),
);
}
fn mapping(
index: usize,
name: &str,
arrow_type: DataType,
mssql_type: MssqlType,
) -> SchemaMapping {
SchemaMapping::new(
ArrowFieldRef::new(index, name.to_owned(), false, arrow_type),
MssqlColumn::new(Identifier::new(name).unwrap(), mssql_type, false),
)
}
fn assert_single_diagnostic(
err: crate::Error,
expected_code: DiagnosticCode,
expected_row: Option<usize>,
expected_field: Option<(usize, &str)>,
) {
let crate::Error::ValueConversion { diagnostics } = err else {
panic!("expected value conversion error");
};
assert_eq!(diagnostics.len(), 1);
let diagnostic = &diagnostics.all()[0];
assert_eq!(diagnostic.code(), expected_code);
assert_eq!(diagnostic.row(), expected_row);
assert_eq!(
diagnostic
.field()
.map(|field| (field.index(), field.name())),
expected_field
);
}
}