1use crate::connection::{RemoteDbType, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, Pool, RemoteField, RemoteSchema, RemoteSchemaRef,
4 RemoteType, SqliteType,
5};
6use datafusion::arrow::array::{
7 ArrayBuilder, ArrayRef, BinaryBuilder, Float64Builder, Int32Builder, Int64Builder, NullBuilder,
8 RecordBatch, RecordBatchOptions, StringBuilder, make_builder,
9};
10use datafusion::arrow::datatypes::{DataType, SchemaRef};
11use datafusion::common::{DataFusionError, project_schema};
12use datafusion::execution::SendableRecordBatchStream;
13use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
14use derive_getters::Getters;
15use derive_with::With;
16use itertools::Itertools;
17use log::{debug, error};
18use rusqlite::types::ValueRef;
19use rusqlite::{Column, Row, Rows};
20use std::any::Any;
21use std::collections::HashMap;
22use std::path::PathBuf;
23use std::sync::Arc;
24
25#[derive(Debug, Clone, With, Getters)]
26pub struct SqliteConnectionOptions {
27 pub path: PathBuf,
28 pub stream_chunk_size: usize,
29}
30
31impl SqliteConnectionOptions {
32 pub fn new(path: PathBuf) -> Self {
33 Self {
34 path,
35 stream_chunk_size: 2048,
36 }
37 }
38}
39
40impl From<SqliteConnectionOptions> for ConnectionOptions {
41 fn from(options: SqliteConnectionOptions) -> Self {
42 ConnectionOptions::Sqlite(options)
43 }
44}
45
46#[derive(Debug)]
47pub struct SqlitePool {
48 path: PathBuf,
49}
50
51pub async fn connect_sqlite(options: &SqliteConnectionOptions) -> DFResult<SqlitePool> {
52 let _ = rusqlite::Connection::open(&options.path).map_err(|e| {
53 DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
54 })?;
55 Ok(SqlitePool {
56 path: options.path.clone(),
57 })
58}
59
60#[async_trait::async_trait]
61impl Pool for SqlitePool {
62 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
63 Ok(Arc::new(SqliteConnection {
64 path: self.path.clone(),
65 }))
66 }
67}
68
69#[derive(Debug)]
70pub struct SqliteConnection {
71 path: PathBuf,
72}
73
74#[async_trait::async_trait]
75impl Connection for SqliteConnection {
76 fn as_any(&self) -> &dyn Any {
77 self
78 }
79
80 async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
81 let sql = RemoteDbType::Sqlite.query_limit_1(sql);
82 let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
83 DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
84 })?;
85 let mut stmt = conn.prepare(&sql).map_err(|e| {
86 DataFusionError::Execution(format!("Failed to prepare sqlite statement: {e:?}"))
87 })?;
88 let columns: Vec<OwnedColumn> =
89 stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
90 let rows = stmt.query([]).map_err(|e| {
91 DataFusionError::Execution(format!("Failed to query sqlite statement: {e:?}"))
92 })?;
93
94 let remote_schema = Arc::new(build_remote_schema(columns.as_slice(), rows)?);
95 Ok(remote_schema)
96 }
97
98 async fn query(
99 &self,
100 conn_options: &ConnectionOptions,
101 sql: &str,
102 table_schema: SchemaRef,
103 projection: Option<&Vec<usize>>,
104 unparsed_filters: &[String],
105 limit: Option<usize>,
106 ) -> DFResult<SendableRecordBatchStream> {
107 let projected_schema = project_schema(&table_schema, projection)?;
108 let sql = RemoteDbType::Sqlite.rewrite_query(sql, unparsed_filters, limit);
109 debug!("[remote-table] executing sqlite query: {sql}");
110
111 let (tx, mut rx) = tokio::sync::mpsc::channel::<DFResult<RecordBatch>>(1);
112 let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
113 DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
114 })?;
115
116 let projection = projection.cloned();
117 let chunk_size = conn_options.stream_chunk_size();
118
119 spawn_background_task(tx, conn, sql, table_schema, projection, chunk_size);
120
121 let stream = async_stream::stream! {
122 while let Some(batch) = rx.recv().await {
123 yield batch;
124 }
125 };
126 Ok(Box::pin(RecordBatchStreamAdapter::new(
127 projected_schema,
128 stream,
129 )))
130 }
131}
132
133#[derive(Debug)]
134struct OwnedColumn {
135 name: String,
136 decl_type: Option<String>,
137}
138
139fn sqlite_col_to_owned_col(sqlite_col: &Column) -> OwnedColumn {
140 OwnedColumn {
141 name: sqlite_col.name().to_string(),
142 decl_type: sqlite_col.decl_type().map(|x| x.to_string()),
143 }
144}
145
146fn decl_type_to_remote_type(decl_type: &str) -> DFResult<SqliteType> {
147 if [
148 "tinyint", "smallint", "int", "integer", "bigint", "int2", "int4", "int8",
149 ]
150 .contains(&decl_type)
151 {
152 return Ok(SqliteType::Integer);
153 }
154 if ["real", "float", "double", "numeric"].contains(&decl_type) {
155 return Ok(SqliteType::Real);
156 }
157 if decl_type.starts_with("real") || decl_type.starts_with("numeric") {
158 return Ok(SqliteType::Real);
159 }
160 if ["text", "varchar", "char", "string"].contains(&decl_type) {
161 return Ok(SqliteType::Text);
162 }
163 if decl_type.starts_with("char")
164 || decl_type.starts_with("varchar")
165 || decl_type.starts_with("text")
166 {
167 return Ok(SqliteType::Text);
168 }
169 if ["binary", "varbinary", "tinyblob", "blob"].contains(&decl_type) {
170 return Ok(SqliteType::Blob);
171 }
172 if decl_type.starts_with("binary") || decl_type.starts_with("varbinary") {
173 return Ok(SqliteType::Blob);
174 }
175 Err(DataFusionError::NotImplemented(format!(
176 "Unsupported sqlite decl type: {decl_type}",
177 )))
178}
179
180fn build_remote_schema(columns: &[OwnedColumn], mut rows: Rows) -> DFResult<RemoteSchema> {
181 let mut remote_field_map = HashMap::with_capacity(columns.len());
182 let mut unknown_cols = vec![];
183 for (col_idx, col) in columns.iter().enumerate() {
184 if let Some(decl_type) = &col.decl_type {
185 let remote_type =
186 RemoteType::Sqlite(decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?);
187 remote_field_map.insert(col_idx, RemoteField::new(&col.name, remote_type, true));
188 } else {
189 unknown_cols.push(col_idx);
191 }
192 }
193
194 if !unknown_cols.is_empty() {
195 while let Some(row) = rows.next().map_err(|e| {
196 DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
197 })? {
198 let mut to_be_removed = vec![];
199 for col_idx in unknown_cols.iter() {
200 let value_ref = row.get_ref(*col_idx).map_err(|e| {
201 DataFusionError::Execution(format!(
202 "Failed to get value ref for column {col_idx}: {e:?}"
203 ))
204 })?;
205 match value_ref {
206 ValueRef::Null => {}
207 ValueRef::Integer(_) => {
208 remote_field_map.insert(
209 *col_idx,
210 RemoteField::new(
211 columns[*col_idx].name.clone(),
212 RemoteType::Sqlite(SqliteType::Integer),
213 true,
214 ),
215 );
216 to_be_removed.push(*col_idx);
217 }
218 ValueRef::Real(_) => {
219 remote_field_map.insert(
220 *col_idx,
221 RemoteField::new(
222 columns[*col_idx].name.clone(),
223 RemoteType::Sqlite(SqliteType::Real),
224 true,
225 ),
226 );
227 to_be_removed.push(*col_idx);
228 }
229 ValueRef::Text(_) => {
230 remote_field_map.insert(
231 *col_idx,
232 RemoteField::new(
233 columns[*col_idx].name.clone(),
234 RemoteType::Sqlite(SqliteType::Text),
235 true,
236 ),
237 );
238 to_be_removed.push(*col_idx);
239 }
240 ValueRef::Blob(_) => {
241 remote_field_map.insert(
242 *col_idx,
243 RemoteField::new(
244 columns[*col_idx].name.clone(),
245 RemoteType::Sqlite(SqliteType::Blob),
246 true,
247 ),
248 );
249 to_be_removed.push(*col_idx);
250 }
251 }
252 }
253 for col_idx in to_be_removed.iter() {
254 unknown_cols.retain(|&x| x != *col_idx);
255 }
256 if unknown_cols.is_empty() {
257 break;
258 }
259 }
260 }
261
262 if !unknown_cols.is_empty() {
263 return Err(DataFusionError::NotImplemented(format!(
264 "Failed to infer sqlite decl type for columns: {unknown_cols:?}"
265 )));
266 }
267 let remote_fields = remote_field_map
268 .into_iter()
269 .sorted_by_key(|entry| entry.0)
270 .map(|entry| entry.1)
271 .collect::<Vec<_>>();
272 Ok(RemoteSchema::new(remote_fields))
273}
274
275fn spawn_background_task(
276 tx: tokio::sync::mpsc::Sender<DFResult<RecordBatch>>,
277 conn: rusqlite::Connection,
278 sql: String,
279 table_schema: SchemaRef,
280 projection: Option<Vec<usize>>,
281 chunk_size: usize,
282) {
283 std::thread::spawn(move || {
284 let runtime = match tokio::runtime::Builder::new_current_thread().build() {
285 Ok(runtime) => runtime,
286 Err(e) => {
287 error!("Failed to create tokio runtime to run sqlite query: {e:?}");
288 return;
289 }
290 };
291 let local_set = tokio::task::LocalSet::new();
292 local_set.block_on(&runtime, async move {
293 let mut stmt = match conn.prepare(&sql) {
294 Ok(stmt) => stmt,
295 Err(e) => {
296 let _ = tx
297 .send(Err(DataFusionError::Execution(format!(
298 "Failed to prepare sqlite statement: {e:?}"
299 ))))
300 .await;
301 return;
302 }
303 };
304 let columns: Vec<OwnedColumn> =
305 stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
306 let mut rows = match stmt.query([]) {
307 Ok(rows) => rows,
308 Err(e) => {
309 let _ = tx
310 .send(Err(DataFusionError::Execution(format!(
311 "Failed to query sqlite statement: {e:?}"
312 ))))
313 .await;
314 return;
315 }
316 };
317
318 loop {
319 let (batch, is_empty) = match rows_to_batch(
320 &mut rows,
321 &table_schema,
322 &columns,
323 projection.as_ref(),
324 chunk_size,
325 ) {
326 Ok((batch, is_empty)) => (batch, is_empty),
327 Err(e) => {
328 let _ = tx
329 .send(Err(DataFusionError::Execution(format!(
330 "Failed to convert rows to batch: {e:?}"
331 ))))
332 .await;
333 return;
334 }
335 };
336 if is_empty {
337 break;
338 }
339 if tx.send(Ok(batch)).await.is_err() {
340 return;
341 }
342 }
343 });
344 });
345}
346
347fn rows_to_batch(
348 rows: &mut Rows,
349 table_schema: &SchemaRef,
350 columns: &[OwnedColumn],
351 projection: Option<&Vec<usize>>,
352 chunk_size: usize,
353) -> DFResult<(RecordBatch, bool)> {
354 let projected_schema = project_schema(table_schema, projection)?;
355 let mut array_builders = vec![];
356 for field in table_schema.fields() {
357 let builder = make_builder(field.data_type(), 1000);
358 array_builders.push(builder);
359 }
360
361 let mut is_empty = true;
362 let mut row_count = 0;
363 while let Some(row) = rows.next().map_err(|e| {
364 DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
365 })? {
366 is_empty = false;
367 row_count += 1;
368 append_rows_to_array_builders(
369 row,
370 table_schema,
371 columns,
372 projection,
373 array_builders.as_mut_slice(),
374 )?;
375 if row_count >= chunk_size {
376 break;
377 }
378 }
379
380 let projected_columns = array_builders
381 .into_iter()
382 .enumerate()
383 .filter(|(idx, _)| projections_contains(projection, *idx))
384 .map(|(_, mut builder)| builder.finish())
385 .collect::<Vec<ArrayRef>>();
386 let options = RecordBatchOptions::new().with_row_count(Some(row_count));
387 Ok((
388 RecordBatch::try_new_with_options(projected_schema, projected_columns, &options)?,
389 is_empty,
390 ))
391}
392
393macro_rules! handle_primitive_type {
394 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
395 let builder = $builder
396 .as_any_mut()
397 .downcast_mut::<$builder_ty>()
398 .unwrap_or_else(|| {
399 panic!(
400 "Failed to downcast builder to {} for {:?} and {:?}",
401 stringify!($builder_ty),
402 $field,
403 $col
404 )
405 });
406
407 let v: Option<$value_ty> = $row.get($index).map_err(|e| {
408 DataFusionError::Execution(format!(
409 "Failed to get optional {} value for {:?} and {:?}: {e:?}",
410 stringify!($value_ty),
411 $field,
412 $col
413 ))
414 })?;
415
416 match v {
417 Some(v) => builder.append_value(v),
418 None => builder.append_null(),
419 }
420 }};
421}
422
423fn append_rows_to_array_builders(
424 row: &Row,
425 table_schema: &SchemaRef,
426 columns: &[OwnedColumn],
427 projection: Option<&Vec<usize>>,
428 array_builders: &mut [Box<dyn ArrayBuilder>],
429) -> DFResult<()> {
430 for (idx, field) in table_schema.fields.iter().enumerate() {
431 if !projections_contains(projection, idx) {
432 continue;
433 }
434 let builder = &mut array_builders[idx];
435 let col = columns.get(idx);
436 match field.data_type() {
437 DataType::Null => {
438 let builder = builder
439 .as_any_mut()
440 .downcast_mut::<NullBuilder>()
441 .expect("Failed to downcast builder to NullBuilder");
442 builder.append_null();
443 }
444 DataType::Int32 => {
445 handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx);
446 }
447 DataType::Int64 => {
448 handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx);
449 }
450 DataType::Float64 => {
451 handle_primitive_type!(builder, field, col, Float64Builder, f64, row, idx);
452 }
453 DataType::Utf8 => {
454 handle_primitive_type!(builder, field, col, StringBuilder, String, row, idx);
455 }
456 DataType::Binary => {
457 handle_primitive_type!(builder, field, col, BinaryBuilder, Vec<u8>, row, idx);
458 }
459 _ => {
460 return Err(DataFusionError::NotImplemented(format!(
461 "Unsupported data type {} for col: {:?}",
462 field.data_type(),
463 col
464 )));
465 }
466 }
467 }
468 Ok(())
469}