1mod error;
26
27use std::sync::Arc;
28
29use aqueducts_schemas::destinations::WriteMode;
30use arrow_odbc::{
31 insert_into_table,
32 odbc_api::{ConnectionOptions, Environment},
33 OdbcReaderBuilder, OdbcWriter,
34};
35use datafusion::{
36 arrow::{
37 array::{RecordBatch, RecordBatchIterator},
38 compute::concat_batches,
39 datatypes::Schema,
40 error::ArrowError,
41 },
42 catalog::MemTable,
43 prelude::SessionContext,
44};
45use error::Result;
46use tracing::error;
47
48#[doc(hidden)]
50pub async fn register_odbc_source(
51 ctx: Arc<SessionContext>,
52 connection_string: &str,
53 query: &str,
54 source_name: &str,
55) -> error::Result<()> {
56 let odbc_environment = Environment::new().unwrap();
57
58 let connection = odbc_environment
59 .connect_with_connection_string(connection_string, ConnectionOptions::default())?;
60
61 let parameters = ();
62
63 let cursor = connection
64 .execute(query, parameters, None)?
65 .expect("SELECT statement must produce a cursor");
66
67 let reader = OdbcReaderBuilder::new().build(cursor)?;
68
69 let batches = reader
70 .into_iter()
71 .collect::<std::result::Result<Vec<RecordBatch>, ArrowError>>()?;
72
73 let df = ctx.read_batches(batches)?;
74
75 let schema = df.schema().clone();
76 let partitioned = df.collect_partitioned().await?;
77 let table = MemTable::try_new(Arc::new(schema.as_arrow().clone()), partitioned)?;
78
79 ctx.register_table(source_name, Arc::new(table))?;
80
81 Ok(())
82}
83
84#[doc(hidden)]
87pub async fn register_odbc_destination(
88 connection_string: &str,
89 destination_name: &str,
90) -> Result<()> {
91 let odbc_environment = Environment::new().unwrap();
92
93 let connection = odbc_environment
94 .connect_with_connection_string(connection_string, ConnectionOptions::default())?;
95
96 let parameters = ();
97
98 let query = format!("SELECT * FROM {destination_name} LIMIT 1");
99 connection
100 .execute(query.as_str(), parameters, None)?
101 .expect("SELECT statement must produce a cursor");
102
103 Ok(())
104}
105
106#[doc(hidden)]
107pub async fn write_arrow_batches(
108 connection_string: &str,
109 destination_name: &str,
110 write_mode: WriteMode,
111 batches: Vec<datafusion::arrow::array::RecordBatch>,
112 schema: std::sync::Arc<datafusion::arrow::datatypes::Schema>,
113 batch_size: usize,
114) -> error::Result<()> {
115 match write_mode {
116 WriteMode::Append => {
117 append_arrow_batches(
118 connection_string,
119 destination_name,
120 batches,
121 schema,
122 batch_size,
123 )
124 .await
125 }
126 WriteMode::Custom(custom_statements) => {
127 custom(
128 connection_string,
129 custom_statements.pre_insert.clone(),
130 custom_statements.insert.as_str(),
131 batches,
132 schema,
133 batch_size,
134 )
135 .await
136 }
137 }
138}
139
140async fn append_arrow_batches(
142 connection_string: &str,
143 destination_name: &str,
144 batches: Vec<RecordBatch>,
145 schema: Arc<Schema>,
146 batch_size: usize,
147) -> Result<()> {
148 let odbc_environment = Environment::new().unwrap();
149
150 let connection = odbc_environment
151 .connect_with_connection_string(connection_string, ConnectionOptions::default())?;
152
153 let batches = [concat_batches(&schema, batches.iter())?];
154 let mut record_batch_iterator = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
155
156 insert_into_table(
157 &connection,
158 &mut record_batch_iterator,
159 destination_name,
160 batch_size,
161 )?;
162
163 Ok(())
164}
165
166async fn custom(
171 connection_string: &str,
172 pre_insert: Option<String>,
173 insert: &str,
174 batches: Vec<RecordBatch>,
175 schema: Arc<Schema>,
176 batch_size: usize,
177) -> Result<()> {
178 let odbc_environment = Environment::new()?;
179
180 let connection = odbc_environment
181 .connect_with_connection_string(connection_string, ConnectionOptions::default())?;
182
183 let batches = [concat_batches(&schema, batches.iter())?];
184 let record_batch_iterator =
185 RecordBatchIterator::new(batches.into_iter().map(Ok), schema.clone());
186
187 let mut writer = OdbcWriter::new(batch_size, &schema, connection.prepare(insert)?)?;
188
189 let _ = connection.set_autocommit(false);
190
191 let result = || -> Result<()> {
192 if let Some(stmt) = pre_insert {
193 connection.execute(&stmt, (), None)?;
194 }
195 writer.write_all(record_batch_iterator)?;
196
197 Ok(())
198 };
199
200 match result() {
201 Ok(_) => {
202 connection.commit()?;
203 Ok(())
204 }
205 Err(err) => {
206 connection.rollback()?;
207 error!("ROLLBACK transaction: {err:?}");
208 Err(err)
209 }
210 }
211}
212
213#[cfg(all(test, feature = "odbc_tests"))]
214mod tests {
215 use datafusion::arrow::array::*;
216 use datafusion::{assert_batches_eq, prelude::*};
217 use std::sync::Arc;
218
219 use super::*;
220
221 #[tokio::test]
222 #[tracing_test::traced_test]
223 async fn test_register_odbc_source_ok() {
224 let connection_string: &str = "\
225 Driver={PostgreSQL Unicode};\
226 Server=localhost;\
227 UID=postgres;\
228 PWD=postgres;\
229 ";
230
231 let ctx = Arc::new(SessionContext::new());
232
233 register_odbc_source(
234 ctx.clone(),
235 connection_string,
236 "SELECT * FROM temp_readings WHERE timestamp::date BETWEEN '2024-01-01' AND '2024-01-31'",
237 "my_table",
238 )
239 .await
240 .unwrap();
241
242 let result = ctx
243 .sql("SELECT count(*) num_rows FROM my_table")
244 .await
245 .unwrap()
246 .collect()
247 .await
248 .unwrap();
249
250 assert_batches_eq!(
251 &[
252 "+----------+",
253 "| num_rows |",
254 "+----------+",
255 "| 1000 |",
256 "+----------+",
257 ],
258 result.as_slice()
259 );
260 }
261
262 #[tokio::test]
263 #[tracing_test::traced_test]
264 async fn test_register_odbc_destination_ok() {
265 let connection_string: &str = "\
266 Driver={PostgreSQL Unicode};\
267 Server=localhost;\
268 UID=postgres;\
269 PWD=postgres;\
270 ";
271
272 let result = register_odbc_destination(connection_string, "temp_readings_empty").await;
273
274 assert!(result.is_ok());
275 }
276
277 #[tokio::test]
278 #[tracing_test::traced_test]
279 async fn test_write_arrow_batches_ok() {
280 let connection_string: &str = "\
281 Driver={PostgreSQL Unicode};\
282 Server=localhost;\
283 UID=postgres;\
284 PWD=postgres;\
285 ";
286
287 let locations = (0..1000).collect::<Vec<i32>>();
288 let timestamps = (1704067200..1704068200).collect::<Vec<i64>>();
289 let temperatures = (0..1000).map(|i| i as f64).collect::<Vec<f64>>();
290 let humidity = (0..1000).map(|i| i as f64).collect::<Vec<f64>>();
291 let conditions = (0..1000)
292 .map(|i| format!("CONDITION_{i}"))
293 .collect::<Vec<String>>();
294
295 let a: ArrayRef = Arc::new(Int32Array::from(locations));
296 let b: ArrayRef = Arc::new(TimestampSecondArray::from(timestamps));
297 let c: ArrayRef = Arc::new(Float64Array::from(temperatures));
298 let d: ArrayRef = Arc::new(Float64Array::from(humidity));
299 let e: ArrayRef = Arc::new(StringArray::from(conditions));
300
301 let record_batch = RecordBatch::try_from_iter(vec![
302 ("location_id", a),
303 ("timestamp", b),
304 ("temperature_c", c),
305 ("humidity", d),
306 ("weather_condition", e),
307 ])
308 .unwrap();
309 let schema = record_batch.schema();
310
311 let result = append_arrow_batches(
312 connection_string,
313 "temp_readings_empty",
314 vec![record_batch],
315 schema,
316 100,
317 )
318 .await;
319
320 assert!(result.is_ok());
321 }
322
323 #[tokio::test]
325 #[tracing_test::traced_test]
326 async fn test_custom_delete_insert_ok() {
327 use arrow_odbc::odbc_api::{ConnectionOptions, Environment};
328 use arrow_odbc::OdbcReaderBuilder;
329
330 let odbc_environment = Environment::new().unwrap();
331 let connection_string: &str = "\
332 Driver={PostgreSQL Unicode};\
333 Server=localhost;\
334 UID=postgres;\
335 PWD=postgres;\
336 ";
337 let connection = odbc_environment
338 .connect_with_connection_string(connection_string, ConnectionOptions::default())
339 .unwrap();
340 let _ = connection
341 .execute("truncate test_custom_delete_insert_ok", (), None)
342 .unwrap();
343
344 let record_batch = RecordBatch::try_from_iter(vec![
345 ("id", Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef),
346 (
347 "value",
348 Arc::new(StringArray::from(vec!["original", "original"])) as ArrayRef,
349 ),
350 ])
351 .unwrap();
352 let schema = record_batch.schema();
353
354 let _ = append_arrow_batches(
355 connection_string,
356 "test_custom_delete_insert_ok",
357 vec![record_batch],
358 schema.clone(),
359 100,
360 )
361 .await;
362
363 let new_batch = RecordBatch::try_from_iter(vec![
364 ("id", Arc::new(Int32Array::from(vec![1])) as ArrayRef),
365 (
366 "value",
367 Arc::new(StringArray::from(vec!["updated"])) as ArrayRef,
368 ),
369 ])
370 .unwrap();
371
372 custom(
373 connection_string,
374 Some("delete from test_custom_delete_insert_ok where id = 1".to_string()),
375 "insert into test_custom_delete_insert_ok values (?, ?)",
376 vec![new_batch],
377 schema,
378 50,
379 )
380 .await
381 .unwrap();
382
383 let cursor = connection
384 .execute(
385 "select * from test_custom_delete_insert_ok order by id",
386 (),
387 None,
388 )
389 .unwrap()
390 .unwrap();
391 let result = OdbcReaderBuilder::new().build(cursor).unwrap();
392 for batch in result {
393 assert_batches_eq!(
394 [
395 "+----+----------+",
396 "| id | value |",
397 "+----+----------+",
398 "| 1 | updated |",
399 "| 2 | original |",
400 "+----+----------+",
401 ],
402 &[batch.unwrap()]
403 );
404 }
405 }
406
407 #[tokio::test]
409 #[tracing_test::traced_test]
410 async fn test_custom_delete_insert_failed() {
411 use arrow_odbc::odbc_api::{ConnectionOptions, Environment};
412 use arrow_odbc::OdbcReaderBuilder;
413
414 let odbc_environment = Environment::new().unwrap();
415 let connection_string: &str = "\
416 Driver={PostgreSQL Unicode};\
417 Server=localhost;\
418 UID=postgres;\
419 PWD=postgres;\
420 ";
421 let connection = odbc_environment
422 .connect_with_connection_string(connection_string, ConnectionOptions::default())
423 .unwrap();
424 let _ = connection
425 .execute("truncate test_custom_delete_insert_failed", (), None)
426 .unwrap();
427
428 let record_batch = RecordBatch::try_from_iter(vec![
429 ("id", Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef),
430 (
431 "value",
432 Arc::new(StringArray::from(vec!["original", "original"])) as ArrayRef,
433 ),
434 ])
435 .unwrap();
436 let schema = record_batch.schema();
437
438 let _ = append_arrow_batches(
439 connection_string,
440 "test_custom_delete_insert_failed",
441 vec![record_batch],
442 schema.clone(),
443 100,
444 )
445 .await;
446
447 let new_batch = RecordBatch::try_from_iter(vec![
448 ("id", Arc::new(Int32Array::from(vec![1])) as ArrayRef),
449 (
450 "value",
451 Arc::new(StringArray::from(vec!["updated"])) as ArrayRef,
452 ),
453 ])
454 .unwrap();
455
456 custom(
457 connection_string,
458 Some("delete from test_custom_delete_insert_failed where id = 1".to_string()),
459 "insert into WRONG_TABLE values (?, ?)",
460 vec![new_batch],
461 schema,
462 50,
463 )
464 .await
465 .ok();
466
467 let cursor = connection
468 .execute(
469 "select * from test_custom_delete_insert_failed order by id",
470 (),
471 None,
472 )
473 .unwrap()
474 .unwrap();
475 let result = OdbcReaderBuilder::new().build(cursor).unwrap();
476 for batch in result {
477 assert_batches_eq!(
478 [
479 "+----+----------+",
480 "| id | value |",
481 "+----+----------+",
482 "| 1 | original |",
483 "| 2 | original |",
484 "+----+----------+",
485 ],
486 &[batch.unwrap()]
487 );
488 }
489 }
490}