use std::cmp::min;
use arrow::{
datatypes::SchemaRef,
error::ArrowError,
record_batch::{RecordBatch, RecordBatchReader},
};
use odbc_api::{BlockCursor, Cursor, buffers::ColumnarAnyBuffer};
use crate::{BufferAllocationOptions, ConcurrentOdbcReader, Error};
use super::{TextEncoding, to_record_batch::ToRecordBatch};
pub struct OdbcReader<C: Cursor> {
converter: ToRecordBatch,
batch_stream: BlockCursor<C, ColumnarAnyBuffer>,
fallibale_allocations: bool,
}
impl<C: Cursor> OdbcReader<C> {
pub fn into_concurrent(self) -> Result<ConcurrentOdbcReader<C>, Error>
where
C: Send + 'static,
{
ConcurrentOdbcReader::from_block_cursor(
self.batch_stream,
self.converter,
self.fallibale_allocations,
)
}
pub fn into_cursor(self) -> Result<C, odbc_api::Error> {
let (cursor, _buffer) = self.batch_stream.unbind()?;
Ok(cursor)
}
pub fn max_rows_per_batch(&self) -> usize {
self.batch_stream.row_array_size()
}
}
impl<C> Iterator for OdbcReader<C>
where
C: Cursor,
{
type Item = Result<RecordBatch, ArrowError>;
fn next(&mut self) -> Option<Self::Item> {
match self.batch_stream.fetch_with_truncation_check(true) {
Ok(Some(batch)) => {
let result_record_batch = self
.converter
.buffer_to_record_batch(batch)
.map_err(|mapping_error| ArrowError::ExternalError(Box::new(mapping_error)));
Some(result_record_batch)
}
Ok(None) => None,
Err(odbc_error) => Some(Err(odbc_to_arrow_error(odbc_error))),
}
}
}
impl<C> RecordBatchReader for OdbcReader<C>
where
C: Cursor,
{
fn schema(&self) -> SchemaRef {
self.converter.schema().clone()
}
}
const DEFAULT_MAX_ROWS_PER_BATCH: usize = u16::MAX as usize;
const DEFAULT_MAX_BYTES_PER_BATCH: usize = 512 * 1024 * 1024;
#[derive(Clone)]
pub struct OdbcReaderBuilder {
max_num_rows_per_batch: usize,
max_bytes_per_batch: usize,
schema: Option<SchemaRef>,
max_text_size: Option<usize>,
max_binary_size: Option<usize>,
map_value_errors_to_null: bool,
dbms_name: Option<String>,
fallibale_allocations: bool,
trim_fixed_sized_character_strings: bool,
text_encoding: TextEncoding,
}
impl OdbcReaderBuilder {
pub fn new() -> Self {
OdbcReaderBuilder {
max_num_rows_per_batch: DEFAULT_MAX_ROWS_PER_BATCH,
max_bytes_per_batch: DEFAULT_MAX_BYTES_PER_BATCH,
schema: None,
max_text_size: None,
max_binary_size: None,
fallibale_allocations: false,
map_value_errors_to_null: false,
dbms_name: None,
trim_fixed_sized_character_strings: false,
text_encoding: TextEncoding::Auto,
}
}
pub fn with_max_num_rows_per_batch(&mut self, max_num_rows_per_batch: usize) -> &mut Self {
self.max_num_rows_per_batch = max_num_rows_per_batch;
self
}
pub fn with_max_bytes_per_batch(&mut self, max_bytes_per_batch: usize) -> &mut Self {
self.max_bytes_per_batch = max_bytes_per_batch;
self
}
pub fn with_schema(&mut self, schema: SchemaRef) -> &mut Self {
self.schema = Some(schema);
self
}
pub fn with_max_text_size(&mut self, max_text_size: usize) -> &mut Self {
self.max_text_size = Some(max_text_size);
self
}
pub fn with_max_binary_size(&mut self, max_binary_size: usize) -> &mut Self {
self.max_binary_size = Some(max_binary_size);
self
}
pub fn with_fallibale_allocations(&mut self, fallibale_allocations: bool) -> &mut Self {
self.fallibale_allocations = fallibale_allocations;
self
}
pub fn value_errors_as_null(&mut self, map_value_errors_to_null: bool) -> &mut Self {
self.map_value_errors_to_null = map_value_errors_to_null;
self
}
pub fn trim_fixed_sized_characters(
&mut self,
fixed_sized_character_strings_are_trimmed: bool,
) -> &mut Self {
self.trim_fixed_sized_character_strings = fixed_sized_character_strings_are_trimmed;
self
}
pub fn with_payload_text_encoding(&mut self, text_encoding: TextEncoding) -> &mut Self {
self.text_encoding = text_encoding;
self
}
pub fn with_dbms_name(&mut self, dbms_name: String) -> &mut Self {
self.dbms_name = Some(dbms_name);
self
}
fn buffer_size_in_rows(&self, bytes_per_row: usize) -> Result<usize, Error> {
if bytes_per_row == 0 {
return Ok(self.max_bytes_per_batch);
}
let rows_per_batch = self.max_bytes_per_batch / bytes_per_row;
if rows_per_batch == 0 {
Err(Error::OdbcBufferTooSmall {
max_bytes_per_batch: self.max_bytes_per_batch,
bytes_per_row,
})
} else {
Ok(min(self.max_num_rows_per_batch, rows_per_batch))
}
}
pub fn build<C>(&self, mut cursor: C) -> Result<OdbcReader<C>, Error>
where
C: Cursor,
{
let buffer_allocation_options = BufferAllocationOptions {
max_text_size: self.max_text_size,
max_binary_size: self.max_binary_size,
fallibale_allocations: self.fallibale_allocations,
};
let converter = ToRecordBatch::new(
&mut cursor,
self.schema.clone(),
buffer_allocation_options,
self.map_value_errors_to_null,
self.dbms_name.as_deref(),
self.trim_fixed_sized_character_strings,
self.text_encoding,
)?;
let bytes_per_row = converter.row_size_in_bytes();
let buffer_size_in_rows = self.buffer_size_in_rows(bytes_per_row)?;
let row_set_buffer =
converter.allocate_buffer(buffer_size_in_rows, self.fallibale_allocations)?;
let batch_stream = cursor.bind_buffer(row_set_buffer).unwrap();
Ok(OdbcReader {
converter,
batch_stream,
fallibale_allocations: self.fallibale_allocations,
})
}
}
pub fn odbc_to_arrow_error(odbc_error: odbc_api::Error) -> ArrowError {
ArrowError::from_external_error(Box::new(odbc_error))
}
impl Default for OdbcReaderBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::{DEFAULT_MAX_BYTES_PER_BATCH, DEFAULT_MAX_ROWS_PER_BATCH, OdbcReaderBuilder};
#[test]
fn default_constructed_builder() {
let def = OdbcReaderBuilder::default();
assert_eq!(def.max_num_rows_per_batch, DEFAULT_MAX_ROWS_PER_BATCH);
assert_eq!(def.max_bytes_per_batch, DEFAULT_MAX_BYTES_PER_BATCH);
}
}