datafusion_remote_table/connection/
sqlite.rs

1use crate::connection::{RemoteDbType, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, Literalize, Pool, RemoteField, RemoteSchema,
4    RemoteSchemaRef, RemoteSource, RemoteType, SqliteConnectionOptions, SqliteType,
5    literalize_array,
6};
7use datafusion::arrow::array::{
8    ArrayBuilder, ArrayRef, BinaryBuilder, Float64Builder, Int32Builder, Int64Builder, NullBuilder,
9    RecordBatch, RecordBatchOptions, StringBuilder, make_builder,
10};
11use datafusion::arrow::datatypes::{DataType, SchemaRef};
12use datafusion::common::{DataFusionError, project_schema};
13use datafusion::execution::SendableRecordBatchStream;
14use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
15use futures::StreamExt;
16use itertools::Itertools;
17use log::{debug, error};
18use rusqlite::types::ValueRef;
19use rusqlite::{Column, Row, Rows};
20use std::any::Any;
21use std::collections::HashMap;
22use std::path::PathBuf;
23use std::sync::Arc;
24
25#[derive(Debug)]
26pub struct SqlitePool {
27    path: PathBuf,
28}
29
30pub async fn connect_sqlite(options: &SqliteConnectionOptions) -> DFResult<SqlitePool> {
31    let _ = rusqlite::Connection::open(&options.path).map_err(|e| {
32        DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
33    })?;
34    Ok(SqlitePool {
35        path: options.path.clone(),
36    })
37}
38
39#[async_trait::async_trait]
40impl Pool for SqlitePool {
41    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
42        Ok(Arc::new(SqliteConnection {
43            path: self.path.clone(),
44        }))
45    }
46}
47
48#[derive(Debug)]
49pub struct SqliteConnection {
50    path: PathBuf,
51}
52
53#[async_trait::async_trait]
54impl Connection for SqliteConnection {
55    fn as_any(&self) -> &dyn Any {
56        self
57    }
58
59    async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef> {
60        let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
61            DataFusionError::Plan(format!("Failed to open sqlite connection: {e:?}"))
62        })?;
63        match source {
64            RemoteSource::Table(table) => {
65                // TODO missing auto increment, could use sqlparser to parse create table sql
66                let sql = format!(
67                    "PRAGMA table_info({})",
68                    RemoteDbType::Sqlite.sql_table_name(table)
69                );
70                let mut stmt = conn.prepare(&sql).map_err(|e| {
71                    DataFusionError::Plan(format!("Failed to prepare sqlite statement: {e:?}"))
72                })?;
73                let rows = stmt.query([]).map_err(|e| {
74                    DataFusionError::Plan(format!("Failed to query sqlite statement: {e:?}"))
75                })?;
76                let remote_schema = Arc::new(build_remote_schema_for_table(rows)?);
77                Ok(remote_schema)
78            }
79            RemoteSource::Query(_query) => {
80                let sql = RemoteDbType::Sqlite.limit_1_query_if_possible(source);
81                let mut stmt = conn.prepare(&sql).map_err(|e| {
82                    DataFusionError::Plan(format!("Failed to prepare sqlite statement: {e:?}"))
83                })?;
84                let columns: Vec<OwnedColumn> =
85                    stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
86                let rows = stmt.query([]).map_err(|e| {
87                    DataFusionError::Plan(format!("Failed to query sqlite statement: {e:?}"))
88                })?;
89
90                let remote_schema =
91                    Arc::new(build_remote_schema_for_query(columns.as_slice(), rows)?);
92                Ok(remote_schema)
93            }
94        }
95    }
96
97    async fn query(
98        &self,
99        conn_options: &ConnectionOptions,
100        source: &RemoteSource,
101        table_schema: SchemaRef,
102        projection: Option<&Vec<usize>>,
103        unparsed_filters: &[String],
104        limit: Option<usize>,
105    ) -> DFResult<SendableRecordBatchStream> {
106        let projected_schema = project_schema(&table_schema, projection)?;
107        let sql = RemoteDbType::Sqlite.rewrite_query(source, unparsed_filters, limit);
108        debug!("[remote-table] executing sqlite query: {sql}");
109
110        let (tx, mut rx) = tokio::sync::mpsc::channel::<DFResult<RecordBatch>>(1);
111        let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
112            DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
113        })?;
114
115        let projection = projection.cloned();
116        let chunk_size = conn_options.stream_chunk_size();
117
118        spawn_background_task(tx, conn, sql, table_schema, projection, chunk_size);
119
120        let stream = async_stream::stream! {
121            while let Some(batch) = rx.recv().await {
122                yield batch;
123            }
124        };
125        Ok(Box::pin(RecordBatchStreamAdapter::new(
126            projected_schema,
127            stream,
128        )))
129    }
130
131    async fn insert(
132        &self,
133        _conn_options: &ConnectionOptions,
134        literalizer: Arc<dyn Literalize>,
135        table: &[String],
136        remote_schema: RemoteSchemaRef,
137        mut input: SendableRecordBatchStream,
138    ) -> DFResult<usize> {
139        let input_schema = input.schema();
140        let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
141            DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
142        })?;
143
144        let mut total_count = 0;
145        while let Some(batch) = input.next().await {
146            let batch = batch?;
147
148            let mut columns = Vec::with_capacity(remote_schema.fields.len());
149            for i in 0..batch.num_columns() {
150                let input_field = input_schema.field(i);
151                let remote_field = &remote_schema.fields[i];
152                if remote_field.auto_increment && input_field.is_nullable() {
153                    continue;
154                }
155
156                let remote_type = remote_schema.fields[i].remote_type.clone();
157                let array = batch.column(i);
158                let column = literalize_array(literalizer.as_ref(), array, remote_type)?;
159                columns.push(column);
160            }
161
162            let num_rows = columns[0].len();
163            let num_columns = columns.len();
164
165            let mut values = Vec::with_capacity(num_rows);
166            for i in 0..num_rows {
167                let mut value = Vec::with_capacity(num_columns);
168                for col in columns.iter() {
169                    value.push(col[i].as_str());
170                }
171                values.push(format!("({})", value.join(",")));
172            }
173
174            let mut col_names = Vec::with_capacity(remote_schema.fields.len());
175            for (remote_field, input_field) in
176                remote_schema.fields.iter().zip(input_schema.fields.iter())
177            {
178                if remote_field.auto_increment && input_field.is_nullable() {
179                    continue;
180                }
181                col_names.push(RemoteDbType::Sqlite.sql_identifier(&remote_field.name));
182            }
183
184            let sql = format!(
185                "INSERT INTO {} ({}) VALUES {}",
186                RemoteDbType::Sqlite.sql_table_name(table),
187                col_names.join(","),
188                values.join(",")
189            );
190
191            let count = conn.execute(&sql, []).map_err(|e| {
192                DataFusionError::Execution(format!(
193                    "Failed to execute insert statement on sqlite: {e:?}, sql: {sql}"
194                ))
195            })?;
196            total_count += count;
197        }
198
199        Ok(total_count)
200    }
201}
202
203#[derive(Debug)]
204struct OwnedColumn {
205    name: String,
206    decl_type: Option<String>,
207}
208
209fn sqlite_col_to_owned_col(sqlite_col: &Column) -> OwnedColumn {
210    OwnedColumn {
211        name: sqlite_col.name().to_string(),
212        decl_type: sqlite_col.decl_type().map(|x| x.to_string()),
213    }
214}
215
216fn decl_type_to_remote_type(decl_type: &str) -> DFResult<SqliteType> {
217    if [
218        "tinyint", "smallint", "int", "integer", "bigint", "int2", "int4", "int8",
219    ]
220    .contains(&decl_type)
221    {
222        return Ok(SqliteType::Integer);
223    }
224    if ["real", "float", "double", "numeric"].contains(&decl_type) {
225        return Ok(SqliteType::Real);
226    }
227    if decl_type.starts_with("real") || decl_type.starts_with("numeric") {
228        return Ok(SqliteType::Real);
229    }
230    if ["text", "varchar", "char", "string"].contains(&decl_type) {
231        return Ok(SqliteType::Text);
232    }
233    if decl_type.starts_with("char")
234        || decl_type.starts_with("varchar")
235        || decl_type.starts_with("text")
236    {
237        return Ok(SqliteType::Text);
238    }
239    if ["binary", "varbinary", "tinyblob", "blob"].contains(&decl_type) {
240        return Ok(SqliteType::Blob);
241    }
242    if decl_type.starts_with("binary") || decl_type.starts_with("varbinary") {
243        return Ok(SqliteType::Blob);
244    }
245    Err(DataFusionError::NotImplemented(format!(
246        "Unsupported sqlite decl type: {decl_type}",
247    )))
248}
249
250fn build_remote_schema_for_table(mut rows: Rows) -> DFResult<RemoteSchema> {
251    let mut remote_fields = vec![];
252    while let Some(row) = rows.next().map_err(|e| {
253        DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
254    })? {
255        let name = row.get::<_, String>(1).map_err(|e| {
256            DataFusionError::Execution(format!("Failed to get col name from sqlite row: {e:?}"))
257        })?;
258        let decl_type = row.get::<_, String>(2).map_err(|e| {
259            DataFusionError::Execution(format!("Failed to get decl type from sqlite row: {e:?}"))
260        })?;
261        let remote_type = decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?;
262        let nullable = row.get::<_, i64>(3).map_err(|e| {
263            DataFusionError::Execution(format!("Failed to get nullable from sqlite row: {e:?}"))
264        })? == 0;
265        remote_fields.push(RemoteField::new(
266            &name,
267            RemoteType::Sqlite(remote_type),
268            nullable,
269        ));
270    }
271    Ok(RemoteSchema::new(remote_fields))
272}
273
274fn build_remote_schema_for_query(
275    columns: &[OwnedColumn],
276    mut rows: Rows,
277) -> DFResult<RemoteSchema> {
278    let mut remote_field_map = HashMap::with_capacity(columns.len());
279    let mut unknown_cols = vec![];
280    for (col_idx, col) in columns.iter().enumerate() {
281        if let Some(decl_type) = &col.decl_type {
282            let remote_type =
283                RemoteType::Sqlite(decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?);
284            remote_field_map.insert(col_idx, RemoteField::new(&col.name, remote_type, true));
285        } else {
286            // None for expressions
287            unknown_cols.push(col_idx);
288        }
289    }
290
291    if !unknown_cols.is_empty() {
292        while let Some(row) = rows.next().map_err(|e| {
293            DataFusionError::Plan(format!("Failed to get next row from sqlite: {e:?}"))
294        })? {
295            let mut to_be_removed = vec![];
296            for col_idx in unknown_cols.iter() {
297                let value_ref = row.get_ref(*col_idx).map_err(|e| {
298                    DataFusionError::Plan(format!(
299                        "Failed to get value ref for column {col_idx}: {e:?}"
300                    ))
301                })?;
302                match value_ref {
303                    ValueRef::Null => {}
304                    ValueRef::Integer(_) => {
305                        remote_field_map.insert(
306                            *col_idx,
307                            RemoteField::new(
308                                columns[*col_idx].name.clone(),
309                                RemoteType::Sqlite(SqliteType::Integer),
310                                true,
311                            ),
312                        );
313                        to_be_removed.push(*col_idx);
314                    }
315                    ValueRef::Real(_) => {
316                        remote_field_map.insert(
317                            *col_idx,
318                            RemoteField::new(
319                                columns[*col_idx].name.clone(),
320                                RemoteType::Sqlite(SqliteType::Real),
321                                true,
322                            ),
323                        );
324                        to_be_removed.push(*col_idx);
325                    }
326                    ValueRef::Text(_) => {
327                        remote_field_map.insert(
328                            *col_idx,
329                            RemoteField::new(
330                                columns[*col_idx].name.clone(),
331                                RemoteType::Sqlite(SqliteType::Text),
332                                true,
333                            ),
334                        );
335                        to_be_removed.push(*col_idx);
336                    }
337                    ValueRef::Blob(_) => {
338                        remote_field_map.insert(
339                            *col_idx,
340                            RemoteField::new(
341                                columns[*col_idx].name.clone(),
342                                RemoteType::Sqlite(SqliteType::Blob),
343                                true,
344                            ),
345                        );
346                        to_be_removed.push(*col_idx);
347                    }
348                }
349            }
350            for col_idx in to_be_removed.iter() {
351                unknown_cols.retain(|&x| x != *col_idx);
352            }
353            if unknown_cols.is_empty() {
354                break;
355            }
356        }
357    }
358
359    if !unknown_cols.is_empty() {
360        return Err(DataFusionError::Plan(format!(
361            "Failed to infer sqlite decl type for columns: {}",
362            unknown_cols
363                .iter()
364                .map(|idx| &columns[*idx].name)
365                .join(", ")
366        )));
367    }
368    let remote_fields = remote_field_map
369        .into_iter()
370        .sorted_by_key(|entry| entry.0)
371        .map(|entry| entry.1)
372        .collect::<Vec<_>>();
373    Ok(RemoteSchema::new(remote_fields))
374}
375
376fn spawn_background_task(
377    tx: tokio::sync::mpsc::Sender<DFResult<RecordBatch>>,
378    conn: rusqlite::Connection,
379    sql: String,
380    table_schema: SchemaRef,
381    projection: Option<Vec<usize>>,
382    chunk_size: usize,
383) {
384    std::thread::spawn(move || {
385        let runtime = match tokio::runtime::Builder::new_current_thread().build() {
386            Ok(runtime) => runtime,
387            Err(e) => {
388                error!("Failed to create tokio runtime to run sqlite query: {e:?}");
389                return;
390            }
391        };
392        let local_set = tokio::task::LocalSet::new();
393        local_set.block_on(&runtime, async move {
394            let mut stmt = match conn.prepare(&sql) {
395                Ok(stmt) => stmt,
396                Err(e) => {
397                    let _ = tx
398                        .send(Err(DataFusionError::Execution(format!(
399                            "Failed to prepare sqlite statement: {e:?}"
400                        ))))
401                        .await;
402                    return;
403                }
404            };
405            let columns: Vec<OwnedColumn> =
406                stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
407            let mut rows = match stmt.query([]) {
408                Ok(rows) => rows,
409                Err(e) => {
410                    let _ = tx
411                        .send(Err(DataFusionError::Execution(format!(
412                            "Failed to query sqlite statement: {e:?}"
413                        ))))
414                        .await;
415                    return;
416                }
417            };
418
419            loop {
420                let (batch, is_empty) = match rows_to_batch(
421                    &mut rows,
422                    &table_schema,
423                    &columns,
424                    projection.as_ref(),
425                    chunk_size,
426                ) {
427                    Ok((batch, is_empty)) => (batch, is_empty),
428                    Err(e) => {
429                        let _ = tx
430                            .send(Err(DataFusionError::Execution(format!(
431                                "Failed to convert rows to batch: {e:?}"
432                            ))))
433                            .await;
434                        return;
435                    }
436                };
437                if is_empty {
438                    break;
439                }
440                if tx.send(Ok(batch)).await.is_err() {
441                    return;
442                }
443            }
444        });
445    });
446}
447
448fn rows_to_batch(
449    rows: &mut Rows,
450    table_schema: &SchemaRef,
451    columns: &[OwnedColumn],
452    projection: Option<&Vec<usize>>,
453    chunk_size: usize,
454) -> DFResult<(RecordBatch, bool)> {
455    let projected_schema = project_schema(table_schema, projection)?;
456    let mut array_builders = vec![];
457    for field in table_schema.fields() {
458        let builder = make_builder(field.data_type(), 1000);
459        array_builders.push(builder);
460    }
461
462    let mut is_empty = true;
463    let mut row_count = 0;
464    while let Some(row) = rows.next().map_err(|e| {
465        DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
466    })? {
467        is_empty = false;
468        row_count += 1;
469        append_rows_to_array_builders(
470            row,
471            table_schema,
472            columns,
473            projection,
474            array_builders.as_mut_slice(),
475        )?;
476        if row_count >= chunk_size {
477            break;
478        }
479    }
480
481    let projected_columns = array_builders
482        .into_iter()
483        .enumerate()
484        .filter(|(idx, _)| projections_contains(projection, *idx))
485        .map(|(_, mut builder)| builder.finish())
486        .collect::<Vec<ArrayRef>>();
487    let options = RecordBatchOptions::new().with_row_count(Some(row_count));
488    Ok((
489        RecordBatch::try_new_with_options(projected_schema, projected_columns, &options)?,
490        is_empty,
491    ))
492}
493
494macro_rules! handle_primitive_type {
495    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
496        let builder = $builder
497            .as_any_mut()
498            .downcast_mut::<$builder_ty>()
499            .unwrap_or_else(|| {
500                panic!(
501                    "Failed to downcast builder to {} for {:?} and {:?}",
502                    stringify!($builder_ty),
503                    $field,
504                    $col
505                )
506            });
507
508        let v: Option<$value_ty> = $row.get($index).map_err(|e| {
509            DataFusionError::Execution(format!(
510                "Failed to get optional {} value for {:?} and {:?}: {e:?}",
511                stringify!($value_ty),
512                $field,
513                $col
514            ))
515        })?;
516
517        match v {
518            Some(v) => builder.append_value(v),
519            None => builder.append_null(),
520        }
521    }};
522}
523
524fn append_rows_to_array_builders(
525    row: &Row,
526    table_schema: &SchemaRef,
527    columns: &[OwnedColumn],
528    projection: Option<&Vec<usize>>,
529    array_builders: &mut [Box<dyn ArrayBuilder>],
530) -> DFResult<()> {
531    for (idx, field) in table_schema.fields.iter().enumerate() {
532        if !projections_contains(projection, idx) {
533            continue;
534        }
535        let builder = &mut array_builders[idx];
536        let col = columns.get(idx);
537        match field.data_type() {
538            DataType::Null => {
539                let builder = builder
540                    .as_any_mut()
541                    .downcast_mut::<NullBuilder>()
542                    .expect("Failed to downcast builder to NullBuilder");
543                builder.append_null();
544            }
545            DataType::Int32 => {
546                handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx);
547            }
548            DataType::Int64 => {
549                handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx);
550            }
551            DataType::Float64 => {
552                handle_primitive_type!(builder, field, col, Float64Builder, f64, row, idx);
553            }
554            DataType::Utf8 => {
555                handle_primitive_type!(builder, field, col, StringBuilder, String, row, idx);
556            }
557            DataType::Binary => {
558                handle_primitive_type!(builder, field, col, BinaryBuilder, Vec<u8>, row, idx);
559            }
560            _ => {
561                return Err(DataFusionError::NotImplemented(format!(
562                    "Unsupported data type {} for col: {:?}",
563                    field.data_type(),
564                    col
565                )));
566            }
567        }
568    }
569    Ok(())
570}