1use crate::connection::{RemoteDbType, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, Literalize, Pool, RemoteField, RemoteSchema,
4 RemoteSchemaRef, RemoteSource, RemoteType, SqliteConnectionOptions, SqliteType,
5 literalize_array,
6};
7use datafusion::arrow::array::{
8 ArrayBuilder, ArrayRef, BinaryBuilder, Float64Builder, Int32Builder, Int64Builder, NullBuilder,
9 RecordBatch, RecordBatchOptions, StringBuilder, make_builder,
10};
11use datafusion::arrow::datatypes::{DataType, SchemaRef};
12use datafusion::common::{DataFusionError, project_schema};
13use datafusion::execution::SendableRecordBatchStream;
14use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
15use futures::StreamExt;
16use itertools::Itertools;
17use log::{debug, error};
18use rusqlite::types::ValueRef;
19use rusqlite::{Column, Row, Rows};
20use std::any::Any;
21use std::collections::HashMap;
22use std::path::PathBuf;
23use std::sync::Arc;
24
25#[derive(Debug)]
26pub struct SqlitePool {
27 path: PathBuf,
28}
29
30pub async fn connect_sqlite(options: &SqliteConnectionOptions) -> DFResult<SqlitePool> {
31 let _ = rusqlite::Connection::open(&options.path).map_err(|e| {
32 DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
33 })?;
34 Ok(SqlitePool {
35 path: options.path.clone(),
36 })
37}
38
39#[async_trait::async_trait]
40impl Pool for SqlitePool {
41 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
42 Ok(Arc::new(SqliteConnection {
43 path: self.path.clone(),
44 }))
45 }
46}
47
48#[derive(Debug)]
49pub struct SqliteConnection {
50 path: PathBuf,
51}
52
53#[async_trait::async_trait]
54impl Connection for SqliteConnection {
55 fn as_any(&self) -> &dyn Any {
56 self
57 }
58
59 async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef> {
60 let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
61 DataFusionError::Plan(format!("Failed to open sqlite connection: {e:?}"))
62 })?;
63 match source {
64 RemoteSource::Table(table) => {
65 let sql = format!(
67 "PRAGMA table_info({})",
68 RemoteDbType::Sqlite.sql_table_name(table)
69 );
70 let mut stmt = conn.prepare(&sql).map_err(|e| {
71 DataFusionError::Plan(format!("Failed to prepare sqlite statement: {e:?}"))
72 })?;
73 let rows = stmt.query([]).map_err(|e| {
74 DataFusionError::Plan(format!("Failed to query sqlite statement: {e:?}"))
75 })?;
76 let remote_schema = Arc::new(build_remote_schema_for_table(rows)?);
77 Ok(remote_schema)
78 }
79 RemoteSource::Query(_query) => {
80 let sql = RemoteDbType::Sqlite.limit_1_query_if_possible(source);
81 let mut stmt = conn.prepare(&sql).map_err(|e| {
82 DataFusionError::Plan(format!("Failed to prepare sqlite statement: {e:?}"))
83 })?;
84 let columns: Vec<OwnedColumn> =
85 stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
86 let rows = stmt.query([]).map_err(|e| {
87 DataFusionError::Plan(format!("Failed to query sqlite statement: {e:?}"))
88 })?;
89
90 let remote_schema =
91 Arc::new(build_remote_schema_for_query(columns.as_slice(), rows)?);
92 Ok(remote_schema)
93 }
94 }
95 }
96
97 async fn query(
98 &self,
99 conn_options: &ConnectionOptions,
100 source: &RemoteSource,
101 table_schema: SchemaRef,
102 projection: Option<&Vec<usize>>,
103 unparsed_filters: &[String],
104 limit: Option<usize>,
105 ) -> DFResult<SendableRecordBatchStream> {
106 let projected_schema = project_schema(&table_schema, projection)?;
107 let sql = RemoteDbType::Sqlite.rewrite_query(source, unparsed_filters, limit);
108 debug!("[remote-table] executing sqlite query: {sql}");
109
110 let (tx, mut rx) = tokio::sync::mpsc::channel::<DFResult<RecordBatch>>(1);
111 let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
112 DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
113 })?;
114
115 let projection = projection.cloned();
116 let chunk_size = conn_options.stream_chunk_size();
117
118 spawn_background_task(tx, conn, sql, table_schema, projection, chunk_size);
119
120 let stream = async_stream::stream! {
121 while let Some(batch) = rx.recv().await {
122 yield batch;
123 }
124 };
125 Ok(Box::pin(RecordBatchStreamAdapter::new(
126 projected_schema,
127 stream,
128 )))
129 }
130
131 async fn insert(
132 &self,
133 _conn_options: &ConnectionOptions,
134 literalizer: Arc<dyn Literalize>,
135 table: &[String],
136 remote_schema: RemoteSchemaRef,
137 mut input: SendableRecordBatchStream,
138 ) -> DFResult<usize> {
139 let input_schema = input.schema();
140 let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
141 DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
142 })?;
143
144 let mut total_count = 0;
145 while let Some(batch) = input.next().await {
146 let batch = batch?;
147
148 let mut columns = Vec::with_capacity(remote_schema.fields.len());
149 for i in 0..batch.num_columns() {
150 let input_field = input_schema.field(i);
151 let remote_field = &remote_schema.fields[i];
152 if remote_field.auto_increment && input_field.is_nullable() {
153 continue;
154 }
155
156 let remote_type = remote_schema.fields[i].remote_type.clone();
157 let array = batch.column(i);
158 let column = literalize_array(literalizer.as_ref(), array, remote_type)?;
159 columns.push(column);
160 }
161
162 let num_rows = columns[0].len();
163 let num_columns = columns.len();
164
165 let mut values = Vec::with_capacity(num_rows);
166 for i in 0..num_rows {
167 let mut value = Vec::with_capacity(num_columns);
168 for col in columns.iter() {
169 value.push(col[i].as_str());
170 }
171 values.push(format!("({})", value.join(",")));
172 }
173
174 let mut col_names = Vec::with_capacity(remote_schema.fields.len());
175 for (remote_field, input_field) in
176 remote_schema.fields.iter().zip(input_schema.fields.iter())
177 {
178 if remote_field.auto_increment && input_field.is_nullable() {
179 continue;
180 }
181 col_names.push(RemoteDbType::Sqlite.sql_identifier(&remote_field.name));
182 }
183
184 let sql = format!(
185 "INSERT INTO {} ({}) VALUES {}",
186 RemoteDbType::Sqlite.sql_table_name(table),
187 col_names.join(","),
188 values.join(",")
189 );
190
191 let count = conn.execute(&sql, []).map_err(|e| {
192 DataFusionError::Execution(format!(
193 "Failed to execute insert statement on sqlite: {e:?}, sql: {sql}"
194 ))
195 })?;
196 total_count += count;
197 }
198
199 Ok(total_count)
200 }
201}
202
203#[derive(Debug)]
204struct OwnedColumn {
205 name: String,
206 decl_type: Option<String>,
207}
208
209fn sqlite_col_to_owned_col(sqlite_col: &Column) -> OwnedColumn {
210 OwnedColumn {
211 name: sqlite_col.name().to_string(),
212 decl_type: sqlite_col.decl_type().map(|x| x.to_string()),
213 }
214}
215
216fn decl_type_to_remote_type(decl_type: &str) -> DFResult<SqliteType> {
217 if [
218 "tinyint", "smallint", "int", "integer", "bigint", "int2", "int4", "int8",
219 ]
220 .contains(&decl_type)
221 {
222 return Ok(SqliteType::Integer);
223 }
224 if ["real", "float", "double", "numeric"].contains(&decl_type) {
225 return Ok(SqliteType::Real);
226 }
227 if decl_type.starts_with("real") || decl_type.starts_with("numeric") {
228 return Ok(SqliteType::Real);
229 }
230 if ["text", "varchar", "char", "string"].contains(&decl_type) {
231 return Ok(SqliteType::Text);
232 }
233 if decl_type.starts_with("char")
234 || decl_type.starts_with("varchar")
235 || decl_type.starts_with("text")
236 {
237 return Ok(SqliteType::Text);
238 }
239 if ["binary", "varbinary", "tinyblob", "blob"].contains(&decl_type) {
240 return Ok(SqliteType::Blob);
241 }
242 if decl_type.starts_with("binary") || decl_type.starts_with("varbinary") {
243 return Ok(SqliteType::Blob);
244 }
245 Err(DataFusionError::NotImplemented(format!(
246 "Unsupported sqlite decl type: {decl_type}",
247 )))
248}
249
250fn build_remote_schema_for_table(mut rows: Rows) -> DFResult<RemoteSchema> {
251 let mut remote_fields = vec![];
252 while let Some(row) = rows.next().map_err(|e| {
253 DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
254 })? {
255 let name = row.get::<_, String>(1).map_err(|e| {
256 DataFusionError::Execution(format!("Failed to get col name from sqlite row: {e:?}"))
257 })?;
258 let decl_type = row.get::<_, String>(2).map_err(|e| {
259 DataFusionError::Execution(format!("Failed to get decl type from sqlite row: {e:?}"))
260 })?;
261 let remote_type = decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?;
262 let nullable = row.get::<_, i64>(3).map_err(|e| {
263 DataFusionError::Execution(format!("Failed to get nullable from sqlite row: {e:?}"))
264 })? == 0;
265 remote_fields.push(RemoteField::new(
266 &name,
267 RemoteType::Sqlite(remote_type),
268 nullable,
269 ));
270 }
271 Ok(RemoteSchema::new(remote_fields))
272}
273
274fn build_remote_schema_for_query(
275 columns: &[OwnedColumn],
276 mut rows: Rows,
277) -> DFResult<RemoteSchema> {
278 let mut remote_field_map = HashMap::with_capacity(columns.len());
279 let mut unknown_cols = vec![];
280 for (col_idx, col) in columns.iter().enumerate() {
281 if let Some(decl_type) = &col.decl_type {
282 let remote_type =
283 RemoteType::Sqlite(decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?);
284 remote_field_map.insert(col_idx, RemoteField::new(&col.name, remote_type, true));
285 } else {
286 unknown_cols.push(col_idx);
288 }
289 }
290
291 if !unknown_cols.is_empty() {
292 while let Some(row) = rows.next().map_err(|e| {
293 DataFusionError::Plan(format!("Failed to get next row from sqlite: {e:?}"))
294 })? {
295 let mut to_be_removed = vec![];
296 for col_idx in unknown_cols.iter() {
297 let value_ref = row.get_ref(*col_idx).map_err(|e| {
298 DataFusionError::Plan(format!(
299 "Failed to get value ref for column {col_idx}: {e:?}"
300 ))
301 })?;
302 match value_ref {
303 ValueRef::Null => {}
304 ValueRef::Integer(_) => {
305 remote_field_map.insert(
306 *col_idx,
307 RemoteField::new(
308 columns[*col_idx].name.clone(),
309 RemoteType::Sqlite(SqliteType::Integer),
310 true,
311 ),
312 );
313 to_be_removed.push(*col_idx);
314 }
315 ValueRef::Real(_) => {
316 remote_field_map.insert(
317 *col_idx,
318 RemoteField::new(
319 columns[*col_idx].name.clone(),
320 RemoteType::Sqlite(SqliteType::Real),
321 true,
322 ),
323 );
324 to_be_removed.push(*col_idx);
325 }
326 ValueRef::Text(_) => {
327 remote_field_map.insert(
328 *col_idx,
329 RemoteField::new(
330 columns[*col_idx].name.clone(),
331 RemoteType::Sqlite(SqliteType::Text),
332 true,
333 ),
334 );
335 to_be_removed.push(*col_idx);
336 }
337 ValueRef::Blob(_) => {
338 remote_field_map.insert(
339 *col_idx,
340 RemoteField::new(
341 columns[*col_idx].name.clone(),
342 RemoteType::Sqlite(SqliteType::Blob),
343 true,
344 ),
345 );
346 to_be_removed.push(*col_idx);
347 }
348 }
349 }
350 for col_idx in to_be_removed.iter() {
351 unknown_cols.retain(|&x| x != *col_idx);
352 }
353 if unknown_cols.is_empty() {
354 break;
355 }
356 }
357 }
358
359 if !unknown_cols.is_empty() {
360 return Err(DataFusionError::Plan(format!(
361 "Failed to infer sqlite decl type for columns: {}",
362 unknown_cols
363 .iter()
364 .map(|idx| &columns[*idx].name)
365 .join(", ")
366 )));
367 }
368 let remote_fields = remote_field_map
369 .into_iter()
370 .sorted_by_key(|entry| entry.0)
371 .map(|entry| entry.1)
372 .collect::<Vec<_>>();
373 Ok(RemoteSchema::new(remote_fields))
374}
375
376fn spawn_background_task(
377 tx: tokio::sync::mpsc::Sender<DFResult<RecordBatch>>,
378 conn: rusqlite::Connection,
379 sql: String,
380 table_schema: SchemaRef,
381 projection: Option<Vec<usize>>,
382 chunk_size: usize,
383) {
384 std::thread::spawn(move || {
385 let runtime = match tokio::runtime::Builder::new_current_thread().build() {
386 Ok(runtime) => runtime,
387 Err(e) => {
388 error!("Failed to create tokio runtime to run sqlite query: {e:?}");
389 return;
390 }
391 };
392 let local_set = tokio::task::LocalSet::new();
393 local_set.block_on(&runtime, async move {
394 let mut stmt = match conn.prepare(&sql) {
395 Ok(stmt) => stmt,
396 Err(e) => {
397 let _ = tx
398 .send(Err(DataFusionError::Execution(format!(
399 "Failed to prepare sqlite statement: {e:?}"
400 ))))
401 .await;
402 return;
403 }
404 };
405 let columns: Vec<OwnedColumn> =
406 stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
407 let mut rows = match stmt.query([]) {
408 Ok(rows) => rows,
409 Err(e) => {
410 let _ = tx
411 .send(Err(DataFusionError::Execution(format!(
412 "Failed to query sqlite statement: {e:?}"
413 ))))
414 .await;
415 return;
416 }
417 };
418
419 loop {
420 let (batch, is_empty) = match rows_to_batch(
421 &mut rows,
422 &table_schema,
423 &columns,
424 projection.as_ref(),
425 chunk_size,
426 ) {
427 Ok((batch, is_empty)) => (batch, is_empty),
428 Err(e) => {
429 let _ = tx
430 .send(Err(DataFusionError::Execution(format!(
431 "Failed to convert rows to batch: {e:?}"
432 ))))
433 .await;
434 return;
435 }
436 };
437 if is_empty {
438 break;
439 }
440 if tx.send(Ok(batch)).await.is_err() {
441 return;
442 }
443 }
444 });
445 });
446}
447
448fn rows_to_batch(
449 rows: &mut Rows,
450 table_schema: &SchemaRef,
451 columns: &[OwnedColumn],
452 projection: Option<&Vec<usize>>,
453 chunk_size: usize,
454) -> DFResult<(RecordBatch, bool)> {
455 let projected_schema = project_schema(table_schema, projection)?;
456 let mut array_builders = vec![];
457 for field in table_schema.fields() {
458 let builder = make_builder(field.data_type(), 1000);
459 array_builders.push(builder);
460 }
461
462 let mut is_empty = true;
463 let mut row_count = 0;
464 while let Some(row) = rows.next().map_err(|e| {
465 DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
466 })? {
467 is_empty = false;
468 row_count += 1;
469 append_rows_to_array_builders(
470 row,
471 table_schema,
472 columns,
473 projection,
474 array_builders.as_mut_slice(),
475 )?;
476 if row_count >= chunk_size {
477 break;
478 }
479 }
480
481 let projected_columns = array_builders
482 .into_iter()
483 .enumerate()
484 .filter(|(idx, _)| projections_contains(projection, *idx))
485 .map(|(_, mut builder)| builder.finish())
486 .collect::<Vec<ArrayRef>>();
487 let options = RecordBatchOptions::new().with_row_count(Some(row_count));
488 Ok((
489 RecordBatch::try_new_with_options(projected_schema, projected_columns, &options)?,
490 is_empty,
491 ))
492}
493
494macro_rules! handle_primitive_type {
495 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
496 let builder = $builder
497 .as_any_mut()
498 .downcast_mut::<$builder_ty>()
499 .unwrap_or_else(|| {
500 panic!(
501 "Failed to downcast builder to {} for {:?} and {:?}",
502 stringify!($builder_ty),
503 $field,
504 $col
505 )
506 });
507
508 let v: Option<$value_ty> = $row.get($index).map_err(|e| {
509 DataFusionError::Execution(format!(
510 "Failed to get optional {} value for {:?} and {:?}: {e:?}",
511 stringify!($value_ty),
512 $field,
513 $col
514 ))
515 })?;
516
517 match v {
518 Some(v) => builder.append_value(v),
519 None => builder.append_null(),
520 }
521 }};
522}
523
524fn append_rows_to_array_builders(
525 row: &Row,
526 table_schema: &SchemaRef,
527 columns: &[OwnedColumn],
528 projection: Option<&Vec<usize>>,
529 array_builders: &mut [Box<dyn ArrayBuilder>],
530) -> DFResult<()> {
531 for (idx, field) in table_schema.fields.iter().enumerate() {
532 if !projections_contains(projection, idx) {
533 continue;
534 }
535 let builder = &mut array_builders[idx];
536 let col = columns.get(idx);
537 match field.data_type() {
538 DataType::Null => {
539 let builder = builder
540 .as_any_mut()
541 .downcast_mut::<NullBuilder>()
542 .expect("Failed to downcast builder to NullBuilder");
543 builder.append_null();
544 }
545 DataType::Int32 => {
546 handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx);
547 }
548 DataType::Int64 => {
549 handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx);
550 }
551 DataType::Float64 => {
552 handle_primitive_type!(builder, field, col, Float64Builder, f64, row, idx);
553 }
554 DataType::Utf8 => {
555 handle_primitive_type!(builder, field, col, StringBuilder, String, row, idx);
556 }
557 DataType::Binary => {
558 handle_primitive_type!(builder, field, col, BinaryBuilder, Vec<u8>, row, idx);
559 }
560 _ => {
561 return Err(DataFusionError::NotImplemented(format!(
562 "Unsupported data type {} for col: {:?}",
563 field.data_type(),
564 col
565 )));
566 }
567 }
568 }
569 Ok(())
570}