1use crate::connection::{RemoteDbType, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, Pool, RemoteField, RemoteSchema, RemoteSchemaRef,
4 RemoteType, SqliteType,
5};
6use datafusion::arrow::array::{
7 ArrayBuilder, ArrayRef, BinaryBuilder, Float64Builder, Int32Builder, Int64Builder, NullBuilder,
8 RecordBatch, RecordBatchOptions, StringBuilder, make_builder,
9};
10use datafusion::arrow::datatypes::{DataType, SchemaRef};
11use datafusion::common::{DataFusionError, project_schema};
12use datafusion::execution::SendableRecordBatchStream;
13use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
14use derive_getters::Getters;
15use derive_with::With;
16use itertools::Itertools;
17use log::debug;
18use rusqlite::types::ValueRef;
19use rusqlite::{Column, Row, Rows};
20use std::any::Any;
21use std::collections::HashMap;
22use std::path::PathBuf;
23use std::sync::Arc;
24
25#[derive(Debug, Clone, With, Getters)]
26pub struct SqliteConnectionOptions {
27 pub path: PathBuf,
28 pub stream_chunk_size: usize,
29}
30
31impl SqliteConnectionOptions {
32 pub fn new(path: PathBuf) -> Self {
33 Self {
34 path,
35 stream_chunk_size: 2048,
36 }
37 }
38}
39
40impl From<SqliteConnectionOptions> for ConnectionOptions {
41 fn from(options: SqliteConnectionOptions) -> Self {
42 ConnectionOptions::Sqlite(options)
43 }
44}
45
46#[derive(Debug)]
47pub struct SqlitePool {
48 pool: tokio_rusqlite::Connection,
49}
50
51pub async fn connect_sqlite(options: &SqliteConnectionOptions) -> DFResult<SqlitePool> {
52 let pool = tokio_rusqlite::Connection::open(&options.path)
53 .await
54 .map_err(|e| {
55 DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
56 })?;
57 Ok(SqlitePool { pool })
58}
59
60#[async_trait::async_trait]
61impl Pool for SqlitePool {
62 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
63 let conn = self.pool.clone();
64 Ok(Arc::new(SqliteConnection { conn }))
65 }
66}
67
68#[derive(Debug)]
69pub struct SqliteConnection {
70 conn: tokio_rusqlite::Connection,
71}
72
73#[async_trait::async_trait]
74impl Connection for SqliteConnection {
75 fn as_any(&self) -> &dyn Any {
76 self
77 }
78
79 async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
80 let sql = RemoteDbType::Sqlite.query_limit_1(sql);
81 self.conn
82 .call(move |conn| {
83 let mut stmt = conn.prepare(&sql)?;
84 let columns: Vec<OwnedColumn> =
85 stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
86 let rows = stmt.query([])?;
87
88 let remote_schema = Arc::new(
89 build_remote_schema(columns.as_slice(), rows)
90 .map_err(|e| tokio_rusqlite::Error::Other(Box::new(e)))?,
91 );
92 Ok(remote_schema)
93 })
94 .await
95 .map_err(|e| DataFusionError::Execution(format!("Failed to infer schema: {e:?}")))
96 }
97
98 async fn query(
99 &self,
100 conn_options: &ConnectionOptions,
101 sql: &str,
102 table_schema: SchemaRef,
103 projection: Option<&Vec<usize>>,
104 unparsed_filters: &[String],
105 limit: Option<usize>,
106 ) -> DFResult<SendableRecordBatchStream> {
107 let projected_schema = project_schema(&table_schema, projection)?;
108 let sql = RemoteDbType::Sqlite.rewrite_query(sql, unparsed_filters, limit);
109 debug!("[remote-table] executing sqlite query: {sql}");
110
111 let conn = self.conn.clone();
112 let projection = projection.cloned();
113 let limit = conn_options.stream_chunk_size();
114 let stream = async_stream::stream! {
115 let mut offset = 0;
116 loop {
117 let sql = format!("SELECT * FROM ({sql}) LIMIT {limit} OFFSET {offset}");
118 let sql_clone = sql.clone();
119 let conn = conn.clone();
120 let projection = projection.clone();
121 let table_schema = table_schema.clone();
122 let (batch, is_empty) = conn
123 .call(move |conn| {
124 let mut stmt = conn.prepare(&sql)?;
125 let columns: Vec<OwnedColumn> =
126 stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
127 let rows = stmt.query([])?;
128
129 rows_to_batch(rows, &table_schema, columns, projection.as_ref())
130 .map_err(|e| tokio_rusqlite::Error::Other(e.into()))
131 })
132 .await
133 .map_err(|e| {
134 DataFusionError::Execution(format!(
135 "Failed to execute query {sql_clone} on sqlite: {e:?}"
136 ))
137 })?;
138 if is_empty {
139 break;
140 }
141 yield Ok(batch);
142 offset += limit;
143 }
144 };
145 Ok(Box::pin(RecordBatchStreamAdapter::new(
146 projected_schema,
147 stream,
148 )))
149 }
150}
151
152#[derive(Debug)]
153struct OwnedColumn {
154 name: String,
155 decl_type: Option<String>,
156}
157
158fn sqlite_col_to_owned_col(sqlite_col: &Column) -> OwnedColumn {
159 OwnedColumn {
160 name: sqlite_col.name().to_string(),
161 decl_type: sqlite_col.decl_type().map(|x| x.to_string()),
162 }
163}
164
165fn decl_type_to_remote_type(decl_type: &str) -> DFResult<SqliteType> {
166 if ["tinyint", "smallint", "int", "integer", "bigint"].contains(&decl_type) {
167 return Ok(SqliteType::Integer);
168 }
169 if ["real", "float", "double"].contains(&decl_type) {
170 return Ok(SqliteType::Real);
171 }
172 if decl_type.starts_with("real") {
173 return Ok(SqliteType::Real);
174 }
175 if ["text", "varchar", "char", "string"].contains(&decl_type) {
176 return Ok(SqliteType::Text);
177 }
178 if decl_type.starts_with("char")
179 || decl_type.starts_with("varchar")
180 || decl_type.starts_with("text")
181 {
182 return Ok(SqliteType::Text);
183 }
184 if ["binary", "varbinary", "tinyblob", "blob"].contains(&decl_type) {
185 return Ok(SqliteType::Blob);
186 }
187 if decl_type.starts_with("binary") || decl_type.starts_with("varbinary") {
188 return Ok(SqliteType::Blob);
189 }
190 Err(DataFusionError::NotImplemented(format!(
191 "Unsupported sqlite decl type: {decl_type}",
192 )))
193}
194
195fn build_remote_schema(columns: &[OwnedColumn], mut rows: Rows) -> DFResult<RemoteSchema> {
196 let mut remote_field_map = HashMap::with_capacity(columns.len());
197 let mut unknown_cols = vec![];
198 for (col_idx, col) in columns.iter().enumerate() {
199 if let Some(decl_type) = &col.decl_type {
200 let remote_type =
201 RemoteType::Sqlite(decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?);
202 remote_field_map.insert(col_idx, RemoteField::new(&col.name, remote_type, true));
203 } else {
204 unknown_cols.push(col_idx);
205 }
206 }
207
208 if !unknown_cols.is_empty() {
209 while let Some(row) = rows.next().map_err(|e| {
210 DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
211 })? {
212 let mut to_be_removed = vec![];
213 for col_idx in unknown_cols.iter() {
214 let value_ref = row.get_ref(*col_idx).map_err(|e| {
215 DataFusionError::Execution(format!(
216 "Failed to get value ref for column {col_idx}: {e:?}"
217 ))
218 })?;
219 match value_ref {
220 ValueRef::Null => {}
221 ValueRef::Integer(_) => {
222 remote_field_map.insert(
223 *col_idx,
224 RemoteField::new(
225 columns[*col_idx].name.clone(),
226 RemoteType::Sqlite(SqliteType::Integer),
227 true,
228 ),
229 );
230 to_be_removed.push(*col_idx);
231 }
232 ValueRef::Real(_) => {
233 remote_field_map.insert(
234 *col_idx,
235 RemoteField::new(
236 columns[*col_idx].name.clone(),
237 RemoteType::Sqlite(SqliteType::Real),
238 true,
239 ),
240 );
241 to_be_removed.push(*col_idx);
242 }
243 ValueRef::Text(_) => {
244 remote_field_map.insert(
245 *col_idx,
246 RemoteField::new(
247 columns[*col_idx].name.clone(),
248 RemoteType::Sqlite(SqliteType::Text),
249 true,
250 ),
251 );
252 to_be_removed.push(*col_idx);
253 }
254 ValueRef::Blob(_) => {
255 remote_field_map.insert(
256 *col_idx,
257 RemoteField::new(
258 columns[*col_idx].name.clone(),
259 RemoteType::Sqlite(SqliteType::Blob),
260 true,
261 ),
262 );
263 to_be_removed.push(*col_idx);
264 }
265 }
266 }
267 for col_idx in to_be_removed.iter() {
268 unknown_cols.retain(|&x| x != *col_idx);
269 }
270 if unknown_cols.is_empty() {
271 break;
272 }
273 }
274 }
275
276 if !unknown_cols.is_empty() {
277 return Err(DataFusionError::NotImplemented(format!(
278 "Failed to infer sqlite decl type for columns: {unknown_cols:?}"
279 )));
280 }
281 let remote_fields = remote_field_map
282 .into_iter()
283 .sorted_by_key(|entry| entry.0)
284 .map(|entry| entry.1)
285 .collect::<Vec<_>>();
286 Ok(RemoteSchema::new(remote_fields))
287}
288
289fn rows_to_batch(
290 mut rows: Rows,
291 table_schema: &SchemaRef,
292 columns: Vec<OwnedColumn>,
293 projection: Option<&Vec<usize>>,
294) -> DFResult<(RecordBatch, bool)> {
295 let projected_schema = project_schema(table_schema, projection)?;
296 let mut array_builders = vec![];
297 for field in table_schema.fields() {
298 let builder = make_builder(field.data_type(), 1000);
299 array_builders.push(builder);
300 }
301
302 let mut is_empty = true;
303 let mut row_count = 0;
304 while let Some(row) = rows.next().map_err(|e| {
305 DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
306 })? {
307 is_empty = false;
308 row_count += 1;
309 append_rows_to_array_builders(
310 row,
311 table_schema,
312 &columns,
313 projection,
314 array_builders.as_mut_slice(),
315 )?;
316 }
317
318 let projected_columns = array_builders
319 .into_iter()
320 .enumerate()
321 .filter(|(idx, _)| projections_contains(projection, *idx))
322 .map(|(_, mut builder)| builder.finish())
323 .collect::<Vec<ArrayRef>>();
324 let options = RecordBatchOptions::new().with_row_count(Some(row_count));
325 Ok((
326 RecordBatch::try_new_with_options(projected_schema, projected_columns, &options)?,
327 is_empty,
328 ))
329}
330
331macro_rules! handle_primitive_type {
332 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
333 let builder = $builder
334 .as_any_mut()
335 .downcast_mut::<$builder_ty>()
336 .unwrap_or_else(|| {
337 panic!(
338 "Failed to downcast builder to {} for {:?} and {:?}",
339 stringify!($builder_ty),
340 $field,
341 $col
342 )
343 });
344
345 let v: Option<$value_ty> = $row.get($index).map_err(|e| {
346 DataFusionError::Execution(format!(
347 "Failed to get optional {} value for {:?} and {:?}: {e:?}",
348 stringify!($value_ty),
349 $field,
350 $col
351 ))
352 })?;
353
354 match v {
355 Some(v) => builder.append_value(v),
356 None => builder.append_null(),
357 }
358 }};
359}
360
361fn append_rows_to_array_builders(
362 row: &Row,
363 table_schema: &SchemaRef,
364 columns: &[OwnedColumn],
365 projection: Option<&Vec<usize>>,
366 array_builders: &mut [Box<dyn ArrayBuilder>],
367) -> DFResult<()> {
368 for (idx, field) in table_schema.fields.iter().enumerate() {
369 if !projections_contains(projection, idx) {
370 continue;
371 }
372 let builder = &mut array_builders[idx];
373 let col = columns.get(idx);
374 match field.data_type() {
375 DataType::Null => {
376 let builder = builder
377 .as_any_mut()
378 .downcast_mut::<NullBuilder>()
379 .expect("Failed to downcast builder to NullBuilder");
380 builder.append_null();
381 }
382 DataType::Int32 => {
383 handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx);
384 }
385 DataType::Int64 => {
386 handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx);
387 }
388 DataType::Float64 => {
389 handle_primitive_type!(builder, field, col, Float64Builder, f64, row, idx);
390 }
391 DataType::Utf8 => {
392 handle_primitive_type!(builder, field, col, StringBuilder, String, row, idx);
393 }
394 DataType::Binary => {
395 handle_primitive_type!(builder, field, col, BinaryBuilder, Vec<u8>, row, idx);
396 }
397 _ => {
398 return Err(DataFusionError::NotImplemented(format!(
399 "Unsupported data type {:?} for col: {:?}",
400 field.data_type(),
401 col
402 )));
403 }
404 }
405 }
406 Ok(())
407}