datafusion_remote_table/connection/
sqlite.rs

1use crate::connection::{RemoteDbType, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, Literalize, Pool, PoolState, RemoteField,
4    RemoteSchema, RemoteSchemaRef, RemoteSource, RemoteType, SqliteConnectionOptions, SqliteType,
5    literalize_array,
6};
7use arrow::array::{
8    ArrayBuilder, ArrayRef, BinaryBuilder, BinaryViewBuilder, Float64Builder, Int32Builder,
9    Int64Builder, NullBuilder, RecordBatch, RecordBatchOptions, StringBuilder, StringViewBuilder,
10    make_builder,
11};
12use arrow::datatypes::{DataType, SchemaRef};
13use datafusion_common::{DataFusionError, project_schema};
14use datafusion_execution::SendableRecordBatchStream;
15use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
16use itertools::Itertools;
17use log::{debug, error};
18use rusqlite::types::ValueRef;
19use rusqlite::{Column, Row, Rows};
20use std::any::Any;
21use std::collections::HashMap;
22use std::path::PathBuf;
23use std::sync::Arc;
24use std::sync::atomic::{AtomicUsize, Ordering};
25
26#[derive(Debug)]
27pub struct SqlitePool {
28    path: PathBuf,
29    connections: Arc<AtomicUsize>,
30}
31
32pub async fn connect_sqlite(options: &SqliteConnectionOptions) -> DFResult<SqlitePool> {
33    let _ = rusqlite::Connection::open(&options.path).map_err(|e| {
34        DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
35    })?;
36    Ok(SqlitePool {
37        path: options.path.clone(),
38        connections: Arc::new(AtomicUsize::new(0)),
39    })
40}
41
42#[async_trait::async_trait]
43impl Pool for SqlitePool {
44    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
45        self.connections.fetch_add(1, Ordering::SeqCst);
46        Ok(Arc::new(SqliteConnection {
47            path: self.path.clone(),
48            pool_connections: self.connections.clone(),
49        }))
50    }
51
52    async fn state(&self) -> DFResult<PoolState> {
53        Ok(PoolState {
54            connections: self.connections.load(Ordering::SeqCst),
55            idle_connections: 0,
56        })
57    }
58}
59
60#[derive(Debug)]
61pub struct SqliteConnection {
62    path: PathBuf,
63    pool_connections: Arc<AtomicUsize>,
64}
65
66impl Drop for SqliteConnection {
67    fn drop(&mut self) {
68        self.pool_connections.fetch_sub(1, Ordering::SeqCst);
69    }
70}
71
72#[async_trait::async_trait]
73impl Connection for SqliteConnection {
74    fn as_any(&self) -> &dyn Any {
75        self
76    }
77
78    async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef> {
79        let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
80            DataFusionError::Plan(format!("Failed to open sqlite connection: {e:?}"))
81        })?;
82        match source {
83            RemoteSource::Table(table) => {
84                // TODO missing auto increment, could use sqlparser to parse create table sql
85                let sql = format!(
86                    "PRAGMA table_info({})",
87                    RemoteDbType::Sqlite.sql_table_name(table)
88                );
89                let mut stmt = conn.prepare(&sql).map_err(|e| {
90                    DataFusionError::Plan(format!("Failed to prepare sqlite statement: {e:?}"))
91                })?;
92                let rows = stmt.query([]).map_err(|e| {
93                    DataFusionError::Plan(format!("Failed to query sqlite statement: {e:?}"))
94                })?;
95                let remote_schema = Arc::new(build_remote_schema_for_table(rows)?);
96                Ok(remote_schema)
97            }
98            RemoteSource::Query(_query) => {
99                let sql = RemoteDbType::Sqlite.limit_1_query_if_possible(source);
100                let mut stmt = conn.prepare(&sql).map_err(|e| {
101                    DataFusionError::Plan(format!("Failed to prepare sqlite statement: {e:?}"))
102                })?;
103                let columns: Vec<OwnedColumn> =
104                    stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
105                let rows = stmt.query([]).map_err(|e| {
106                    DataFusionError::Plan(format!("Failed to query sqlite statement: {e:?}"))
107                })?;
108
109                let remote_schema =
110                    Arc::new(build_remote_schema_for_query(columns.as_slice(), rows)?);
111                Ok(remote_schema)
112            }
113        }
114    }
115
116    async fn query(
117        &self,
118        conn_options: &ConnectionOptions,
119        source: &RemoteSource,
120        table_schema: SchemaRef,
121        projection: Option<&Vec<usize>>,
122        unparsed_filters: &[String],
123        limit: Option<usize>,
124    ) -> DFResult<SendableRecordBatchStream> {
125        let projected_schema = project_schema(&table_schema, projection)?;
126        let sql = RemoteDbType::Sqlite.rewrite_query(source, unparsed_filters, limit);
127        debug!("[remote-table] executing sqlite query: {sql}");
128
129        let (tx, mut rx) = tokio::sync::mpsc::channel::<DFResult<RecordBatch>>(1);
130        let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
131            DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
132        })?;
133
134        let projection = projection.cloned();
135        let chunk_size = conn_options.stream_chunk_size();
136
137        spawn_background_task(tx, conn, sql, table_schema, projection, chunk_size);
138
139        let stream = async_stream::stream! {
140            while let Some(batch) = rx.recv().await {
141                yield batch;
142            }
143        };
144        Ok(Box::pin(RecordBatchStreamAdapter::new(
145            projected_schema,
146            stream,
147        )))
148    }
149
150    async fn insert(
151        &self,
152        _conn_options: &ConnectionOptions,
153        literalizer: Arc<dyn Literalize>,
154        table: &[String],
155        remote_schema: RemoteSchemaRef,
156        batch: RecordBatch,
157    ) -> DFResult<usize> {
158        let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
159            DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
160        })?;
161
162        let mut columns = Vec::with_capacity(remote_schema.fields.len());
163        for i in 0..batch.num_columns() {
164            let input_field = batch.schema_ref().field(i);
165            let remote_field = &remote_schema.fields[i];
166            if remote_field.auto_increment && input_field.is_nullable() {
167                continue;
168            }
169
170            let remote_type = remote_schema.fields[i].remote_type.clone();
171            let array = batch.column(i);
172            let column = literalize_array(literalizer.as_ref(), array, remote_type)?;
173            columns.push(column);
174        }
175
176        let num_rows = columns[0].len();
177        let num_columns = columns.len();
178
179        let mut values = Vec::with_capacity(num_rows);
180        for i in 0..num_rows {
181            let mut value = Vec::with_capacity(num_columns);
182            for col in columns.iter() {
183                value.push(col[i].as_str());
184            }
185            values.push(format!("({})", value.join(",")));
186        }
187
188        let mut col_names = Vec::with_capacity(remote_schema.fields.len());
189        for (remote_field, input_field) in remote_schema
190            .fields
191            .iter()
192            .zip(batch.schema_ref().fields.iter())
193        {
194            if remote_field.auto_increment && input_field.is_nullable() {
195                continue;
196            }
197            col_names.push(RemoteDbType::Sqlite.sql_identifier(&remote_field.name));
198        }
199
200        let sql = format!(
201            "INSERT INTO {} ({}) VALUES {}",
202            RemoteDbType::Sqlite.sql_table_name(table),
203            col_names.join(","),
204            values.join(",")
205        );
206
207        let count = conn.execute(&sql, []).map_err(|e| {
208            DataFusionError::Execution(format!(
209                "Failed to execute insert statement on sqlite: {e:?}, sql: {sql}"
210            ))
211        })?;
212
213        Ok(count)
214    }
215}
216
217#[derive(Debug)]
218struct OwnedColumn {
219    name: String,
220    decl_type: Option<String>,
221}
222
223fn sqlite_col_to_owned_col(sqlite_col: &Column) -> OwnedColumn {
224    OwnedColumn {
225        name: sqlite_col.name().to_string(),
226        decl_type: sqlite_col.decl_type().map(|x| x.to_string()),
227    }
228}
229
230fn decl_type_to_remote_type(decl_type: &str) -> DFResult<SqliteType> {
231    if [
232        "tinyint", "smallint", "int", "integer", "bigint", "int2", "int4", "int8",
233    ]
234    .contains(&decl_type)
235    {
236        return Ok(SqliteType::Integer);
237    }
238    if ["real", "float", "double", "numeric"].contains(&decl_type) {
239        return Ok(SqliteType::Real);
240    }
241    if decl_type.starts_with("real") || decl_type.starts_with("numeric") {
242        return Ok(SqliteType::Real);
243    }
244    if ["text", "varchar", "char", "string"].contains(&decl_type) {
245        return Ok(SqliteType::Text);
246    }
247    if decl_type.starts_with("char")
248        || decl_type.starts_with("varchar")
249        || decl_type.starts_with("text")
250    {
251        return Ok(SqliteType::Text);
252    }
253    if ["binary", "varbinary", "tinyblob", "blob"].contains(&decl_type) {
254        return Ok(SqliteType::Blob);
255    }
256    if decl_type.starts_with("binary") || decl_type.starts_with("varbinary") {
257        return Ok(SqliteType::Blob);
258    }
259    Err(DataFusionError::NotImplemented(format!(
260        "Unsupported sqlite decl type: {decl_type}",
261    )))
262}
263
264fn build_remote_schema_for_table(mut rows: Rows) -> DFResult<RemoteSchema> {
265    let mut remote_fields = vec![];
266    while let Some(row) = rows.next().map_err(|e| {
267        DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
268    })? {
269        let name = row.get::<_, String>(1).map_err(|e| {
270            DataFusionError::Execution(format!("Failed to get col name from sqlite row: {e:?}"))
271        })?;
272        let decl_type = row.get::<_, String>(2).map_err(|e| {
273            DataFusionError::Execution(format!("Failed to get decl type from sqlite row: {e:?}"))
274        })?;
275        let remote_type = decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?;
276        let nullable = row.get::<_, i64>(3).map_err(|e| {
277            DataFusionError::Execution(format!("Failed to get nullable from sqlite row: {e:?}"))
278        })? == 0;
279        remote_fields.push(RemoteField::new(
280            &name,
281            RemoteType::Sqlite(remote_type),
282            nullable,
283        ));
284    }
285    Ok(RemoteSchema::new(remote_fields))
286}
287
288fn build_remote_schema_for_query(
289    columns: &[OwnedColumn],
290    mut rows: Rows,
291) -> DFResult<RemoteSchema> {
292    let mut remote_field_map = HashMap::with_capacity(columns.len());
293    let mut unknown_cols = vec![];
294    for (col_idx, col) in columns.iter().enumerate() {
295        if let Some(decl_type) = &col.decl_type {
296            let remote_type =
297                RemoteType::Sqlite(decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?);
298            remote_field_map.insert(col_idx, RemoteField::new(&col.name, remote_type, true));
299        } else {
300            // None for expressions
301            unknown_cols.push(col_idx);
302        }
303    }
304
305    if !unknown_cols.is_empty() {
306        while let Some(row) = rows.next().map_err(|e| {
307            DataFusionError::Plan(format!("Failed to get next row from sqlite: {e:?}"))
308        })? {
309            let mut to_be_removed = vec![];
310            for col_idx in unknown_cols.iter() {
311                let value_ref = row.get_ref(*col_idx).map_err(|e| {
312                    DataFusionError::Plan(format!(
313                        "Failed to get value ref for column {col_idx}: {e:?}"
314                    ))
315                })?;
316                match value_ref {
317                    ValueRef::Null => {}
318                    ValueRef::Integer(_) => {
319                        remote_field_map.insert(
320                            *col_idx,
321                            RemoteField::new(
322                                columns[*col_idx].name.clone(),
323                                RemoteType::Sqlite(SqliteType::Integer),
324                                true,
325                            ),
326                        );
327                        to_be_removed.push(*col_idx);
328                    }
329                    ValueRef::Real(_) => {
330                        remote_field_map.insert(
331                            *col_idx,
332                            RemoteField::new(
333                                columns[*col_idx].name.clone(),
334                                RemoteType::Sqlite(SqliteType::Real),
335                                true,
336                            ),
337                        );
338                        to_be_removed.push(*col_idx);
339                    }
340                    ValueRef::Text(_) => {
341                        remote_field_map.insert(
342                            *col_idx,
343                            RemoteField::new(
344                                columns[*col_idx].name.clone(),
345                                RemoteType::Sqlite(SqliteType::Text),
346                                true,
347                            ),
348                        );
349                        to_be_removed.push(*col_idx);
350                    }
351                    ValueRef::Blob(_) => {
352                        remote_field_map.insert(
353                            *col_idx,
354                            RemoteField::new(
355                                columns[*col_idx].name.clone(),
356                                RemoteType::Sqlite(SqliteType::Blob),
357                                true,
358                            ),
359                        );
360                        to_be_removed.push(*col_idx);
361                    }
362                }
363            }
364            for col_idx in to_be_removed.iter() {
365                unknown_cols.retain(|&x| x != *col_idx);
366            }
367            if unknown_cols.is_empty() {
368                break;
369            }
370        }
371    }
372
373    if !unknown_cols.is_empty() {
374        return Err(DataFusionError::Plan(format!(
375            "Failed to infer sqlite decl type for columns: {}",
376            unknown_cols
377                .iter()
378                .map(|idx| &columns[*idx].name)
379                .join(", ")
380        )));
381    }
382    let remote_fields = remote_field_map
383        .into_iter()
384        .sorted_by_key(|entry| entry.0)
385        .map(|entry| entry.1)
386        .collect::<Vec<_>>();
387    Ok(RemoteSchema::new(remote_fields))
388}
389
390fn spawn_background_task(
391    tx: tokio::sync::mpsc::Sender<DFResult<RecordBatch>>,
392    conn: rusqlite::Connection,
393    sql: String,
394    table_schema: SchemaRef,
395    projection: Option<Vec<usize>>,
396    chunk_size: usize,
397) {
398    std::thread::spawn(move || {
399        let runtime = match tokio::runtime::Builder::new_current_thread().build() {
400            Ok(runtime) => runtime,
401            Err(e) => {
402                error!("Failed to create tokio runtime to run sqlite query: {e:?}");
403                return;
404            }
405        };
406        let local_set = tokio::task::LocalSet::new();
407        local_set.block_on(&runtime, async move {
408            let mut stmt = match conn.prepare(&sql) {
409                Ok(stmt) => stmt,
410                Err(e) => {
411                    let _ = tx
412                        .send(Err(DataFusionError::Execution(format!(
413                            "Failed to prepare sqlite statement: {e:?}"
414                        ))))
415                        .await;
416                    return;
417                }
418            };
419            let columns: Vec<OwnedColumn> =
420                stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
421            let mut rows = match stmt.query([]) {
422                Ok(rows) => rows,
423                Err(e) => {
424                    let _ = tx
425                        .send(Err(DataFusionError::Execution(format!(
426                            "Failed to query sqlite statement: {e:?}"
427                        ))))
428                        .await;
429                    return;
430                }
431            };
432
433            loop {
434                let (batch, is_empty) = match rows_to_batch(
435                    &mut rows,
436                    &table_schema,
437                    &columns,
438                    projection.as_ref(),
439                    chunk_size,
440                ) {
441                    Ok((batch, is_empty)) => (batch, is_empty),
442                    Err(e) => {
443                        let _ = tx
444                            .send(Err(DataFusionError::Execution(format!(
445                                "Failed to convert rows to batch: {e:?}"
446                            ))))
447                            .await;
448                        return;
449                    }
450                };
451                if is_empty {
452                    break;
453                }
454                if tx.send(Ok(batch)).await.is_err() {
455                    return;
456                }
457            }
458        });
459    });
460}
461
462fn rows_to_batch(
463    rows: &mut Rows,
464    table_schema: &SchemaRef,
465    columns: &[OwnedColumn],
466    projection: Option<&Vec<usize>>,
467    chunk_size: usize,
468) -> DFResult<(RecordBatch, bool)> {
469    let projected_schema = project_schema(table_schema, projection)?;
470    let mut array_builders = vec![];
471    for field in table_schema.fields() {
472        let builder = make_builder(field.data_type(), 1000);
473        array_builders.push(builder);
474    }
475
476    let mut is_empty = true;
477    let mut row_count = 0;
478    while let Some(row) = rows.next().map_err(|e| {
479        DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
480    })? {
481        is_empty = false;
482        row_count += 1;
483        append_rows_to_array_builders(
484            row,
485            table_schema,
486            columns,
487            projection,
488            array_builders.as_mut_slice(),
489        )?;
490        if row_count >= chunk_size {
491            break;
492        }
493    }
494
495    let projected_columns = array_builders
496        .into_iter()
497        .enumerate()
498        .filter(|(idx, _)| projections_contains(projection, *idx))
499        .map(|(_, mut builder)| builder.finish())
500        .collect::<Vec<ArrayRef>>();
501    let options = RecordBatchOptions::new().with_row_count(Some(row_count));
502    Ok((
503        RecordBatch::try_new_with_options(projected_schema, projected_columns, &options)?,
504        is_empty,
505    ))
506}
507
508macro_rules! handle_primitive_type {
509    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
510        let builder = $builder
511            .as_any_mut()
512            .downcast_mut::<$builder_ty>()
513            .unwrap_or_else(|| {
514                panic!(
515                    "Failed to downcast builder to {} for {:?} and {:?}",
516                    stringify!($builder_ty),
517                    $field,
518                    $col
519                )
520            });
521
522        let v: Option<$value_ty> = $row.get($index).map_err(|e| {
523            DataFusionError::Execution(format!(
524                "Failed to get optional {} value for {:?} and {:?}: {e:?}",
525                stringify!($value_ty),
526                $field,
527                $col
528            ))
529        })?;
530
531        match v {
532            Some(v) => builder.append_value(v),
533            None => builder.append_null(),
534        }
535    }};
536}
537
538fn append_rows_to_array_builders(
539    row: &Row,
540    table_schema: &SchemaRef,
541    columns: &[OwnedColumn],
542    projection: Option<&Vec<usize>>,
543    array_builders: &mut [Box<dyn ArrayBuilder>],
544) -> DFResult<()> {
545    for (idx, field) in table_schema.fields.iter().enumerate() {
546        if !projections_contains(projection, idx) {
547            continue;
548        }
549        let builder = &mut array_builders[idx];
550        let col = columns.get(idx);
551        match field.data_type() {
552            DataType::Null => {
553                let builder = builder
554                    .as_any_mut()
555                    .downcast_mut::<NullBuilder>()
556                    .expect("Failed to downcast builder to NullBuilder");
557                builder.append_null();
558            }
559            DataType::Int32 => {
560                handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx);
561            }
562            DataType::Int64 => {
563                handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx);
564            }
565            DataType::Float64 => {
566                handle_primitive_type!(builder, field, col, Float64Builder, f64, row, idx);
567            }
568            DataType::Utf8 => {
569                handle_primitive_type!(builder, field, col, StringBuilder, String, row, idx);
570            }
571            DataType::Utf8View => {
572                handle_primitive_type!(builder, field, col, StringViewBuilder, String, row, idx);
573            }
574            DataType::Binary => {
575                handle_primitive_type!(builder, field, col, BinaryBuilder, Vec<u8>, row, idx);
576            }
577            DataType::BinaryView => {
578                handle_primitive_type!(builder, field, col, BinaryViewBuilder, Vec<u8>, row, idx);
579            }
580            _ => {
581                return Err(DataFusionError::NotImplemented(format!(
582                    "Unsupported data type {} for col: {:?}",
583                    field.data_type(),
584                    col
585                )));
586            }
587        }
588    }
589    Ok(())
590}