datafusion_remote_table/connection/
sqlite.rs

1use crate::connection::{RemoteDbType, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, Pool, RemoteField, RemoteSchema, RemoteSchemaRef,
4    RemoteType, SqliteType,
5};
6use datafusion::arrow::array::{
7    ArrayBuilder, ArrayRef, BinaryBuilder, Float64Builder, Int64Builder, NullBuilder, RecordBatch,
8    StringBuilder, make_builder,
9};
10use datafusion::arrow::datatypes::{DataType, SchemaRef};
11use datafusion::common::{DataFusionError, project_schema};
12use datafusion::execution::SendableRecordBatchStream;
13use datafusion::physical_plan::memory::MemoryStream;
14use datafusion::prelude::Expr;
15use itertools::Itertools;
16use regex::Regex;
17use rusqlite::types::ValueRef;
18use rusqlite::{Column, Row, Rows};
19use std::collections::HashMap;
20use std::path::PathBuf;
21use std::sync::{Arc, LazyLock};
22
23#[derive(Debug)]
24pub struct SqlitePool {
25    pool: tokio_rusqlite::Connection,
26}
27
28pub async fn connect_sqlite(path: &PathBuf) -> DFResult<SqlitePool> {
29    let pool = tokio_rusqlite::Connection::open(path).await.map_err(|e| {
30        DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
31    })?;
32    Ok(SqlitePool { pool })
33}
34
35#[async_trait::async_trait]
36impl Pool for SqlitePool {
37    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
38        let conn = self.pool.clone();
39        Ok(Arc::new(SqliteConnection { conn }))
40    }
41}
42
43#[derive(Debug)]
44pub struct SqliteConnection {
45    conn: tokio_rusqlite::Connection,
46}
47
48#[async_trait::async_trait]
49impl Connection for SqliteConnection {
50    async fn infer_schema(&self, sql: &str) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
51        let sql = sql.to_string();
52        self.conn
53            .call(move |conn| {
54                let mut stmt = conn.prepare(&sql)?;
55                let columns: Vec<OwnedColumn> =
56                    stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
57                let rows = stmt.query([])?;
58
59                let remote_schema = Arc::new(
60                    build_remote_schema(columns.as_slice(), rows)
61                        .map_err(|e| tokio_rusqlite::Error::Other(Box::new(e)))?,
62                );
63                let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
64                Ok((remote_schema, arrow_schema))
65            })
66            .await
67            .map_err(|e| DataFusionError::Execution(format!("Failed to infer schema: {e:?}")))
68    }
69
70    async fn query(
71        &self,
72        _conn_options: &ConnectionOptions,
73        sql: &str,
74        table_schema: SchemaRef,
75        projection: Option<&Vec<usize>>,
76        filters: &[Expr],
77        limit: Option<usize>,
78    ) -> DFResult<SendableRecordBatchStream> {
79        let projected_schema = project_schema(&table_schema, projection)?;
80        let sql = RemoteDbType::Sqlite
81            .try_rewrite_query(sql, filters, limit)
82            .unwrap_or_else(|| sql.to_string());
83        let sql_clone = sql.clone();
84        let projection = projection.cloned();
85        let batch = self
86            .conn
87            .call(move |conn| {
88                let mut stmt = conn.prepare(&sql)?;
89                let columns: Vec<OwnedColumn> =
90                    stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
91                let rows = stmt.query([])?;
92
93                let batch = rows_to_batch(rows, &table_schema, columns, projection.as_ref())
94                    .map_err(|e| tokio_rusqlite::Error::Other(e.into()))?;
95                Ok(batch)
96            })
97            .await
98            .map_err(|e| {
99                DataFusionError::Execution(format!(
100                    "Failed to execute query {sql_clone} on sqlite: {e:?}"
101                ))
102            })?;
103
104        let memory_stream = MemoryStream::try_new(vec![batch], projected_schema, None)?;
105        Ok(Box::pin(memory_stream))
106    }
107}
108
109#[derive(Debug)]
110struct OwnedColumn {
111    name: String,
112    decl_type: Option<String>,
113}
114
115fn sqlite_col_to_owned_col(sqlite_col: &Column) -> OwnedColumn {
116    OwnedColumn {
117        name: sqlite_col.name().to_string(),
118        decl_type: sqlite_col.decl_type().map(|x| x.to_string()),
119    }
120}
121
122static CHAR_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^(?i)char\(\d+\)$").unwrap());
123static VARCHAR_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^(?i)varchar\(\d+\)$").unwrap());
124static TEXT_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^(?i)text\(\d+\)$").unwrap());
125static BINARY_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^(?i)binary\(\d+\)$").unwrap());
126static VARBINARY_RE: LazyLock<Regex> =
127    LazyLock::new(|| Regex::new(r"^(?i)varbinary\(\d+\)$").unwrap());
128
129fn decl_type_to_remote_type(decl_type: &str) -> DFResult<RemoteType> {
130    if "null".eq(decl_type) {
131        return Ok(RemoteType::Sqlite(SqliteType::Null));
132    }
133    if ["tinyint", "smallint", "int", "integer", "bigint"].contains(&decl_type) {
134        return Ok(RemoteType::Sqlite(SqliteType::Integer));
135    }
136    if ["real", "float", "double"].contains(&decl_type) {
137        return Ok(RemoteType::Sqlite(SqliteType::Real));
138    }
139    if ["text", "varchar", "char", "string"].contains(&decl_type) {
140        return Ok(RemoteType::Sqlite(SqliteType::Text));
141    }
142    if CHAR_RE.is_match(decl_type) || VARCHAR_RE.is_match(decl_type) || TEXT_RE.is_match(decl_type)
143    {
144        return Ok(RemoteType::Sqlite(SqliteType::Text));
145    }
146    if ["binary", "varbinary", "tinyblob", "blob"].contains(&decl_type) {
147        return Ok(RemoteType::Sqlite(SqliteType::Blob));
148    }
149    if BINARY_RE.is_match(decl_type) || VARBINARY_RE.is_match(decl_type) {
150        return Ok(RemoteType::Sqlite(SqliteType::Blob));
151    }
152    Err(DataFusionError::NotImplemented(format!(
153        "Unsupported sqlite decl type: {decl_type}",
154    )))
155}
156
157fn build_remote_schema(columns: &[OwnedColumn], mut rows: Rows) -> DFResult<RemoteSchema> {
158    let mut remote_field_map = HashMap::with_capacity(columns.len());
159    let mut unknown_cols = vec![];
160    for (col_idx, col) in columns.iter().enumerate() {
161        if let Some(decl_type) = &col.decl_type {
162            let remote_type = decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?;
163            remote_field_map.insert(col_idx, RemoteField::new(&col.name, remote_type, true));
164        } else {
165            unknown_cols.push(col_idx);
166        }
167    }
168
169    if !unknown_cols.is_empty() {
170        while let Some(row) = rows.next().map_err(|e| {
171            DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
172        })? {
173            let mut to_be_removed = vec![];
174            for col_idx in unknown_cols.iter() {
175                let value_ref = row.get_ref(*col_idx).map_err(|e| {
176                    DataFusionError::Execution(format!(
177                        "Failed to get value ref for column {col_idx}: {e:?}"
178                    ))
179                })?;
180                match value_ref {
181                    ValueRef::Null => {}
182                    ValueRef::Integer(_) => {
183                        remote_field_map.insert(
184                            *col_idx,
185                            RemoteField::new(
186                                columns[*col_idx].name.clone(),
187                                RemoteType::Sqlite(SqliteType::Integer),
188                                true,
189                            ),
190                        );
191                        to_be_removed.push(*col_idx);
192                    }
193                    ValueRef::Real(_) => {
194                        remote_field_map.insert(
195                            *col_idx,
196                            RemoteField::new(
197                                columns[*col_idx].name.clone(),
198                                RemoteType::Sqlite(SqliteType::Real),
199                                true,
200                            ),
201                        );
202                        to_be_removed.push(*col_idx);
203                    }
204                    ValueRef::Text(_) => {
205                        remote_field_map.insert(
206                            *col_idx,
207                            RemoteField::new(
208                                columns[*col_idx].name.clone(),
209                                RemoteType::Sqlite(SqliteType::Text),
210                                true,
211                            ),
212                        );
213                        to_be_removed.push(*col_idx);
214                    }
215                    ValueRef::Blob(_) => {
216                        remote_field_map.insert(
217                            *col_idx,
218                            RemoteField::new(
219                                columns[*col_idx].name.clone(),
220                                RemoteType::Sqlite(SqliteType::Blob),
221                                true,
222                            ),
223                        );
224                        to_be_removed.push(*col_idx);
225                    }
226                }
227            }
228            for col_idx in to_be_removed.iter() {
229                unknown_cols.retain(|&x| x != *col_idx);
230            }
231            if unknown_cols.is_empty() {
232                break;
233            }
234        }
235    }
236
237    if !unknown_cols.is_empty() {
238        return Err(DataFusionError::NotImplemented(format!(
239            "Failed to infer sqlite decl type for columns: {unknown_cols:?}"
240        )));
241    }
242    let remote_fields = remote_field_map
243        .into_iter()
244        .sorted_by_key(|entry| entry.0)
245        .map(|entry| entry.1)
246        .collect::<Vec<_>>();
247    Ok(RemoteSchema::new(remote_fields))
248}
249
250fn rows_to_batch(
251    mut rows: Rows,
252    table_schema: &SchemaRef,
253    columns: Vec<OwnedColumn>,
254    projection: Option<&Vec<usize>>,
255) -> DFResult<RecordBatch> {
256    let projected_schema = project_schema(table_schema, projection)?;
257    let mut array_builders = vec![];
258    for field in table_schema.fields() {
259        let builder = make_builder(field.data_type(), 1000);
260        array_builders.push(builder);
261    }
262
263    while let Some(row) = rows.next().map_err(|e| {
264        DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
265    })? {
266        append_rows_to_array_builders(
267            row,
268            table_schema,
269            &columns,
270            projection,
271            array_builders.as_mut_slice(),
272        )?;
273    }
274
275    let projected_columns = array_builders
276        .into_iter()
277        .enumerate()
278        .filter(|(idx, _)| projections_contains(projection, *idx))
279        .map(|(_, mut builder)| builder.finish())
280        .collect::<Vec<ArrayRef>>();
281    Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
282}
283
284macro_rules! handle_primitive_type {
285    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
286        let builder = $builder
287            .as_any_mut()
288            .downcast_mut::<$builder_ty>()
289            .unwrap_or_else(|| {
290                panic!(
291                    "Failed to downcast builder to {} for {:?} and {:?}",
292                    stringify!($builder_ty),
293                    $field,
294                    $col
295                )
296            });
297
298        let v: Option<$value_ty> = $row.get($index).map_err(|e| {
299            DataFusionError::Execution(format!(
300                "Failed to get optional {} value for {:?} and {:?}: {e:?}",
301                stringify!($value_ty),
302                $field,
303                $col
304            ))
305        })?;
306
307        match v {
308            Some(v) => builder.append_value(v),
309            None => builder.append_null(),
310        }
311    }};
312}
313
314fn append_rows_to_array_builders(
315    row: &Row,
316    table_schema: &SchemaRef,
317    columns: &[OwnedColumn],
318    projection: Option<&Vec<usize>>,
319    array_builders: &mut [Box<dyn ArrayBuilder>],
320) -> DFResult<()> {
321    for (idx, field) in table_schema.fields.iter().enumerate() {
322        if !projections_contains(projection, idx) {
323            continue;
324        }
325        let builder = &mut array_builders[idx];
326        let col = columns.get(idx);
327        match field.data_type() {
328            DataType::Null => {
329                let builder = builder
330                    .as_any_mut()
331                    .downcast_mut::<NullBuilder>()
332                    .expect("Failed to downcast builder to NullBuilder");
333                builder.append_null();
334            }
335            DataType::Int64 => {
336                handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx);
337            }
338            DataType::Float64 => {
339                handle_primitive_type!(builder, field, col, Float64Builder, f64, row, idx);
340            }
341            DataType::Utf8 => {
342                handle_primitive_type!(builder, field, col, StringBuilder, String, row, idx);
343            }
344            DataType::Binary => {
345                handle_primitive_type!(builder, field, col, BinaryBuilder, Vec<u8>, row, idx);
346            }
347            _ => {
348                return Err(DataFusionError::NotImplemented(format!(
349                    "Unsupported data type {:?} for col: {:?}",
350                    field.data_type(),
351                    col
352                )));
353            }
354        }
355    }
356    Ok(())
357}