datafusion_remote_table/connection/
sqlite.rs

1use crate::connection::{RemoteDbType, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, Pool, RemoteField, RemoteSchema, RemoteSchemaRef,
4    RemoteType, SqliteType, TableSource, Unparse, unparse_array,
5};
6use datafusion::arrow::array::{
7    ArrayBuilder, ArrayRef, BinaryBuilder, Float64Builder, Int32Builder, Int64Builder, NullBuilder,
8    RecordBatch, RecordBatchOptions, StringBuilder, make_builder,
9};
10use datafusion::arrow::datatypes::{DataType, SchemaRef};
11use datafusion::common::{DataFusionError, project_schema};
12use datafusion::execution::SendableRecordBatchStream;
13use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
14use derive_getters::Getters;
15use derive_with::With;
16use futures::StreamExt;
17use itertools::Itertools;
18use log::{debug, error};
19use rusqlite::types::ValueRef;
20use rusqlite::{Column, Row, Rows};
21use std::any::Any;
22use std::collections::HashMap;
23use std::path::PathBuf;
24use std::sync::Arc;
25
26#[derive(Debug, Clone, With, Getters)]
27pub struct SqliteConnectionOptions {
28    pub path: PathBuf,
29    pub stream_chunk_size: usize,
30}
31
32impl SqliteConnectionOptions {
33    pub fn new(path: PathBuf) -> Self {
34        Self {
35            path,
36            stream_chunk_size: 2048,
37        }
38    }
39}
40
41impl From<SqliteConnectionOptions> for ConnectionOptions {
42    fn from(options: SqliteConnectionOptions) -> Self {
43        ConnectionOptions::Sqlite(options)
44    }
45}
46
47#[derive(Debug)]
48pub struct SqlitePool {
49    path: PathBuf,
50}
51
52pub async fn connect_sqlite(options: &SqliteConnectionOptions) -> DFResult<SqlitePool> {
53    let _ = rusqlite::Connection::open(&options.path).map_err(|e| {
54        DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
55    })?;
56    Ok(SqlitePool {
57        path: options.path.clone(),
58    })
59}
60
61#[async_trait::async_trait]
62impl Pool for SqlitePool {
63    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
64        Ok(Arc::new(SqliteConnection {
65            path: self.path.clone(),
66        }))
67    }
68}
69
70#[derive(Debug)]
71pub struct SqliteConnection {
72    path: PathBuf,
73}
74
75#[async_trait::async_trait]
76impl Connection for SqliteConnection {
77    fn as_any(&self) -> &dyn Any {
78        self
79    }
80
81    async fn infer_schema(&self, source: &TableSource) -> DFResult<RemoteSchemaRef> {
82        let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
83            DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
84        })?;
85        match source {
86            TableSource::Table(table) => {
87                // TODO missing auto increment, could use sqlparser to parse create table sql
88                let sql = format!(
89                    "PRAGMA table_info({})",
90                    RemoteDbType::Sqlite.sql_table_name(table)
91                );
92                let mut stmt = conn.prepare(&sql).map_err(|e| {
93                    DataFusionError::Execution(format!("Failed to prepare sqlite statement: {e:?}"))
94                })?;
95                let rows = stmt.query([]).map_err(|e| {
96                    DataFusionError::Execution(format!("Failed to query sqlite statement: {e:?}"))
97                })?;
98                let remote_schema = Arc::new(build_remote_schema_for_table(rows)?);
99                Ok(remote_schema)
100            }
101            TableSource::Query(_query) => {
102                let sql = RemoteDbType::Sqlite.limit_1_query_if_possible(source);
103                let mut stmt = conn.prepare(&sql).map_err(|e| {
104                    DataFusionError::Execution(format!("Failed to prepare sqlite statement: {e:?}"))
105                })?;
106                let columns: Vec<OwnedColumn> =
107                    stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
108                let rows = stmt.query([]).map_err(|e| {
109                    DataFusionError::Execution(format!("Failed to query sqlite statement: {e:?}"))
110                })?;
111
112                let remote_schema =
113                    Arc::new(build_remote_schema_for_query(columns.as_slice(), rows)?);
114                Ok(remote_schema)
115            }
116        }
117    }
118
119    async fn query(
120        &self,
121        conn_options: &ConnectionOptions,
122        source: &TableSource,
123        table_schema: SchemaRef,
124        projection: Option<&Vec<usize>>,
125        unparsed_filters: &[String],
126        limit: Option<usize>,
127    ) -> DFResult<SendableRecordBatchStream> {
128        let projected_schema = project_schema(&table_schema, projection)?;
129        let sql = RemoteDbType::Sqlite.rewrite_query(source, unparsed_filters, limit);
130        debug!("[remote-table] executing sqlite query: {sql}");
131
132        let (tx, mut rx) = tokio::sync::mpsc::channel::<DFResult<RecordBatch>>(1);
133        let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
134            DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
135        })?;
136
137        let projection = projection.cloned();
138        let chunk_size = conn_options.stream_chunk_size();
139
140        spawn_background_task(tx, conn, sql, table_schema, projection, chunk_size);
141
142        let stream = async_stream::stream! {
143            while let Some(batch) = rx.recv().await {
144                yield batch;
145            }
146        };
147        Ok(Box::pin(RecordBatchStreamAdapter::new(
148            projected_schema,
149            stream,
150        )))
151    }
152
153    async fn insert(
154        &self,
155        _conn_options: &ConnectionOptions,
156        unparser: Arc<dyn Unparse>,
157        table: &[String],
158        remote_schema: RemoteSchemaRef,
159        mut input: SendableRecordBatchStream,
160    ) -> DFResult<usize> {
161        let input_schema = input.schema();
162        let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
163            DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
164        })?;
165
166        let mut total_count = 0;
167        while let Some(batch) = input.next().await {
168            let batch = batch?;
169
170            let mut columns = Vec::with_capacity(remote_schema.fields.len());
171            for i in 0..batch.num_columns() {
172                let input_field = input_schema.field(i);
173                let remote_field = &remote_schema.fields[i];
174                if remote_field.auto_increment && input_field.is_nullable() {
175                    continue;
176                }
177
178                let remote_type = remote_schema.fields[i].remote_type.clone();
179                let array = batch.column(i);
180                let column = unparse_array(unparser.as_ref(), array, remote_type)?;
181                columns.push(column);
182            }
183
184            let num_rows = columns[0].len();
185            let num_columns = columns.len();
186
187            let mut values = Vec::with_capacity(num_rows);
188            for i in 0..num_rows {
189                let mut value = Vec::with_capacity(num_columns);
190                for col in columns.iter() {
191                    value.push(col[i].as_str());
192                }
193                values.push(format!("({})", value.join(",")));
194            }
195
196            let mut col_names = Vec::with_capacity(remote_schema.fields.len());
197            for (remote_field, input_field) in
198                remote_schema.fields.iter().zip(input_schema.fields.iter())
199            {
200                if remote_field.auto_increment && input_field.is_nullable() {
201                    continue;
202                }
203                col_names.push(RemoteDbType::Sqlite.sql_identifier(&remote_field.name));
204            }
205
206            let sql = format!(
207                "INSERT INTO {} ({}) VALUES {}",
208                RemoteDbType::Sqlite.sql_table_name(table),
209                col_names.join(","),
210                values.join(",")
211            );
212
213            let count = conn.execute(&sql, []).map_err(|e| {
214                DataFusionError::Execution(format!(
215                    "Failed to execute insert statement on sqlite: {e:?}, sql: {sql}"
216                ))
217            })?;
218            total_count += count;
219        }
220
221        Ok(total_count)
222    }
223}
224
225#[derive(Debug)]
226struct OwnedColumn {
227    name: String,
228    decl_type: Option<String>,
229}
230
231fn sqlite_col_to_owned_col(sqlite_col: &Column) -> OwnedColumn {
232    OwnedColumn {
233        name: sqlite_col.name().to_string(),
234        decl_type: sqlite_col.decl_type().map(|x| x.to_string()),
235    }
236}
237
238fn decl_type_to_remote_type(decl_type: &str) -> DFResult<SqliteType> {
239    if [
240        "tinyint", "smallint", "int", "integer", "bigint", "int2", "int4", "int8",
241    ]
242    .contains(&decl_type)
243    {
244        return Ok(SqliteType::Integer);
245    }
246    if ["real", "float", "double", "numeric"].contains(&decl_type) {
247        return Ok(SqliteType::Real);
248    }
249    if decl_type.starts_with("real") || decl_type.starts_with("numeric") {
250        return Ok(SqliteType::Real);
251    }
252    if ["text", "varchar", "char", "string"].contains(&decl_type) {
253        return Ok(SqliteType::Text);
254    }
255    if decl_type.starts_with("char")
256        || decl_type.starts_with("varchar")
257        || decl_type.starts_with("text")
258    {
259        return Ok(SqliteType::Text);
260    }
261    if ["binary", "varbinary", "tinyblob", "blob"].contains(&decl_type) {
262        return Ok(SqliteType::Blob);
263    }
264    if decl_type.starts_with("binary") || decl_type.starts_with("varbinary") {
265        return Ok(SqliteType::Blob);
266    }
267    Err(DataFusionError::NotImplemented(format!(
268        "Unsupported sqlite decl type: {decl_type}",
269    )))
270}
271
272fn build_remote_schema_for_table(mut rows: Rows) -> DFResult<RemoteSchema> {
273    let mut remote_fields = vec![];
274    while let Some(row) = rows.next().map_err(|e| {
275        DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
276    })? {
277        let name = row.get::<_, String>(1).map_err(|e| {
278            DataFusionError::Execution(format!("Failed to get col name from sqlite row: {e:?}"))
279        })?;
280        let decl_type = row.get::<_, String>(2).map_err(|e| {
281            DataFusionError::Execution(format!("Failed to get decl type from sqlite row: {e:?}"))
282        })?;
283        let remote_type = decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?;
284        let nullable = row.get::<_, i64>(3).map_err(|e| {
285            DataFusionError::Execution(format!("Failed to get nullable from sqlite row: {e:?}"))
286        })? == 0;
287        remote_fields.push(RemoteField::new(
288            &name,
289            RemoteType::Sqlite(remote_type),
290            nullable,
291        ));
292    }
293    Ok(RemoteSchema::new(remote_fields))
294}
295
296fn build_remote_schema_for_query(
297    columns: &[OwnedColumn],
298    mut rows: Rows,
299) -> DFResult<RemoteSchema> {
300    let mut remote_field_map = HashMap::with_capacity(columns.len());
301    let mut unknown_cols = vec![];
302    for (col_idx, col) in columns.iter().enumerate() {
303        if let Some(decl_type) = &col.decl_type {
304            let remote_type =
305                RemoteType::Sqlite(decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?);
306            remote_field_map.insert(col_idx, RemoteField::new(&col.name, remote_type, true));
307        } else {
308            // None for expressions
309            unknown_cols.push(col_idx);
310        }
311    }
312
313    if !unknown_cols.is_empty() {
314        while let Some(row) = rows.next().map_err(|e| {
315            DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
316        })? {
317            let mut to_be_removed = vec![];
318            for col_idx in unknown_cols.iter() {
319                let value_ref = row.get_ref(*col_idx).map_err(|e| {
320                    DataFusionError::Execution(format!(
321                        "Failed to get value ref for column {col_idx}: {e:?}"
322                    ))
323                })?;
324                match value_ref {
325                    ValueRef::Null => {}
326                    ValueRef::Integer(_) => {
327                        remote_field_map.insert(
328                            *col_idx,
329                            RemoteField::new(
330                                columns[*col_idx].name.clone(),
331                                RemoteType::Sqlite(SqliteType::Integer),
332                                true,
333                            ),
334                        );
335                        to_be_removed.push(*col_idx);
336                    }
337                    ValueRef::Real(_) => {
338                        remote_field_map.insert(
339                            *col_idx,
340                            RemoteField::new(
341                                columns[*col_idx].name.clone(),
342                                RemoteType::Sqlite(SqliteType::Real),
343                                true,
344                            ),
345                        );
346                        to_be_removed.push(*col_idx);
347                    }
348                    ValueRef::Text(_) => {
349                        remote_field_map.insert(
350                            *col_idx,
351                            RemoteField::new(
352                                columns[*col_idx].name.clone(),
353                                RemoteType::Sqlite(SqliteType::Text),
354                                true,
355                            ),
356                        );
357                        to_be_removed.push(*col_idx);
358                    }
359                    ValueRef::Blob(_) => {
360                        remote_field_map.insert(
361                            *col_idx,
362                            RemoteField::new(
363                                columns[*col_idx].name.clone(),
364                                RemoteType::Sqlite(SqliteType::Blob),
365                                true,
366                            ),
367                        );
368                        to_be_removed.push(*col_idx);
369                    }
370                }
371            }
372            for col_idx in to_be_removed.iter() {
373                unknown_cols.retain(|&x| x != *col_idx);
374            }
375            if unknown_cols.is_empty() {
376                break;
377            }
378        }
379    }
380
381    if !unknown_cols.is_empty() {
382        return Err(DataFusionError::NotImplemented(format!(
383            "Failed to infer sqlite decl type for columns: {unknown_cols:?}"
384        )));
385    }
386    let remote_fields = remote_field_map
387        .into_iter()
388        .sorted_by_key(|entry| entry.0)
389        .map(|entry| entry.1)
390        .collect::<Vec<_>>();
391    Ok(RemoteSchema::new(remote_fields))
392}
393
394fn spawn_background_task(
395    tx: tokio::sync::mpsc::Sender<DFResult<RecordBatch>>,
396    conn: rusqlite::Connection,
397    sql: String,
398    table_schema: SchemaRef,
399    projection: Option<Vec<usize>>,
400    chunk_size: usize,
401) {
402    std::thread::spawn(move || {
403        let runtime = match tokio::runtime::Builder::new_current_thread().build() {
404            Ok(runtime) => runtime,
405            Err(e) => {
406                error!("Failed to create tokio runtime to run sqlite query: {e:?}");
407                return;
408            }
409        };
410        let local_set = tokio::task::LocalSet::new();
411        local_set.block_on(&runtime, async move {
412            let mut stmt = match conn.prepare(&sql) {
413                Ok(stmt) => stmt,
414                Err(e) => {
415                    let _ = tx
416                        .send(Err(DataFusionError::Execution(format!(
417                            "Failed to prepare sqlite statement: {e:?}"
418                        ))))
419                        .await;
420                    return;
421                }
422            };
423            let columns: Vec<OwnedColumn> =
424                stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
425            let mut rows = match stmt.query([]) {
426                Ok(rows) => rows,
427                Err(e) => {
428                    let _ = tx
429                        .send(Err(DataFusionError::Execution(format!(
430                            "Failed to query sqlite statement: {e:?}"
431                        ))))
432                        .await;
433                    return;
434                }
435            };
436
437            loop {
438                let (batch, is_empty) = match rows_to_batch(
439                    &mut rows,
440                    &table_schema,
441                    &columns,
442                    projection.as_ref(),
443                    chunk_size,
444                ) {
445                    Ok((batch, is_empty)) => (batch, is_empty),
446                    Err(e) => {
447                        let _ = tx
448                            .send(Err(DataFusionError::Execution(format!(
449                                "Failed to convert rows to batch: {e:?}"
450                            ))))
451                            .await;
452                        return;
453                    }
454                };
455                if is_empty {
456                    break;
457                }
458                if tx.send(Ok(batch)).await.is_err() {
459                    return;
460                }
461            }
462        });
463    });
464}
465
466fn rows_to_batch(
467    rows: &mut Rows,
468    table_schema: &SchemaRef,
469    columns: &[OwnedColumn],
470    projection: Option<&Vec<usize>>,
471    chunk_size: usize,
472) -> DFResult<(RecordBatch, bool)> {
473    let projected_schema = project_schema(table_schema, projection)?;
474    let mut array_builders = vec![];
475    for field in table_schema.fields() {
476        let builder = make_builder(field.data_type(), 1000);
477        array_builders.push(builder);
478    }
479
480    let mut is_empty = true;
481    let mut row_count = 0;
482    while let Some(row) = rows.next().map_err(|e| {
483        DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
484    })? {
485        is_empty = false;
486        row_count += 1;
487        append_rows_to_array_builders(
488            row,
489            table_schema,
490            columns,
491            projection,
492            array_builders.as_mut_slice(),
493        )?;
494        if row_count >= chunk_size {
495            break;
496        }
497    }
498
499    let projected_columns = array_builders
500        .into_iter()
501        .enumerate()
502        .filter(|(idx, _)| projections_contains(projection, *idx))
503        .map(|(_, mut builder)| builder.finish())
504        .collect::<Vec<ArrayRef>>();
505    let options = RecordBatchOptions::new().with_row_count(Some(row_count));
506    Ok((
507        RecordBatch::try_new_with_options(projected_schema, projected_columns, &options)?,
508        is_empty,
509    ))
510}
511
512macro_rules! handle_primitive_type {
513    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
514        let builder = $builder
515            .as_any_mut()
516            .downcast_mut::<$builder_ty>()
517            .unwrap_or_else(|| {
518                panic!(
519                    "Failed to downcast builder to {} for {:?} and {:?}",
520                    stringify!($builder_ty),
521                    $field,
522                    $col
523                )
524            });
525
526        let v: Option<$value_ty> = $row.get($index).map_err(|e| {
527            DataFusionError::Execution(format!(
528                "Failed to get optional {} value for {:?} and {:?}: {e:?}",
529                stringify!($value_ty),
530                $field,
531                $col
532            ))
533        })?;
534
535        match v {
536            Some(v) => builder.append_value(v),
537            None => builder.append_null(),
538        }
539    }};
540}
541
542fn append_rows_to_array_builders(
543    row: &Row,
544    table_schema: &SchemaRef,
545    columns: &[OwnedColumn],
546    projection: Option<&Vec<usize>>,
547    array_builders: &mut [Box<dyn ArrayBuilder>],
548) -> DFResult<()> {
549    for (idx, field) in table_schema.fields.iter().enumerate() {
550        if !projections_contains(projection, idx) {
551            continue;
552        }
553        let builder = &mut array_builders[idx];
554        let col = columns.get(idx);
555        match field.data_type() {
556            DataType::Null => {
557                let builder = builder
558                    .as_any_mut()
559                    .downcast_mut::<NullBuilder>()
560                    .expect("Failed to downcast builder to NullBuilder");
561                builder.append_null();
562            }
563            DataType::Int32 => {
564                handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx);
565            }
566            DataType::Int64 => {
567                handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx);
568            }
569            DataType::Float64 => {
570                handle_primitive_type!(builder, field, col, Float64Builder, f64, row, idx);
571            }
572            DataType::Utf8 => {
573                handle_primitive_type!(builder, field, col, StringBuilder, String, row, idx);
574            }
575            DataType::Binary => {
576                handle_primitive_type!(builder, field, col, BinaryBuilder, Vec<u8>, row, idx);
577            }
578            _ => {
579                return Err(DataFusionError::NotImplemented(format!(
580                    "Unsupported data type {} for col: {:?}",
581                    field.data_type(),
582                    col
583                )));
584            }
585        }
586    }
587    Ok(())
588}