datafusion_remote_table/connection/
sqlite.rs

1use crate::connection::{RemoteDbType, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, Pool, RemoteField, RemoteSchema, RemoteSchemaRef,
4    RemoteType, SqliteType,
5};
6use datafusion::arrow::array::{
7    ArrayBuilder, ArrayRef, BinaryBuilder, Float64Builder, Int32Builder, Int64Builder, NullBuilder,
8    RecordBatch, RecordBatchOptions, StringBuilder, make_builder,
9};
10use datafusion::arrow::datatypes::{DataType, SchemaRef};
11use datafusion::common::{DataFusionError, project_schema};
12use datafusion::execution::SendableRecordBatchStream;
13use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
14use derive_getters::Getters;
15use derive_with::With;
16use itertools::Itertools;
17use log::{debug, error};
18use rusqlite::types::ValueRef;
19use rusqlite::{Column, Row, Rows};
20use std::any::Any;
21use std::collections::HashMap;
22use std::path::PathBuf;
23use std::sync::Arc;
24
25#[derive(Debug, Clone, With, Getters)]
26pub struct SqliteConnectionOptions {
27    pub path: PathBuf,
28    pub stream_chunk_size: usize,
29}
30
31impl SqliteConnectionOptions {
32    pub fn new(path: PathBuf) -> Self {
33        Self {
34            path,
35            stream_chunk_size: 2048,
36        }
37    }
38}
39
40impl From<SqliteConnectionOptions> for ConnectionOptions {
41    fn from(options: SqliteConnectionOptions) -> Self {
42        ConnectionOptions::Sqlite(options)
43    }
44}
45
46#[derive(Debug)]
47pub struct SqlitePool {
48    path: PathBuf,
49}
50
51pub async fn connect_sqlite(options: &SqliteConnectionOptions) -> DFResult<SqlitePool> {
52    let _ = rusqlite::Connection::open(&options.path).map_err(|e| {
53        DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
54    })?;
55    Ok(SqlitePool {
56        path: options.path.clone(),
57    })
58}
59
60#[async_trait::async_trait]
61impl Pool for SqlitePool {
62    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
63        Ok(Arc::new(SqliteConnection {
64            path: self.path.clone(),
65        }))
66    }
67}
68
69#[derive(Debug)]
70pub struct SqliteConnection {
71    path: PathBuf,
72}
73
74#[async_trait::async_trait]
75impl Connection for SqliteConnection {
76    fn as_any(&self) -> &dyn Any {
77        self
78    }
79
80    async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
81        let sql = RemoteDbType::Sqlite.query_limit_1(sql);
82        let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
83            DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
84        })?;
85        let mut stmt = conn.prepare(&sql).map_err(|e| {
86            DataFusionError::Execution(format!("Failed to prepare sqlite statement: {e:?}"))
87        })?;
88        let columns: Vec<OwnedColumn> =
89            stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
90        let rows = stmt.query([]).map_err(|e| {
91            DataFusionError::Execution(format!("Failed to query sqlite statement: {e:?}"))
92        })?;
93
94        let remote_schema = Arc::new(build_remote_schema(columns.as_slice(), rows)?);
95        Ok(remote_schema)
96    }
97
98    async fn query(
99        &self,
100        conn_options: &ConnectionOptions,
101        sql: &str,
102        table_schema: SchemaRef,
103        projection: Option<&Vec<usize>>,
104        unparsed_filters: &[String],
105        limit: Option<usize>,
106    ) -> DFResult<SendableRecordBatchStream> {
107        let projected_schema = project_schema(&table_schema, projection)?;
108        let sql = RemoteDbType::Sqlite.rewrite_query(sql, unparsed_filters, limit);
109        debug!("[remote-table] executing sqlite query: {sql}");
110
111        let (tx, mut rx) = tokio::sync::mpsc::channel::<DFResult<RecordBatch>>(1);
112        let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
113            DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
114        })?;
115
116        let projection = projection.cloned();
117        let chunk_size = conn_options.stream_chunk_size();
118
119        spawn_background_task(tx, conn, sql, table_schema, projection, chunk_size);
120
121        let stream = async_stream::stream! {
122            while let Some(batch) = rx.recv().await {
123                yield batch;
124            }
125        };
126        Ok(Box::pin(RecordBatchStreamAdapter::new(
127            projected_schema,
128            stream,
129        )))
130    }
131}
132
133#[derive(Debug)]
134struct OwnedColumn {
135    name: String,
136    decl_type: Option<String>,
137}
138
139fn sqlite_col_to_owned_col(sqlite_col: &Column) -> OwnedColumn {
140    OwnedColumn {
141        name: sqlite_col.name().to_string(),
142        decl_type: sqlite_col.decl_type().map(|x| x.to_string()),
143    }
144}
145
146fn decl_type_to_remote_type(decl_type: &str) -> DFResult<SqliteType> {
147    if [
148        "tinyint", "smallint", "int", "integer", "bigint", "int2", "int4", "int8",
149    ]
150    .contains(&decl_type)
151    {
152        return Ok(SqliteType::Integer);
153    }
154    if ["real", "float", "double", "numeric"].contains(&decl_type) {
155        return Ok(SqliteType::Real);
156    }
157    if decl_type.starts_with("real") || decl_type.starts_with("numeric") {
158        return Ok(SqliteType::Real);
159    }
160    if ["text", "varchar", "char", "string"].contains(&decl_type) {
161        return Ok(SqliteType::Text);
162    }
163    if decl_type.starts_with("char")
164        || decl_type.starts_with("varchar")
165        || decl_type.starts_with("text")
166    {
167        return Ok(SqliteType::Text);
168    }
169    if ["binary", "varbinary", "tinyblob", "blob"].contains(&decl_type) {
170        return Ok(SqliteType::Blob);
171    }
172    if decl_type.starts_with("binary") || decl_type.starts_with("varbinary") {
173        return Ok(SqliteType::Blob);
174    }
175    Err(DataFusionError::NotImplemented(format!(
176        "Unsupported sqlite decl type: {decl_type}",
177    )))
178}
179
180fn build_remote_schema(columns: &[OwnedColumn], mut rows: Rows) -> DFResult<RemoteSchema> {
181    let mut remote_field_map = HashMap::with_capacity(columns.len());
182    let mut unknown_cols = vec![];
183    for (col_idx, col) in columns.iter().enumerate() {
184        if let Some(decl_type) = &col.decl_type {
185            let remote_type =
186                RemoteType::Sqlite(decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?);
187            remote_field_map.insert(col_idx, RemoteField::new(&col.name, remote_type, true));
188        } else {
189            // None for expressions
190            unknown_cols.push(col_idx);
191        }
192    }
193
194    if !unknown_cols.is_empty() {
195        while let Some(row) = rows.next().map_err(|e| {
196            DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
197        })? {
198            let mut to_be_removed = vec![];
199            for col_idx in unknown_cols.iter() {
200                let value_ref = row.get_ref(*col_idx).map_err(|e| {
201                    DataFusionError::Execution(format!(
202                        "Failed to get value ref for column {col_idx}: {e:?}"
203                    ))
204                })?;
205                match value_ref {
206                    ValueRef::Null => {}
207                    ValueRef::Integer(_) => {
208                        remote_field_map.insert(
209                            *col_idx,
210                            RemoteField::new(
211                                columns[*col_idx].name.clone(),
212                                RemoteType::Sqlite(SqliteType::Integer),
213                                true,
214                            ),
215                        );
216                        to_be_removed.push(*col_idx);
217                    }
218                    ValueRef::Real(_) => {
219                        remote_field_map.insert(
220                            *col_idx,
221                            RemoteField::new(
222                                columns[*col_idx].name.clone(),
223                                RemoteType::Sqlite(SqliteType::Real),
224                                true,
225                            ),
226                        );
227                        to_be_removed.push(*col_idx);
228                    }
229                    ValueRef::Text(_) => {
230                        remote_field_map.insert(
231                            *col_idx,
232                            RemoteField::new(
233                                columns[*col_idx].name.clone(),
234                                RemoteType::Sqlite(SqliteType::Text),
235                                true,
236                            ),
237                        );
238                        to_be_removed.push(*col_idx);
239                    }
240                    ValueRef::Blob(_) => {
241                        remote_field_map.insert(
242                            *col_idx,
243                            RemoteField::new(
244                                columns[*col_idx].name.clone(),
245                                RemoteType::Sqlite(SqliteType::Blob),
246                                true,
247                            ),
248                        );
249                        to_be_removed.push(*col_idx);
250                    }
251                }
252            }
253            for col_idx in to_be_removed.iter() {
254                unknown_cols.retain(|&x| x != *col_idx);
255            }
256            if unknown_cols.is_empty() {
257                break;
258            }
259        }
260    }
261
262    if !unknown_cols.is_empty() {
263        return Err(DataFusionError::NotImplemented(format!(
264            "Failed to infer sqlite decl type for columns: {unknown_cols:?}"
265        )));
266    }
267    let remote_fields = remote_field_map
268        .into_iter()
269        .sorted_by_key(|entry| entry.0)
270        .map(|entry| entry.1)
271        .collect::<Vec<_>>();
272    Ok(RemoteSchema::new(remote_fields))
273}
274
275fn spawn_background_task(
276    tx: tokio::sync::mpsc::Sender<DFResult<RecordBatch>>,
277    conn: rusqlite::Connection,
278    sql: String,
279    table_schema: SchemaRef,
280    projection: Option<Vec<usize>>,
281    chunk_size: usize,
282) {
283    std::thread::spawn(move || {
284        let runtime = match tokio::runtime::Builder::new_current_thread().build() {
285            Ok(runtime) => runtime,
286            Err(e) => {
287                error!("Failed to create tokio runtime to run sqlite query: {e:?}");
288                return;
289            }
290        };
291        let local_set = tokio::task::LocalSet::new();
292        local_set.block_on(&runtime, async move {
293            let mut stmt = match conn.prepare(&sql) {
294                Ok(stmt) => stmt,
295                Err(e) => {
296                    let _ = tx
297                        .send(Err(DataFusionError::Execution(format!(
298                            "Failed to prepare sqlite statement: {e:?}"
299                        ))))
300                        .await;
301                    return;
302                }
303            };
304            let columns: Vec<OwnedColumn> =
305                stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
306            let mut rows = match stmt.query([]) {
307                Ok(rows) => rows,
308                Err(e) => {
309                    let _ = tx
310                        .send(Err(DataFusionError::Execution(format!(
311                            "Failed to query sqlite statement: {e:?}"
312                        ))))
313                        .await;
314                    return;
315                }
316            };
317
318            loop {
319                let (batch, is_empty) = match rows_to_batch(
320                    &mut rows,
321                    &table_schema,
322                    &columns,
323                    projection.as_ref(),
324                    chunk_size,
325                ) {
326                    Ok((batch, is_empty)) => (batch, is_empty),
327                    Err(e) => {
328                        let _ = tx
329                            .send(Err(DataFusionError::Execution(format!(
330                                "Failed to convert rows to batch: {e:?}"
331                            ))))
332                            .await;
333                        return;
334                    }
335                };
336                if is_empty {
337                    break;
338                }
339                if tx.send(Ok(batch)).await.is_err() {
340                    return;
341                }
342            }
343        });
344    });
345}
346
347fn rows_to_batch(
348    rows: &mut Rows,
349    table_schema: &SchemaRef,
350    columns: &[OwnedColumn],
351    projection: Option<&Vec<usize>>,
352    chunk_size: usize,
353) -> DFResult<(RecordBatch, bool)> {
354    let projected_schema = project_schema(table_schema, projection)?;
355    let mut array_builders = vec![];
356    for field in table_schema.fields() {
357        let builder = make_builder(field.data_type(), 1000);
358        array_builders.push(builder);
359    }
360
361    let mut is_empty = true;
362    let mut row_count = 0;
363    while let Some(row) = rows.next().map_err(|e| {
364        DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
365    })? {
366        is_empty = false;
367        row_count += 1;
368        append_rows_to_array_builders(
369            row,
370            table_schema,
371            columns,
372            projection,
373            array_builders.as_mut_slice(),
374        )?;
375        if row_count >= chunk_size {
376            break;
377        }
378    }
379
380    let projected_columns = array_builders
381        .into_iter()
382        .enumerate()
383        .filter(|(idx, _)| projections_contains(projection, *idx))
384        .map(|(_, mut builder)| builder.finish())
385        .collect::<Vec<ArrayRef>>();
386    let options = RecordBatchOptions::new().with_row_count(Some(row_count));
387    Ok((
388        RecordBatch::try_new_with_options(projected_schema, projected_columns, &options)?,
389        is_empty,
390    ))
391}
392
393macro_rules! handle_primitive_type {
394    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
395        let builder = $builder
396            .as_any_mut()
397            .downcast_mut::<$builder_ty>()
398            .unwrap_or_else(|| {
399                panic!(
400                    "Failed to downcast builder to {} for {:?} and {:?}",
401                    stringify!($builder_ty),
402                    $field,
403                    $col
404                )
405            });
406
407        let v: Option<$value_ty> = $row.get($index).map_err(|e| {
408            DataFusionError::Execution(format!(
409                "Failed to get optional {} value for {:?} and {:?}: {e:?}",
410                stringify!($value_ty),
411                $field,
412                $col
413            ))
414        })?;
415
416        match v {
417            Some(v) => builder.append_value(v),
418            None => builder.append_null(),
419        }
420    }};
421}
422
423fn append_rows_to_array_builders(
424    row: &Row,
425    table_schema: &SchemaRef,
426    columns: &[OwnedColumn],
427    projection: Option<&Vec<usize>>,
428    array_builders: &mut [Box<dyn ArrayBuilder>],
429) -> DFResult<()> {
430    for (idx, field) in table_schema.fields.iter().enumerate() {
431        if !projections_contains(projection, idx) {
432            continue;
433        }
434        let builder = &mut array_builders[idx];
435        let col = columns.get(idx);
436        match field.data_type() {
437            DataType::Null => {
438                let builder = builder
439                    .as_any_mut()
440                    .downcast_mut::<NullBuilder>()
441                    .expect("Failed to downcast builder to NullBuilder");
442                builder.append_null();
443            }
444            DataType::Int32 => {
445                handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx);
446            }
447            DataType::Int64 => {
448                handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx);
449            }
450            DataType::Float64 => {
451                handle_primitive_type!(builder, field, col, Float64Builder, f64, row, idx);
452            }
453            DataType::Utf8 => {
454                handle_primitive_type!(builder, field, col, StringBuilder, String, row, idx);
455            }
456            DataType::Binary => {
457                handle_primitive_type!(builder, field, col, BinaryBuilder, Vec<u8>, row, idx);
458            }
459            _ => {
460                return Err(DataFusionError::NotImplemented(format!(
461                    "Unsupported data type {} for col: {:?}",
462                    field.data_type(),
463                    col
464                )));
465            }
466        }
467    }
468    Ok(())
469}