datafusion_remote_table/connection/
sqlite.rs

1use crate::connection::{RemoteDbType, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, Pool, RemoteField, RemoteSchema, RemoteSchemaRef,
4    RemoteType, SqliteType,
5};
6use datafusion::arrow::array::{
7    ArrayBuilder, ArrayRef, BinaryBuilder, Float64Builder, Int32Builder, Int64Builder, NullBuilder,
8    RecordBatch, StringBuilder, make_builder,
9};
10use datafusion::arrow::datatypes::{DataType, SchemaRef};
11use datafusion::common::{DataFusionError, project_schema};
12use datafusion::execution::SendableRecordBatchStream;
13use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
14use derive_getters::Getters;
15use derive_with::With;
16use itertools::Itertools;
17use rusqlite::types::ValueRef;
18use rusqlite::{Column, Row, Rows};
19use std::collections::HashMap;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23#[derive(Debug, Clone, With, Getters)]
24pub struct SqliteConnectionOptions {
25    pub path: PathBuf,
26    pub stream_chunk_size: usize,
27}
28
29impl SqliteConnectionOptions {
30    pub fn new(path: PathBuf) -> Self {
31        Self {
32            path,
33            stream_chunk_size: 2048,
34        }
35    }
36}
37
38#[derive(Debug)]
39pub struct SqlitePool {
40    pool: tokio_rusqlite::Connection,
41}
42
43pub async fn connect_sqlite(options: &SqliteConnectionOptions) -> DFResult<SqlitePool> {
44    let pool = tokio_rusqlite::Connection::open(&options.path)
45        .await
46        .map_err(|e| {
47            DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
48        })?;
49    Ok(SqlitePool { pool })
50}
51
52#[async_trait::async_trait]
53impl Pool for SqlitePool {
54    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
55        let conn = self.pool.clone();
56        Ok(Arc::new(SqliteConnection { conn }))
57    }
58}
59
60#[derive(Debug)]
61pub struct SqliteConnection {
62    conn: tokio_rusqlite::Connection,
63}
64
65#[async_trait::async_trait]
66impl Connection for SqliteConnection {
67    async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
68        let sql = RemoteDbType::Sqlite.query_limit_1(sql)?;
69        self.conn
70            .call(move |conn| {
71                let mut stmt = conn.prepare(&sql)?;
72                let columns: Vec<OwnedColumn> =
73                    stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
74                let rows = stmt.query([])?;
75
76                let remote_schema = Arc::new(
77                    build_remote_schema(columns.as_slice(), rows)
78                        .map_err(|e| tokio_rusqlite::Error::Other(Box::new(e)))?,
79                );
80                Ok(remote_schema)
81            })
82            .await
83            .map_err(|e| DataFusionError::Execution(format!("Failed to infer schema: {e:?}")))
84    }
85
86    async fn query(
87        &self,
88        conn_options: &ConnectionOptions,
89        sql: &str,
90        table_schema: SchemaRef,
91        projection: Option<&Vec<usize>>,
92        unparsed_filters: &[String],
93        limit: Option<usize>,
94    ) -> DFResult<SendableRecordBatchStream> {
95        let projected_schema = project_schema(&table_schema, projection)?;
96        let sql = RemoteDbType::Sqlite.try_rewrite_query(sql, unparsed_filters, limit)?;
97        let conn = self.conn.clone();
98        let projection = projection.cloned();
99        let limit = conn_options.stream_chunk_size();
100        let stream = async_stream::stream! {
101            let mut offset = 0;
102            loop {
103                let sql = format!("SELECT * FROM ({sql}) LIMIT {limit} OFFSET {offset}");
104                let sql_clone = sql.clone();
105                let conn = conn.clone();
106                let projection = projection.clone();
107                let table_schema = table_schema.clone();
108                let (batch, is_empty) = conn
109                    .call(move |conn| {
110                        let mut stmt = conn.prepare(&sql)?;
111                        let columns: Vec<OwnedColumn> =
112                            stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
113                        let rows = stmt.query([])?;
114
115                        rows_to_batch(rows, &table_schema, columns, projection.as_ref())
116                            .map_err(|e| tokio_rusqlite::Error::Other(e.into()))
117                })
118                .await
119                .map_err(|e| {
120                    DataFusionError::Execution(format!(
121                        "Failed to execute query {sql_clone} on sqlite: {e:?}"
122                    ))
123                })?;
124                if is_empty {
125                    break;
126                }
127                yield Ok(batch);
128                offset += limit;
129            }
130        };
131        Ok(Box::pin(RecordBatchStreamAdapter::new(
132            projected_schema,
133            stream,
134        )))
135    }
136}
137
138#[derive(Debug)]
139struct OwnedColumn {
140    name: String,
141    decl_type: Option<String>,
142}
143
144fn sqlite_col_to_owned_col(sqlite_col: &Column) -> OwnedColumn {
145    OwnedColumn {
146        name: sqlite_col.name().to_string(),
147        decl_type: sqlite_col.decl_type().map(|x| x.to_string()),
148    }
149}
150
151fn decl_type_to_remote_type(decl_type: &str) -> DFResult<SqliteType> {
152    if ["tinyint", "smallint", "int", "integer", "bigint"].contains(&decl_type) {
153        return Ok(SqliteType::Integer);
154    }
155    if ["real", "float", "double"].contains(&decl_type) {
156        return Ok(SqliteType::Real);
157    }
158    if decl_type.starts_with("real") {
159        return Ok(SqliteType::Real);
160    }
161    if ["text", "varchar", "char", "string"].contains(&decl_type) {
162        return Ok(SqliteType::Text);
163    }
164    if decl_type.starts_with("char")
165        || decl_type.starts_with("varchar")
166        || decl_type.starts_with("text")
167    {
168        return Ok(SqliteType::Text);
169    }
170    if ["binary", "varbinary", "tinyblob", "blob"].contains(&decl_type) {
171        return Ok(SqliteType::Blob);
172    }
173    if decl_type.starts_with("binary") || decl_type.starts_with("varbinary") {
174        return Ok(SqliteType::Blob);
175    }
176    Err(DataFusionError::NotImplemented(format!(
177        "Unsupported sqlite decl type: {decl_type}",
178    )))
179}
180
181fn build_remote_schema(columns: &[OwnedColumn], mut rows: Rows) -> DFResult<RemoteSchema> {
182    let mut remote_field_map = HashMap::with_capacity(columns.len());
183    let mut unknown_cols = vec![];
184    for (col_idx, col) in columns.iter().enumerate() {
185        if let Some(decl_type) = &col.decl_type {
186            let remote_type =
187                RemoteType::Sqlite(decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?);
188            remote_field_map.insert(col_idx, RemoteField::new(&col.name, remote_type, true));
189        } else {
190            unknown_cols.push(col_idx);
191        }
192    }
193
194    if !unknown_cols.is_empty() {
195        while let Some(row) = rows.next().map_err(|e| {
196            DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
197        })? {
198            let mut to_be_removed = vec![];
199            for col_idx in unknown_cols.iter() {
200                let value_ref = row.get_ref(*col_idx).map_err(|e| {
201                    DataFusionError::Execution(format!(
202                        "Failed to get value ref for column {col_idx}: {e:?}"
203                    ))
204                })?;
205                match value_ref {
206                    ValueRef::Null => {}
207                    ValueRef::Integer(_) => {
208                        remote_field_map.insert(
209                            *col_idx,
210                            RemoteField::new(
211                                columns[*col_idx].name.clone(),
212                                RemoteType::Sqlite(SqliteType::Integer),
213                                true,
214                            ),
215                        );
216                        to_be_removed.push(*col_idx);
217                    }
218                    ValueRef::Real(_) => {
219                        remote_field_map.insert(
220                            *col_idx,
221                            RemoteField::new(
222                                columns[*col_idx].name.clone(),
223                                RemoteType::Sqlite(SqliteType::Real),
224                                true,
225                            ),
226                        );
227                        to_be_removed.push(*col_idx);
228                    }
229                    ValueRef::Text(_) => {
230                        remote_field_map.insert(
231                            *col_idx,
232                            RemoteField::new(
233                                columns[*col_idx].name.clone(),
234                                RemoteType::Sqlite(SqliteType::Text),
235                                true,
236                            ),
237                        );
238                        to_be_removed.push(*col_idx);
239                    }
240                    ValueRef::Blob(_) => {
241                        remote_field_map.insert(
242                            *col_idx,
243                            RemoteField::new(
244                                columns[*col_idx].name.clone(),
245                                RemoteType::Sqlite(SqliteType::Blob),
246                                true,
247                            ),
248                        );
249                        to_be_removed.push(*col_idx);
250                    }
251                }
252            }
253            for col_idx in to_be_removed.iter() {
254                unknown_cols.retain(|&x| x != *col_idx);
255            }
256            if unknown_cols.is_empty() {
257                break;
258            }
259        }
260    }
261
262    if !unknown_cols.is_empty() {
263        return Err(DataFusionError::NotImplemented(format!(
264            "Failed to infer sqlite decl type for columns: {unknown_cols:?}"
265        )));
266    }
267    let remote_fields = remote_field_map
268        .into_iter()
269        .sorted_by_key(|entry| entry.0)
270        .map(|entry| entry.1)
271        .collect::<Vec<_>>();
272    Ok(RemoteSchema::new(remote_fields))
273}
274
275fn rows_to_batch(
276    mut rows: Rows,
277    table_schema: &SchemaRef,
278    columns: Vec<OwnedColumn>,
279    projection: Option<&Vec<usize>>,
280) -> DFResult<(RecordBatch, bool)> {
281    let projected_schema = project_schema(table_schema, projection)?;
282    let mut array_builders = vec![];
283    for field in table_schema.fields() {
284        let builder = make_builder(field.data_type(), 1000);
285        array_builders.push(builder);
286    }
287
288    let mut is_empty = true;
289    while let Some(row) = rows.next().map_err(|e| {
290        DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
291    })? {
292        is_empty = false;
293        append_rows_to_array_builders(
294            row,
295            table_schema,
296            &columns,
297            projection,
298            array_builders.as_mut_slice(),
299        )?;
300    }
301
302    let projected_columns = array_builders
303        .into_iter()
304        .enumerate()
305        .filter(|(idx, _)| projections_contains(projection, *idx))
306        .map(|(_, mut builder)| builder.finish())
307        .collect::<Vec<ArrayRef>>();
308    Ok((
309        RecordBatch::try_new(projected_schema, projected_columns)?,
310        is_empty,
311    ))
312}
313
314macro_rules! handle_primitive_type {
315    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
316        let builder = $builder
317            .as_any_mut()
318            .downcast_mut::<$builder_ty>()
319            .unwrap_or_else(|| {
320                panic!(
321                    "Failed to downcast builder to {} for {:?} and {:?}",
322                    stringify!($builder_ty),
323                    $field,
324                    $col
325                )
326            });
327
328        let v: Option<$value_ty> = $row.get($index).map_err(|e| {
329            DataFusionError::Execution(format!(
330                "Failed to get optional {} value for {:?} and {:?}: {e:?}",
331                stringify!($value_ty),
332                $field,
333                $col
334            ))
335        })?;
336
337        match v {
338            Some(v) => builder.append_value(v),
339            None => builder.append_null(),
340        }
341    }};
342}
343
344fn append_rows_to_array_builders(
345    row: &Row,
346    table_schema: &SchemaRef,
347    columns: &[OwnedColumn],
348    projection: Option<&Vec<usize>>,
349    array_builders: &mut [Box<dyn ArrayBuilder>],
350) -> DFResult<()> {
351    for (idx, field) in table_schema.fields.iter().enumerate() {
352        if !projections_contains(projection, idx) {
353            continue;
354        }
355        let builder = &mut array_builders[idx];
356        let col = columns.get(idx);
357        match field.data_type() {
358            DataType::Null => {
359                let builder = builder
360                    .as_any_mut()
361                    .downcast_mut::<NullBuilder>()
362                    .expect("Failed to downcast builder to NullBuilder");
363                builder.append_null();
364            }
365            DataType::Int32 => {
366                handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx);
367            }
368            DataType::Int64 => {
369                handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx);
370            }
371            DataType::Float64 => {
372                handle_primitive_type!(builder, field, col, Float64Builder, f64, row, idx);
373            }
374            DataType::Utf8 => {
375                handle_primitive_type!(builder, field, col, StringBuilder, String, row, idx);
376            }
377            DataType::Binary => {
378                handle_primitive_type!(builder, field, col, BinaryBuilder, Vec<u8>, row, idx);
379            }
380            _ => {
381                return Err(DataFusionError::NotImplemented(format!(
382                    "Unsupported data type {:?} for col: {:?}",
383                    field.data_type(),
384                    col
385                )));
386            }
387        }
388    }
389    Ok(())
390}