1use std::{borrow::Cow, cmp::min};
2
3use thiserror::Error;
4
5use arrow::{
6 array::Array,
7 datatypes::{
8 DataType, Date32Type, Date64Type, Field, Float16Type, Float32Type, Float64Type, Int8Type,
9 Int16Type, Int32Type, Int64Type, Schema, Time32MillisecondType, Time32SecondType,
10 Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
11 TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt8Type,
12 },
13 error::ArrowError,
14 record_batch::{RecordBatch, RecordBatchReader},
15};
16use odbc_api::{
17 ColumnarBulkInserter, Connection, Prepared, StatementConnection,
18 buffers::{AnyBuffer, AnySliceMut, BufferDesc},
19 handles::{AsStatementRef, StatementImpl},
20};
21
22use crate::{
23 date_time::{
24 NullableTimeAsText, epoch_to_date, epoch_to_timestamp_ms, epoch_to_timestamp_ns,
25 epoch_to_timestamp_s, epoch_to_timestamp_us, sec_since_midnight_to_time,
26 },
27 decimal::{NullableDecimal128AsText, NullableDecimal256AsText},
28};
29
30use self::{
31 binary::VariadicBinary,
32 boolean::boolean_to_bit,
33 map_arrow_to_odbc::MapArrowToOdbc,
34 text::{LargeUtf8ToNativeText, Utf8ToNativeText},
35};
36
37mod binary;
38mod boolean;
39mod map_arrow_to_odbc;
40mod text;
41
42pub fn insert_into_table(
51 connection: &Connection,
52 batches: &mut impl RecordBatchReader,
53 table_name: &str,
54 batch_size: usize,
55) -> Result<(), WriterError> {
56 let schema = batches.schema();
57 let mut inserter =
58 OdbcWriter::with_connection(connection, schema.as_ref(), table_name, batch_size)?;
59 inserter.write_all(batches)
60}
61
62fn insert_statement_text(table: &str, column_names: &[&'_ str]) -> String {
66 let column_names = column_names
68 .iter()
69 .map(|cn| quote_column_name(cn))
70 .collect::<Vec<_>>();
71 let columns = column_names.join(", ");
72 let values = column_names
73 .iter()
74 .map(|_| "?")
75 .collect::<Vec<_>>()
76 .join(", ");
77 format!("INSERT INTO {table} ({columns}) VALUES ({values})")
81}
82
83fn quote_column_name(column_name: &str) -> Cow<'_, str> {
85 if column_name.contains(|c| !valid_in_column_name(c)) {
86 Cow::Owned(format!("\"{column_name}\""))
87 } else {
88 Cow::Borrowed(column_name)
89 }
90}
91
92fn valid_in_column_name(c: char) -> bool {
94 c.is_alphanumeric() || c == '@' || c == '$' || c == '#' || c == '_'
97}
98
99pub fn insert_statement_from_schema(schema: &Schema, table_name: &str) -> String {
127 let fields = schema.fields();
128 let num_columns = fields.len();
129 let column_names: Vec<_> = (0..num_columns)
130 .map(|i| fields[i].name().as_str())
131 .collect();
132 insert_statement_text(table_name, &column_names)
133}
134
135#[derive(Debug, Error)]
137pub enum WriterError {
138 #[error("Failure to bind the array parameter buffers to the statement.\n{0}")]
139 BindParameterBuffers(#[source] odbc_api::Error),
140 #[error("Failure to execute the sql statement, sending the data to the database.\n{0}")]
141 ExecuteStatment(#[source] odbc_api::Error),
142 #[error("An error occured rebinding a parameter buffer to the sql statement.\n{0}")]
143 RebindBuffer(#[source] odbc_api::Error),
144 #[error("The arrow data type {0} is not supported for insertion.")]
145 UnsupportedArrowDataType(DataType),
146 #[error("An error occured extracting a record batch from an error reader.\n{0}")]
147 ReadingRecordBatch(#[source] ArrowError),
148 #[error("An error occurred preparing SQL statement. SQL:\n{sql}\n{source}")]
149 PreparingInsertStatement {
150 #[source]
151 source: odbc_api::Error,
152 sql: String,
153 },
154}
155
156pub struct OdbcWriter<S> {
158 inserter: ColumnarBulkInserter<S, AnyBuffer>,
162 strategies: Vec<Box<dyn WriteStrategy>>,
166}
167
168impl<S> OdbcWriter<S>
169where
170 S: AsStatementRef,
171{
172 pub fn new(
193 row_capacity: usize,
194 schema: &Schema,
195 statement: Prepared<S>,
196 ) -> Result<Self, WriterError> {
197 let strategies: Vec<_> = schema
198 .fields()
199 .iter()
200 .map(|field| field_to_write_strategy(field.as_ref()))
201 .collect::<Result<_, _>>()?;
202 let descriptions = strategies.iter().map(|cws| cws.buffer_desc());
203 let inserter = statement
204 .into_column_inserter(row_capacity, descriptions)
205 .map_err(WriterError::BindParameterBuffers)?;
206
207 Ok(Self {
208 inserter,
209 strategies,
210 })
211 }
212
213 pub fn write_all(
216 &mut self,
217 reader: impl Iterator<Item = Result<RecordBatch, ArrowError>>,
218 ) -> Result<(), WriterError> {
219 for result in reader {
220 let record_batch = result.map_err(WriterError::ReadingRecordBatch)?;
221 self.write_batch(&record_batch)?;
222 }
223 self.flush()?;
224 Ok(())
225 }
226
227 pub fn write_batch(&mut self, record_batch: &RecordBatch) -> Result<(), WriterError> {
230 let capacity = self.inserter.capacity();
231 let mut remanining_rows = record_batch.num_rows();
232 while remanining_rows != 0 {
236 let chunk_size = min(capacity - self.inserter.num_rows(), remanining_rows);
237 let param_offset = self.inserter.num_rows();
238 self.inserter.set_num_rows(param_offset + chunk_size);
239 let chunk = record_batch.slice(record_batch.num_rows() - remanining_rows, chunk_size);
240 for (index, (array, strategy)) in chunk
241 .columns()
242 .iter()
243 .zip(self.strategies.iter())
244 .enumerate()
245 {
246 strategy.write_rows(param_offset, self.inserter.column_mut(index), array)?
247 }
248
249 if self.inserter.num_rows() == capacity {
252 self.flush()?;
253 }
254 remanining_rows -= chunk_size;
255 }
256
257 Ok(())
258 }
259
260 pub fn flush(&mut self) -> Result<(), WriterError> {
266 self.inserter
267 .execute()
268 .map_err(WriterError::ExecuteStatment)?;
269 self.inserter.clear();
270 Ok(())
271 }
272}
273
274impl<'env> OdbcWriter<StatementConnection<'env>> {
275 pub fn from_connection(
284 connection: Connection<'env>,
285 schema: &Schema,
286 table_name: &str,
287 row_capacity: usize,
288 ) -> Result<Self, WriterError> {
289 let sql = insert_statement_from_schema(schema, table_name);
290 let statement = connection
291 .into_prepared(&sql)
292 .map_err(|source| WriterError::PreparingInsertStatement { source, sql })?;
293 Self::new(row_capacity, schema, statement)
294 }
295}
296
297impl<'o> OdbcWriter<StatementImpl<'o>> {
298 pub fn with_connection(
307 connection: &'o Connection<'o>,
308 schema: &Schema,
309 table_name: &str,
310 row_capacity: usize,
311 ) -> Result<Self, WriterError> {
312 let sql = insert_statement_from_schema(schema, table_name);
313 let statement = connection
314 .prepare(&sql)
315 .map_err(|source| WriterError::PreparingInsertStatement { source, sql })?;
316 Self::new(row_capacity, schema, statement)
317 }
318}
319
320pub trait WriteStrategy {
321 fn buffer_desc(&self) -> BufferDesc;
323
324 fn write_rows(
331 &self,
332 param_offset: usize,
333 column_buf: AnySliceMut<'_>,
334 array: &dyn Array,
335 ) -> Result<(), WriterError>;
336}
337
338fn field_to_write_strategy(field: &Field) -> Result<Box<dyn WriteStrategy>, WriterError> {
339 let is_nullable = field.is_nullable();
340 let strategy = match field.data_type() {
341 DataType::Utf8 => Box::new(Utf8ToNativeText {}),
342 DataType::Boolean => boolean_to_bit(is_nullable),
343 DataType::LargeUtf8 => Box::new(LargeUtf8ToNativeText {}),
344 DataType::Int8 => Int8Type::identical(is_nullable),
345 DataType::Int16 => Int16Type::identical(is_nullable),
346 DataType::Int32 => Int32Type::identical(is_nullable),
347 DataType::Int64 => Int64Type::identical(is_nullable),
348 DataType::UInt8 => UInt8Type::identical(is_nullable),
349 DataType::Float16 => Float16Type::map_with(is_nullable, |half| half.to_f32()),
350 DataType::Float32 => Float32Type::identical(is_nullable),
351 DataType::Float64 => Float64Type::identical(is_nullable),
352 DataType::Timestamp(TimeUnit::Second, None) => {
353 TimestampSecondType::map_with(is_nullable, epoch_to_timestamp_s)
354 }
355 DataType::Timestamp(TimeUnit::Millisecond, None) => {
356 TimestampMillisecondType::map_with(is_nullable, epoch_to_timestamp_ms)
357 }
358 DataType::Timestamp(TimeUnit::Microsecond, None) => {
359 TimestampMicrosecondType::map_with(is_nullable, epoch_to_timestamp_us)
360 }
361 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
362 TimestampNanosecondType::map_with(is_nullable, |ns| {
363 epoch_to_timestamp_ns((ns / 100) * 100)
365 })
366 }
367 DataType::Date32 => Date32Type::map_with(is_nullable, epoch_to_date),
368 DataType::Date64 => Date64Type::map_with(is_nullable, |days_since_epoch| {
369 epoch_to_date(days_since_epoch.try_into().unwrap())
370 }),
371 DataType::Time32(TimeUnit::Second) => {
372 Time32SecondType::map_with(is_nullable, sec_since_midnight_to_time)
373 }
374 DataType::Time32(TimeUnit::Millisecond) => {
375 Box::new(NullableTimeAsText::<Time32MillisecondType>::new())
376 }
377 DataType::Time64(TimeUnit::Microsecond) => {
378 Box::new(NullableTimeAsText::<Time64MicrosecondType>::new())
379 }
380 DataType::Time64(TimeUnit::Nanosecond) => {
381 Box::new(NullableTimeAsText::<Time64NanosecondType>::new())
382 }
383 DataType::Binary => Box::new(VariadicBinary::new(1)),
384 DataType::FixedSizeBinary(length) => {
385 Box::new(VariadicBinary::new((*length).try_into().unwrap()))
386 }
387 DataType::Decimal128(precision, scale) => {
388 Box::new(NullableDecimal128AsText::new(*precision, *scale))
389 }
390 DataType::Decimal256(precision, scale) => {
391 Box::new(NullableDecimal256AsText::new(*precision, *scale))
392 }
393 unsupported => return Err(WriterError::UnsupportedArrowDataType(unsupported.clone())),
394 };
395 Ok(strategy)
396}