1use crate::connection::{RemoteDbType, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, Pool, RemoteField, RemoteSchema, RemoteSchemaRef,
4 RemoteType, SqliteType, TableSource, Unparse, unparse_array,
5};
6use datafusion::arrow::array::{
7 ArrayBuilder, ArrayRef, BinaryBuilder, Float64Builder, Int32Builder, Int64Builder, NullBuilder,
8 RecordBatch, RecordBatchOptions, StringBuilder, make_builder,
9};
10use datafusion::arrow::datatypes::{DataType, SchemaRef};
11use datafusion::common::{DataFusionError, project_schema};
12use datafusion::execution::SendableRecordBatchStream;
13use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
14use derive_getters::Getters;
15use derive_with::With;
16use futures::StreamExt;
17use itertools::Itertools;
18use log::{debug, error};
19use rusqlite::types::ValueRef;
20use rusqlite::{Column, Row, Rows};
21use std::any::Any;
22use std::collections::HashMap;
23use std::path::PathBuf;
24use std::sync::Arc;
25
26#[derive(Debug, Clone, With, Getters)]
27pub struct SqliteConnectionOptions {
28 pub path: PathBuf,
29 pub stream_chunk_size: usize,
30}
31
32impl SqliteConnectionOptions {
33 pub fn new(path: PathBuf) -> Self {
34 Self {
35 path,
36 stream_chunk_size: 2048,
37 }
38 }
39}
40
41impl From<SqliteConnectionOptions> for ConnectionOptions {
42 fn from(options: SqliteConnectionOptions) -> Self {
43 ConnectionOptions::Sqlite(options)
44 }
45}
46
47#[derive(Debug)]
48pub struct SqlitePool {
49 path: PathBuf,
50}
51
52pub async fn connect_sqlite(options: &SqliteConnectionOptions) -> DFResult<SqlitePool> {
53 let _ = rusqlite::Connection::open(&options.path).map_err(|e| {
54 DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
55 })?;
56 Ok(SqlitePool {
57 path: options.path.clone(),
58 })
59}
60
61#[async_trait::async_trait]
62impl Pool for SqlitePool {
63 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
64 Ok(Arc::new(SqliteConnection {
65 path: self.path.clone(),
66 }))
67 }
68}
69
70#[derive(Debug)]
71pub struct SqliteConnection {
72 path: PathBuf,
73}
74
75#[async_trait::async_trait]
76impl Connection for SqliteConnection {
77 fn as_any(&self) -> &dyn Any {
78 self
79 }
80
81 async fn infer_schema(&self, source: &TableSource) -> DFResult<RemoteSchemaRef> {
82 let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
83 DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
84 })?;
85 match source {
86 TableSource::Table(table) => {
87 let sql = format!(
89 "PRAGMA table_info({})",
90 RemoteDbType::Sqlite.sql_table_name(table)
91 );
92 let mut stmt = conn.prepare(&sql).map_err(|e| {
93 DataFusionError::Execution(format!("Failed to prepare sqlite statement: {e:?}"))
94 })?;
95 let rows = stmt.query([]).map_err(|e| {
96 DataFusionError::Execution(format!("Failed to query sqlite statement: {e:?}"))
97 })?;
98 let remote_schema = Arc::new(build_remote_schema_for_table(rows)?);
99 Ok(remote_schema)
100 }
101 TableSource::Query(_query) => {
102 let sql = RemoteDbType::Sqlite.limit_1_query_if_possible(source);
103 let mut stmt = conn.prepare(&sql).map_err(|e| {
104 DataFusionError::Execution(format!("Failed to prepare sqlite statement: {e:?}"))
105 })?;
106 let columns: Vec<OwnedColumn> =
107 stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
108 let rows = stmt.query([]).map_err(|e| {
109 DataFusionError::Execution(format!("Failed to query sqlite statement: {e:?}"))
110 })?;
111
112 let remote_schema =
113 Arc::new(build_remote_schema_for_query(columns.as_slice(), rows)?);
114 Ok(remote_schema)
115 }
116 }
117 }
118
119 async fn query(
120 &self,
121 conn_options: &ConnectionOptions,
122 source: &TableSource,
123 table_schema: SchemaRef,
124 projection: Option<&Vec<usize>>,
125 unparsed_filters: &[String],
126 limit: Option<usize>,
127 ) -> DFResult<SendableRecordBatchStream> {
128 let projected_schema = project_schema(&table_schema, projection)?;
129 let sql = RemoteDbType::Sqlite.rewrite_query(source, unparsed_filters, limit);
130 debug!("[remote-table] executing sqlite query: {sql}");
131
132 let (tx, mut rx) = tokio::sync::mpsc::channel::<DFResult<RecordBatch>>(1);
133 let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
134 DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
135 })?;
136
137 let projection = projection.cloned();
138 let chunk_size = conn_options.stream_chunk_size();
139
140 spawn_background_task(tx, conn, sql, table_schema, projection, chunk_size);
141
142 let stream = async_stream::stream! {
143 while let Some(batch) = rx.recv().await {
144 yield batch;
145 }
146 };
147 Ok(Box::pin(RecordBatchStreamAdapter::new(
148 projected_schema,
149 stream,
150 )))
151 }
152
153 async fn insert(
154 &self,
155 _conn_options: &ConnectionOptions,
156 unparser: Arc<dyn Unparse>,
157 table: &[String],
158 remote_schema: RemoteSchemaRef,
159 mut input: SendableRecordBatchStream,
160 ) -> DFResult<usize> {
161 let input_schema = input.schema();
162 let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
163 DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
164 })?;
165
166 let mut total_count = 0;
167 while let Some(batch) = input.next().await {
168 let batch = batch?;
169
170 let mut columns = Vec::with_capacity(remote_schema.fields.len());
171 for i in 0..batch.num_columns() {
172 let input_field = input_schema.field(i);
173 let remote_field = &remote_schema.fields[i];
174 if remote_field.auto_increment && input_field.is_nullable() {
175 continue;
176 }
177
178 let remote_type = remote_schema.fields[i].remote_type.clone();
179 let array = batch.column(i);
180 let column = unparse_array(unparser.as_ref(), array, remote_type)?;
181 columns.push(column);
182 }
183
184 let num_rows = columns[0].len();
185 let num_columns = columns.len();
186
187 let mut values = Vec::with_capacity(num_rows);
188 for i in 0..num_rows {
189 let mut value = Vec::with_capacity(num_columns);
190 for col in columns.iter() {
191 value.push(col[i].as_str());
192 }
193 values.push(format!("({})", value.join(",")));
194 }
195
196 let mut col_names = Vec::with_capacity(remote_schema.fields.len());
197 for (remote_field, input_field) in
198 remote_schema.fields.iter().zip(input_schema.fields.iter())
199 {
200 if remote_field.auto_increment && input_field.is_nullable() {
201 continue;
202 }
203 col_names.push(RemoteDbType::Sqlite.sql_identifier(&remote_field.name));
204 }
205
206 let sql = format!(
207 "INSERT INTO {} ({}) VALUES {}",
208 RemoteDbType::Sqlite.sql_table_name(table),
209 col_names.join(","),
210 values.join(",")
211 );
212
213 let count = conn.execute(&sql, []).map_err(|e| {
214 DataFusionError::Execution(format!(
215 "Failed to execute insert statement on sqlite: {e:?}, sql: {sql}"
216 ))
217 })?;
218 total_count += count;
219 }
220
221 Ok(total_count)
222 }
223}
224
225#[derive(Debug)]
226struct OwnedColumn {
227 name: String,
228 decl_type: Option<String>,
229}
230
231fn sqlite_col_to_owned_col(sqlite_col: &Column) -> OwnedColumn {
232 OwnedColumn {
233 name: sqlite_col.name().to_string(),
234 decl_type: sqlite_col.decl_type().map(|x| x.to_string()),
235 }
236}
237
238fn decl_type_to_remote_type(decl_type: &str) -> DFResult<SqliteType> {
239 if [
240 "tinyint", "smallint", "int", "integer", "bigint", "int2", "int4", "int8",
241 ]
242 .contains(&decl_type)
243 {
244 return Ok(SqliteType::Integer);
245 }
246 if ["real", "float", "double", "numeric"].contains(&decl_type) {
247 return Ok(SqliteType::Real);
248 }
249 if decl_type.starts_with("real") || decl_type.starts_with("numeric") {
250 return Ok(SqliteType::Real);
251 }
252 if ["text", "varchar", "char", "string"].contains(&decl_type) {
253 return Ok(SqliteType::Text);
254 }
255 if decl_type.starts_with("char")
256 || decl_type.starts_with("varchar")
257 || decl_type.starts_with("text")
258 {
259 return Ok(SqliteType::Text);
260 }
261 if ["binary", "varbinary", "tinyblob", "blob"].contains(&decl_type) {
262 return Ok(SqliteType::Blob);
263 }
264 if decl_type.starts_with("binary") || decl_type.starts_with("varbinary") {
265 return Ok(SqliteType::Blob);
266 }
267 Err(DataFusionError::NotImplemented(format!(
268 "Unsupported sqlite decl type: {decl_type}",
269 )))
270}
271
272fn build_remote_schema_for_table(mut rows: Rows) -> DFResult<RemoteSchema> {
273 let mut remote_fields = vec![];
274 while let Some(row) = rows.next().map_err(|e| {
275 DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
276 })? {
277 let name = row.get::<_, String>(1).map_err(|e| {
278 DataFusionError::Execution(format!("Failed to get col name from sqlite row: {e:?}"))
279 })?;
280 let decl_type = row.get::<_, String>(2).map_err(|e| {
281 DataFusionError::Execution(format!("Failed to get decl type from sqlite row: {e:?}"))
282 })?;
283 let remote_type = decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?;
284 let nullable = row.get::<_, i64>(3).map_err(|e| {
285 DataFusionError::Execution(format!("Failed to get nullable from sqlite row: {e:?}"))
286 })? == 0;
287 remote_fields.push(RemoteField::new(
288 &name,
289 RemoteType::Sqlite(remote_type),
290 nullable,
291 ));
292 }
293 Ok(RemoteSchema::new(remote_fields))
294}
295
296fn build_remote_schema_for_query(
297 columns: &[OwnedColumn],
298 mut rows: Rows,
299) -> DFResult<RemoteSchema> {
300 let mut remote_field_map = HashMap::with_capacity(columns.len());
301 let mut unknown_cols = vec![];
302 for (col_idx, col) in columns.iter().enumerate() {
303 if let Some(decl_type) = &col.decl_type {
304 let remote_type =
305 RemoteType::Sqlite(decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?);
306 remote_field_map.insert(col_idx, RemoteField::new(&col.name, remote_type, true));
307 } else {
308 unknown_cols.push(col_idx);
310 }
311 }
312
313 if !unknown_cols.is_empty() {
314 while let Some(row) = rows.next().map_err(|e| {
315 DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
316 })? {
317 let mut to_be_removed = vec![];
318 for col_idx in unknown_cols.iter() {
319 let value_ref = row.get_ref(*col_idx).map_err(|e| {
320 DataFusionError::Execution(format!(
321 "Failed to get value ref for column {col_idx}: {e:?}"
322 ))
323 })?;
324 match value_ref {
325 ValueRef::Null => {}
326 ValueRef::Integer(_) => {
327 remote_field_map.insert(
328 *col_idx,
329 RemoteField::new(
330 columns[*col_idx].name.clone(),
331 RemoteType::Sqlite(SqliteType::Integer),
332 true,
333 ),
334 );
335 to_be_removed.push(*col_idx);
336 }
337 ValueRef::Real(_) => {
338 remote_field_map.insert(
339 *col_idx,
340 RemoteField::new(
341 columns[*col_idx].name.clone(),
342 RemoteType::Sqlite(SqliteType::Real),
343 true,
344 ),
345 );
346 to_be_removed.push(*col_idx);
347 }
348 ValueRef::Text(_) => {
349 remote_field_map.insert(
350 *col_idx,
351 RemoteField::new(
352 columns[*col_idx].name.clone(),
353 RemoteType::Sqlite(SqliteType::Text),
354 true,
355 ),
356 );
357 to_be_removed.push(*col_idx);
358 }
359 ValueRef::Blob(_) => {
360 remote_field_map.insert(
361 *col_idx,
362 RemoteField::new(
363 columns[*col_idx].name.clone(),
364 RemoteType::Sqlite(SqliteType::Blob),
365 true,
366 ),
367 );
368 to_be_removed.push(*col_idx);
369 }
370 }
371 }
372 for col_idx in to_be_removed.iter() {
373 unknown_cols.retain(|&x| x != *col_idx);
374 }
375 if unknown_cols.is_empty() {
376 break;
377 }
378 }
379 }
380
381 if !unknown_cols.is_empty() {
382 return Err(DataFusionError::NotImplemented(format!(
383 "Failed to infer sqlite decl type for columns: {unknown_cols:?}"
384 )));
385 }
386 let remote_fields = remote_field_map
387 .into_iter()
388 .sorted_by_key(|entry| entry.0)
389 .map(|entry| entry.1)
390 .collect::<Vec<_>>();
391 Ok(RemoteSchema::new(remote_fields))
392}
393
394fn spawn_background_task(
395 tx: tokio::sync::mpsc::Sender<DFResult<RecordBatch>>,
396 conn: rusqlite::Connection,
397 sql: String,
398 table_schema: SchemaRef,
399 projection: Option<Vec<usize>>,
400 chunk_size: usize,
401) {
402 std::thread::spawn(move || {
403 let runtime = match tokio::runtime::Builder::new_current_thread().build() {
404 Ok(runtime) => runtime,
405 Err(e) => {
406 error!("Failed to create tokio runtime to run sqlite query: {e:?}");
407 return;
408 }
409 };
410 let local_set = tokio::task::LocalSet::new();
411 local_set.block_on(&runtime, async move {
412 let mut stmt = match conn.prepare(&sql) {
413 Ok(stmt) => stmt,
414 Err(e) => {
415 let _ = tx
416 .send(Err(DataFusionError::Execution(format!(
417 "Failed to prepare sqlite statement: {e:?}"
418 ))))
419 .await;
420 return;
421 }
422 };
423 let columns: Vec<OwnedColumn> =
424 stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
425 let mut rows = match stmt.query([]) {
426 Ok(rows) => rows,
427 Err(e) => {
428 let _ = tx
429 .send(Err(DataFusionError::Execution(format!(
430 "Failed to query sqlite statement: {e:?}"
431 ))))
432 .await;
433 return;
434 }
435 };
436
437 loop {
438 let (batch, is_empty) = match rows_to_batch(
439 &mut rows,
440 &table_schema,
441 &columns,
442 projection.as_ref(),
443 chunk_size,
444 ) {
445 Ok((batch, is_empty)) => (batch, is_empty),
446 Err(e) => {
447 let _ = tx
448 .send(Err(DataFusionError::Execution(format!(
449 "Failed to convert rows to batch: {e:?}"
450 ))))
451 .await;
452 return;
453 }
454 };
455 if is_empty {
456 break;
457 }
458 if tx.send(Ok(batch)).await.is_err() {
459 return;
460 }
461 }
462 });
463 });
464}
465
466fn rows_to_batch(
467 rows: &mut Rows,
468 table_schema: &SchemaRef,
469 columns: &[OwnedColumn],
470 projection: Option<&Vec<usize>>,
471 chunk_size: usize,
472) -> DFResult<(RecordBatch, bool)> {
473 let projected_schema = project_schema(table_schema, projection)?;
474 let mut array_builders = vec![];
475 for field in table_schema.fields() {
476 let builder = make_builder(field.data_type(), 1000);
477 array_builders.push(builder);
478 }
479
480 let mut is_empty = true;
481 let mut row_count = 0;
482 while let Some(row) = rows.next().map_err(|e| {
483 DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
484 })? {
485 is_empty = false;
486 row_count += 1;
487 append_rows_to_array_builders(
488 row,
489 table_schema,
490 columns,
491 projection,
492 array_builders.as_mut_slice(),
493 )?;
494 if row_count >= chunk_size {
495 break;
496 }
497 }
498
499 let projected_columns = array_builders
500 .into_iter()
501 .enumerate()
502 .filter(|(idx, _)| projections_contains(projection, *idx))
503 .map(|(_, mut builder)| builder.finish())
504 .collect::<Vec<ArrayRef>>();
505 let options = RecordBatchOptions::new().with_row_count(Some(row_count));
506 Ok((
507 RecordBatch::try_new_with_options(projected_schema, projected_columns, &options)?,
508 is_empty,
509 ))
510}
511
512macro_rules! handle_primitive_type {
513 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
514 let builder = $builder
515 .as_any_mut()
516 .downcast_mut::<$builder_ty>()
517 .unwrap_or_else(|| {
518 panic!(
519 "Failed to downcast builder to {} for {:?} and {:?}",
520 stringify!($builder_ty),
521 $field,
522 $col
523 )
524 });
525
526 let v: Option<$value_ty> = $row.get($index).map_err(|e| {
527 DataFusionError::Execution(format!(
528 "Failed to get optional {} value for {:?} and {:?}: {e:?}",
529 stringify!($value_ty),
530 $field,
531 $col
532 ))
533 })?;
534
535 match v {
536 Some(v) => builder.append_value(v),
537 None => builder.append_null(),
538 }
539 }};
540}
541
542fn append_rows_to_array_builders(
543 row: &Row,
544 table_schema: &SchemaRef,
545 columns: &[OwnedColumn],
546 projection: Option<&Vec<usize>>,
547 array_builders: &mut [Box<dyn ArrayBuilder>],
548) -> DFResult<()> {
549 for (idx, field) in table_schema.fields.iter().enumerate() {
550 if !projections_contains(projection, idx) {
551 continue;
552 }
553 let builder = &mut array_builders[idx];
554 let col = columns.get(idx);
555 match field.data_type() {
556 DataType::Null => {
557 let builder = builder
558 .as_any_mut()
559 .downcast_mut::<NullBuilder>()
560 .expect("Failed to downcast builder to NullBuilder");
561 builder.append_null();
562 }
563 DataType::Int32 => {
564 handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx);
565 }
566 DataType::Int64 => {
567 handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx);
568 }
569 DataType::Float64 => {
570 handle_primitive_type!(builder, field, col, Float64Builder, f64, row, idx);
571 }
572 DataType::Utf8 => {
573 handle_primitive_type!(builder, field, col, StringBuilder, String, row, idx);
574 }
575 DataType::Binary => {
576 handle_primitive_type!(builder, field, col, BinaryBuilder, Vec<u8>, row, idx);
577 }
578 _ => {
579 return Err(DataFusionError::NotImplemented(format!(
580 "Unsupported data type {} for col: {:?}",
581 field.data_type(),
582 col
583 )));
584 }
585 }
586 }
587 Ok(())
588}