use std::{cmp::min, num::NonZeroUsize, sync::Arc};
use arrow::array::{ArrayRef, StringBuilder};
use encoding_rs::mem::convert_utf16_to_str;
use odbc_api::{
DataType as OdbcDataType,
buffers::{AnySlice, BufferDesc},
};
use super::{ColumnFailure, MappingError, ReadStrategy};
pub fn choose_text_strategy(
sql_type: OdbcDataType,
lazy_display_size: impl FnOnce() -> Result<Option<NonZeroUsize>, odbc_api::Error>,
max_text_size: Option<usize>,
trim_fixed_sized_character_strings: bool,
text_encoding: TextEncoding,
) -> Result<Box<dyn ReadStrategy + Send>, ColumnFailure> {
let apply_buffer_limit = |len| match (len, max_text_size) {
(None, None) => Err(ColumnFailure::ZeroSizedColumn { sql_type }),
(None, Some(limit)) => Ok(limit),
(Some(len), None) => Ok(len),
(Some(len), Some(limit)) => Ok(min(len, limit)),
};
let is_fixed_sized_char = matches!(
sql_type,
OdbcDataType::Char { .. } | OdbcDataType::WChar { .. }
);
let trim = trim_fixed_sized_character_strings && is_fixed_sized_char;
let strategy: Box<dyn ReadStrategy + Send> = if text_encoding.use_utf16() {
let hex_len = sql_type
.utf16_len()
.map(Ok)
.or_else(|| lazy_display_size().transpose())
.transpose()
.map_err(|source| ColumnFailure::UnknownStringLength { sql_type, source })?;
let hex_len = apply_buffer_limit(hex_len.map(NonZeroUsize::get))?;
wide_text_strategy(hex_len, trim)
} else {
let octet_len = sql_type
.utf8_len()
.map(Ok)
.or_else(|| lazy_display_size().transpose())
.transpose()
.map_err(|source| ColumnFailure::UnknownStringLength { sql_type, source })?;
let octet_len = apply_buffer_limit(octet_len.map(NonZeroUsize::get))?;
narrow_text_strategy(octet_len, trim)
};
Ok(strategy)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TextEncoding {
Auto,
Utf8,
Utf16,
}
impl Default for TextEncoding {
fn default() -> Self {
Self::Auto
}
}
impl TextEncoding {
pub fn use_utf16(&self) -> bool {
match self {
Self::Auto => cfg!(target_os = "windows"),
Self::Utf8 => false,
Self::Utf16 => true,
}
}
}
fn wide_text_strategy(u16_len: usize, trim: bool) -> Box<dyn ReadStrategy + Send> {
Box::new(WideText::new(u16_len, trim))
}
fn narrow_text_strategy(octet_len: usize, trim: bool) -> Box<dyn ReadStrategy + Send> {
Box::new(NarrowText::new(octet_len, trim))
}
pub struct WideText {
max_str_len: usize,
trim: bool,
}
impl WideText {
pub fn new(max_str_len: usize, trim: bool) -> Self {
Self { max_str_len, trim }
}
}
impl ReadStrategy for WideText {
fn buffer_desc(&self) -> BufferDesc {
BufferDesc::WText {
max_str_len: self.max_str_len,
}
}
fn fill_arrow_array(&self, column_view: AnySlice) -> Result<ArrayRef, MappingError> {
let view = column_view.as_w_text_view().unwrap();
let item_capacity = view.len();
let data_capacity = self.max_str_len * item_capacity;
let mut builder = StringBuilder::with_capacity(item_capacity, data_capacity);
let mut converter = Utf16ToUtf8Converter::new();
for value in view.iter() {
let opt = if let Some(utf16) = value {
let slice = converter.utf16_to_utf8(utf16.as_slice());
let slice = if self.trim { slice.trim() } else { slice };
Some(slice)
} else {
None
};
builder.append_option(opt);
}
Ok(Arc::new(builder.finish()))
}
}
struct Utf16ToUtf8Converter {
buf_utf8: String,
}
impl Utf16ToUtf8Converter {
fn new() -> Self {
Self {
buf_utf8: String::new(),
}
}
fn utf16_to_utf8(&mut self, utf16: &[u16]) -> &str {
let max_utf8_len = utf16.len() * 3;
self.buf_utf8.clear();
for _ in 0..max_utf8_len {
self.buf_utf8.push('\0');
}
let written = convert_utf16_to_str(utf16, &mut self.buf_utf8);
&self.buf_utf8[..written]
}
}
pub struct NarrowText {
max_str_len: usize,
trim: bool,
}
impl NarrowText {
pub fn new(max_str_len: usize, trim: bool) -> Self {
Self { max_str_len, trim }
}
}
impl ReadStrategy for NarrowText {
fn buffer_desc(&self) -> BufferDesc {
BufferDesc::Text {
max_str_len: self.max_str_len,
}
}
fn fill_arrow_array(&self, column_view: AnySlice) -> Result<ArrayRef, MappingError> {
let view = column_view.as_text_view().unwrap();
let mut builder = StringBuilder::with_capacity(view.len(), self.max_str_len * view.len());
for value in view.iter() {
builder.append_option(
value
.map(|bytes| {
let untrimmed = simdutf8::basic::from_utf8(bytes).map_err(|_| {
MappingError::InvalidUtf8 {
lossy_value: String::from_utf8_lossy(bytes).into_owned(),
}
})?;
Ok(if self.trim {
untrimmed.trim()
} else {
untrimmed
})
})
.transpose()?,
);
}
Ok(Arc::new(builder.finish()))
}
}
#[cfg(test)]
mod tests {
use odbc_api::buffers::{AnySlice, ColumnBuffer, TextColumn};
use crate::reader::{MappingError, ReadStrategy as _, text::Utf16ToUtf8Converter};
use super::NarrowText;
#[test]
fn do_not_split_buffer_accross_char_boundaries() {
let utf_16_with_multibyte = "Colt Telecom España S.A."
.encode_utf16()
.collect::<Vec<u16>>();
let six = "123456".encode_utf16().collect::<Vec<u16>>();
let mut converter = Utf16ToUtf8Converter::new();
let first = converter.utf16_to_utf8(&utf_16_with_multibyte).to_owned();
let second = converter.utf16_to_utf8(&six);
assert_eq!(first, "Colt Telecom España S.A.");
assert_eq!(second, "123456");
}
#[test]
fn must_return_error_for_invalid_utf8() {
let mut column = TextColumn::new(1, 10);
column.set_value(0, Some(&[b'H', b'e', b'l', b'l', b'o', 0xc3]));
let column_view = AnySlice::Text(column.view(1));
let strategy = NarrowText::new(5, false);
let result = strategy.fill_arrow_array(column_view);
let error = result.unwrap_err();
let MappingError::InvalidUtf8 { lossy_value } = error else {
panic!("Not an InvalidUtf8 error")
};
assert_eq!(lossy_value, "Hello�");
}
}