use crate::{
DiagnosticCode, MssqlTypeLength, Result, SchemaMapping,
arrow::cell::ArrowCell,
conversion::arrow_to_mssql::{
fixed_size_binary::FixedSizeBinaryArrowToMssql, variable_width::VariableWidthArrowToMssql,
},
mssql::cell::MssqlCell,
};
use super::{row_mapping_diagnostic, value_conversion_error, value_too_long_error};
pub(super) fn nvar_char_cell<'a>(
mapping: &SchemaMapping,
row_index: usize,
length: MssqlTypeLength,
cell: ArrowCell<'a>,
) -> Result<MssqlCell<'a>> {
let classified = match VariableWidthArrowToMssql::classify(mapping, row_index)? {
VariableWidthArrowToMssql::Utf8ToNVarChar { length }
| VariableWidthArrowToMssql::LargeUtf8ToNVarChar { length } => length,
other => {
return Err(value_conversion_error(row_mapping_diagnostic(
mapping,
row_index,
DiagnosticCode::ValueConversionUnsupported,
format!("variable-width mapping {other:?} is not supported by nvarchar conversion"),
)));
}
};
debug_assert_eq!(length, classified);
let value = mssql_nvarchar_value(mapping, row_index, cell)?;
let code_units = value.encode_utf16().count();
if exceeds_length(length, code_units) {
return Err(value_too_long_error(
mapping,
row_index,
format!(
"string value has {code_units} UTF-16 code unit(s), exceeding planned {}",
mapping.mssql().ty().to_sql()
),
));
}
Ok(MssqlCell::NVarChar(Some(value)))
}
pub(super) fn var_binary_cell<'a>(
mapping: &SchemaMapping,
row_index: usize,
length: MssqlTypeLength,
cell: ArrowCell<'a>,
) -> Result<MssqlCell<'a>> {
let classified = match VariableWidthArrowToMssql::classify(mapping, row_index)? {
VariableWidthArrowToMssql::BinaryToVarBinary { length }
| VariableWidthArrowToMssql::LargeBinaryToVarBinary { length } => length,
other => {
return Err(value_conversion_error(row_mapping_diagnostic(
mapping,
row_index,
DiagnosticCode::ValueConversionUnsupported,
format!(
"variable-width mapping {other:?} is not supported by varbinary conversion"
),
)));
}
};
debug_assert_eq!(length, classified);
let value = mssql_varbinary_value(mapping, row_index, cell)?;
let bytes = value.len();
if exceeds_length(length, bytes) {
return Err(value_too_long_error(
mapping,
row_index,
format!(
"binary value has {bytes} byte(s), exceeding planned {}",
mapping.mssql().ty().to_sql()
),
));
}
Ok(MssqlCell::VarBinary(Some(value)))
}
pub(super) fn binary_cell<'a>(
mapping: &SchemaMapping,
row_index: usize,
length: usize,
cell: ArrowCell<'a>,
) -> Result<MssqlCell<'a>> {
let classified = match FixedSizeBinaryArrowToMssql::classify(mapping, row_index)? {
FixedSizeBinaryArrowToMssql::FixedSizeBinaryToBinary { length } => length,
};
debug_assert_eq!(length, classified);
let value = mssql_varbinary_value(mapping, row_index, cell)?;
let bytes = value.len();
if bytes != length {
return Err(value_conversion_error(row_mapping_diagnostic(
mapping,
row_index,
DiagnosticCode::ValueTypeMismatch,
format!(
"binary value has {bytes} byte(s), but planned {} requires exactly {length}",
mapping.mssql().ty().to_sql()
),
)));
}
Ok(MssqlCell::VarBinary(Some(value)))
}
fn mssql_nvarchar_value<'a>(
mapping: &SchemaMapping,
row_index: usize,
cell: ArrowCell<'a>,
) -> Result<&'a str> {
match cell {
ArrowCell::Utf8(value) => Ok(value),
other => Err(value_conversion_error(row_mapping_diagnostic(
mapping,
row_index,
DiagnosticCode::ValueTypeMismatch,
format!("expected Arrow UTF-8 payload, got {other:?}"),
))),
}
}
fn mssql_varbinary_value<'a>(
mapping: &SchemaMapping,
row_index: usize,
cell: ArrowCell<'a>,
) -> Result<&'a [u8]> {
match cell {
ArrowCell::Binary(value) => Ok(value),
other => Err(value_conversion_error(row_mapping_diagnostic(
mapping,
row_index,
DiagnosticCode::ValueTypeMismatch,
format!("expected Arrow binary payload, got {other:?}"),
))),
}
}
fn exceeds_length(length: MssqlTypeLength, actual: usize) -> bool {
match length {
MssqlTypeLength::Bounded(limit) => actual > limit,
MssqlTypeLength::Max => false,
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow_schema::{DataType, Field, Schema};
use super::super::{ArrowToMssqlRuntimeMapping, mssql_cell_from_arrow_cell};
use crate::{
BinaryPolicy, DiagnosticCode, MssqlProfile, MssqlType, PlanOptions, SchemaMapping,
StringPolicy, arrow::cell::ArrowCell, mssql::cell::MssqlCell,
plan_arrow_schema_to_mssql_mappings,
};
#[test]
fn converts_empty_ascii_and_non_ascii_strings() {
let mappings = mappings_for_schema(Schema::new(vec![
Field::new("text", DataType::Utf8, true),
Field::new("large_text", DataType::LargeUtf8, true),
]));
assert_eq!(
convert_cell(&mappings[0], ArrowCell::Utf8(""), 0).unwrap(),
MssqlCell::NVarChar(Some(""))
);
assert_eq!(
convert_cell(&mappings[0], ArrowCell::Utf8("ascii"), 1).unwrap(),
MssqlCell::NVarChar(Some("ascii"))
);
assert_eq!(
convert_cell(&mappings[0], ArrowCell::Utf8("Tokyo"), 2).unwrap(),
MssqlCell::NVarChar(Some("Tokyo"))
);
assert_eq!(
convert_cell(&mappings[1], ArrowCell::Utf8(""), 0).unwrap(),
MssqlCell::NVarChar(Some(""))
);
assert_eq!(
convert_cell(&mappings[1], ArrowCell::Utf8("ascii"), 1).unwrap(),
MssqlCell::NVarChar(Some("ascii"))
);
assert_eq!(
convert_cell(&mappings[1], ArrowCell::Utf8("emoji"), 2).unwrap(),
MssqlCell::NVarChar(Some("emoji"))
);
}
#[test]
fn converts_empty_and_non_empty_binary_values() {
let mappings = mappings_for_schema(Schema::new(vec![
Field::new("bytes", DataType::Binary, true),
Field::new("large_bytes", DataType::LargeBinary, true),
]));
assert_eq!(
convert_cell(&mappings[0], ArrowCell::Binary(b""), 0).unwrap(),
MssqlCell::VarBinary(Some(b""))
);
assert_eq!(
convert_cell(&mappings[0], ArrowCell::Binary(b"abc"), 1).unwrap(),
MssqlCell::VarBinary(Some(b"abc"))
);
assert_eq!(
convert_cell(&mappings[1], ArrowCell::Binary(b""), 0).unwrap(),
MssqlCell::VarBinary(Some(b""))
);
assert_eq!(
convert_cell(&mappings[1], ArrowCell::Binary(b"large"), 1).unwrap(),
MssqlCell::VarBinary(Some(b"large"))
);
}
#[test]
fn scalar_variable_conversion_reuses_classifier_for_large_values() {
let mappings = mappings_for_schema(Schema::new(vec![
Field::new("large_text", DataType::LargeUtf8, true),
Field::new("large_bytes", DataType::LargeBinary, true),
]));
assert_eq!(
convert_cell(&mappings[0], ArrowCell::Utf8("large text"), 0).unwrap(),
MssqlCell::NVarChar(Some("large text"))
);
assert_eq!(
convert_cell(&mappings[1], ArrowCell::Binary(b"large bytes"), 1).unwrap(),
MssqlCell::VarBinary(Some(b"large bytes"))
);
}
#[test]
fn rejects_bounded_nvarchar_by_utf16_code_units() {
let mappings = mappings_for_schema_with_options(
Schema::new(vec![Field::new("text", DataType::Utf8, true)]),
PlanOptions {
string_policy: StringPolicy::NVarChar(2),
..PlanOptions::default()
},
);
assert_eq!(
convert_cell(&mappings[0], ArrowCell::Utf8("ab"), 0).unwrap(),
MssqlCell::NVarChar(Some("ab"))
);
assert_eq!(
convert_cell(&mappings[0], ArrowCell::Utf8("🙂"), 1).unwrap(),
MssqlCell::NVarChar(Some("🙂"))
);
let err = convert_cell(&mappings[0], ArrowCell::Utf8("abc"), 2).unwrap_err();
assert_single_diagnostic(
err,
DiagnosticCode::ValueTooLong,
Some(2),
Some((0, "text")),
);
}
#[test]
fn rejects_bounded_varbinary_by_byte_count() {
let mappings = mappings_for_schema_with_options(
Schema::new(vec![Field::new("bytes", DataType::Binary, true)]),
PlanOptions {
binary_policy: BinaryPolicy::VarBinary(2),
..PlanOptions::default()
},
);
assert_eq!(
convert_cell(&mappings[0], ArrowCell::Binary(b""), 0).unwrap(),
MssqlCell::VarBinary(Some(b""))
);
assert_eq!(
convert_cell(&mappings[0], ArrowCell::Binary(b"ab"), 1).unwrap(),
MssqlCell::VarBinary(Some(b"ab"))
);
let err = convert_cell(&mappings[0], ArrowCell::Binary(b"abc"), 2).unwrap_err();
assert_single_diagnostic(
err,
DiagnosticCode::ValueTooLong,
Some(2),
Some((0, "bytes")),
);
}
#[test]
fn converts_fixed_size_binary_values() {
let mappings = mappings_for_schema(Schema::new(vec![Field::new(
"digest",
DataType::FixedSizeBinary(3),
true,
)]));
assert_eq!(mappings[0].mssql().ty(), &MssqlType::Binary(3));
assert_eq!(
convert_cell(&mappings[0], ArrowCell::Binary(b"abc"), 0).unwrap(),
MssqlCell::VarBinary(Some(b"abc"))
);
}
#[test]
fn rejects_fixed_size_binary_values_with_wrong_runtime_length() {
let mappings = mappings_for_schema(Schema::new(vec![Field::new(
"digest",
DataType::FixedSizeBinary(3),
true,
)]));
for (row_index, value) in [(0, &b""[..]), (1, &b"ab"[..]), (2, &b"abcd"[..])] {
let err = convert_cell(&mappings[0], ArrowCell::Binary(value), row_index).unwrap_err();
assert_single_diagnostic(
err,
DiagnosticCode::ValueTypeMismatch,
Some(row_index),
Some((0, "digest")),
);
}
}
fn convert_cell<'a>(
mapping: &SchemaMapping,
cell: ArrowCell<'a>,
row_index: usize,
) -> crate::Result<MssqlCell<'a>> {
let options = PlanOptions::default();
let runtime_mapping = ArrowToMssqlRuntimeMapping::new(mapping, &options);
mssql_cell_from_arrow_cell(runtime_mapping, cell, row_index)
}
fn mappings_for_schema(schema: Schema) -> Vec<SchemaMapping> {
mappings_for_schema_with_options(schema, PlanOptions::default())
}
fn mappings_for_schema_with_options(
schema: Schema,
options: PlanOptions,
) -> Vec<SchemaMapping> {
plan_arrow_schema_to_mssql_mappings(
Arc::new(schema),
MssqlProfile::sql_server_2016_compat_100(),
options,
)
.unwrap()
.into_parts()
.0
}
fn assert_single_diagnostic(
err: crate::Error,
expected_code: DiagnosticCode,
expected_row: Option<usize>,
expected_field: Option<(usize, &str)>,
) {
let crate::Error::ValueConversion { diagnostics } = err else {
panic!("expected value conversion error");
};
assert_eq!(diagnostics.len(), 1);
let diagnostic = &diagnostics.all()[0];
assert_eq!(diagnostic.code(), expected_code);
assert_eq!(diagnostic.row(), expected_row);
assert_eq!(
diagnostic
.field()
.map(|field| (field.index(), field.name())),
expected_field
);
}
}