datafusion_remote_table/connection/
sqlite.rs

1use crate::connection::{RemoteDbType, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, Pool, RemoteField, RemoteSchema, RemoteSchemaRef,
4    RemoteType, SqliteType,
5};
6use datafusion::arrow::array::{
7    ArrayBuilder, ArrayRef, BinaryBuilder, Float64Builder, Int64Builder, NullBuilder, RecordBatch,
8    StringBuilder, make_builder,
9};
10use datafusion::arrow::datatypes::{DataType, SchemaRef};
11use datafusion::common::{DataFusionError, project_schema};
12use datafusion::execution::SendableRecordBatchStream;
13use datafusion::physical_plan::memory::MemoryStream;
14use datafusion::prelude::Expr;
15use itertools::Itertools;
16use rusqlite::types::ValueRef;
17use rusqlite::{Column, Row, Rows};
18use std::collections::HashMap;
19use std::path::PathBuf;
20use std::sync::Arc;
21
22#[derive(Debug)]
23pub struct SqlitePool {
24    pool: tokio_rusqlite::Connection,
25}
26
27pub async fn connect_sqlite(path: &PathBuf) -> DFResult<SqlitePool> {
28    let pool = tokio_rusqlite::Connection::open(path).await.map_err(|e| {
29        DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
30    })?;
31    Ok(SqlitePool { pool })
32}
33
34#[async_trait::async_trait]
35impl Pool for SqlitePool {
36    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
37        let conn = self.pool.clone();
38        Ok(Arc::new(SqliteConnection { conn }))
39    }
40}
41
42#[derive(Debug)]
43pub struct SqliteConnection {
44    conn: tokio_rusqlite::Connection,
45}
46
47#[async_trait::async_trait]
48impl Connection for SqliteConnection {
49    async fn infer_schema(&self, sql: &str) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
50        let sql = sql.to_string();
51        self.conn
52            .call(move |conn| {
53                let mut stmt = conn.prepare(&sql)?;
54                let columns: Vec<OwnedColumn> =
55                    stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
56                let rows = stmt.query([])?;
57
58                let remote_schema = Arc::new(
59                    build_remote_schema(columns.as_slice(), rows)
60                        .map_err(|e| tokio_rusqlite::Error::Other(Box::new(e)))?,
61                );
62                let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
63                Ok((remote_schema, arrow_schema))
64            })
65            .await
66            .map_err(|e| DataFusionError::Execution(format!("Failed to infer schema: {e:?}")))
67    }
68
69    async fn query(
70        &self,
71        _conn_options: &ConnectionOptions,
72        sql: &str,
73        table_schema: SchemaRef,
74        projection: Option<&Vec<usize>>,
75        filters: &[Expr],
76        limit: Option<usize>,
77    ) -> DFResult<SendableRecordBatchStream> {
78        let projected_schema = project_schema(&table_schema, projection)?;
79        let sql = RemoteDbType::Sqlite
80            .try_rewrite_query(sql, filters, limit)
81            .unwrap_or_else(|| sql.to_string());
82        let sql_clone = sql.clone();
83        let projection = projection.cloned();
84        let batch = self
85            .conn
86            .call(move |conn| {
87                let mut stmt = conn.prepare(&sql)?;
88                let columns: Vec<OwnedColumn> =
89                    stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
90                let rows = stmt.query([])?;
91
92                let batch = rows_to_batch(rows, &table_schema, columns, projection.as_ref())
93                    .map_err(|e| tokio_rusqlite::Error::Other(e.into()))?;
94                Ok(batch)
95            })
96            .await
97            .map_err(|e| {
98                DataFusionError::Execution(format!(
99                    "Failed to execute query {sql_clone} on sqlite: {e:?}"
100                ))
101            })?;
102
103        let memory_stream = MemoryStream::try_new(vec![batch], projected_schema, None)?;
104        Ok(Box::pin(memory_stream))
105    }
106}
107
108#[derive(Debug)]
109struct OwnedColumn {
110    name: String,
111    decl_type: Option<String>,
112}
113
114fn sqlite_col_to_owned_col(sqlite_col: &Column) -> OwnedColumn {
115    OwnedColumn {
116        name: sqlite_col.name().to_string(),
117        decl_type: sqlite_col.decl_type().map(|x| x.to_string()),
118    }
119}
120
121fn decl_type_to_remote_type(decl_type: &str) -> DFResult<RemoteType> {
122    if "null".eq(decl_type) {
123        return Ok(RemoteType::Sqlite(SqliteType::Null));
124    }
125    if ["tinyint", "smallint", "int", "integer", "bigint"].contains(&decl_type) {
126        return Ok(RemoteType::Sqlite(SqliteType::Integer));
127    }
128    if ["real", "float", "double"].contains(&decl_type) {
129        return Ok(RemoteType::Sqlite(SqliteType::Real));
130    }
131    if decl_type.starts_with("real") {
132        return Ok(RemoteType::Sqlite(SqliteType::Real));
133    }
134    if ["text", "varchar", "char", "string"].contains(&decl_type) {
135        return Ok(RemoteType::Sqlite(SqliteType::Text));
136    }
137    if decl_type.starts_with("char")
138        || decl_type.starts_with("varchar")
139        || decl_type.starts_with("text")
140    {
141        return Ok(RemoteType::Sqlite(SqliteType::Text));
142    }
143    if ["binary", "varbinary", "tinyblob", "blob"].contains(&decl_type) {
144        return Ok(RemoteType::Sqlite(SqliteType::Blob));
145    }
146    if decl_type.starts_with("binary") || decl_type.starts_with("varbinary") {
147        return Ok(RemoteType::Sqlite(SqliteType::Blob));
148    }
149    Err(DataFusionError::NotImplemented(format!(
150        "Unsupported sqlite decl type: {decl_type}",
151    )))
152}
153
154fn build_remote_schema(columns: &[OwnedColumn], mut rows: Rows) -> DFResult<RemoteSchema> {
155    let mut remote_field_map = HashMap::with_capacity(columns.len());
156    let mut unknown_cols = vec![];
157    for (col_idx, col) in columns.iter().enumerate() {
158        if let Some(decl_type) = &col.decl_type {
159            let remote_type = decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?;
160            remote_field_map.insert(col_idx, RemoteField::new(&col.name, remote_type, true));
161        } else {
162            unknown_cols.push(col_idx);
163        }
164    }
165
166    if !unknown_cols.is_empty() {
167        while let Some(row) = rows.next().map_err(|e| {
168            DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
169        })? {
170            let mut to_be_removed = vec![];
171            for col_idx in unknown_cols.iter() {
172                let value_ref = row.get_ref(*col_idx).map_err(|e| {
173                    DataFusionError::Execution(format!(
174                        "Failed to get value ref for column {col_idx}: {e:?}"
175                    ))
176                })?;
177                match value_ref {
178                    ValueRef::Null => {}
179                    ValueRef::Integer(_) => {
180                        remote_field_map.insert(
181                            *col_idx,
182                            RemoteField::new(
183                                columns[*col_idx].name.clone(),
184                                RemoteType::Sqlite(SqliteType::Integer),
185                                true,
186                            ),
187                        );
188                        to_be_removed.push(*col_idx);
189                    }
190                    ValueRef::Real(_) => {
191                        remote_field_map.insert(
192                            *col_idx,
193                            RemoteField::new(
194                                columns[*col_idx].name.clone(),
195                                RemoteType::Sqlite(SqliteType::Real),
196                                true,
197                            ),
198                        );
199                        to_be_removed.push(*col_idx);
200                    }
201                    ValueRef::Text(_) => {
202                        remote_field_map.insert(
203                            *col_idx,
204                            RemoteField::new(
205                                columns[*col_idx].name.clone(),
206                                RemoteType::Sqlite(SqliteType::Text),
207                                true,
208                            ),
209                        );
210                        to_be_removed.push(*col_idx);
211                    }
212                    ValueRef::Blob(_) => {
213                        remote_field_map.insert(
214                            *col_idx,
215                            RemoteField::new(
216                                columns[*col_idx].name.clone(),
217                                RemoteType::Sqlite(SqliteType::Blob),
218                                true,
219                            ),
220                        );
221                        to_be_removed.push(*col_idx);
222                    }
223                }
224            }
225            for col_idx in to_be_removed.iter() {
226                unknown_cols.retain(|&x| x != *col_idx);
227            }
228            if unknown_cols.is_empty() {
229                break;
230            }
231        }
232    }
233
234    if !unknown_cols.is_empty() {
235        return Err(DataFusionError::NotImplemented(format!(
236            "Failed to infer sqlite decl type for columns: {unknown_cols:?}"
237        )));
238    }
239    let remote_fields = remote_field_map
240        .into_iter()
241        .sorted_by_key(|entry| entry.0)
242        .map(|entry| entry.1)
243        .collect::<Vec<_>>();
244    Ok(RemoteSchema::new(remote_fields))
245}
246
247fn rows_to_batch(
248    mut rows: Rows,
249    table_schema: &SchemaRef,
250    columns: Vec<OwnedColumn>,
251    projection: Option<&Vec<usize>>,
252) -> DFResult<RecordBatch> {
253    let projected_schema = project_schema(table_schema, projection)?;
254    let mut array_builders = vec![];
255    for field in table_schema.fields() {
256        let builder = make_builder(field.data_type(), 1000);
257        array_builders.push(builder);
258    }
259
260    while let Some(row) = rows.next().map_err(|e| {
261        DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
262    })? {
263        append_rows_to_array_builders(
264            row,
265            table_schema,
266            &columns,
267            projection,
268            array_builders.as_mut_slice(),
269        )?;
270    }
271
272    let projected_columns = array_builders
273        .into_iter()
274        .enumerate()
275        .filter(|(idx, _)| projections_contains(projection, *idx))
276        .map(|(_, mut builder)| builder.finish())
277        .collect::<Vec<ArrayRef>>();
278    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
279}
280
281macro_rules! handle_primitive_type {
282    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
283        let builder = $builder
284            .as_any_mut()
285            .downcast_mut::<$builder_ty>()
286            .unwrap_or_else(|| {
287                panic!(
288                    "Failed to downcast builder to {} for {:?} and {:?}",
289                    stringify!($builder_ty),
290                    $field,
291                    $col
292                )
293            });
294
295        let v: Option<$value_ty> = $row.get($index).map_err(|e| {
296            DataFusionError::Execution(format!(
297                "Failed to get optional {} value for {:?} and {:?}: {e:?}",
298                stringify!($value_ty),
299                $field,
300                $col
301            ))
302        })?;
303
304        match v {
305            Some(v) => builder.append_value(v),
306            None => builder.append_null(),
307        }
308    }};
309}
310
311fn append_rows_to_array_builders(
312    row: &Row,
313    table_schema: &SchemaRef,
314    columns: &[OwnedColumn],
315    projection: Option<&Vec<usize>>,
316    array_builders: &mut [Box<dyn ArrayBuilder>],
317) -> DFResult<()> {
318    for (idx, field) in table_schema.fields.iter().enumerate() {
319        if !projections_contains(projection, idx) {
320            continue;
321        }
322        let builder = &mut array_builders[idx];
323        let col = columns.get(idx);
324        match field.data_type() {
325            DataType::Null => {
326                let builder = builder
327                    .as_any_mut()
328                    .downcast_mut::<NullBuilder>()
329                    .expect("Failed to downcast builder to NullBuilder");
330                builder.append_null();
331            }
332            DataType::Int64 => {
333                handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx);
334            }
335            DataType::Float64 => {
336                handle_primitive_type!(builder, field, col, Float64Builder, f64, row, idx);
337            }
338            DataType::Utf8 => {
339                handle_primitive_type!(builder, field, col, StringBuilder, String, row, idx);
340            }
341            DataType::Binary => {
342                handle_primitive_type!(builder, field, col, BinaryBuilder, Vec<u8>, row, idx);
343            }
344            _ => {
345                return Err(DataFusionError::NotImplemented(format!(
346                    "Unsupported data type {:?} for col: {:?}",
347                    field.data_type(),
348                    col
349                )));
350            }
351        }
352    }
353    Ok(())
354}