datafusion_remote_table/connection/
sqlite.rs

1use crate::connection::{RemoteDbType, projections_contains};
2use crate::{
3    Connection, ConnectionOptions, DFResult, Pool, RemoteField, RemoteSchema, RemoteSchemaRef,
4    RemoteType, SqliteType,
5};
6use datafusion::arrow::array::{
7    ArrayBuilder, ArrayRef, BinaryBuilder, Float64Builder, Int32Builder, Int64Builder, NullBuilder,
8    RecordBatch, RecordBatchOptions, StringBuilder, make_builder,
9};
10use datafusion::arrow::datatypes::{DataType, SchemaRef};
11use datafusion::common::{DataFusionError, project_schema};
12use datafusion::execution::SendableRecordBatchStream;
13use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
14use derive_getters::Getters;
15use derive_with::With;
16use itertools::Itertools;
17use log::debug;
18use rusqlite::types::ValueRef;
19use rusqlite::{Column, Row, Rows};
20use std::collections::HashMap;
21use std::path::PathBuf;
22use std::sync::Arc;
23
24#[derive(Debug, Clone, With, Getters)]
25pub struct SqliteConnectionOptions {
26    pub path: PathBuf,
27    pub stream_chunk_size: usize,
28}
29
30impl SqliteConnectionOptions {
31    pub fn new(path: PathBuf) -> Self {
32        Self {
33            path,
34            stream_chunk_size: 2048,
35        }
36    }
37}
38
39#[derive(Debug)]
40pub struct SqlitePool {
41    pool: tokio_rusqlite::Connection,
42}
43
44pub async fn connect_sqlite(options: &SqliteConnectionOptions) -> DFResult<SqlitePool> {
45    let pool = tokio_rusqlite::Connection::open(&options.path)
46        .await
47        .map_err(|e| {
48            DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
49        })?;
50    Ok(SqlitePool { pool })
51}
52
53#[async_trait::async_trait]
54impl Pool for SqlitePool {
55    async fn get(&self) -> DFResult<Arc<dyn Connection>> {
56        let conn = self.pool.clone();
57        Ok(Arc::new(SqliteConnection { conn }))
58    }
59}
60
61#[derive(Debug)]
62pub struct SqliteConnection {
63    conn: tokio_rusqlite::Connection,
64}
65
66#[async_trait::async_trait]
67impl Connection for SqliteConnection {
68    async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
69        let sql = RemoteDbType::Sqlite.query_limit_1(sql);
70        self.conn
71            .call(move |conn| {
72                let mut stmt = conn.prepare(&sql)?;
73                let columns: Vec<OwnedColumn> =
74                    stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
75                let rows = stmt.query([])?;
76
77                let remote_schema = Arc::new(
78                    build_remote_schema(columns.as_slice(), rows)
79                        .map_err(|e| tokio_rusqlite::Error::Other(Box::new(e)))?,
80                );
81                Ok(remote_schema)
82            })
83            .await
84            .map_err(|e| DataFusionError::Execution(format!("Failed to infer schema: {e:?}")))
85    }
86
87    async fn query(
88        &self,
89        conn_options: &ConnectionOptions,
90        sql: &str,
91        table_schema: SchemaRef,
92        projection: Option<&Vec<usize>>,
93        unparsed_filters: &[String],
94        limit: Option<usize>,
95    ) -> DFResult<SendableRecordBatchStream> {
96        let projected_schema = project_schema(&table_schema, projection)?;
97        let sql = RemoteDbType::Sqlite.rewrite_query(sql, unparsed_filters, limit);
98        debug!("[remote-table] executing sqlite query: {sql}");
99
100        let conn = self.conn.clone();
101        let projection = projection.cloned();
102        let limit = conn_options.stream_chunk_size();
103        let stream = async_stream::stream! {
104            let mut offset = 0;
105            loop {
106                let sql = format!("SELECT * FROM ({sql}) LIMIT {limit} OFFSET {offset}");
107                let sql_clone = sql.clone();
108                let conn = conn.clone();
109                let projection = projection.clone();
110                let table_schema = table_schema.clone();
111                let (batch, is_empty) = conn
112                    .call(move |conn| {
113                        let mut stmt = conn.prepare(&sql)?;
114                        let columns: Vec<OwnedColumn> =
115                            stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
116                        let rows = stmt.query([])?;
117
118                        rows_to_batch(rows, &table_schema, columns, projection.as_ref())
119                            .map_err(|e| tokio_rusqlite::Error::Other(e.into()))
120                })
121                .await
122                .map_err(|e| {
123                    DataFusionError::Execution(format!(
124                        "Failed to execute query {sql_clone} on sqlite: {e:?}"
125                    ))
126                })?;
127                if is_empty {
128                    break;
129                }
130                yield Ok(batch);
131                offset += limit;
132            }
133        };
134        Ok(Box::pin(RecordBatchStreamAdapter::new(
135            projected_schema,
136            stream,
137        )))
138    }
139}
140
141#[derive(Debug)]
142struct OwnedColumn {
143    name: String,
144    decl_type: Option<String>,
145}
146
147fn sqlite_col_to_owned_col(sqlite_col: &Column) -> OwnedColumn {
148    OwnedColumn {
149        name: sqlite_col.name().to_string(),
150        decl_type: sqlite_col.decl_type().map(|x| x.to_string()),
151    }
152}
153
154fn decl_type_to_remote_type(decl_type: &str) -> DFResult<SqliteType> {
155    if ["tinyint", "smallint", "int", "integer", "bigint"].contains(&decl_type) {
156        return Ok(SqliteType::Integer);
157    }
158    if ["real", "float", "double"].contains(&decl_type) {
159        return Ok(SqliteType::Real);
160    }
161    if decl_type.starts_with("real") {
162        return Ok(SqliteType::Real);
163    }
164    if ["text", "varchar", "char", "string"].contains(&decl_type) {
165        return Ok(SqliteType::Text);
166    }
167    if decl_type.starts_with("char")
168        || decl_type.starts_with("varchar")
169        || decl_type.starts_with("text")
170    {
171        return Ok(SqliteType::Text);
172    }
173    if ["binary", "varbinary", "tinyblob", "blob"].contains(&decl_type) {
174        return Ok(SqliteType::Blob);
175    }
176    if decl_type.starts_with("binary") || decl_type.starts_with("varbinary") {
177        return Ok(SqliteType::Blob);
178    }
179    Err(DataFusionError::NotImplemented(format!(
180        "Unsupported sqlite decl type: {decl_type}",
181    )))
182}
183
184fn build_remote_schema(columns: &[OwnedColumn], mut rows: Rows) -> DFResult<RemoteSchema> {
185    let mut remote_field_map = HashMap::with_capacity(columns.len());
186    let mut unknown_cols = vec![];
187    for (col_idx, col) in columns.iter().enumerate() {
188        if let Some(decl_type) = &col.decl_type {
189            let remote_type =
190                RemoteType::Sqlite(decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?);
191            remote_field_map.insert(col_idx, RemoteField::new(&col.name, remote_type, true));
192        } else {
193            unknown_cols.push(col_idx);
194        }
195    }
196
197    if !unknown_cols.is_empty() {
198        while let Some(row) = rows.next().map_err(|e| {
199            DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
200        })? {
201            let mut to_be_removed = vec![];
202            for col_idx in unknown_cols.iter() {
203                let value_ref = row.get_ref(*col_idx).map_err(|e| {
204                    DataFusionError::Execution(format!(
205                        "Failed to get value ref for column {col_idx}: {e:?}"
206                    ))
207                })?;
208                match value_ref {
209                    ValueRef::Null => {}
210                    ValueRef::Integer(_) => {
211                        remote_field_map.insert(
212                            *col_idx,
213                            RemoteField::new(
214                                columns[*col_idx].name.clone(),
215                                RemoteType::Sqlite(SqliteType::Integer),
216                                true,
217                            ),
218                        );
219                        to_be_removed.push(*col_idx);
220                    }
221                    ValueRef::Real(_) => {
222                        remote_field_map.insert(
223                            *col_idx,
224                            RemoteField::new(
225                                columns[*col_idx].name.clone(),
226                                RemoteType::Sqlite(SqliteType::Real),
227                                true,
228                            ),
229                        );
230                        to_be_removed.push(*col_idx);
231                    }
232                    ValueRef::Text(_) => {
233                        remote_field_map.insert(
234                            *col_idx,
235                            RemoteField::new(
236                                columns[*col_idx].name.clone(),
237                                RemoteType::Sqlite(SqliteType::Text),
238                                true,
239                            ),
240                        );
241                        to_be_removed.push(*col_idx);
242                    }
243                    ValueRef::Blob(_) => {
244                        remote_field_map.insert(
245                            *col_idx,
246                            RemoteField::new(
247                                columns[*col_idx].name.clone(),
248                                RemoteType::Sqlite(SqliteType::Blob),
249                                true,
250                            ),
251                        );
252                        to_be_removed.push(*col_idx);
253                    }
254                }
255            }
256            for col_idx in to_be_removed.iter() {
257                unknown_cols.retain(|&x| x != *col_idx);
258            }
259            if unknown_cols.is_empty() {
260                break;
261            }
262        }
263    }
264
265    if !unknown_cols.is_empty() {
266        return Err(DataFusionError::NotImplemented(format!(
267            "Failed to infer sqlite decl type for columns: {unknown_cols:?}"
268        )));
269    }
270    let remote_fields = remote_field_map
271        .into_iter()
272        .sorted_by_key(|entry| entry.0)
273        .map(|entry| entry.1)
274        .collect::<Vec<_>>();
275    Ok(RemoteSchema::new(remote_fields))
276}
277
278fn rows_to_batch(
279    mut rows: Rows,
280    table_schema: &SchemaRef,
281    columns: Vec<OwnedColumn>,
282    projection: Option<&Vec<usize>>,
283) -> DFResult<(RecordBatch, bool)> {
284    let projected_schema = project_schema(table_schema, projection)?;
285    let mut array_builders = vec![];
286    for field in table_schema.fields() {
287        let builder = make_builder(field.data_type(), 1000);
288        array_builders.push(builder);
289    }
290
291    let mut is_empty = true;
292    let mut row_count = 0;
293    while let Some(row) = rows.next().map_err(|e| {
294        DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
295    })? {
296        is_empty = false;
297        row_count += 1;
298        append_rows_to_array_builders(
299            row,
300            table_schema,
301            &columns,
302            projection,
303            array_builders.as_mut_slice(),
304        )?;
305    }
306
307    let projected_columns = array_builders
308        .into_iter()
309        .enumerate()
310        .filter(|(idx, _)| projections_contains(projection, *idx))
311        .map(|(_, mut builder)| builder.finish())
312        .collect::<Vec<ArrayRef>>();
313    let options = RecordBatchOptions::new().with_row_count(Some(row_count));
314    Ok((
315        RecordBatch::try_new_with_options(projected_schema, projected_columns, &options)?,
316        is_empty,
317    ))
318}
319
320macro_rules! handle_primitive_type {
321    ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
322        let builder = $builder
323            .as_any_mut()
324            .downcast_mut::<$builder_ty>()
325            .unwrap_or_else(|| {
326                panic!(
327                    "Failed to downcast builder to {} for {:?} and {:?}",
328                    stringify!($builder_ty),
329                    $field,
330                    $col
331                )
332            });
333
334        let v: Option<$value_ty> = $row.get($index).map_err(|e| {
335            DataFusionError::Execution(format!(
336                "Failed to get optional {} value for {:?} and {:?}: {e:?}",
337                stringify!($value_ty),
338                $field,
339                $col
340            ))
341        })?;
342
343        match v {
344            Some(v) => builder.append_value(v),
345            None => builder.append_null(),
346        }
347    }};
348}
349
350fn append_rows_to_array_builders(
351    row: &Row,
352    table_schema: &SchemaRef,
353    columns: &[OwnedColumn],
354    projection: Option<&Vec<usize>>,
355    array_builders: &mut [Box<dyn ArrayBuilder>],
356) -> DFResult<()> {
357    for (idx, field) in table_schema.fields.iter().enumerate() {
358        if !projections_contains(projection, idx) {
359            continue;
360        }
361        let builder = &mut array_builders[idx];
362        let col = columns.get(idx);
363        match field.data_type() {
364            DataType::Null => {
365                let builder = builder
366                    .as_any_mut()
367                    .downcast_mut::<NullBuilder>()
368                    .expect("Failed to downcast builder to NullBuilder");
369                builder.append_null();
370            }
371            DataType::Int32 => {
372                handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx);
373            }
374            DataType::Int64 => {
375                handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx);
376            }
377            DataType::Float64 => {
378                handle_primitive_type!(builder, field, col, Float64Builder, f64, row, idx);
379            }
380            DataType::Utf8 => {
381                handle_primitive_type!(builder, field, col, StringBuilder, String, row, idx);
382            }
383            DataType::Binary => {
384                handle_primitive_type!(builder, field, col, BinaryBuilder, Vec<u8>, row, idx);
385            }
386            _ => {
387                return Err(DataFusionError::NotImplemented(format!(
388                    "Unsupported data type {:?} for col: {:?}",
389                    field.data_type(),
390                    col
391                )));
392            }
393        }
394    }
395    Ok(())
396}