1use crate::connection::{RemoteDbType, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, Pool, RemoteField, RemoteSchema, RemoteSchemaRef,
4 RemoteType, SqliteType,
5};
6use datafusion::arrow::array::{
7 ArrayBuilder, ArrayRef, BinaryBuilder, Float64Builder, Int32Builder, Int64Builder, NullBuilder,
8 RecordBatch, StringBuilder, make_builder,
9};
10use datafusion::arrow::datatypes::{DataType, SchemaRef};
11use datafusion::common::{DataFusionError, project_schema};
12use datafusion::execution::SendableRecordBatchStream;
13use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
14use datafusion::prelude::Expr;
15use derive_getters::Getters;
16use derive_with::With;
17use itertools::Itertools;
18use rusqlite::types::ValueRef;
19use rusqlite::{Column, Row, Rows};
20use std::collections::HashMap;
21use std::path::PathBuf;
22use std::sync::Arc;
23
24#[derive(Debug, Clone, With, Getters)]
25pub struct SqliteConnectionOptions {
26 pub path: PathBuf,
27 pub stream_chunk_size: usize,
28}
29
30impl SqliteConnectionOptions {
31 pub fn new(path: PathBuf) -> Self {
32 Self {
33 path,
34 stream_chunk_size: 2048,
35 }
36 }
37}
38
39#[derive(Debug)]
40pub struct SqlitePool {
41 pool: tokio_rusqlite::Connection,
42}
43
44pub async fn connect_sqlite(options: &SqliteConnectionOptions) -> DFResult<SqlitePool> {
45 let pool = tokio_rusqlite::Connection::open(&options.path)
46 .await
47 .map_err(|e| {
48 DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
49 })?;
50 Ok(SqlitePool { pool })
51}
52
53#[async_trait::async_trait]
54impl Pool for SqlitePool {
55 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
56 let conn = self.pool.clone();
57 Ok(Arc::new(SqliteConnection { conn }))
58 }
59}
60
61#[derive(Debug)]
62pub struct SqliteConnection {
63 conn: tokio_rusqlite::Connection,
64}
65
66#[async_trait::async_trait]
67impl Connection for SqliteConnection {
68 async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
69 let sql = RemoteDbType::Sqlite.query_limit_1(sql)?;
70 self.conn
71 .call(move |conn| {
72 let mut stmt = conn.prepare(&sql)?;
73 let columns: Vec<OwnedColumn> =
74 stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
75 let rows = stmt.query([])?;
76
77 let remote_schema = Arc::new(
78 build_remote_schema(columns.as_slice(), rows)
79 .map_err(|e| tokio_rusqlite::Error::Other(Box::new(e)))?,
80 );
81 Ok(remote_schema)
82 })
83 .await
84 .map_err(|e| DataFusionError::Execution(format!("Failed to infer schema: {e:?}")))
85 }
86
87 async fn query(
88 &self,
89 conn_options: &ConnectionOptions,
90 sql: &str,
91 table_schema: SchemaRef,
92 projection: Option<&Vec<usize>>,
93 filters: &[Expr],
94 limit: Option<usize>,
95 ) -> DFResult<SendableRecordBatchStream> {
96 let projected_schema = project_schema(&table_schema, projection)?;
97 let sql = RemoteDbType::Sqlite.try_rewrite_query(sql, filters, limit)?;
98 let conn = self.conn.clone();
99 let projection = projection.cloned();
100 let limit = conn_options.stream_chunk_size();
101 let stream = async_stream::stream! {
102 let mut offset = 0;
103 loop {
104 let sql = format!("SELECT * FROM ({sql}) LIMIT {limit} OFFSET {offset}");
105 let sql_clone = sql.clone();
106 let conn = conn.clone();
107 let projection = projection.clone();
108 let table_schema = table_schema.clone();
109 let (batch, is_empty) = conn
110 .call(move |conn| {
111 let mut stmt = conn.prepare(&sql)?;
112 let columns: Vec<OwnedColumn> =
113 stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
114 let rows = stmt.query([])?;
115
116 rows_to_batch(rows, &table_schema, columns, projection.as_ref())
117 .map_err(|e| tokio_rusqlite::Error::Other(e.into()))
118 })
119 .await
120 .map_err(|e| {
121 DataFusionError::Execution(format!(
122 "Failed to execute query {sql_clone} on sqlite: {e:?}"
123 ))
124 })?;
125 if is_empty {
126 break;
127 }
128 yield Ok(batch);
129 offset += limit;
130 }
131 };
132 Ok(Box::pin(RecordBatchStreamAdapter::new(
133 projected_schema,
134 stream,
135 )))
136 }
137}
138
139#[derive(Debug)]
140struct OwnedColumn {
141 name: String,
142 decl_type: Option<String>,
143}
144
145fn sqlite_col_to_owned_col(sqlite_col: &Column) -> OwnedColumn {
146 OwnedColumn {
147 name: sqlite_col.name().to_string(),
148 decl_type: sqlite_col.decl_type().map(|x| x.to_string()),
149 }
150}
151
152fn decl_type_to_remote_type(decl_type: &str) -> DFResult<SqliteType> {
153 if ["tinyint", "smallint", "int", "integer", "bigint"].contains(&decl_type) {
154 return Ok(SqliteType::Integer);
155 }
156 if ["real", "float", "double"].contains(&decl_type) {
157 return Ok(SqliteType::Real);
158 }
159 if decl_type.starts_with("real") {
160 return Ok(SqliteType::Real);
161 }
162 if ["text", "varchar", "char", "string"].contains(&decl_type) {
163 return Ok(SqliteType::Text);
164 }
165 if decl_type.starts_with("char")
166 || decl_type.starts_with("varchar")
167 || decl_type.starts_with("text")
168 {
169 return Ok(SqliteType::Text);
170 }
171 if ["binary", "varbinary", "tinyblob", "blob"].contains(&decl_type) {
172 return Ok(SqliteType::Blob);
173 }
174 if decl_type.starts_with("binary") || decl_type.starts_with("varbinary") {
175 return Ok(SqliteType::Blob);
176 }
177 Err(DataFusionError::NotImplemented(format!(
178 "Unsupported sqlite decl type: {decl_type}",
179 )))
180}
181
182fn build_remote_schema(columns: &[OwnedColumn], mut rows: Rows) -> DFResult<RemoteSchema> {
183 let mut remote_field_map = HashMap::with_capacity(columns.len());
184 let mut unknown_cols = vec![];
185 for (col_idx, col) in columns.iter().enumerate() {
186 if let Some(decl_type) = &col.decl_type {
187 let remote_type =
188 RemoteType::Sqlite(decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?);
189 remote_field_map.insert(col_idx, RemoteField::new(&col.name, remote_type, true));
190 } else {
191 unknown_cols.push(col_idx);
192 }
193 }
194
195 if !unknown_cols.is_empty() {
196 while let Some(row) = rows.next().map_err(|e| {
197 DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
198 })? {
199 let mut to_be_removed = vec![];
200 for col_idx in unknown_cols.iter() {
201 let value_ref = row.get_ref(*col_idx).map_err(|e| {
202 DataFusionError::Execution(format!(
203 "Failed to get value ref for column {col_idx}: {e:?}"
204 ))
205 })?;
206 match value_ref {
207 ValueRef::Null => {}
208 ValueRef::Integer(_) => {
209 remote_field_map.insert(
210 *col_idx,
211 RemoteField::new(
212 columns[*col_idx].name.clone(),
213 RemoteType::Sqlite(SqliteType::Integer),
214 true,
215 ),
216 );
217 to_be_removed.push(*col_idx);
218 }
219 ValueRef::Real(_) => {
220 remote_field_map.insert(
221 *col_idx,
222 RemoteField::new(
223 columns[*col_idx].name.clone(),
224 RemoteType::Sqlite(SqliteType::Real),
225 true,
226 ),
227 );
228 to_be_removed.push(*col_idx);
229 }
230 ValueRef::Text(_) => {
231 remote_field_map.insert(
232 *col_idx,
233 RemoteField::new(
234 columns[*col_idx].name.clone(),
235 RemoteType::Sqlite(SqliteType::Text),
236 true,
237 ),
238 );
239 to_be_removed.push(*col_idx);
240 }
241 ValueRef::Blob(_) => {
242 remote_field_map.insert(
243 *col_idx,
244 RemoteField::new(
245 columns[*col_idx].name.clone(),
246 RemoteType::Sqlite(SqliteType::Blob),
247 true,
248 ),
249 );
250 to_be_removed.push(*col_idx);
251 }
252 }
253 }
254 for col_idx in to_be_removed.iter() {
255 unknown_cols.retain(|&x| x != *col_idx);
256 }
257 if unknown_cols.is_empty() {
258 break;
259 }
260 }
261 }
262
263 if !unknown_cols.is_empty() {
264 return Err(DataFusionError::NotImplemented(format!(
265 "Failed to infer sqlite decl type for columns: {unknown_cols:?}"
266 )));
267 }
268 let remote_fields = remote_field_map
269 .into_iter()
270 .sorted_by_key(|entry| entry.0)
271 .map(|entry| entry.1)
272 .collect::<Vec<_>>();
273 Ok(RemoteSchema::new(remote_fields))
274}
275
276fn rows_to_batch(
277 mut rows: Rows,
278 table_schema: &SchemaRef,
279 columns: Vec<OwnedColumn>,
280 projection: Option<&Vec<usize>>,
281) -> DFResult<(RecordBatch, bool)> {
282 let projected_schema = project_schema(table_schema, projection)?;
283 let mut array_builders = vec![];
284 for field in table_schema.fields() {
285 let builder = make_builder(field.data_type(), 1000);
286 array_builders.push(builder);
287 }
288
289 let mut is_empty = true;
290 while let Some(row) = rows.next().map_err(|e| {
291 DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
292 })? {
293 is_empty = false;
294 append_rows_to_array_builders(
295 row,
296 table_schema,
297 &columns,
298 projection,
299 array_builders.as_mut_slice(),
300 )?;
301 }
302
303 let projected_columns = array_builders
304 .into_iter()
305 .enumerate()
306 .filter(|(idx, _)| projections_contains(projection, *idx))
307 .map(|(_, mut builder)| builder.finish())
308 .collect::<Vec<ArrayRef>>();
309 Ok((
310 RecordBatch::try_new(projected_schema, projected_columns)?,
311 is_empty,
312 ))
313}
314
315macro_rules! handle_primitive_type {
316 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
317 let builder = $builder
318 .as_any_mut()
319 .downcast_mut::<$builder_ty>()
320 .unwrap_or_else(|| {
321 panic!(
322 "Failed to downcast builder to {} for {:?} and {:?}",
323 stringify!($builder_ty),
324 $field,
325 $col
326 )
327 });
328
329 let v: Option<$value_ty> = $row.get($index).map_err(|e| {
330 DataFusionError::Execution(format!(
331 "Failed to get optional {} value for {:?} and {:?}: {e:?}",
332 stringify!($value_ty),
333 $field,
334 $col
335 ))
336 })?;
337
338 match v {
339 Some(v) => builder.append_value(v),
340 None => builder.append_null(),
341 }
342 }};
343}
344
345fn append_rows_to_array_builders(
346 row: &Row,
347 table_schema: &SchemaRef,
348 columns: &[OwnedColumn],
349 projection: Option<&Vec<usize>>,
350 array_builders: &mut [Box<dyn ArrayBuilder>],
351) -> DFResult<()> {
352 for (idx, field) in table_schema.fields.iter().enumerate() {
353 if !projections_contains(projection, idx) {
354 continue;
355 }
356 let builder = &mut array_builders[idx];
357 let col = columns.get(idx);
358 match field.data_type() {
359 DataType::Null => {
360 let builder = builder
361 .as_any_mut()
362 .downcast_mut::<NullBuilder>()
363 .expect("Failed to downcast builder to NullBuilder");
364 builder.append_null();
365 }
366 DataType::Int32 => {
367 handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx);
368 }
369 DataType::Int64 => {
370 handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx);
371 }
372 DataType::Float64 => {
373 handle_primitive_type!(builder, field, col, Float64Builder, f64, row, idx);
374 }
375 DataType::Utf8 => {
376 handle_primitive_type!(builder, field, col, StringBuilder, String, row, idx);
377 }
378 DataType::Binary => {
379 handle_primitive_type!(builder, field, col, BinaryBuilder, Vec<u8>, row, idx);
380 }
381 _ => {
382 return Err(DataFusionError::NotImplemented(format!(
383 "Unsupported data type {:?} for col: {:?}",
384 field.data_type(),
385 col
386 )));
387 }
388 }
389 }
390 Ok(())
391}