1use std::{borrow::Cow, cmp::min, sync::Arc};
2
3use thiserror::Error;
4
5use arrow::{
6 array::Array,
7 datatypes::{
8 DataType, Date32Type, Date64Type, Field, Float16Type, Float32Type, Float64Type, Int8Type,
9 Int16Type, Int32Type, Int64Type, Schema, Time32MillisecondType, Time32SecondType,
10 Time64MicrosecondType, Time64NanosecondType, TimeUnit, UInt8Type,
11 },
12 error::ArrowError,
13 record_batch::{RecordBatch, RecordBatchReader},
14};
15use odbc_api::{
16 ColumnarBulkInserter, Connection, ConnectionTransitions, Prepared,
17 buffers::{AnyBuffer, AnySliceMut, BufferDesc},
18 handles::{AsStatementRef, StatementConnection, StatementImpl, StatementParent},
19};
20
21use crate::{
22 date_time::{NullableTimeAsText, epoch_to_date, sec_since_midnight_to_time},
23 decimal::{NullableDecimal128AsText, NullableDecimal256AsText},
24 odbc_writer::timestamp::insert_timestamp_strategy,
25};
26
27use self::{
28 binary::VariadicBinary,
29 boolean::boolean_to_bit,
30 map_arrow_to_odbc::MapArrowToOdbc,
31 text::{LargeUtf8ToNativeText, Utf8ToNativeText},
32};
33
34mod binary;
35mod boolean;
36mod map_arrow_to_odbc;
37mod text;
38mod timestamp;
39
40pub fn insert_into_table(
49 connection: &Connection,
50 batches: &mut impl RecordBatchReader,
51 table_name: &str,
52 batch_size: usize,
53) -> Result<(), WriterError> {
54 let schema = batches.schema();
55 let mut inserter =
56 OdbcWriter::with_connection(connection, schema.as_ref(), table_name, batch_size)?;
57 inserter.write_all(batches)
58}
59
60fn insert_statement_text(table: &str, column_names: &[&'_ str]) -> String {
64 let column_names = column_names
66 .iter()
67 .map(|cn| quote_column_name(cn))
68 .collect::<Vec<_>>();
69 let columns = column_names.join(", ");
70 let values = column_names
71 .iter()
72 .map(|_| "?")
73 .collect::<Vec<_>>()
74 .join(", ");
75 format!("INSERT INTO {table} ({columns}) VALUES ({values})")
79}
80
81fn quote_column_name(column_name: &str) -> Cow<'_, str> {
83 let is_already_quoted = || {
92 (column_name.starts_with('"') && column_name.ends_with('"'))
93 || column_name.starts_with('[') && column_name.ends_with(']')
94 || column_name.starts_with('`') && column_name.ends_with('`')
95 };
96 let contains_invalid_characters = || column_name.contains(|c| !valid_in_column_name(c));
97 let needs_quotes = contains_invalid_characters() && !is_already_quoted();
98 if needs_quotes {
99 Cow::Owned(format!("\"{column_name}\""))
100 } else {
101 Cow::Borrowed(column_name)
102 }
103}
104
105fn valid_in_column_name(c: char) -> bool {
107 c.is_alphanumeric() || c == '@' || c == '$' || c == '#' || c == '_'
110}
111
112pub fn insert_statement_from_schema(schema: &Schema, table_name: &str) -> String {
140 let fields = schema.fields();
141 let num_columns = fields.len();
142 let column_names: Vec<_> = (0..num_columns)
143 .map(|i| fields[i].name().as_str())
144 .collect();
145 insert_statement_text(table_name, &column_names)
146}
147
148#[derive(Debug, Error)]
150pub enum WriterError {
151 #[error("Failure to bind the array parameter buffers to the statement.\n{0}")]
152 BindParameterBuffers(#[source] odbc_api::Error),
153 #[error("Failure to execute the sql statement, sending the data to the database.\n{0}")]
154 ExecuteStatment(#[source] odbc_api::Error),
155 #[error("An error occured rebinding a parameter buffer to the sql statement.\n{0}")]
156 RebindBuffer(#[source] odbc_api::Error),
157 #[error("The arrow data type {0} is not supported for insertion.")]
158 UnsupportedArrowDataType(DataType),
159 #[error("An error occured extracting a record batch from an error reader.\n{0}")]
160 ReadingRecordBatch(#[source] ArrowError),
161 #[error("Unable to parse '{time_zone}' into a valid IANA time zone.")]
162 InvalidTimeZone { time_zone: Arc<str> },
163 #[error("An error occurred preparing SQL statement. SQL:\n{sql}\n{source}")]
164 PreparingInsertStatement {
165 #[source]
166 source: odbc_api::Error,
167 sql: String,
168 },
169}
170
171pub struct OdbcWriter<S> {
173 inserter: ColumnarBulkInserter<S, AnyBuffer>,
177 strategies: Vec<Box<dyn WriteStrategy>>,
181}
182
183impl<S> OdbcWriter<S>
184where
185 S: AsStatementRef,
186{
187 pub fn new(
208 row_capacity: usize,
209 schema: &Schema,
210 statement: Prepared<S>,
211 ) -> Result<Self, WriterError> {
212 let strategies: Vec<_> = schema
213 .fields()
214 .iter()
215 .map(|field| field_to_write_strategy(field.as_ref()))
216 .collect::<Result<_, _>>()?;
217 let descriptions = strategies.iter().map(|cws| cws.buffer_desc());
218 let inserter = statement
219 .into_column_inserter(row_capacity, descriptions)
220 .map_err(WriterError::BindParameterBuffers)?;
221
222 Ok(Self {
223 inserter,
224 strategies,
225 })
226 }
227
228 pub fn write_all(
231 &mut self,
232 reader: impl Iterator<Item = Result<RecordBatch, ArrowError>>,
233 ) -> Result<(), WriterError> {
234 for result in reader {
235 let record_batch = result.map_err(WriterError::ReadingRecordBatch)?;
236 self.write_batch(&record_batch)?;
237 }
238 self.flush()?;
239 Ok(())
240 }
241
242 pub fn write_batch(&mut self, record_batch: &RecordBatch) -> Result<(), WriterError> {
245 let capacity = self.inserter.capacity();
246 let mut remanining_rows = record_batch.num_rows();
247 while remanining_rows != 0 {
251 let chunk_size = min(capacity - self.inserter.num_rows(), remanining_rows);
252 let param_offset = self.inserter.num_rows();
253 self.inserter.set_num_rows(param_offset + chunk_size);
254 let chunk = record_batch.slice(record_batch.num_rows() - remanining_rows, chunk_size);
255 for (index, (array, strategy)) in chunk
256 .columns()
257 .iter()
258 .zip(self.strategies.iter())
259 .enumerate()
260 {
261 strategy.write_rows(param_offset, self.inserter.column_mut(index), array)?
262 }
263
264 if self.inserter.num_rows() == capacity {
267 self.flush()?;
268 }
269 remanining_rows -= chunk_size;
270 }
271
272 Ok(())
273 }
274
275 pub fn flush(&mut self) -> Result<(), WriterError> {
281 self.inserter
282 .execute()
283 .map_err(WriterError::ExecuteStatment)?;
284 self.inserter.clear();
285 Ok(())
286 }
287}
288
289impl<C> OdbcWriter<StatementConnection<C>>
290where
291 C: StatementParent,
292{
293 pub fn from_connection<C2>(
302 connection: C2,
303 schema: &Schema,
304 table_name: &str,
305 row_capacity: usize,
306 ) -> Result<Self, WriterError>
307 where
308 C2: ConnectionTransitions<StatementParent = C>,
309 {
310 let sql = insert_statement_from_schema(schema, table_name);
311 let statement = connection
312 .into_prepared(&sql)
313 .map_err(|source| WriterError::PreparingInsertStatement { source, sql })?;
314 Self::new(row_capacity, schema, statement)
315 }
316}
317
318impl<'o> OdbcWriter<StatementImpl<'o>> {
319 pub fn with_connection(
328 connection: &'o Connection<'o>,
329 schema: &Schema,
330 table_name: &str,
331 row_capacity: usize,
332 ) -> Result<Self, WriterError> {
333 let sql = insert_statement_from_schema(schema, table_name);
334 let statement = connection
335 .prepare(&sql)
336 .map_err(|source| WriterError::PreparingInsertStatement { source, sql })?;
337 Self::new(row_capacity, schema, statement)
338 }
339}
340
341pub trait WriteStrategy {
342 fn buffer_desc(&self) -> BufferDesc;
344
345 fn write_rows(
352 &self,
353 param_offset: usize,
354 column_buf: AnySliceMut<'_>,
355 array: &dyn Array,
356 ) -> Result<(), WriterError>;
357}
358
359fn field_to_write_strategy(field: &Field) -> Result<Box<dyn WriteStrategy>, WriterError> {
360 let is_nullable = field.is_nullable();
361 let strategy = match field.data_type() {
362 DataType::Utf8 => Box::new(Utf8ToNativeText {}),
363 DataType::Boolean => boolean_to_bit(is_nullable),
364 DataType::LargeUtf8 => Box::new(LargeUtf8ToNativeText {}),
365 DataType::Int8 => Int8Type::identical(is_nullable),
366 DataType::Int16 => Int16Type::identical(is_nullable),
367 DataType::Int32 => Int32Type::identical(is_nullable),
368 DataType::Int64 => Int64Type::identical(is_nullable),
369 DataType::UInt8 => UInt8Type::identical(is_nullable),
370 DataType::Float16 => Float16Type::map_with(is_nullable, |half| half.to_f32()),
371 DataType::Float32 => Float32Type::identical(is_nullable),
372 DataType::Float64 => Float64Type::identical(is_nullable),
373 DataType::Timestamp(time_unit, time_zone) => {
374 insert_timestamp_strategy(is_nullable, &time_unit, time_zone.clone())?
375 }
376 DataType::Date32 => Date32Type::map_with(is_nullable, epoch_to_date),
377 DataType::Date64 => Date64Type::map_with(is_nullable, |days_since_epoch| {
378 epoch_to_date(days_since_epoch.try_into().unwrap())
379 }),
380 DataType::Time32(TimeUnit::Second) => {
381 Time32SecondType::map_with(is_nullable, sec_since_midnight_to_time)
382 }
383 DataType::Time32(TimeUnit::Millisecond) => {
384 Box::new(NullableTimeAsText::<Time32MillisecondType>::new())
385 }
386 DataType::Time64(TimeUnit::Microsecond) => {
387 Box::new(NullableTimeAsText::<Time64MicrosecondType>::new())
388 }
389 DataType::Time64(TimeUnit::Nanosecond) => {
390 Box::new(NullableTimeAsText::<Time64NanosecondType>::new())
391 }
392 DataType::Binary => Box::new(VariadicBinary::new(1)),
393 DataType::FixedSizeBinary(length) => {
394 Box::new(VariadicBinary::new((*length).try_into().unwrap()))
395 }
396 DataType::Decimal128(precision, scale) => {
397 Box::new(NullableDecimal128AsText::new(*precision, *scale))
398 }
399 DataType::Decimal256(precision, scale) => {
400 Box::new(NullableDecimal256AsText::new(*precision, *scale))
401 }
402 unsupported => return Err(WriterError::UnsupportedArrowDataType(unsupported.clone())),
403 };
404 Ok(strategy)
405}