1use std::{borrow::Cow, cmp::min, sync::Arc};
2
3use thiserror::Error;
4
5use arrow::{
6 array::Array,
7 datatypes::{
8 DataType, Date32Type, Date64Type, Field, Float16Type, Float32Type, Float64Type, Int8Type,
9 Int16Type, Int32Type, Int64Type, Schema, Time32MillisecondType, Time32SecondType,
10 Time64MicrosecondType, Time64NanosecondType, TimeUnit, UInt8Type,
11 },
12 error::ArrowError,
13 record_batch::{RecordBatch, RecordBatchReader},
14};
15use odbc_api::{
16 ColumnarBulkInserter, Connection, Prepared,
17 buffers::{AnyBuffer, AnySliceMut, BufferDesc},
18 handles::{AsStatementRef, StatementConnection, StatementImpl},
19};
20
21use crate::{
22 date_time::{NullableTimeAsText, epoch_to_date, sec_since_midnight_to_time},
23 decimal::{NullableDecimal128AsText, NullableDecimal256AsText},
24 odbc_writer::timestamp::insert_timestamp_strategy,
25};
26
27use self::{
28 binary::VariadicBinary,
29 boolean::boolean_to_bit,
30 map_arrow_to_odbc::MapArrowToOdbc,
31 text::{LargeUtf8ToNativeText, Utf8ToNativeText},
32};
33
34mod binary;
35mod boolean;
36mod map_arrow_to_odbc;
37mod text;
38mod timestamp;
39
40pub fn insert_into_table(
49 connection: &Connection,
50 batches: &mut impl RecordBatchReader,
51 table_name: &str,
52 batch_size: usize,
53) -> Result<(), WriterError> {
54 let schema = batches.schema();
55 let mut inserter =
56 OdbcWriter::with_connection(connection, schema.as_ref(), table_name, batch_size)?;
57 inserter.write_all(batches)
58}
59
60fn insert_statement_text(table: &str, column_names: &[&'_ str]) -> String {
64 let column_names = column_names
66 .iter()
67 .map(|cn| quote_column_name(cn))
68 .collect::<Vec<_>>();
69 let columns = column_names.join(", ");
70 let values = column_names
71 .iter()
72 .map(|_| "?")
73 .collect::<Vec<_>>()
74 .join(", ");
75 format!("INSERT INTO {table} ({columns}) VALUES ({values})")
79}
80
81fn quote_column_name(column_name: &str) -> Cow<'_, str> {
83 if column_name.contains(|c| !valid_in_column_name(c)) {
84 Cow::Owned(format!("\"{column_name}\""))
85 } else {
86 Cow::Borrowed(column_name)
87 }
88}
89
90fn valid_in_column_name(c: char) -> bool {
92 c.is_alphanumeric() || c == '@' || c == '$' || c == '#' || c == '_'
95}
96
97pub fn insert_statement_from_schema(schema: &Schema, table_name: &str) -> String {
125 let fields = schema.fields();
126 let num_columns = fields.len();
127 let column_names: Vec<_> = (0..num_columns)
128 .map(|i| fields[i].name().as_str())
129 .collect();
130 insert_statement_text(table_name, &column_names)
131}
132
133#[derive(Debug, Error)]
135pub enum WriterError {
136 #[error("Failure to bind the array parameter buffers to the statement.\n{0}")]
137 BindParameterBuffers(#[source] odbc_api::Error),
138 #[error("Failure to execute the sql statement, sending the data to the database.\n{0}")]
139 ExecuteStatment(#[source] odbc_api::Error),
140 #[error("An error occured rebinding a parameter buffer to the sql statement.\n{0}")]
141 RebindBuffer(#[source] odbc_api::Error),
142 #[error("The arrow data type {0} is not supported for insertion.")]
143 UnsupportedArrowDataType(DataType),
144 #[error("An error occured extracting a record batch from an error reader.\n{0}")]
145 ReadingRecordBatch(#[source] ArrowError),
146 #[error("Unable to parse '{time_zone}' into a valid IANA time zone.")]
147 InvalidTimeZone { time_zone: Arc<str> },
148 #[error("An error occurred preparing SQL statement. SQL:\n{sql}\n{source}")]
149 PreparingInsertStatement {
150 #[source]
151 source: odbc_api::Error,
152 sql: String,
153 },
154}
155
156pub struct OdbcWriter<S> {
158 inserter: ColumnarBulkInserter<S, AnyBuffer>,
162 strategies: Vec<Box<dyn WriteStrategy>>,
166}
167
168impl<S> OdbcWriter<S>
169where
170 S: AsStatementRef,
171{
172 pub fn new(
193 row_capacity: usize,
194 schema: &Schema,
195 statement: Prepared<S>,
196 ) -> Result<Self, WriterError> {
197 let strategies: Vec<_> = schema
198 .fields()
199 .iter()
200 .map(|field| field_to_write_strategy(field.as_ref()))
201 .collect::<Result<_, _>>()?;
202 let descriptions = strategies.iter().map(|cws| cws.buffer_desc());
203 let inserter = statement
204 .into_column_inserter(row_capacity, descriptions)
205 .map_err(WriterError::BindParameterBuffers)?;
206
207 Ok(Self {
208 inserter,
209 strategies,
210 })
211 }
212
213 pub fn write_all(
216 &mut self,
217 reader: impl Iterator<Item = Result<RecordBatch, ArrowError>>,
218 ) -> Result<(), WriterError> {
219 for result in reader {
220 let record_batch = result.map_err(WriterError::ReadingRecordBatch)?;
221 self.write_batch(&record_batch)?;
222 }
223 self.flush()?;
224 Ok(())
225 }
226
227 pub fn write_batch(&mut self, record_batch: &RecordBatch) -> Result<(), WriterError> {
230 let capacity = self.inserter.capacity();
231 let mut remanining_rows = record_batch.num_rows();
232 while remanining_rows != 0 {
236 let chunk_size = min(capacity - self.inserter.num_rows(), remanining_rows);
237 let param_offset = self.inserter.num_rows();
238 self.inserter.set_num_rows(param_offset + chunk_size);
239 let chunk = record_batch.slice(record_batch.num_rows() - remanining_rows, chunk_size);
240 for (index, (array, strategy)) in chunk
241 .columns()
242 .iter()
243 .zip(self.strategies.iter())
244 .enumerate()
245 {
246 strategy.write_rows(param_offset, self.inserter.column_mut(index), array)?
247 }
248
249 if self.inserter.num_rows() == capacity {
252 self.flush()?;
253 }
254 remanining_rows -= chunk_size;
255 }
256
257 Ok(())
258 }
259
260 pub fn flush(&mut self) -> Result<(), WriterError> {
266 self.inserter
267 .execute()
268 .map_err(WriterError::ExecuteStatment)?;
269 self.inserter.clear();
270 Ok(())
271 }
272}
273
274impl<'env> OdbcWriter<StatementConnection<Connection<'env>>> {
275 pub fn from_connection(
284 connection: Connection<'env>,
285 schema: &Schema,
286 table_name: &str,
287 row_capacity: usize,
288 ) -> Result<Self, WriterError> {
289 let sql = insert_statement_from_schema(schema, table_name);
290 let statement = connection
291 .into_prepared(&sql)
292 .map_err(|source| WriterError::PreparingInsertStatement { source, sql })?;
293 Self::new(row_capacity, schema, statement)
294 }
295}
296
297impl<'o> OdbcWriter<StatementImpl<'o>> {
298 pub fn with_connection(
307 connection: &'o Connection<'o>,
308 schema: &Schema,
309 table_name: &str,
310 row_capacity: usize,
311 ) -> Result<Self, WriterError> {
312 let sql = insert_statement_from_schema(schema, table_name);
313 let statement = connection
314 .prepare(&sql)
315 .map_err(|source| WriterError::PreparingInsertStatement { source, sql })?;
316 Self::new(row_capacity, schema, statement)
317 }
318}
319
320pub trait WriteStrategy {
321 fn buffer_desc(&self) -> BufferDesc;
323
324 fn write_rows(
331 &self,
332 param_offset: usize,
333 column_buf: AnySliceMut<'_>,
334 array: &dyn Array,
335 ) -> Result<(), WriterError>;
336}
337
338fn field_to_write_strategy(field: &Field) -> Result<Box<dyn WriteStrategy>, WriterError> {
339 let is_nullable = field.is_nullable();
340 let strategy = match field.data_type() {
341 DataType::Utf8 => Box::new(Utf8ToNativeText {}),
342 DataType::Boolean => boolean_to_bit(is_nullable),
343 DataType::LargeUtf8 => Box::new(LargeUtf8ToNativeText {}),
344 DataType::Int8 => Int8Type::identical(is_nullable),
345 DataType::Int16 => Int16Type::identical(is_nullable),
346 DataType::Int32 => Int32Type::identical(is_nullable),
347 DataType::Int64 => Int64Type::identical(is_nullable),
348 DataType::UInt8 => UInt8Type::identical(is_nullable),
349 DataType::Float16 => Float16Type::map_with(is_nullable, |half| half.to_f32()),
350 DataType::Float32 => Float32Type::identical(is_nullable),
351 DataType::Float64 => Float64Type::identical(is_nullable),
352 DataType::Timestamp(time_unit, time_zone) => {
353 insert_timestamp_strategy(is_nullable, &time_unit, time_zone.clone())?
354 }
355 DataType::Date32 => Date32Type::map_with(is_nullable, epoch_to_date),
356 DataType::Date64 => Date64Type::map_with(is_nullable, |days_since_epoch| {
357 epoch_to_date(days_since_epoch.try_into().unwrap())
358 }),
359 DataType::Time32(TimeUnit::Second) => {
360 Time32SecondType::map_with(is_nullable, sec_since_midnight_to_time)
361 }
362 DataType::Time32(TimeUnit::Millisecond) => {
363 Box::new(NullableTimeAsText::<Time32MillisecondType>::new())
364 }
365 DataType::Time64(TimeUnit::Microsecond) => {
366 Box::new(NullableTimeAsText::<Time64MicrosecondType>::new())
367 }
368 DataType::Time64(TimeUnit::Nanosecond) => {
369 Box::new(NullableTimeAsText::<Time64NanosecondType>::new())
370 }
371 DataType::Binary => Box::new(VariadicBinary::new(1)),
372 DataType::FixedSizeBinary(length) => {
373 Box::new(VariadicBinary::new((*length).try_into().unwrap()))
374 }
375 DataType::Decimal128(precision, scale) => {
376 Box::new(NullableDecimal128AsText::new(*precision, *scale))
377 }
378 DataType::Decimal256(precision, scale) => {
379 Box::new(NullableDecimal256AsText::new(*precision, *scale))
380 }
381 unsupported => return Err(WriterError::UnsupportedArrowDataType(unsupported.clone())),
382 };
383 Ok(strategy)
384}