datafusion_remote_table/connection/
sqlite.rs

1use crate::connection::{RemoteDbType, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, Pool, RemoteField, RemoteSchema, RemoteSchemaRef,
4    RemoteType, SqliteType,
5};
6use datafusion::arrow::array::{
7    ArrayBuilder, ArrayRef, BinaryBuilder, Float64Builder, Int32Builder, Int64Builder, NullBuilder,
8    RecordBatch, RecordBatchOptions, StringBuilder, make_builder,
9};
10use datafusion::arrow::datatypes::{DataType, SchemaRef};
11use datafusion::common::{DataFusionError, project_schema};
12use datafusion::execution::SendableRecordBatchStream;
13use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
14use derive_getters::Getters;
15use derive_with::With;
16use itertools::Itertools;
17use log::debug;
18use rusqlite::types::ValueRef;
19use rusqlite::{Column, Row, Rows};
20use std::any::Any;
21use std::collections::HashMap;
22use std::path::PathBuf;
23use std::sync::Arc;
24
25#[derive(Debug, Clone, With, Getters)]
26pub struct SqliteConnectionOptions {
27    pub path: PathBuf,
28    pub stream_chunk_size: usize,
29}
30
31impl SqliteConnectionOptions {
32    pub fn new(path: PathBuf) -> Self {
33        Self {
34            path,
35            stream_chunk_size: 2048,
36        }
37    }
38}
39
40impl From<SqliteConnectionOptions> for ConnectionOptions {
41    fn from(options: SqliteConnectionOptions) -> Self {
42        ConnectionOptions::Sqlite(options)
43    }
44}
45
46#[derive(Debug)]
47pub struct SqlitePool {
48    pool: tokio_rusqlite::Connection,
49}
50
51pub async fn connect_sqlite(options: &SqliteConnectionOptions) -> DFResult<SqlitePool> {
52    let pool = tokio_rusqlite::Connection::open(&options.path)
53        .await
54        .map_err(|e| {
55            DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
56        })?;
57    Ok(SqlitePool { pool })
58}
59
60#[async_trait::async_trait]
61impl Pool for SqlitePool {
62    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
63        let conn = self.pool.clone();
64        Ok(Arc::new(SqliteConnection { conn }))
65    }
66}
67
68#[derive(Debug)]
69pub struct SqliteConnection {
70    conn: tokio_rusqlite::Connection,
71}
72
73#[async_trait::async_trait]
74impl Connection for SqliteConnection {
75    fn as_any(&self) -> &dyn Any {
76        self
77    }
78
79    async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
80        let sql = RemoteDbType::Sqlite.query_limit_1(sql);
81        self.conn
82            .call(move |conn| {
83                let mut stmt = conn.prepare(&sql)?;
84                let columns: Vec<OwnedColumn> =
85                    stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
86                let rows = stmt.query([])?;
87
88                let remote_schema = Arc::new(
89                    build_remote_schema(columns.as_slice(), rows)
90                        .map_err(|e| tokio_rusqlite::Error::Other(Box::new(e)))?,
91                );
92                Ok(remote_schema)
93            })
94            .await
95            .map_err(|e| DataFusionError::Execution(format!("Failed to infer schema: {e:?}")))
96    }
97
98    async fn query(
99        &self,
100        conn_options: &ConnectionOptions,
101        sql: &str,
102        table_schema: SchemaRef,
103        projection: Option<&Vec<usize>>,
104        unparsed_filters: &[String],
105        limit: Option<usize>,
106    ) -> DFResult<SendableRecordBatchStream> {
107        let projected_schema = project_schema(&table_schema, projection)?;
108        let sql = RemoteDbType::Sqlite.rewrite_query(sql, unparsed_filters, limit);
109        debug!("[remote-table] executing sqlite query: {sql}");
110
111        let conn = self.conn.clone();
112        let projection = projection.cloned();
113        let limit = conn_options.stream_chunk_size();
114        let stream = async_stream::stream! {
115            let mut offset = 0;
116            loop {
117                let sql = format!("SELECT * FROM ({sql}) LIMIT {limit} OFFSET {offset}");
118                let sql_clone = sql.clone();
119                let conn = conn.clone();
120                let projection = projection.clone();
121                let table_schema = table_schema.clone();
122                let (batch, is_empty) = conn
123                    .call(move |conn| {
124                        let mut stmt = conn.prepare(&sql)?;
125                        let columns: Vec<OwnedColumn> =
126                            stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
127                        let rows = stmt.query([])?;
128
129                        rows_to_batch(rows, &table_schema, columns, projection.as_ref())
130                            .map_err(|e| tokio_rusqlite::Error::Other(e.into()))
131                })
132                .await
133                .map_err(|e| {
134                    DataFusionError::Execution(format!(
135                        "Failed to execute query {sql_clone} on sqlite: {e:?}"
136                    ))
137                })?;
138                if is_empty {
139                    break;
140                }
141                yield Ok(batch);
142                offset += limit;
143            }
144        };
145        Ok(Box::pin(RecordBatchStreamAdapter::new(
146            projected_schema,
147            stream,
148        )))
149    }
150}
151
152#[derive(Debug)]
153struct OwnedColumn {
154    name: String,
155    decl_type: Option<String>,
156}
157
158fn sqlite_col_to_owned_col(sqlite_col: &Column) -> OwnedColumn {
159    OwnedColumn {
160        name: sqlite_col.name().to_string(),
161        decl_type: sqlite_col.decl_type().map(|x| x.to_string()),
162    }
163}
164
165fn decl_type_to_remote_type(decl_type: &str) -> DFResult<SqliteType> {
166    if ["tinyint", "smallint", "int", "integer", "bigint"].contains(&decl_type) {
167        return Ok(SqliteType::Integer);
168    }
169    if ["real", "float", "double"].contains(&decl_type) {
170        return Ok(SqliteType::Real);
171    }
172    if decl_type.starts_with("real") {
173        return Ok(SqliteType::Real);
174    }
175    if ["text", "varchar", "char", "string"].contains(&decl_type) {
176        return Ok(SqliteType::Text);
177    }
178    if decl_type.starts_with("char")
179        || decl_type.starts_with("varchar")
180        || decl_type.starts_with("text")
181    {
182        return Ok(SqliteType::Text);
183    }
184    if ["binary", "varbinary", "tinyblob", "blob"].contains(&decl_type) {
185        return Ok(SqliteType::Blob);
186    }
187    if decl_type.starts_with("binary") || decl_type.starts_with("varbinary") {
188        return Ok(SqliteType::Blob);
189    }
190    Err(DataFusionError::NotImplemented(format!(
191        "Unsupported sqlite decl type: {decl_type}",
192    )))
193}
194
195fn build_remote_schema(columns: &[OwnedColumn], mut rows: Rows) -> DFResult<RemoteSchema> {
196    let mut remote_field_map = HashMap::with_capacity(columns.len());
197    let mut unknown_cols = vec![];
198    for (col_idx, col) in columns.iter().enumerate() {
199        if let Some(decl_type) = &col.decl_type {
200            let remote_type =
201                RemoteType::Sqlite(decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?);
202            remote_field_map.insert(col_idx, RemoteField::new(&col.name, remote_type, true));
203        } else {
204            // None for expressions
205            unknown_cols.push(col_idx);
206        }
207    }
208
209    if !unknown_cols.is_empty() {
210        while let Some(row) = rows.next().map_err(|e| {
211            DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
212        })? {
213            let mut to_be_removed = vec![];
214            for col_idx in unknown_cols.iter() {
215                let value_ref = row.get_ref(*col_idx).map_err(|e| {
216                    DataFusionError::Execution(format!(
217                        "Failed to get value ref for column {col_idx}: {e:?}"
218                    ))
219                })?;
220                match value_ref {
221                    ValueRef::Null => {}
222                    ValueRef::Integer(_) => {
223                        remote_field_map.insert(
224                            *col_idx,
225                            RemoteField::new(
226                                columns[*col_idx].name.clone(),
227                                RemoteType::Sqlite(SqliteType::Integer),
228                                true,
229                            ),
230                        );
231                        to_be_removed.push(*col_idx);
232                    }
233                    ValueRef::Real(_) => {
234                        remote_field_map.insert(
235                            *col_idx,
236                            RemoteField::new(
237                                columns[*col_idx].name.clone(),
238                                RemoteType::Sqlite(SqliteType::Real),
239                                true,
240                            ),
241                        );
242                        to_be_removed.push(*col_idx);
243                    }
244                    ValueRef::Text(_) => {
245                        remote_field_map.insert(
246                            *col_idx,
247                            RemoteField::new(
248                                columns[*col_idx].name.clone(),
249                                RemoteType::Sqlite(SqliteType::Text),
250                                true,
251                            ),
252                        );
253                        to_be_removed.push(*col_idx);
254                    }
255                    ValueRef::Blob(_) => {
256                        remote_field_map.insert(
257                            *col_idx,
258                            RemoteField::new(
259                                columns[*col_idx].name.clone(),
260                                RemoteType::Sqlite(SqliteType::Blob),
261                                true,
262                            ),
263                        );
264                        to_be_removed.push(*col_idx);
265                    }
266                }
267            }
268            for col_idx in to_be_removed.iter() {
269                unknown_cols.retain(|&x| x != *col_idx);
270            }
271            if unknown_cols.is_empty() {
272                break;
273            }
274        }
275    }
276
277    if !unknown_cols.is_empty() {
278        return Err(DataFusionError::NotImplemented(format!(
279            "Failed to infer sqlite decl type for columns: {unknown_cols:?}"
280        )));
281    }
282    let remote_fields = remote_field_map
283        .into_iter()
284        .sorted_by_key(|entry| entry.0)
285        .map(|entry| entry.1)
286        .collect::<Vec<_>>();
287    Ok(RemoteSchema::new(remote_fields))
288}
289
290fn rows_to_batch(
291    mut rows: Rows,
292    table_schema: &SchemaRef,
293    columns: Vec<OwnedColumn>,
294    projection: Option<&Vec<usize>>,
295) -> DFResult<(RecordBatch, bool)> {
296    let projected_schema = project_schema(table_schema, projection)?;
297    let mut array_builders = vec![];
298    for field in table_schema.fields() {
299        let builder = make_builder(field.data_type(), 1000);
300        array_builders.push(builder);
301    }
302
303    let mut is_empty = true;
304    let mut row_count = 0;
305    while let Some(row) = rows.next().map_err(|e| {
306        DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
307    })? {
308        is_empty = false;
309        row_count += 1;
310        append_rows_to_array_builders(
311            row,
312            table_schema,
313            &columns,
314            projection,
315            array_builders.as_mut_slice(),
316        )?;
317    }
318
319    let projected_columns = array_builders
320        .into_iter()
321        .enumerate()
322        .filter(|(idx, _)| projections_contains(projection, *idx))
323        .map(|(_, mut builder)| builder.finish())
324        .collect::<Vec<ArrayRef>>();
325    let options = RecordBatchOptions::new().with_row_count(Some(row_count));
326    Ok((
327        RecordBatch::try_new_with_options(projected_schema, projected_columns, &options)?,
328        is_empty,
329    ))
330}
331
332macro_rules! handle_primitive_type {
333    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
334        let builder = $builder
335            .as_any_mut()
336            .downcast_mut::<$builder_ty>()
337            .unwrap_or_else(|| {
338                panic!(
339                    "Failed to downcast builder to {} for {:?} and {:?}",
340                    stringify!($builder_ty),
341                    $field,
342                    $col
343                )
344            });
345
346        let v: Option<$value_ty> = $row.get($index).map_err(|e| {
347            DataFusionError::Execution(format!(
348                "Failed to get optional {} value for {:?} and {:?}: {e:?}",
349                stringify!($value_ty),
350                $field,
351                $col
352            ))
353        })?;
354
355        match v {
356            Some(v) => builder.append_value(v),
357            None => builder.append_null(),
358        }
359    }};
360}
361
362fn append_rows_to_array_builders(
363    row: &Row,
364    table_schema: &SchemaRef,
365    columns: &[OwnedColumn],
366    projection: Option<&Vec<usize>>,
367    array_builders: &mut [Box<dyn ArrayBuilder>],
368) -> DFResult<()> {
369    for (idx, field) in table_schema.fields.iter().enumerate() {
370        if !projections_contains(projection, idx) {
371            continue;
372        }
373        let builder = &mut array_builders[idx];
374        let col = columns.get(idx);
375        match field.data_type() {
376            DataType::Null => {
377                let builder = builder
378                    .as_any_mut()
379                    .downcast_mut::<NullBuilder>()
380                    .expect("Failed to downcast builder to NullBuilder");
381                builder.append_null();
382            }
383            DataType::Int32 => {
384                handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx);
385            }
386            DataType::Int64 => {
387                handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx);
388            }
389            DataType::Float64 => {
390                handle_primitive_type!(builder, field, col, Float64Builder, f64, row, idx);
391            }
392            DataType::Utf8 => {
393                handle_primitive_type!(builder, field, col, StringBuilder, String, row, idx);
394            }
395            DataType::Binary => {
396                handle_primitive_type!(builder, field, col, BinaryBuilder, Vec<u8>, row, idx);
397            }
398            _ => {
399                return Err(DataFusionError::NotImplemented(format!(
400                    "Unsupported data type {:?} for col: {:?}",
401                    field.data_type(),
402                    col
403                )));
404            }
405        }
406    }
407    Ok(())
408}