1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372
use std::cmp::min;
use thiserror::Error;
use arrow::{
array::Array,
datatypes::{
DataType, Date32Type, Date64Type, Field, Float16Type, Float32Type, Float64Type, Int16Type,
Int32Type, Int64Type, Int8Type, Schema, Time32MillisecondType, Time32SecondType,
Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt8Type,
},
error::ArrowError,
record_batch::{RecordBatch, RecordBatchReader},
};
use odbc_api::{
buffers::{AnyBuffer, AnySliceMut, BufferDesc},
handles::{AsStatementRef, StatementImpl},
ColumnarBulkInserter, Connection, Prepared, StatementConnection,
};
use crate::{
date_time::{
epoch_to_date, epoch_to_timestamp, sec_since_midnight_to_time, NullableTimeAsText,
},
decimal::{NullableDecimal128AsText, NullableDecimal256AsText},
};
use self::{
binary::VariadicBinary,
boolean::boolean_to_bit,
map_arrow_to_odbc::MapArrowToOdbc,
text::{LargeUtf8ToNativeText, Utf8ToNativeText},
};
mod binary;
mod boolean;
mod map_arrow_to_odbc;
mod text;
/// Fastest and most convinient way to stream the contents of arrow record batches into a database
/// table. For usecase there you want to insert repeatedly into the same table from different
/// streams it is more efficient to create an instance of [`self::OdbcWriter`] and reuse it.
///
/// **Note:**
///
/// If table or column names are derived from user input, be sure to sanatize the input in order to
/// prevent SQL injection attacks.
pub fn insert_into_table(
connection: &Connection,
batches: &mut impl RecordBatchReader,
table_name: &str,
batch_size: usize,
) -> Result<(), WriterError> {
let schema = batches.schema();
let mut inserter =
OdbcWriter::with_connection(connection, schema.as_ref(), table_name, batch_size)?;
inserter.write_all(batches)
}
/// Generates an insert statement using the table and column names.
///
/// `INSERT INTO <table> (<column_names 0>, <column_names 1>, ...) VALUES (?, ?, ...)`
fn insert_statement_text(table: &str, column_names: &[&'_ str]) -> String {
// Generate statement text from table name and headline
let columns = column_names.join(", ");
let values = column_names
.iter()
.map(|_| "?")
.collect::<Vec<_>>()
.join(", ");
format!("INSERT INTO {table} ({columns}) VALUES ({values});")
}
/// Creates an SQL insert statement from an arrow schema. The resulting statement will have one
/// placeholer (`?`) for each column in the statement.
///
/// **Note:**
///
/// If table or column names are derived from user input, be sure to sanatize the input in order to
/// prevent SQL injection attacks.
///
/// # Example
///
/// ```
/// use arrow_odbc::{
/// insert_statement_from_schema,
/// arrow::datatypes::{Field, DataType, Schema},
/// };
///
/// let field_a = Field::new("a", DataType::Int64, false);
/// let field_b = Field::new("b", DataType::Boolean, false);
///
/// let schema = Schema::new(vec![field_a, field_b]);
/// let sql = insert_statement_from_schema(&schema, "MyTable");
///
/// assert_eq!("INSERT INTO MyTable (a, b) VALUES (?, ?);", sql)
/// ```
///
/// This function is automatically invoked by [`crate::OdbcWriter::with_connection`].
pub fn insert_statement_from_schema(schema: &Schema, table_name: &str) -> String {
let fields = schema.fields();
let num_columns = fields.len();
let column_names: Vec<_> = (0..num_columns)
.map(|i| fields[i].name().as_str())
.collect();
insert_statement_text(table_name, &column_names)
}
/// Emitted writing values from arror arrays into a table on the database
#[derive(Debug, Error)]
pub enum WriterError {
#[error("Failure to bind the array parameter buffers to the statement.")]
BindParameterBuffers(#[source] odbc_api::Error),
#[error("Failure to execute the sql statement, sending the data to the database.")]
ExecuteStatment(#[source] odbc_api::Error),
#[error("An error occured rebinding a parameter buffer to the sql statement.")]
RebindBuffer(#[source] odbc_api::Error),
#[error("The arrow data type {0} is not supported for insertion.")]
UnsupportedArrowDataType(DataType),
#[error("An error occured extracting a record batch from an error reader.")]
ReadingRecordBatch(#[source] ArrowError),
#[error("An error occurred preparing SQL statement: {0}", sql)]
PreparingInsertStatement {
#[source]
source: odbc_api::Error,
sql: String,
},
#[deprecated(note = "Use Variant UnsupportedArrowDataType instead")]
#[error("Inserting arrays with timestamp information is currently not supported.")]
TimeZonesNotSupported,
}
/// Inserts batches from an [`arrow::record_batch::RecordBatchReader`] into a database.
pub struct OdbcWriter<S> {
/// Prepared statement with bound array parameter buffers. Data is copied into these buffers
/// until they are full. Then we execute the statement. This is repeated until we run out of
/// data.
inserter: ColumnarBulkInserter<S, AnyBuffer>,
/// For each field in the arrow schema we decide on which buffer to use to send the parameters
/// to the database, and need to remember how to copy the data from an arrow array to an odbc
/// mutable buffer slice for any column.
strategies: Vec<Box<dyn WriteStrategy>>,
}
impl<S> OdbcWriter<S>
where
S: AsStatementRef,
{
/// Construct a new ODBC writer using an alredy existing prepared statement. Usually you want to
/// call a higher level constructor like [`Self::with_connection`]. Yet, this constructor is
/// useful in two scenarios.
///
/// 1. The prepared statement is already constructed and you do not want to spend the time to
/// prepare it again.
/// 2. You want to use the arrow arrays as arrar parameters for a statement, but that statement
/// is not necessarily an INSERT statement with a simple 1to1 mapping of columns between
/// table and arrow schema.
///
/// # Parameters
///
/// * `row_capacity`: The amount of rows send to the database in each chunk. With the exception
/// of the last chunk, which may be smaller.
/// * `schema`: Schema needs to have one column for each positional parameter of the statement
/// and match the data which will be supplied to the instance later. Otherwise your code will
/// panic.
/// * `statement`: A prepared statement whose SQL text representation contains one placeholder
/// for each column. The order of the placeholers must correspond to the orders of the columns
/// in the `schema`.
pub fn new(
row_capacity: usize,
schema: &Schema,
statement: Prepared<S>,
) -> Result<Self, WriterError> {
let strategies: Vec<_> = schema
.fields()
.iter()
.map(|field| field_to_write_strategy(field.as_ref()))
.collect::<Result<_, _>>()?;
let descriptions = strategies.iter().map(|cws| cws.buffer_desc());
let inserter = statement
.into_column_inserter(row_capacity, descriptions)
.map_err(WriterError::BindParameterBuffers)?;
Ok(Self {
inserter,
strategies,
})
}
/// Consumes all the batches in the record batch reader and sends them chunk by chunk to the
/// database.
pub fn write_all(
&mut self,
reader: impl Iterator<Item = Result<RecordBatch, ArrowError>>,
) -> Result<(), WriterError> {
for result in reader {
let record_batch = result.map_err(WriterError::ReadingRecordBatch)?;
self.write_batch(&record_batch)?;
}
self.flush()?;
Ok(())
}
/// Consumes a single batch and sends it chunk by chunk to the database. The last batch may not
/// be consumed until [`Self::flush`] is called.
pub fn write_batch(&mut self, record_batch: &RecordBatch) -> Result<(), WriterError> {
let capacity = self.inserter.capacity();
let mut remanining_rows = record_batch.num_rows();
// The record batch may contain more rows than the capacity of our writer can hold. So we
// need to be able to fill the buffers multiple times and send them to the database in
// between.
while remanining_rows != 0 {
let chunk_size = min(capacity - self.inserter.num_rows(), remanining_rows);
let param_offset = self.inserter.num_rows();
self.inserter.set_num_rows(param_offset + chunk_size);
let chunk = record_batch.slice(record_batch.num_rows() - remanining_rows, chunk_size);
for (index, (array, strategy)) in chunk
.columns()
.iter()
.zip(self.strategies.iter())
.enumerate()
{
strategy.write_rows(param_offset, self.inserter.column_mut(index), array)?
}
// If we used up all capacity we send the parameters to the database and reset the
// parameter buffers.
if self.inserter.num_rows() == capacity {
self.flush()?;
}
remanining_rows -= chunk_size;
}
Ok(())
}
/// The number of row in an individual record batch must not necessarily match the capacity of
/// the buffers owned by this writer. Therfore sometimes records are not send to the database
/// immediatly but rather we wait for the buffers to be filled then reading the next batch. Once
/// we reach the last batch however, there is no "next batch" anymore. In that case we call this
/// method in order to send the remainder of the records to the database as well.
pub fn flush(&mut self) -> Result<(), WriterError> {
self.inserter
.execute()
.map_err(WriterError::ExecuteStatment)?;
self.inserter.clear();
Ok(())
}
}
impl<'env> OdbcWriter<StatementConnection<'env>> {
/// A writer which takes ownership of the connection and inserts the given schema into a table
/// with matching column names.
///
/// **Note:**
///
/// If table or column names are derived from user input, be sure to sanatize the input in order
/// to prevent SQL injection attacks.
pub fn from_connection(
connection: Connection<'env>,
schema: &Schema,
table_name: &str,
row_capacity: usize,
) -> Result<Self, WriterError> {
let sql = insert_statement_from_schema(schema, table_name);
let statement = connection
.into_prepared(&sql)
.map_err(|source| WriterError::PreparingInsertStatement { source, sql })?;
Self::new(row_capacity, schema, statement)
}
}
impl<'o> OdbcWriter<StatementImpl<'o>> {
/// A writer which borrows the connection and inserts the given schema into a table with
/// matching column names.
///
/// **Note:**
///
/// If table or column names are derived from user input, be sure to sanatize the input in order
/// to prevent SQL injection attacks.
pub fn with_connection(
connection: &'o Connection<'o>,
schema: &Schema,
table_name: &str,
row_capacity: usize,
) -> Result<Self, WriterError> {
let sql = insert_statement_from_schema(schema, table_name);
let statement = connection
.prepare(&sql)
.map_err(|source| WriterError::PreparingInsertStatement { source, sql })?;
Self::new(row_capacity, schema, statement)
}
}
pub trait WriteStrategy {
/// Describe the buffer used to hold the array parameters for the column
fn buffer_desc(&self) -> BufferDesc;
/// # Parameters
///
/// * `param_offset`: Start writing parameters at that position. Number of rows in the parameter
/// buffer before inserting the current chunk.
/// * `column_buf`: Buffer to write the data into
/// * `array`: Buffer to read the data from
fn write_rows(
&self,
param_offset: usize,
column_buf: AnySliceMut<'_>,
array: &dyn Array,
) -> Result<(), WriterError>;
}
fn field_to_write_strategy(field: &Field) -> Result<Box<dyn WriteStrategy>, WriterError> {
let is_nullable = field.is_nullable();
let strategy = match field.data_type() {
DataType::Utf8 => Box::new(Utf8ToNativeText {}),
DataType::Boolean => boolean_to_bit(is_nullable),
DataType::LargeUtf8 => Box::new(LargeUtf8ToNativeText {}),
DataType::Int8 => Int8Type::identical(is_nullable),
DataType::Int16 => Int16Type::identical(is_nullable),
DataType::Int32 => Int32Type::identical(is_nullable),
DataType::Int64 => Int64Type::identical(is_nullable),
DataType::UInt8 => UInt8Type::identical(is_nullable),
DataType::Float16 => Float16Type::map_with(is_nullable, |half| half.to_f32()),
DataType::Float32 => Float32Type::identical(is_nullable),
DataType::Float64 => Float64Type::identical(is_nullable),
DataType::Timestamp(TimeUnit::Second, None) => {
TimestampSecondType::map_with(is_nullable, epoch_to_timestamp::<1>)
}
DataType::Timestamp(TimeUnit::Millisecond, None) => {
TimestampMillisecondType::map_with(is_nullable, epoch_to_timestamp::<1_000>)
}
DataType::Timestamp(TimeUnit::Microsecond, None) => {
TimestampMicrosecondType::map_with(is_nullable, epoch_to_timestamp::<1_000_000>)
}
DataType::Timestamp(TimeUnit::Nanosecond, None) => {
TimestampNanosecondType::map_with(is_nullable, |ns| {
// Drop the last to digits of precision, since we bind it with precision 7 and not 9.
epoch_to_timestamp::<10_000_000>(ns / 100)
})
}
DataType::Date32 => Date32Type::map_with(is_nullable, epoch_to_date),
DataType::Date64 => Date64Type::map_with(is_nullable, |days_since_epoch| {
epoch_to_date(days_since_epoch.try_into().unwrap())
}),
DataType::Time32(TimeUnit::Second) => {
Time32SecondType::map_with(is_nullable, sec_since_midnight_to_time)
}
DataType::Time32(TimeUnit::Millisecond) => {
Box::new(NullableTimeAsText::<Time32MillisecondType>::new())
}
DataType::Time64(TimeUnit::Microsecond) => {
Box::new(NullableTimeAsText::<Time64MicrosecondType>::new())
}
DataType::Time64(TimeUnit::Nanosecond) => {
Box::new(NullableTimeAsText::<Time64NanosecondType>::new())
}
DataType::Binary => Box::new(VariadicBinary::new(1)),
DataType::FixedSizeBinary(length) => {
Box::new(VariadicBinary::new((*length).try_into().unwrap()))
}
DataType::Decimal128(precision, scale) => {
Box::new(NullableDecimal128AsText::new(*precision, *scale))
}
DataType::Decimal256(precision, scale) => {
Box::new(NullableDecimal256AsText::new(*precision, *scale))
}
unsupported => return Err(WriterError::UnsupportedArrowDataType(unsupported.clone())),
};
Ok(strategy)
}