aqueducts_odbc/
lib.rs

1//! # Aqueducts ODBC Integration
2//!
3//! This crate provides ODBC connectivity for Aqueducts pipelines, enabling integration
4//! with databases that support ODBC drivers (PostgreSQL, SQL Server, MySQL, etc.).
5//!
6//! ## Features
7//!
8//! - **OdbcSource**: Read data from ODBC-compatible databases using SQL queries
9//! - **OdbcDestination**: Write data to ODBC-compatible databases with transaction support
10//! - **Connection Pooling**: Efficient connection management for high-throughput operations
11//! - **Transaction Support**: ACID compliance with automatic rollback on errors
12//!
13//! ## Usage
14//!
15//! This crate is typically used through the main `aqueducts` meta-crate with the `odbc` feature:
16//!
17//! ```toml
18//! [dependencies]
19//! aqueducts = { version = "0.10", features = ["odbc"] }
20//! ```
21//!
22//! The ODBC integration is automatically registered when the feature is enabled.
23//! Configure ODBC sources and destinations in your pipeline YAML/JSON/TOML files:
24
25mod error;
26
27use std::sync::Arc;
28
29use aqueducts_schemas::destinations::WriteMode;
30use arrow_odbc::{
31    insert_into_table,
32    odbc_api::{ConnectionOptions, Environment},
33    OdbcReaderBuilder, OdbcWriter,
34};
35use datafusion::{
36    arrow::{
37        array::{RecordBatch, RecordBatchIterator},
38        compute::concat_batches,
39        datatypes::Schema,
40        error::ArrowError,
41    },
42    catalog::MemTable,
43    prelude::SessionContext,
44};
45use error::Result;
46use tracing::error;
47
48/// Register a table via ODBC using [arrow-odbc](https://docs.rs/arrow-odbc)
49#[doc(hidden)]
50pub async fn register_odbc_source(
51    ctx: Arc<SessionContext>,
52    connection_string: &str,
53    query: &str,
54    source_name: &str,
55) -> error::Result<()> {
56    let odbc_environment = Environment::new().unwrap();
57
58    let connection = odbc_environment
59        .connect_with_connection_string(connection_string, ConnectionOptions::default())?;
60
61    let parameters = ();
62
63    let cursor = connection
64        .execute(query, parameters, None)?
65        .expect("SELECT statement must produce a cursor");
66
67    let reader = OdbcReaderBuilder::new().build(cursor)?;
68
69    let batches = reader
70        .into_iter()
71        .collect::<std::result::Result<Vec<RecordBatch>, ArrowError>>()?;
72
73    let df = ctx.read_batches(batches)?;
74
75    let schema = df.schema().clone();
76    let partitioned = df.collect_partitioned().await?;
77    let table = MemTable::try_new(Arc::new(schema.as_arrow().clone()), partitioned)?;
78
79    ctx.register_table(source_name, Arc::new(table))?;
80
81    Ok(())
82}
83
84/// Checks if the provided table for the destination exists
85/// will try to query one record from the provided table name
86#[doc(hidden)]
87pub async fn register_odbc_destination(
88    connection_string: &str,
89    destination_name: &str,
90) -> Result<()> {
91    let odbc_environment = Environment::new().unwrap();
92
93    let connection = odbc_environment
94        .connect_with_connection_string(connection_string, ConnectionOptions::default())?;
95
96    let parameters = ();
97
98    let query = format!("SELECT * FROM {destination_name} LIMIT 1");
99    connection
100        .execute(query.as_str(), parameters, None)?
101        .expect("SELECT statement must produce a cursor");
102
103    Ok(())
104}
105
106#[doc(hidden)]
107pub async fn write_arrow_batches(
108    connection_string: &str,
109    destination_name: &str,
110    write_mode: WriteMode,
111    batches: Vec<datafusion::arrow::array::RecordBatch>,
112    schema: std::sync::Arc<datafusion::arrow::datatypes::Schema>,
113    batch_size: usize,
114) -> error::Result<()> {
115    match write_mode {
116        WriteMode::Append => {
117            append_arrow_batches(
118                connection_string,
119                destination_name,
120                batches,
121                schema,
122                batch_size,
123            )
124            .await
125        }
126        WriteMode::Custom(custom_statements) => {
127            custom(
128                connection_string,
129                custom_statements.pre_insert.clone(),
130                custom_statements.insert.as_str(),
131                batches,
132                schema,
133                batch_size,
134            )
135            .await
136        }
137    }
138}
139
140/// Write arrow batches to a table via ODBC using [arrow-odbc](https://docs.rs/arrow-odbc)
141async fn append_arrow_batches(
142    connection_string: &str,
143    destination_name: &str,
144    batches: Vec<RecordBatch>,
145    schema: Arc<Schema>,
146    batch_size: usize,
147) -> Result<()> {
148    let odbc_environment = Environment::new().unwrap();
149
150    let connection = odbc_environment
151        .connect_with_connection_string(connection_string, ConnectionOptions::default())?;
152
153    let batches = [concat_batches(&schema, batches.iter())?];
154    let mut record_batch_iterator = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
155
156    insert_into_table(
157        &connection,
158        &mut record_batch_iterator,
159        destination_name,
160        batch_size,
161    )?;
162
163    Ok(())
164}
165
166/// Performs an insert with a prepared statement provided.
167/// Optionally, it can execute preliminary statements (such as `delete from ...`).
168/// All statements are executed within the same transaction and it gets rolled back
169/// in case of any errors.
170async fn custom(
171    connection_string: &str,
172    pre_insert: Option<String>,
173    insert: &str,
174    batches: Vec<RecordBatch>,
175    schema: Arc<Schema>,
176    batch_size: usize,
177) -> Result<()> {
178    let odbc_environment = Environment::new()?;
179
180    let connection = odbc_environment
181        .connect_with_connection_string(connection_string, ConnectionOptions::default())?;
182
183    let batches = [concat_batches(&schema, batches.iter())?];
184    let record_batch_iterator =
185        RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone());
186
187    let mut writer = OdbcWriter::new(batch_size, &schema, connection.prepare(insert)?)?;
188
189    let _ = connection.set_autocommit(false);
190
191    let result = || -> Result<()> {
192        if let Some(stmt) = pre_insert {
193            connection.execute(&stmt, (), None)?;
194        }
195        writer.write_all(record_batch_iterator)?;
196
197        Ok(())
198    };
199
200    match result() {
201        Ok(_) => {
202            connection.commit()?;
203            Ok(())
204        }
205        Err(err) => {
206            connection.rollback()?;
207            error!("ROLLBACK transaction: {err:?}");
208            Err(err)
209        }
210    }
211}
212
213#[cfg(all(test, feature = "odbc_tests"))]
214mod tests {
215    use datafusion::arrow::array::*;
216    use datafusion::{assert_batches_eq, prelude::*};
217    use std::sync::Arc;
218
219    use super::*;
220
221    #[tokio::test]
222    #[tracing_test::traced_test]
223    async fn test_register_odbc_source_ok() {
224        let connection_string: &str = "\
225            Driver={PostgreSQL Unicode};\
226            Server=localhost;\
227            UID=postgres;\
228            PWD=postgres;\
229        ";
230
231        let ctx = Arc::new(SessionContext::new());
232
233        register_odbc_source(
234            ctx.clone(),
235            connection_string,
236            "SELECT * FROM temp_readings WHERE timestamp::date BETWEEN '2024-01-01' AND '2024-01-31'",
237            "my_table",
238        )
239        .await
240        .unwrap();
241
242        let result = ctx
243            .sql("SELECT count(*) num_rows FROM my_table")
244            .await
245            .unwrap()
246            .collect()
247            .await
248            .unwrap();
249
250        assert_batches_eq!(
251            &[
252                "+----------+",
253                "| num_rows |",
254                "+----------+",
255                "| 1000     |",
256                "+----------+",
257            ],
258            result.as_slice()
259        );
260    }
261
262    #[tokio::test]
263    #[tracing_test::traced_test]
264    async fn test_register_odbc_destination_ok() {
265        let connection_string: &str = "\
266            Driver={PostgreSQL Unicode};\
267            Server=localhost;\
268            UID=postgres;\
269            PWD=postgres;\
270        ";
271
272        let result = register_odbc_destination(connection_string, "temp_readings_empty").await;
273
274        assert!(result.is_ok());
275    }
276
277    #[tokio::test]
278    #[tracing_test::traced_test]
279    async fn test_write_arrow_batches_ok() {
280        let connection_string: &str = "\
281            Driver={PostgreSQL Unicode};\
282            Server=localhost;\
283            UID=postgres;\
284            PWD=postgres;\
285        ";
286
287        let locations = (0..1000).collect::<Vec<i32>>();
288        let timestamps = (1704067200..1704068200).collect::<Vec<i64>>();
289        let temperatures = (0..1000).map(|i| i as f64).collect::<Vec<f64>>();
290        let humidity = (0..1000).map(|i| i as f64).collect::<Vec<f64>>();
291        let conditions = (0..1000)
292            .map(|i| format!("CONDITION_{i}"))
293            .collect::<Vec<String>>();
294
295        let a: ArrayRef = Arc::new(Int32Array::from(locations));
296        let b: ArrayRef = Arc::new(TimestampSecondArray::from(timestamps));
297        let c: ArrayRef = Arc::new(Float64Array::from(temperatures));
298        let d: ArrayRef = Arc::new(Float64Array::from(humidity));
299        let e: ArrayRef = Arc::new(StringArray::from(conditions));
300
301        let record_batch = RecordBatch::try_from_iter(vec![
302            ("location_id", a),
303            ("timestamp", b),
304            ("temperature_c", c),
305            ("humidity", d),
306            ("weather_condition", e),
307        ])
308        .unwrap();
309        let schema = record_batch.schema();
310
311        let result = append_arrow_batches(
312            connection_string,
313            "temp_readings_empty",
314            vec![record_batch],
315            schema,
316            100,
317        )
318        .await;
319
320        assert!(result.is_ok());
321    }
322
323    /// Tests a transaction with a delete and an insert
324    #[tokio::test]
325    #[tracing_test::traced_test]
326    async fn test_custom_delete_insert_ok() {
327        use arrow_odbc::odbc_api::{ConnectionOptions, Environment};
328        use arrow_odbc::OdbcReaderBuilder;
329
330        let odbc_environment = Environment::new().unwrap();
331        let connection_string: &str = "\
332            Driver={PostgreSQL Unicode};\
333            Server=localhost;\
334            UID=postgres;\
335            PWD=postgres;\
336        ";
337        let connection = odbc_environment
338            .connect_with_connection_string(connection_string, ConnectionOptions::default())
339            .unwrap();
340        let _ = connection
341            .execute("truncate test_custom_delete_insert_ok", (), None)
342            .unwrap();
343
344        let record_batch = RecordBatch::try_from_iter(vec![
345            ("id", Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef),
346            (
347                "value",
348                Arc::new(StringArray::from(vec!["original", "original"])) as ArrayRef,
349            ),
350        ])
351        .unwrap();
352        let schema = record_batch.schema();
353
354        let _ = append_arrow_batches(
355            connection_string,
356            "test_custom_delete_insert_ok",
357            vec![record_batch],
358            schema.clone(),
359            100,
360        )
361        .await;
362
363        let new_batch = RecordBatch::try_from_iter(vec![
364            ("id", Arc::new(Int32Array::from(vec![1])) as ArrayRef),
365            (
366                "value",
367                Arc::new(StringArray::from(vec!["updated"])) as ArrayRef,
368            ),
369        ])
370        .unwrap();
371
372        custom(
373            connection_string,
374            Some("delete from test_custom_delete_insert_ok where id = 1".to_string()),
375            "insert into test_custom_delete_insert_ok values (?, ?)",
376            vec![new_batch],
377            schema,
378            50,
379        )
380        .await
381        .unwrap();
382
383        let cursor = connection
384            .execute(
385                "select * from test_custom_delete_insert_ok order by id",
386                (),
387                None,
388            )
389            .unwrap()
390            .unwrap();
391        let result = OdbcReaderBuilder::new().build(cursor).unwrap();
392        for batch in result {
393            assert_batches_eq!(
394                [
395                    "+----+----------+",
396                    "| id | value    |",
397                    "+----+----------+",
398                    "| 1  | updated  |",
399                    "| 2  | original |",
400                    "+----+----------+",
401                ],
402                &[batch.unwrap()]
403            );
404        }
405    }
406
407    /// Checks transaction is rolled back in case of error
408    #[tokio::test]
409    #[tracing_test::traced_test]
410    async fn test_custom_delete_insert_failed() {
411        use arrow_odbc::odbc_api::{ConnectionOptions, Environment};
412        use arrow_odbc::OdbcReaderBuilder;
413
414        let odbc_environment = Environment::new().unwrap();
415        let connection_string: &str = "\
416            Driver={PostgreSQL Unicode};\
417            Server=localhost;\
418            UID=postgres;\
419            PWD=postgres;\
420        ";
421        let connection = odbc_environment
422            .connect_with_connection_string(connection_string, ConnectionOptions::default())
423            .unwrap();
424        let _ = connection
425            .execute("truncate test_custom_delete_insert_failed", (), None)
426            .unwrap();
427
428        let record_batch = RecordBatch::try_from_iter(vec![
429            ("id", Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef),
430            (
431                "value",
432                Arc::new(StringArray::from(vec!["original", "original"])) as ArrayRef,
433            ),
434        ])
435        .unwrap();
436        let schema = record_batch.schema();
437
438        let _ = append_arrow_batches(
439            connection_string,
440            "test_custom_delete_insert_failed",
441            vec![record_batch],
442            schema.clone(),
443            100,
444        )
445        .await;
446
447        let new_batch = RecordBatch::try_from_iter(vec![
448            ("id", Arc::new(Int32Array::from(vec![1])) as ArrayRef),
449            (
450                "value",
451                Arc::new(StringArray::from(vec!["updated"])) as ArrayRef,
452            ),
453        ])
454        .unwrap();
455
456        custom(
457            connection_string,
458            Some("delete from test_custom_delete_insert_failed where id = 1".to_string()),
459            "insert into WRONG_TABLE values (?, ?)",
460            vec![new_batch],
461            schema,
462            50,
463        )
464        .await
465        .ok();
466
467        let cursor = connection
468            .execute(
469                "select * from test_custom_delete_insert_failed order by id",
470                (),
471                None,
472            )
473            .unwrap()
474            .unwrap();
475        let result = OdbcReaderBuilder::new().build(cursor).unwrap();
476        for batch in result {
477            assert_batches_eq!(
478                [
479                    "+----+----------+",
480                    "| id | value    |",
481                    "+----+----------+",
482                    "| 1  | original |",
483                    "| 2  | original |",
484                    "+----+----------+",
485                ],
486                &[batch.unwrap()]
487            );
488        }
489    }
490}