arrow_odbc/
odbc_writer.rs

1use std::{borrow::Cow, cmp::min};
2
3use thiserror::Error;
4
5use arrow::{
6    array::Array,
7    datatypes::{
8        DataType, Date32Type, Date64Type, Field, Float16Type, Float32Type, Float64Type, Int8Type,
9        Int16Type, Int32Type, Int64Type, Schema, Time32MillisecondType, Time32SecondType,
10        Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
11        TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt8Type,
12    },
13    error::ArrowError,
14    record_batch::{RecordBatch, RecordBatchReader},
15};
16use odbc_api::{
17    ColumnarBulkInserter, Connection, Prepared, StatementConnection,
18    buffers::{AnyBuffer, AnySliceMut, BufferDesc},
19    handles::{AsStatementRef, StatementImpl},
20};
21
22use crate::{
23    date_time::{
24        NullableTimeAsText, epoch_to_date, epoch_to_timestamp_ms, epoch_to_timestamp_ns,
25        epoch_to_timestamp_s, epoch_to_timestamp_us, sec_since_midnight_to_time,
26    },
27    decimal::{NullableDecimal128AsText, NullableDecimal256AsText},
28};
29
30use self::{
31    binary::VariadicBinary,
32    boolean::boolean_to_bit,
33    map_arrow_to_odbc::MapArrowToOdbc,
34    text::{LargeUtf8ToNativeText, Utf8ToNativeText},
35};
36
37mod binary;
38mod boolean;
39mod map_arrow_to_odbc;
40mod text;
41
42/// Fastest and most convinient way to stream the contents of arrow record batches into a database
43/// table. For usecase there you want to insert repeatedly into the same table from different
44/// streams it is more efficient to create an instance of [`self::OdbcWriter`] and reuse it.
45///
46/// **Note:**
47///
48/// If table or column names are derived from user input, be sure to sanatize the input in order to
49/// prevent SQL injection attacks.
50pub fn insert_into_table(
51    connection: &Connection,
52    batches: &mut impl RecordBatchReader,
53    table_name: &str,
54    batch_size: usize,
55) -> Result<(), WriterError> {
56    let schema = batches.schema();
57    let mut inserter =
58        OdbcWriter::with_connection(connection, schema.as_ref(), table_name, batch_size)?;
59    inserter.write_all(batches)
60}
61
62/// Generates an insert statement using the table and column names.
63///
64/// `INSERT INTO <table> (<column_names 0>, <column_names 1>, ...) VALUES (?, ?, ...)`
65fn insert_statement_text(table: &str, column_names: &[&'_ str]) -> String {
66    // Generate statement text from table name and headline
67    let column_names = column_names
68        .iter()
69        .map(|cn| quote_column_name(cn))
70        .collect::<Vec<_>>();
71    let columns = column_names.join(", ");
72    let values = column_names
73        .iter()
74        .map(|_| "?")
75        .collect::<Vec<_>>()
76        .join(", ");
77    // Do not finish the statement with a semicolon. There is anecodtical evidence of IBM db2 not
78    // allowing the command, because it expects now multiple statements.
79    // See: <https://github.com/pacman82/arrow-odbc/issues/63>
80    format!("INSERT INTO {table} ({columns}) VALUES ({values})")
81}
82
83/// Wraps column name in quotes, if need be
84fn quote_column_name(column_name: &str) -> Cow<'_, str> {
85    if column_name.contains(|c| !valid_in_column_name(c)) {
86        Cow::Owned(format!("\"{column_name}\""))
87    } else {
88        Cow::Borrowed(column_name)
89    }
90}
91
92/// Check if this character is allowed in an unquoted column name
93fn valid_in_column_name(c: char) -> bool {
94    // See:
95    // <https://stackoverflow.com/questions/4200351/what-characters-are-valid-in-an-sql-server-database-name>
96    c.is_alphanumeric() || c == '@' || c == '$' || c == '#' || c == '_'
97}
98
99/// Creates an SQL insert statement from an arrow schema. The resulting statement will have one
100/// placeholer (`?`) for each column in the statement.
101///
102/// **Note:**
103///
104/// If the column name contains any character which would make it not a valid qualifier for transact
105/// SQL it will be wrapped in double quotes (`"`) within the insert schema. Valid names consist of
106/// alpha numeric characters, `@`, `$`, `#` and `_`.
107///
108/// # Example
109///
110/// ```
111/// use arrow_odbc::{
112///     insert_statement_from_schema,
113///     arrow::datatypes::{Field, DataType, Schema},
114/// };
115///
116/// let field_a = Field::new("a", DataType::Int64, false);
117/// let field_b = Field::new("b", DataType::Boolean, false);
118///
119/// let schema = Schema::new(vec![field_a, field_b]);
120/// let sql = insert_statement_from_schema(&schema, "MyTable");
121///
122/// assert_eq!("INSERT INTO MyTable (a, b) VALUES (?, ?)", sql)
123/// ```
124///
125/// This function is automatically invoked by [`crate::OdbcWriter::with_connection`].
126pub fn insert_statement_from_schema(schema: &Schema, table_name: &str) -> String {
127    let fields = schema.fields();
128    let num_columns = fields.len();
129    let column_names: Vec<_> = (0..num_columns)
130        .map(|i| fields[i].name().as_str())
131        .collect();
132    insert_statement_text(table_name, &column_names)
133}
134
135/// Emitted writing values from arror arrays into a table on the database
136#[derive(Debug, Error)]
137pub enum WriterError {
138    #[error("Failure to bind the array parameter buffers to the statement.\n{0}")]
139    BindParameterBuffers(#[source] odbc_api::Error),
140    #[error("Failure to execute the sql statement, sending the data to the database.\n{0}")]
141    ExecuteStatment(#[source] odbc_api::Error),
142    #[error("An error occured rebinding a parameter buffer to the sql statement.\n{0}")]
143    RebindBuffer(#[source] odbc_api::Error),
144    #[error("The arrow data type {0} is not supported for insertion.")]
145    UnsupportedArrowDataType(DataType),
146    #[error("An error occured extracting a record batch from an error reader.\n{0}")]
147    ReadingRecordBatch(#[source] ArrowError),
148    #[error("An error occurred preparing SQL statement. SQL:\n{sql}\n{source}")]
149    PreparingInsertStatement {
150        #[source]
151        source: odbc_api::Error,
152        sql: String,
153    },
154}
155
156/// Inserts batches from an [`arrow::record_batch::RecordBatchReader`] into a database.
157pub struct OdbcWriter<S> {
158    /// Prepared statement with bound array parameter buffers. Data is copied into these buffers
159    /// until they are full. Then we execute the statement. This is repeated until we run out of
160    /// data.
161    inserter: ColumnarBulkInserter<S, AnyBuffer>,
162    /// For each field in the arrow schema we decide on which buffer to use to send the parameters
163    /// to the database, and need to remember how to copy the data from an arrow array to an odbc
164    /// mutable buffer slice for any column.
165    strategies: Vec<Box<dyn WriteStrategy>>,
166}
167
168impl<S> OdbcWriter<S>
169where
170    S: AsStatementRef,
171{
172    /// Construct a new ODBC writer using an alredy existing prepared statement. Usually you want to
173    /// call a higher level constructor like [`Self::with_connection`]. Yet, this constructor is
174    /// useful in two scenarios.
175    ///
176    /// 1. The prepared statement is already constructed and you do not want to spend the time to
177    ///    prepare it again.
178    /// 2. You want to use the arrow arrays as arrar parameters for a statement, but that statement
179    ///    is not necessarily an INSERT statement with a simple 1to1 mapping of columns between
180    ///    table and arrow schema.
181    ///
182    /// # Parameters
183    ///
184    /// * `row_capacity`: The amount of rows send to the database in each chunk. With the exception
185    ///   of the last chunk, which may be smaller.
186    /// * `schema`: Schema needs to have one column for each positional parameter of the statement
187    ///   and match the data which will be supplied to the instance later. Otherwise your code will
188    ///   panic.
189    /// * `statement`: A prepared statement whose SQL text representation contains one placeholder
190    ///   for each column. The order of the placeholers must correspond to the orders of the columns
191    ///   in the `schema`.
192    pub fn new(
193        row_capacity: usize,
194        schema: &Schema,
195        statement: Prepared<S>,
196    ) -> Result<Self, WriterError> {
197        let strategies: Vec<_> = schema
198            .fields()
199            .iter()
200            .map(|field| field_to_write_strategy(field.as_ref()))
201            .collect::<Result<_, _>>()?;
202        let descriptions = strategies.iter().map(|cws| cws.buffer_desc());
203        let inserter = statement
204            .into_column_inserter(row_capacity, descriptions)
205            .map_err(WriterError::BindParameterBuffers)?;
206
207        Ok(Self {
208            inserter,
209            strategies,
210        })
211    }
212
213    /// Consumes all the batches in the record batch reader and sends them chunk by chunk to the
214    /// database.
215    pub fn write_all(
216        &mut self,
217        reader: impl Iterator<Item = Result<RecordBatch, ArrowError>>,
218    ) -> Result<(), WriterError> {
219        for result in reader {
220            let record_batch = result.map_err(WriterError::ReadingRecordBatch)?;
221            self.write_batch(&record_batch)?;
222        }
223        self.flush()?;
224        Ok(())
225    }
226
227    /// Consumes a single batch and sends it chunk by chunk to the database. The last batch may not
228    /// be consumed until [`Self::flush`] is called.
229    pub fn write_batch(&mut self, record_batch: &RecordBatch) -> Result<(), WriterError> {
230        let capacity = self.inserter.capacity();
231        let mut remanining_rows = record_batch.num_rows();
232        // The record batch may contain more rows than the capacity of our writer can hold. So we
233        // need to be able to fill the buffers multiple times and send them to the database in
234        // between.
235        while remanining_rows != 0 {
236            let chunk_size = min(capacity - self.inserter.num_rows(), remanining_rows);
237            let param_offset = self.inserter.num_rows();
238            self.inserter.set_num_rows(param_offset + chunk_size);
239            let chunk = record_batch.slice(record_batch.num_rows() - remanining_rows, chunk_size);
240            for (index, (array, strategy)) in chunk
241                .columns()
242                .iter()
243                .zip(self.strategies.iter())
244                .enumerate()
245            {
246                strategy.write_rows(param_offset, self.inserter.column_mut(index), array)?
247            }
248
249            // If we used up all capacity we send the parameters to the database and reset the
250            // parameter buffers.
251            if self.inserter.num_rows() == capacity {
252                self.flush()?;
253            }
254            remanining_rows -= chunk_size;
255        }
256
257        Ok(())
258    }
259
260    /// The number of row in an individual record batch must not necessarily match the capacity of
261    /// the buffers owned by this writer. Therfore sometimes records are not send to the database
262    /// immediatly but rather we wait for the buffers to be filled then reading the next batch. Once
263    /// we reach the last batch however, there is no "next batch" anymore. In that case we call this
264    /// method in order to send the remainder of the records to the database as well.
265    pub fn flush(&mut self) -> Result<(), WriterError> {
266        self.inserter
267            .execute()
268            .map_err(WriterError::ExecuteStatment)?;
269        self.inserter.clear();
270        Ok(())
271    }
272}
273
274impl<'env> OdbcWriter<StatementConnection<'env>> {
275    /// A writer which takes ownership of the connection and inserts the given schema into a table
276    /// with matching column names.
277    ///
278    /// **Note:**
279    ///
280    /// If the column name contains any character which would make it not a valid qualifier for transact
281    /// SQL it will be wrapped in double quotes (`"`) within the insert schema. Valid names consist of
282    /// alpha numeric characters, `@`, `$`, `#` and `_`.
283    pub fn from_connection(
284        connection: Connection<'env>,
285        schema: &Schema,
286        table_name: &str,
287        row_capacity: usize,
288    ) -> Result<Self, WriterError> {
289        let sql = insert_statement_from_schema(schema, table_name);
290        let statement = connection
291            .into_prepared(&sql)
292            .map_err(|source| WriterError::PreparingInsertStatement { source, sql })?;
293        Self::new(row_capacity, schema, statement)
294    }
295}
296
297impl<'o> OdbcWriter<StatementImpl<'o>> {
298    /// A writer which borrows the connection and inserts the given schema into a table with
299    /// matching column names.
300    ///
301    /// **Note:**
302    ///
303    /// If the column name contains any character which would make it not a valid qualifier for transact
304    /// SQL it will be wrapped in double quotes (`"`) within the insert schema. Valid names consist of
305    /// alpha numeric characters, `@`, `$`, `#` and `_`.
306    pub fn with_connection(
307        connection: &'o Connection<'o>,
308        schema: &Schema,
309        table_name: &str,
310        row_capacity: usize,
311    ) -> Result<Self, WriterError> {
312        let sql = insert_statement_from_schema(schema, table_name);
313        let statement = connection
314            .prepare(&sql)
315            .map_err(|source| WriterError::PreparingInsertStatement { source, sql })?;
316        Self::new(row_capacity, schema, statement)
317    }
318}
319
320pub trait WriteStrategy {
321    /// Describe the buffer used to hold the array parameters for the column
322    fn buffer_desc(&self) -> BufferDesc;
323
324    /// # Parameters
325    ///
326    /// * `param_offset`: Start writing parameters at that position. Number of rows in the parameter
327    ///   buffer before inserting the current chunk.
328    /// * `column_buf`: Buffer to write the data into
329    /// * `array`: Buffer to read the data from
330    fn write_rows(
331        &self,
332        param_offset: usize,
333        column_buf: AnySliceMut<'_>,
334        array: &dyn Array,
335    ) -> Result<(), WriterError>;
336}
337
338fn field_to_write_strategy(field: &Field) -> Result<Box<dyn WriteStrategy>, WriterError> {
339    let is_nullable = field.is_nullable();
340    let strategy = match field.data_type() {
341        DataType::Utf8 => Box::new(Utf8ToNativeText {}),
342        DataType::Boolean => boolean_to_bit(is_nullable),
343        DataType::LargeUtf8 => Box::new(LargeUtf8ToNativeText {}),
344        DataType::Int8 => Int8Type::identical(is_nullable),
345        DataType::Int16 => Int16Type::identical(is_nullable),
346        DataType::Int32 => Int32Type::identical(is_nullable),
347        DataType::Int64 => Int64Type::identical(is_nullable),
348        DataType::UInt8 => UInt8Type::identical(is_nullable),
349        DataType::Float16 => Float16Type::map_with(is_nullable, |half| half.to_f32()),
350        DataType::Float32 => Float32Type::identical(is_nullable),
351        DataType::Float64 => Float64Type::identical(is_nullable),
352        DataType::Timestamp(TimeUnit::Second, None) => {
353            TimestampSecondType::map_with(is_nullable, epoch_to_timestamp_s)
354        }
355        DataType::Timestamp(TimeUnit::Millisecond, None) => {
356            TimestampMillisecondType::map_with(is_nullable, epoch_to_timestamp_ms)
357        }
358        DataType::Timestamp(TimeUnit::Microsecond, None) => {
359            TimestampMicrosecondType::map_with(is_nullable, epoch_to_timestamp_us)
360        }
361        DataType::Timestamp(TimeUnit::Nanosecond, None) => {
362            TimestampNanosecondType::map_with(is_nullable, |ns| {
363                // Drop the last to digits of precision, since we bind it with precision 7 and not 9.
364                epoch_to_timestamp_ns((ns / 100) * 100)
365            })
366        }
367        DataType::Date32 => Date32Type::map_with(is_nullable, epoch_to_date),
368        DataType::Date64 => Date64Type::map_with(is_nullable, |days_since_epoch| {
369            epoch_to_date(days_since_epoch.try_into().unwrap())
370        }),
371        DataType::Time32(TimeUnit::Second) => {
372            Time32SecondType::map_with(is_nullable, sec_since_midnight_to_time)
373        }
374        DataType::Time32(TimeUnit::Millisecond) => {
375            Box::new(NullableTimeAsText::<Time32MillisecondType>::new())
376        }
377        DataType::Time64(TimeUnit::Microsecond) => {
378            Box::new(NullableTimeAsText::<Time64MicrosecondType>::new())
379        }
380        DataType::Time64(TimeUnit::Nanosecond) => {
381            Box::new(NullableTimeAsText::<Time64NanosecondType>::new())
382        }
383        DataType::Binary => Box::new(VariadicBinary::new(1)),
384        DataType::FixedSizeBinary(length) => {
385            Box::new(VariadicBinary::new((*length).try_into().unwrap()))
386        }
387        DataType::Decimal128(precision, scale) => {
388            Box::new(NullableDecimal128AsText::new(*precision, *scale))
389        }
390        DataType::Decimal256(precision, scale) => {
391            Box::new(NullableDecimal256AsText::new(*precision, *scale))
392        }
393        unsupported => return Err(WriterError::UnsupportedArrowDataType(unsupported.clone())),
394    };
395    Ok(strategy)
396}