Skip to main content

arrow_odbc/
odbc_writer.rs

1use std::{borrow::Cow, cmp::min, sync::Arc};
2
3use thiserror::Error;
4
5use arrow::{
6    array::Array,
7    datatypes::{
8        DataType, Date32Type, Date64Type, Field, Float16Type, Float32Type, Float64Type, Int8Type,
9        Int16Type, Int32Type, Int64Type, Schema, Time32MillisecondType, Time32SecondType,
10        Time64MicrosecondType, Time64NanosecondType, TimeUnit, UInt8Type,
11    },
12    error::ArrowError,
13    record_batch::{RecordBatch, RecordBatchReader},
14};
15use odbc_api::{
16    BindParamDesc, ColumnarBulkInserter, Connection, ConnectionTransitions, Prepared,
17    buffers::{AnyBuffer, AnySliceMut},
18    handles::{AsStatementRef, StatementConnection, StatementImpl, StatementParent},
19    parameter::WithDataType,
20};
21
22use crate::{
23    date_time::{NullableTimeAsText, epoch_to_date, sec_since_midnight_to_time},
24    decimal::{NullableDecimal128AsText, NullableDecimal256AsText},
25    odbc_writer::timestamp::insert_timestamp_strategy,
26};
27
28use self::{
29    binary::VariadicBinary,
30    boolean::boolean_to_bit,
31    map_arrow_to_odbc::MapArrowToOdbc,
32    text::{LargeUtf8ToNativeText, Utf8ToNativeText},
33};
34
35mod binary;
36mod boolean;
37mod map_arrow_to_odbc;
38mod text;
39mod timestamp;
40
41/// Fastest and most convinient way to stream the contents of arrow record batches into a database
42/// table. For usecase there you want to insert repeatedly into the same table from different
43/// streams it is more efficient to create an instance of [`self::OdbcWriter`] and reuse it.
44///
45/// **Note:**
46///
47/// If table or column names are derived from user input, be sure to sanatize the input in order to
48/// prevent SQL injection attacks.
49pub fn insert_into_table(
50    connection: &Connection,
51    batches: &mut impl RecordBatchReader,
52    table_name: &str,
53    batch_size: usize,
54) -> Result<(), WriterError> {
55    let schema = batches.schema();
56    let mut inserter =
57        OdbcWriter::with_connection(connection, schema.as_ref(), table_name, batch_size)?;
58    inserter.write_all(batches)
59}
60
61/// Generates an insert statement using the table and column names.
62///
63/// `INSERT INTO <table> (<column_names 0>, <column_names 1>, ...) VALUES (?, ?, ...)`
64fn insert_statement_text(table: &str, column_names: &[&'_ str]) -> String {
65    // Generate statement text from table name and headline
66    let column_names = column_names
67        .iter()
68        .map(|cn| quote_column_name(cn))
69        .collect::<Vec<_>>();
70    let columns = column_names.join(", ");
71    let values = column_names
72        .iter()
73        .map(|_| "?")
74        .collect::<Vec<_>>()
75        .join(", ");
76    // Do not finish the statement with a semicolon. There is anecodtical evidence of IBM db2 not
77    // allowing the command, because it expects now multiple statements.
78    // See: <https://github.com/pacman82/arrow-odbc/issues/63>
79    format!("INSERT INTO {table} ({columns}) VALUES ({values})")
80}
81
82/// Wraps column name in quotes, if need be.
83fn quote_column_name(column_name: &str) -> Cow<'_, str> {
84    // We do not want to apply quoting in case the string is already quoted. See:
85    // <https://github.com/pacman82/arrow-odbc-py/issues/162>
86    //
87    // Another approach would have been to apply quoting after detecting keywords. Yet the list of
88    // reserved keywords is large. There is also the issue with different databases having different
89    // quoting rules. So the strategy choosen here is to apply quoting in less situations and not
90    // more, so the user has more control over the final statement. This crate is about arrow and
91    // odbc, less so about SQL dialects and statement construction.
92    let is_already_quoted = || {
93        (column_name.starts_with('"') && column_name.ends_with('"'))
94            || column_name.starts_with('[') && column_name.ends_with(']')
95            || column_name.starts_with('`') && column_name.ends_with('`')
96    };
97    let contains_invalid_characters = || column_name.contains(|c| !valid_in_column_name(c));
98    let needs_quotes = contains_invalid_characters() && !is_already_quoted();
99    if needs_quotes {
100        Cow::Owned(format!("\"{column_name}\""))
101    } else {
102        Cow::Borrowed(column_name)
103    }
104}
105
106/// Check if this character is allowed in an unquoted column name
107fn valid_in_column_name(c: char) -> bool {
108    // See:
109    // <https://stackoverflow.com/questions/4200351/what-characters-are-valid-in-an-sql-server-database-name>
110    c.is_alphanumeric() || c == '@' || c == '$' || c == '#' || c == '_'
111}
112
113/// Creates an SQL insert statement from an arrow schema. The resulting statement will have one
114/// placeholer (`?`) for each column in the statement.
115///
116/// **Note:**
117///
118/// If the column name contains any character which would make it not a valid qualifier for transact
119/// SQL it will be wrapped in double quotes (`"`) within the insert schema. Valid names consist of
120/// alpha numeric characters, `@`, `$`, `#` and `_`.
121///
122/// # Example
123///
124/// ```
125/// use arrow_odbc::{
126///     insert_statement_from_schema,
127///     arrow::datatypes::{Field, DataType, Schema},
128/// };
129///
130/// let field_a = Field::new("a", DataType::Int64, false);
131/// let field_b = Field::new("b", DataType::Boolean, false);
132///
133/// let schema = Schema::new(vec![field_a, field_b]);
134/// let sql = insert_statement_from_schema(&schema, "MyTable");
135///
136/// assert_eq!("INSERT INTO MyTable (a, b) VALUES (?, ?)", sql)
137/// ```
138///
139/// This function is automatically invoked by [`crate::OdbcWriter::with_connection`].
140pub fn insert_statement_from_schema(schema: &Schema, table_name: &str) -> String {
141    let fields = schema.fields();
142    let num_columns = fields.len();
143    let column_names: Vec<_> = (0..num_columns)
144        .map(|i| fields[i].name().as_str())
145        .collect();
146    insert_statement_text(table_name, &column_names)
147}
148
149/// Emitted writing values from arror arrays into a table on the database
150#[derive(Debug, Error)]
151pub enum WriterError {
152    #[error("Failure to bind the array parameter buffers to the statement.\n{0}")]
153    BindParameterBuffers(#[source] odbc_api::Error),
154    #[error("Failure to execute the sql statement, sending the data to the database.\n{0}")]
155    ExecuteStatment(#[source] odbc_api::Error),
156    #[error("An error occured rebinding a parameter buffer to the sql statement.\n{0}")]
157    RebindBuffer(#[source] odbc_api::Error),
158    #[error("The arrow data type {0} is not supported for insertion.")]
159    UnsupportedArrowDataType(DataType),
160    #[error("An error occured extracting a record batch from an error reader.\n{0}")]
161    ReadingRecordBatch(#[source] ArrowError),
162    #[error("Unable to parse '{time_zone}' into a valid IANA time zone.")]
163    InvalidTimeZone { time_zone: Arc<str> },
164    #[error("An error occurred preparing SQL statement. SQL:\n{sql}\n{source}")]
165    PreparingInsertStatement {
166        #[source]
167        source: odbc_api::Error,
168        sql: String,
169    },
170}
171
172/// Inserts batches from an [`arrow::record_batch::RecordBatchReader`] into a database.
173pub struct OdbcWriter<S> {
174    /// Prepared statement with bound array parameter buffers. Data is copied into these buffers
175    /// until they are full. Then we execute the statement. This is repeated until we run out of
176    /// data.
177    inserter: ColumnarBulkInserter<S, WithDataType<AnyBuffer>>,
178    /// For each field in the arrow schema we decide on which buffer to use to send the parameters
179    /// to the database, and need to remember how to copy the data from an arrow array to an odbc
180    /// mutable buffer slice for any column.
181    strategies: Vec<Box<dyn WriteStrategy>>,
182}
183
184impl<S> OdbcWriter<S>
185where
186    S: AsStatementRef,
187{
188    /// Construct a new ODBC writer using an alredy existing prepared statement. Usually you want to
189    /// call a higher level constructor like [`Self::with_connection`]. Yet, this constructor is
190    /// useful in two scenarios.
191    ///
192    /// 1. The prepared statement is already constructed and you do not want to spend the time to
193    ///    prepare it again.
194    /// 2. You want to use the arrow arrays as arrar parameters for a statement, but that statement
195    ///    is not necessarily an INSERT statement with a simple 1to1 mapping of columns between
196    ///    table and arrow schema.
197    ///
198    /// # Parameters
199    ///
200    /// * `row_capacity`: The amount of rows send to the database in each chunk. With the exception
201    ///   of the last chunk, which may be smaller.
202    /// * `schema`: Schema needs to have one column for each positional parameter of the statement
203    ///   and match the data which will be supplied to the instance later. Otherwise your code will
204    ///   panic.
205    /// * `statement`: A prepared statement whose SQL text representation contains one placeholder
206    ///   for each column. The order of the placeholers must correspond to the orders of the columns
207    ///   in the `schema`.
208    pub fn new(
209        row_capacity: usize,
210        schema: &Schema,
211        statement: Prepared<S>,
212    ) -> Result<Self, WriterError> {
213        let strategies: Vec<_> = schema
214            .fields()
215            .iter()
216            .map(|field| field_to_write_strategy(field.as_ref()))
217            .collect::<Result<_, _>>()?;
218        let descriptions = strategies.iter().map(|cws| cws.buffer_desc());
219        let inserter = statement
220            .into_column_inserter(row_capacity, descriptions)
221            .map_err(WriterError::BindParameterBuffers)?;
222
223        Ok(Self {
224            inserter,
225            strategies,
226        })
227    }
228
229    /// Consumes all the batches in the record batch reader and sends them chunk by chunk to the
230    /// database.
231    pub fn write_all(
232        &mut self,
233        reader: impl Iterator<Item = Result<RecordBatch, ArrowError>>,
234    ) -> Result<(), WriterError> {
235        for result in reader {
236            let record_batch = result.map_err(WriterError::ReadingRecordBatch)?;
237            self.write_batch(&record_batch)?;
238        }
239        self.flush()?;
240        Ok(())
241    }
242
243    /// Consumes a single batch and sends it chunk by chunk to the database. The last batch may not
244    /// be consumed until [`Self::flush`] is called.
245    pub fn write_batch(&mut self, record_batch: &RecordBatch) -> Result<(), WriterError> {
246        let capacity = self.inserter.capacity();
247        let mut remanining_rows = record_batch.num_rows();
248        // The record batch may contain more rows than the capacity of our writer can hold. So we
249        // need to be able to fill the buffers multiple times and send them to the database in
250        // between.
251        while remanining_rows != 0 {
252            let chunk_size = min(capacity - self.inserter.num_rows(), remanining_rows);
253            let param_offset = self.inserter.num_rows();
254            self.inserter.set_num_rows(param_offset + chunk_size);
255            let chunk = record_batch.slice(record_batch.num_rows() - remanining_rows, chunk_size);
256            for (index, (array, strategy)) in chunk
257                .columns()
258                .iter()
259                .zip(self.strategies.iter())
260                .enumerate()
261            {
262                strategy.write_rows(param_offset, self.inserter.column_mut(index), array)?
263            }
264
265            // If we used up all capacity we send the parameters to the database and reset the
266            // parameter buffers.
267            if self.inserter.num_rows() == capacity {
268                self.flush()?;
269            }
270            remanining_rows -= chunk_size;
271        }
272
273        Ok(())
274    }
275
276    /// The number of row in an individual record batch must not necessarily match the capacity of
277    /// the buffers owned by this writer. Therfore sometimes records are not send to the database
278    /// immediatly but rather we wait for the buffers to be filled then reading the next batch. Once
279    /// we reach the last batch however, there is no "next batch" anymore. In that case we call this
280    /// method in order to send the remainder of the records to the database as well.
281    pub fn flush(&mut self) -> Result<(), WriterError> {
282        self.inserter
283            .execute()
284            .map_err(WriterError::ExecuteStatment)?;
285        self.inserter.clear();
286        Ok(())
287    }
288}
289
290impl<C> OdbcWriter<StatementConnection<C>>
291where
292    C: StatementParent,
293{
294    /// A writer which takes ownership of the connection and inserts the given schema into a table
295    /// with matching column names.
296    ///
297    /// **Note:**
298    ///
299    /// If the column name contains any character which would make it not a valid qualifier for transact
300    /// SQL it will be wrapped in double quotes (`"`) within the insert schema. Valid names consist of
301    /// alpha numeric characters, `@`, `$`, `#` and `_`.
302    pub fn from_connection<C2>(
303        connection: C2,
304        schema: &Schema,
305        table_name: &str,
306        row_capacity: usize,
307    ) -> Result<Self, WriterError>
308    where
309        C2: ConnectionTransitions<StatementParent = C>,
310    {
311        let sql = insert_statement_from_schema(schema, table_name);
312        let statement = connection
313            .into_prepared(&sql)
314            .map_err(|source| WriterError::PreparingInsertStatement { source, sql })?;
315        Self::new(row_capacity, schema, statement)
316    }
317}
318
319impl<'o> OdbcWriter<StatementImpl<'o>> {
320    /// A writer which borrows the connection and inserts the given schema into a table with
321    /// matching column names.
322    ///
323    /// **Note:**
324    ///
325    /// If the column name contains any character which would make it not a valid qualifier for transact
326    /// SQL it will be wrapped in double quotes (`"`) within the insert schema. Valid names consist of
327    /// alpha numeric characters, `@`, `$`, `#` and `_`.
328    pub fn with_connection(
329        connection: &'o Connection<'o>,
330        schema: &Schema,
331        table_name: &str,
332        row_capacity: usize,
333    ) -> Result<Self, WriterError> {
334        let sql = insert_statement_from_schema(schema, table_name);
335        let statement = connection
336            .prepare(&sql)
337            .map_err(|source| WriterError::PreparingInsertStatement { source, sql })?;
338        Self::new(row_capacity, schema, statement)
339    }
340}
341
342pub trait WriteStrategy {
343    /// Describe the buffer used to hold the array parameters for the column
344    fn buffer_desc(&self) -> BindParamDesc;
345
346    /// # Parameters
347    ///
348    /// * `param_offset`: Start writing parameters at that position. Number of rows in the parameter
349    ///   buffer before inserting the current chunk.
350    /// * `column_buf`: Buffer to write the data into
351    /// * `array`: Buffer to read the data from
352    fn write_rows(
353        &self,
354        param_offset: usize,
355        column_buf: AnySliceMut<'_>,
356        array: &dyn Array,
357    ) -> Result<(), WriterError>;
358}
359
360fn field_to_write_strategy(field: &Field) -> Result<Box<dyn WriteStrategy>, WriterError> {
361    let is_nullable = field.is_nullable();
362    let strategy = match field.data_type() {
363        DataType::Utf8 => Box::new(Utf8ToNativeText {}),
364        DataType::Boolean => boolean_to_bit(is_nullable),
365        DataType::LargeUtf8 => Box::new(LargeUtf8ToNativeText {}),
366        DataType::Int8 => Int8Type::identical(is_nullable),
367        DataType::Int16 => Int16Type::identical(is_nullable),
368        DataType::Int32 => Int32Type::identical(is_nullable),
369        DataType::Int64 => Int64Type::identical(is_nullable),
370        DataType::UInt8 => UInt8Type::identical(is_nullable),
371        DataType::Float16 => Float16Type::map_with(is_nullable, |half| half.to_f32()),
372        DataType::Float32 => Float32Type::identical(is_nullable),
373        DataType::Float64 => Float64Type::identical(is_nullable),
374        DataType::Timestamp(time_unit, time_zone) => {
375            insert_timestamp_strategy(is_nullable, &time_unit, time_zone.clone())?
376        }
377        DataType::Date32 => Date32Type::map_with(is_nullable, epoch_to_date),
378        DataType::Date64 => Date64Type::map_with(is_nullable, |days_since_epoch| {
379            epoch_to_date(days_since_epoch.try_into().unwrap())
380        }),
381        DataType::Time32(TimeUnit::Second) => {
382            Time32SecondType::map_with(is_nullable, sec_since_midnight_to_time)
383        }
384        DataType::Time32(TimeUnit::Millisecond) => {
385            Box::new(NullableTimeAsText::<Time32MillisecondType>::new())
386        }
387        DataType::Time64(TimeUnit::Microsecond) => {
388            Box::new(NullableTimeAsText::<Time64MicrosecondType>::new())
389        }
390        DataType::Time64(TimeUnit::Nanosecond) => {
391            Box::new(NullableTimeAsText::<Time64NanosecondType>::new())
392        }
393        DataType::Binary => Box::new(VariadicBinary::new(1)),
394        DataType::FixedSizeBinary(length) => {
395            Box::new(VariadicBinary::new((*length).try_into().unwrap()))
396        }
397        DataType::Decimal128(precision, scale) => {
398            Box::new(NullableDecimal128AsText::new(*precision, *scale))
399        }
400        DataType::Decimal256(precision, scale) => {
401            Box::new(NullableDecimal256AsText::new(*precision, *scale))
402        }
403        unsupported => return Err(WriterError::UnsupportedArrowDataType(unsupported.clone())),
404    };
405    Ok(strategy)
406}