1use crate::connection::{RemoteDbType, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, Literalize, Pool, PoolState, RemoteField,
4 RemoteSchema, RemoteSchemaRef, RemoteSource, RemoteType, SqliteConnectionOptions, SqliteType,
5 literalize_array,
6};
7use arrow::array::{
8 ArrayBuilder, ArrayRef, BinaryBuilder, BinaryViewBuilder, Float64Builder, Int32Builder,
9 Int64Builder, NullBuilder, RecordBatch, RecordBatchOptions, StringBuilder, StringViewBuilder,
10 make_builder,
11};
12use arrow::datatypes::{DataType, SchemaRef};
13use datafusion_common::{DataFusionError, project_schema};
14use datafusion_execution::SendableRecordBatchStream;
15use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
16use itertools::Itertools;
17use log::{debug, error};
18use rusqlite::types::ValueRef;
19use rusqlite::{Column, Row, Rows};
20use std::any::Any;
21use std::collections::HashMap;
22use std::path::PathBuf;
23use std::sync::Arc;
24use std::sync::atomic::{AtomicUsize, Ordering};
25
26#[derive(Debug)]
27pub struct SqlitePool {
28 path: PathBuf,
29 connections: Arc<AtomicUsize>,
30}
31
32pub async fn connect_sqlite(options: &SqliteConnectionOptions) -> DFResult<SqlitePool> {
33 let _ = rusqlite::Connection::open(&options.path).map_err(|e| {
34 DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
35 })?;
36 Ok(SqlitePool {
37 path: options.path.clone(),
38 connections: Arc::new(AtomicUsize::new(0)),
39 })
40}
41
42#[async_trait::async_trait]
43impl Pool for SqlitePool {
44 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
45 self.connections.fetch_add(1, Ordering::SeqCst);
46 Ok(Arc::new(SqliteConnection {
47 path: self.path.clone(),
48 pool_connections: self.connections.clone(),
49 }))
50 }
51
52 async fn state(&self) -> DFResult<PoolState> {
53 Ok(PoolState {
54 connections: self.connections.load(Ordering::SeqCst),
55 idle_connections: 0,
56 })
57 }
58}
59
60#[derive(Debug)]
61pub struct SqliteConnection {
62 path: PathBuf,
63 pool_connections: Arc<AtomicUsize>,
64}
65
66impl Drop for SqliteConnection {
67 fn drop(&mut self) {
68 self.pool_connections.fetch_sub(1, Ordering::SeqCst);
69 }
70}
71
72#[async_trait::async_trait]
73impl Connection for SqliteConnection {
74 fn as_any(&self) -> &dyn Any {
75 self
76 }
77
78 async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef> {
79 let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
80 DataFusionError::Plan(format!("Failed to open sqlite connection: {e:?}"))
81 })?;
82 match source {
83 RemoteSource::Table(table) => {
84 let sql = format!(
86 "PRAGMA table_info({})",
87 RemoteDbType::Sqlite.sql_table_name(table)
88 );
89 let mut stmt = conn.prepare(&sql).map_err(|e| {
90 DataFusionError::Plan(format!("Failed to prepare sqlite statement: {e:?}"))
91 })?;
92 let rows = stmt.query([]).map_err(|e| {
93 DataFusionError::Plan(format!("Failed to query sqlite statement: {e:?}"))
94 })?;
95 let remote_schema = Arc::new(build_remote_schema_for_table(rows)?);
96 Ok(remote_schema)
97 }
98 RemoteSource::Query(_query) => {
99 let sql = RemoteDbType::Sqlite.limit_1_query_if_possible(source);
100 let mut stmt = conn.prepare(&sql).map_err(|e| {
101 DataFusionError::Plan(format!("Failed to prepare sqlite statement: {e:?}"))
102 })?;
103 let columns: Vec<OwnedColumn> =
104 stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
105 let rows = stmt.query([]).map_err(|e| {
106 DataFusionError::Plan(format!("Failed to query sqlite statement: {e:?}"))
107 })?;
108
109 let remote_schema =
110 Arc::new(build_remote_schema_for_query(columns.as_slice(), rows)?);
111 Ok(remote_schema)
112 }
113 }
114 }
115
116 async fn query(
117 &self,
118 conn_options: &ConnectionOptions,
119 source: &RemoteSource,
120 table_schema: SchemaRef,
121 projection: Option<&Vec<usize>>,
122 unparsed_filters: &[String],
123 limit: Option<usize>,
124 ) -> DFResult<SendableRecordBatchStream> {
125 let projected_schema = project_schema(&table_schema, projection)?;
126 let sql = RemoteDbType::Sqlite.rewrite_query(source, unparsed_filters, limit);
127 debug!("[remote-table] executing sqlite query: {sql}");
128
129 let (tx, mut rx) = tokio::sync::mpsc::channel::<DFResult<RecordBatch>>(1);
130 let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
131 DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
132 })?;
133
134 let projection = projection.cloned();
135 let chunk_size = conn_options.stream_chunk_size();
136
137 spawn_background_task(tx, conn, sql, table_schema, projection, chunk_size);
138
139 let stream = async_stream::stream! {
140 while let Some(batch) = rx.recv().await {
141 yield batch;
142 }
143 };
144 Ok(Box::pin(RecordBatchStreamAdapter::new(
145 projected_schema,
146 stream,
147 )))
148 }
149
150 async fn insert(
151 &self,
152 _conn_options: &ConnectionOptions,
153 literalizer: Arc<dyn Literalize>,
154 table: &[String],
155 remote_schema: RemoteSchemaRef,
156 batch: RecordBatch,
157 ) -> DFResult<usize> {
158 let conn = rusqlite::Connection::open(&self.path).map_err(|e| {
159 DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
160 })?;
161
162 let mut columns = Vec::with_capacity(remote_schema.fields.len());
163 for i in 0..batch.num_columns() {
164 let input_field = batch.schema_ref().field(i);
165 let remote_field = &remote_schema.fields[i];
166 if remote_field.auto_increment && input_field.is_nullable() {
167 continue;
168 }
169
170 let remote_type = remote_schema.fields[i].remote_type.clone();
171 let array = batch.column(i);
172 let column = literalize_array(literalizer.as_ref(), array, remote_type)?;
173 columns.push(column);
174 }
175
176 let num_rows = columns[0].len();
177 let num_columns = columns.len();
178
179 let mut values = Vec::with_capacity(num_rows);
180 for i in 0..num_rows {
181 let mut value = Vec::with_capacity(num_columns);
182 for col in columns.iter() {
183 value.push(col[i].as_str());
184 }
185 values.push(format!("({})", value.join(",")));
186 }
187
188 let mut col_names = Vec::with_capacity(remote_schema.fields.len());
189 for (remote_field, input_field) in remote_schema
190 .fields
191 .iter()
192 .zip(batch.schema_ref().fields.iter())
193 {
194 if remote_field.auto_increment && input_field.is_nullable() {
195 continue;
196 }
197 col_names.push(RemoteDbType::Sqlite.sql_identifier(&remote_field.name));
198 }
199
200 let sql = format!(
201 "INSERT INTO {} ({}) VALUES {}",
202 RemoteDbType::Sqlite.sql_table_name(table),
203 col_names.join(","),
204 values.join(",")
205 );
206
207 let count = conn.execute(&sql, []).map_err(|e| {
208 DataFusionError::Execution(format!(
209 "Failed to execute insert statement on sqlite: {e:?}, sql: {sql}"
210 ))
211 })?;
212
213 Ok(count)
214 }
215}
216
217#[derive(Debug)]
218struct OwnedColumn {
219 name: String,
220 decl_type: Option<String>,
221}
222
223fn sqlite_col_to_owned_col(sqlite_col: &Column) -> OwnedColumn {
224 OwnedColumn {
225 name: sqlite_col.name().to_string(),
226 decl_type: sqlite_col.decl_type().map(|x| x.to_string()),
227 }
228}
229
230fn decl_type_to_remote_type(decl_type: &str) -> DFResult<SqliteType> {
231 if [
232 "tinyint", "smallint", "int", "integer", "bigint", "int2", "int4", "int8",
233 ]
234 .contains(&decl_type)
235 {
236 return Ok(SqliteType::Integer);
237 }
238 if ["real", "float", "double", "numeric"].contains(&decl_type) {
239 return Ok(SqliteType::Real);
240 }
241 if decl_type.starts_with("real") || decl_type.starts_with("numeric") {
242 return Ok(SqliteType::Real);
243 }
244 if ["text", "varchar", "char", "string"].contains(&decl_type) {
245 return Ok(SqliteType::Text);
246 }
247 if decl_type.starts_with("char")
248 || decl_type.starts_with("varchar")
249 || decl_type.starts_with("text")
250 {
251 return Ok(SqliteType::Text);
252 }
253 if ["binary", "varbinary", "tinyblob", "blob"].contains(&decl_type) {
254 return Ok(SqliteType::Blob);
255 }
256 if decl_type.starts_with("binary") || decl_type.starts_with("varbinary") {
257 return Ok(SqliteType::Blob);
258 }
259 Err(DataFusionError::NotImplemented(format!(
260 "Unsupported sqlite decl type: {decl_type}",
261 )))
262}
263
264fn build_remote_schema_for_table(mut rows: Rows) -> DFResult<RemoteSchema> {
265 let mut remote_fields = vec![];
266 while let Some(row) = rows.next().map_err(|e| {
267 DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
268 })? {
269 let name = row.get::<_, String>(1).map_err(|e| {
270 DataFusionError::Execution(format!("Failed to get col name from sqlite row: {e:?}"))
271 })?;
272 let decl_type = row.get::<_, String>(2).map_err(|e| {
273 DataFusionError::Execution(format!("Failed to get decl type from sqlite row: {e:?}"))
274 })?;
275 let remote_type = decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?;
276 let nullable = row.get::<_, i64>(3).map_err(|e| {
277 DataFusionError::Execution(format!("Failed to get nullable from sqlite row: {e:?}"))
278 })? == 0;
279 remote_fields.push(RemoteField::new(
280 &name,
281 RemoteType::Sqlite(remote_type),
282 nullable,
283 ));
284 }
285 Ok(RemoteSchema::new(remote_fields))
286}
287
288fn build_remote_schema_for_query(
289 columns: &[OwnedColumn],
290 mut rows: Rows,
291) -> DFResult<RemoteSchema> {
292 let mut remote_field_map = HashMap::with_capacity(columns.len());
293 let mut unknown_cols = vec![];
294 for (col_idx, col) in columns.iter().enumerate() {
295 if let Some(decl_type) = &col.decl_type {
296 let remote_type =
297 RemoteType::Sqlite(decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?);
298 remote_field_map.insert(col_idx, RemoteField::new(&col.name, remote_type, true));
299 } else {
300 unknown_cols.push(col_idx);
302 }
303 }
304
305 if !unknown_cols.is_empty() {
306 while let Some(row) = rows.next().map_err(|e| {
307 DataFusionError::Plan(format!("Failed to get next row from sqlite: {e:?}"))
308 })? {
309 let mut to_be_removed = vec![];
310 for col_idx in unknown_cols.iter() {
311 let value_ref = row.get_ref(*col_idx).map_err(|e| {
312 DataFusionError::Plan(format!(
313 "Failed to get value ref for column {col_idx}: {e:?}"
314 ))
315 })?;
316 match value_ref {
317 ValueRef::Null => {}
318 ValueRef::Integer(_) => {
319 remote_field_map.insert(
320 *col_idx,
321 RemoteField::new(
322 columns[*col_idx].name.clone(),
323 RemoteType::Sqlite(SqliteType::Integer),
324 true,
325 ),
326 );
327 to_be_removed.push(*col_idx);
328 }
329 ValueRef::Real(_) => {
330 remote_field_map.insert(
331 *col_idx,
332 RemoteField::new(
333 columns[*col_idx].name.clone(),
334 RemoteType::Sqlite(SqliteType::Real),
335 true,
336 ),
337 );
338 to_be_removed.push(*col_idx);
339 }
340 ValueRef::Text(_) => {
341 remote_field_map.insert(
342 *col_idx,
343 RemoteField::new(
344 columns[*col_idx].name.clone(),
345 RemoteType::Sqlite(SqliteType::Text),
346 true,
347 ),
348 );
349 to_be_removed.push(*col_idx);
350 }
351 ValueRef::Blob(_) => {
352 remote_field_map.insert(
353 *col_idx,
354 RemoteField::new(
355 columns[*col_idx].name.clone(),
356 RemoteType::Sqlite(SqliteType::Blob),
357 true,
358 ),
359 );
360 to_be_removed.push(*col_idx);
361 }
362 }
363 }
364 for col_idx in to_be_removed.iter() {
365 unknown_cols.retain(|&x| x != *col_idx);
366 }
367 if unknown_cols.is_empty() {
368 break;
369 }
370 }
371 }
372
373 if !unknown_cols.is_empty() {
374 return Err(DataFusionError::Plan(format!(
375 "Failed to infer sqlite decl type for columns: {}",
376 unknown_cols
377 .iter()
378 .map(|idx| &columns[*idx].name)
379 .join(", ")
380 )));
381 }
382 let remote_fields = remote_field_map
383 .into_iter()
384 .sorted_by_key(|entry| entry.0)
385 .map(|entry| entry.1)
386 .collect::<Vec<_>>();
387 Ok(RemoteSchema::new(remote_fields))
388}
389
390fn spawn_background_task(
391 tx: tokio::sync::mpsc::Sender<DFResult<RecordBatch>>,
392 conn: rusqlite::Connection,
393 sql: String,
394 table_schema: SchemaRef,
395 projection: Option<Vec<usize>>,
396 chunk_size: usize,
397) {
398 std::thread::spawn(move || {
399 let runtime = match tokio::runtime::Builder::new_current_thread().build() {
400 Ok(runtime) => runtime,
401 Err(e) => {
402 error!("Failed to create tokio runtime to run sqlite query: {e:?}");
403 return;
404 }
405 };
406 let local_set = tokio::task::LocalSet::new();
407 local_set.block_on(&runtime, async move {
408 let mut stmt = match conn.prepare(&sql) {
409 Ok(stmt) => stmt,
410 Err(e) => {
411 let _ = tx
412 .send(Err(DataFusionError::Execution(format!(
413 "Failed to prepare sqlite statement: {e:?}"
414 ))))
415 .await;
416 return;
417 }
418 };
419 let columns: Vec<OwnedColumn> =
420 stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
421 let mut rows = match stmt.query([]) {
422 Ok(rows) => rows,
423 Err(e) => {
424 let _ = tx
425 .send(Err(DataFusionError::Execution(format!(
426 "Failed to query sqlite statement: {e:?}"
427 ))))
428 .await;
429 return;
430 }
431 };
432
433 loop {
434 let (batch, is_empty) = match rows_to_batch(
435 &mut rows,
436 &table_schema,
437 &columns,
438 projection.as_ref(),
439 chunk_size,
440 ) {
441 Ok((batch, is_empty)) => (batch, is_empty),
442 Err(e) => {
443 let _ = tx
444 .send(Err(DataFusionError::Execution(format!(
445 "Failed to convert rows to batch: {e:?}"
446 ))))
447 .await;
448 return;
449 }
450 };
451 if is_empty {
452 break;
453 }
454 if tx.send(Ok(batch)).await.is_err() {
455 return;
456 }
457 }
458 });
459 });
460}
461
462fn rows_to_batch(
463 rows: &mut Rows,
464 table_schema: &SchemaRef,
465 columns: &[OwnedColumn],
466 projection: Option<&Vec<usize>>,
467 chunk_size: usize,
468) -> DFResult<(RecordBatch, bool)> {
469 let projected_schema = project_schema(table_schema, projection)?;
470 let mut array_builders = vec![];
471 for field in table_schema.fields() {
472 let builder = make_builder(field.data_type(), 1000);
473 array_builders.push(builder);
474 }
475
476 let mut is_empty = true;
477 let mut row_count = 0;
478 while let Some(row) = rows.next().map_err(|e| {
479 DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
480 })? {
481 is_empty = false;
482 row_count += 1;
483 append_rows_to_array_builders(
484 row,
485 table_schema,
486 columns,
487 projection,
488 array_builders.as_mut_slice(),
489 )?;
490 if row_count >= chunk_size {
491 break;
492 }
493 }
494
495 let projected_columns = array_builders
496 .into_iter()
497 .enumerate()
498 .filter(|(idx, _)| projections_contains(projection, *idx))
499 .map(|(_, mut builder)| builder.finish())
500 .collect::<Vec<ArrayRef>>();
501 let options = RecordBatchOptions::new().with_row_count(Some(row_count));
502 Ok((
503 RecordBatch::try_new_with_options(projected_schema, projected_columns, &options)?,
504 is_empty,
505 ))
506}
507
508macro_rules! handle_primitive_type {
509 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
510 let builder = $builder
511 .as_any_mut()
512 .downcast_mut::<$builder_ty>()
513 .unwrap_or_else(|| {
514 panic!(
515 "Failed to downcast builder to {} for {:?} and {:?}",
516 stringify!($builder_ty),
517 $field,
518 $col
519 )
520 });
521
522 let v: Option<$value_ty> = $row.get($index).map_err(|e| {
523 DataFusionError::Execution(format!(
524 "Failed to get optional {} value for {:?} and {:?}: {e:?}",
525 stringify!($value_ty),
526 $field,
527 $col
528 ))
529 })?;
530
531 match v {
532 Some(v) => builder.append_value(v),
533 None => builder.append_null(),
534 }
535 }};
536}
537
538fn append_rows_to_array_builders(
539 row: &Row,
540 table_schema: &SchemaRef,
541 columns: &[OwnedColumn],
542 projection: Option<&Vec<usize>>,
543 array_builders: &mut [Box<dyn ArrayBuilder>],
544) -> DFResult<()> {
545 for (idx, field) in table_schema.fields.iter().enumerate() {
546 if !projections_contains(projection, idx) {
547 continue;
548 }
549 let builder = &mut array_builders[idx];
550 let col = columns.get(idx);
551 match field.data_type() {
552 DataType::Null => {
553 let builder = builder
554 .as_any_mut()
555 .downcast_mut::<NullBuilder>()
556 .expect("Failed to downcast builder to NullBuilder");
557 builder.append_null();
558 }
559 DataType::Int32 => {
560 handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx);
561 }
562 DataType::Int64 => {
563 handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx);
564 }
565 DataType::Float64 => {
566 handle_primitive_type!(builder, field, col, Float64Builder, f64, row, idx);
567 }
568 DataType::Utf8 => {
569 handle_primitive_type!(builder, field, col, StringBuilder, String, row, idx);
570 }
571 DataType::Utf8View => {
572 handle_primitive_type!(builder, field, col, StringViewBuilder, String, row, idx);
573 }
574 DataType::Binary => {
575 handle_primitive_type!(builder, field, col, BinaryBuilder, Vec<u8>, row, idx);
576 }
577 DataType::BinaryView => {
578 handle_primitive_type!(builder, field, col, BinaryViewBuilder, Vec<u8>, row, idx);
579 }
580 _ => {
581 return Err(DataFusionError::NotImplemented(format!(
582 "Unsupported data type {} for col: {:?}",
583 field.data_type(),
584 col
585 )));
586 }
587 }
588 }
589 Ok(())
590}