1use std::{borrow::Cow, cmp::min, sync::Arc};
2
3use thiserror::Error;
4
5use arrow::{
6 array::Array,
7 datatypes::{
8 DataType, Date32Type, Date64Type, Field, Float16Type, Float32Type, Float64Type, Int8Type,
9 Int16Type, Int32Type, Int64Type, Schema, Time32MillisecondType, Time32SecondType,
10 Time64MicrosecondType, Time64NanosecondType, TimeUnit, UInt8Type,
11 },
12 error::ArrowError,
13 record_batch::{RecordBatch, RecordBatchReader},
14};
15use odbc_api::{
16 BindParamDesc, ColumnarBulkInserter, Connection, ConnectionTransitions, Prepared,
17 buffers::{AnyBuffer, AnySliceMut},
18 handles::{AsStatementRef, StatementConnection, StatementImpl, StatementParent},
19 parameter::WithDataType,
20};
21
22use crate::{
23 date_time::{NullableTimeAsText, epoch_to_date, sec_since_midnight_to_time},
24 decimal::{NullableDecimal128AsText, NullableDecimal256AsText},
25 odbc_writer::timestamp::insert_timestamp_strategy,
26};
27
28use self::{
29 binary::VariadicBinary,
30 boolean::boolean_to_bit,
31 map_arrow_to_odbc::MapArrowToOdbc,
32 text::{LargeUtf8ToNativeText, Utf8ToNativeText},
33};
34
35mod binary;
36mod boolean;
37mod map_arrow_to_odbc;
38mod text;
39mod timestamp;
40
41pub fn insert_into_table(
50 connection: &Connection,
51 batches: &mut impl RecordBatchReader,
52 table_name: &str,
53 batch_size: usize,
54) -> Result<(), WriterError> {
55 let schema = batches.schema();
56 let mut inserter =
57 OdbcWriter::with_connection(connection, schema.as_ref(), table_name, batch_size)?;
58 inserter.write_all(batches)
59}
60
61fn insert_statement_text(table: &str, column_names: &[&'_ str]) -> String {
65 let column_names = column_names
67 .iter()
68 .map(|cn| quote_column_name(cn))
69 .collect::<Vec<_>>();
70 let columns = column_names.join(", ");
71 let values = column_names
72 .iter()
73 .map(|_| "?")
74 .collect::<Vec<_>>()
75 .join(", ");
76 format!("INSERT INTO {table} ({columns}) VALUES ({values})")
80}
81
82fn quote_column_name(column_name: &str) -> Cow<'_, str> {
84 let is_already_quoted = || {
93 (column_name.starts_with('"') && column_name.ends_with('"'))
94 || column_name.starts_with('[') && column_name.ends_with(']')
95 || column_name.starts_with('`') && column_name.ends_with('`')
96 };
97 let contains_invalid_characters = || column_name.contains(|c| !valid_in_column_name(c));
98 let needs_quotes = contains_invalid_characters() && !is_already_quoted();
99 if needs_quotes {
100 Cow::Owned(format!("\"{column_name}\""))
101 } else {
102 Cow::Borrowed(column_name)
103 }
104}
105
106fn valid_in_column_name(c: char) -> bool {
108 c.is_alphanumeric() || c == '@' || c == '$' || c == '#' || c == '_'
111}
112
113pub fn insert_statement_from_schema(schema: &Schema, table_name: &str) -> String {
141 let fields = schema.fields();
142 let num_columns = fields.len();
143 let column_names: Vec<_> = (0..num_columns)
144 .map(|i| fields[i].name().as_str())
145 .collect();
146 insert_statement_text(table_name, &column_names)
147}
148
149#[derive(Debug, Error)]
151pub enum WriterError {
152 #[error("Failure to bind the array parameter buffers to the statement.\n{0}")]
153 BindParameterBuffers(#[source] odbc_api::Error),
154 #[error("Failure to execute the sql statement, sending the data to the database.\n{0}")]
155 ExecuteStatment(#[source] odbc_api::Error),
156 #[error("An error occured rebinding a parameter buffer to the sql statement.\n{0}")]
157 RebindBuffer(#[source] odbc_api::Error),
158 #[error("The arrow data type {0} is not supported for insertion.")]
159 UnsupportedArrowDataType(DataType),
160 #[error("An error occured extracting a record batch from an error reader.\n{0}")]
161 ReadingRecordBatch(#[source] ArrowError),
162 #[error("Unable to parse '{time_zone}' into a valid IANA time zone.")]
163 InvalidTimeZone { time_zone: Arc<str> },
164 #[error("An error occurred preparing SQL statement. SQL:\n{sql}\n{source}")]
165 PreparingInsertStatement {
166 #[source]
167 source: odbc_api::Error,
168 sql: String,
169 },
170}
171
172pub struct OdbcWriter<S> {
174 inserter: ColumnarBulkInserter<S, WithDataType<AnyBuffer>>,
178 strategies: Vec<Box<dyn WriteStrategy>>,
182}
183
184impl<S> OdbcWriter<S>
185where
186 S: AsStatementRef,
187{
188 pub fn new(
209 row_capacity: usize,
210 schema: &Schema,
211 statement: Prepared<S>,
212 ) -> Result<Self, WriterError> {
213 let strategies: Vec<_> = schema
214 .fields()
215 .iter()
216 .map(|field| field_to_write_strategy(field.as_ref()))
217 .collect::<Result<_, _>>()?;
218 let descriptions = strategies.iter().map(|cws| cws.buffer_desc());
219 let inserter = statement
220 .into_column_inserter(row_capacity, descriptions)
221 .map_err(WriterError::BindParameterBuffers)?;
222
223 Ok(Self {
224 inserter,
225 strategies,
226 })
227 }
228
229 pub fn write_all(
232 &mut self,
233 reader: impl Iterator<Item = Result<RecordBatch, ArrowError>>,
234 ) -> Result<(), WriterError> {
235 for result in reader {
236 let record_batch = result.map_err(WriterError::ReadingRecordBatch)?;
237 self.write_batch(&record_batch)?;
238 }
239 self.flush()?;
240 Ok(())
241 }
242
243 pub fn write_batch(&mut self, record_batch: &RecordBatch) -> Result<(), WriterError> {
246 let capacity = self.inserter.capacity();
247 let mut remanining_rows = record_batch.num_rows();
248 while remanining_rows != 0 {
252 let chunk_size = min(capacity - self.inserter.num_rows(), remanining_rows);
253 let param_offset = self.inserter.num_rows();
254 self.inserter.set_num_rows(param_offset + chunk_size);
255 let chunk = record_batch.slice(record_batch.num_rows() - remanining_rows, chunk_size);
256 for (index, (array, strategy)) in chunk
257 .columns()
258 .iter()
259 .zip(self.strategies.iter())
260 .enumerate()
261 {
262 strategy.write_rows(param_offset, self.inserter.column_mut(index), array)?
263 }
264
265 if self.inserter.num_rows() == capacity {
268 self.flush()?;
269 }
270 remanining_rows -= chunk_size;
271 }
272
273 Ok(())
274 }
275
276 pub fn flush(&mut self) -> Result<(), WriterError> {
282 self.inserter
283 .execute()
284 .map_err(WriterError::ExecuteStatment)?;
285 self.inserter.clear();
286 Ok(())
287 }
288}
289
290impl<C> OdbcWriter<StatementConnection<C>>
291where
292 C: StatementParent,
293{
294 pub fn from_connection<C2>(
303 connection: C2,
304 schema: &Schema,
305 table_name: &str,
306 row_capacity: usize,
307 ) -> Result<Self, WriterError>
308 where
309 C2: ConnectionTransitions<StatementParent = C>,
310 {
311 let sql = insert_statement_from_schema(schema, table_name);
312 let statement = connection
313 .into_prepared(&sql)
314 .map_err(|source| WriterError::PreparingInsertStatement { source, sql })?;
315 Self::new(row_capacity, schema, statement)
316 }
317}
318
319impl<'o> OdbcWriter<StatementImpl<'o>> {
320 pub fn with_connection(
329 connection: &'o Connection<'o>,
330 schema: &Schema,
331 table_name: &str,
332 row_capacity: usize,
333 ) -> Result<Self, WriterError> {
334 let sql = insert_statement_from_schema(schema, table_name);
335 let statement = connection
336 .prepare(&sql)
337 .map_err(|source| WriterError::PreparingInsertStatement { source, sql })?;
338 Self::new(row_capacity, schema, statement)
339 }
340}
341
342pub trait WriteStrategy {
343 fn buffer_desc(&self) -> BindParamDesc;
345
346 fn write_rows(
353 &self,
354 param_offset: usize,
355 column_buf: AnySliceMut<'_>,
356 array: &dyn Array,
357 ) -> Result<(), WriterError>;
358}
359
360fn field_to_write_strategy(field: &Field) -> Result<Box<dyn WriteStrategy>, WriterError> {
361 let is_nullable = field.is_nullable();
362 let strategy = match field.data_type() {
363 DataType::Utf8 => Box::new(Utf8ToNativeText {}),
364 DataType::Boolean => boolean_to_bit(is_nullable),
365 DataType::LargeUtf8 => Box::new(LargeUtf8ToNativeText {}),
366 DataType::Int8 => Int8Type::identical(is_nullable),
367 DataType::Int16 => Int16Type::identical(is_nullable),
368 DataType::Int32 => Int32Type::identical(is_nullable),
369 DataType::Int64 => Int64Type::identical(is_nullable),
370 DataType::UInt8 => UInt8Type::identical(is_nullable),
371 DataType::Float16 => Float16Type::map_with(is_nullable, |half| half.to_f32()),
372 DataType::Float32 => Float32Type::identical(is_nullable),
373 DataType::Float64 => Float64Type::identical(is_nullable),
374 DataType::Timestamp(time_unit, time_zone) => {
375 insert_timestamp_strategy(is_nullable, &time_unit, time_zone.clone())?
376 }
377 DataType::Date32 => Date32Type::map_with(is_nullable, epoch_to_date),
378 DataType::Date64 => Date64Type::map_with(is_nullable, |days_since_epoch| {
379 epoch_to_date(days_since_epoch.try_into().unwrap())
380 }),
381 DataType::Time32(TimeUnit::Second) => {
382 Time32SecondType::map_with(is_nullable, sec_since_midnight_to_time)
383 }
384 DataType::Time32(TimeUnit::Millisecond) => {
385 Box::new(NullableTimeAsText::<Time32MillisecondType>::new())
386 }
387 DataType::Time64(TimeUnit::Microsecond) => {
388 Box::new(NullableTimeAsText::<Time64MicrosecondType>::new())
389 }
390 DataType::Time64(TimeUnit::Nanosecond) => {
391 Box::new(NullableTimeAsText::<Time64NanosecondType>::new())
392 }
393 DataType::Binary => Box::new(VariadicBinary::new(1)),
394 DataType::FixedSizeBinary(length) => {
395 Box::new(VariadicBinary::new((*length).try_into().unwrap()))
396 }
397 DataType::Decimal128(precision, scale) => {
398 Box::new(NullableDecimal128AsText::new(*precision, *scale))
399 }
400 DataType::Decimal256(precision, scale) => {
401 Box::new(NullableDecimal256AsText::new(*precision, *scale))
402 }
403 unsupported => return Err(WriterError::UnsupportedArrowDataType(unsupported.clone())),
404 };
405 Ok(strategy)
406}