1use crate::connection::{RemoteDbType, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, Pool, RemoteField, RemoteSchema, RemoteSchemaRef,
4 RemoteType, SqliteType,
5};
6use datafusion::arrow::array::{
7 ArrayBuilder, ArrayRef, BinaryBuilder, Float64Builder, Int32Builder, Int64Builder, NullBuilder,
8 RecordBatch, StringBuilder, make_builder,
9};
10use datafusion::arrow::datatypes::{DataType, SchemaRef};
11use datafusion::common::{DataFusionError, project_schema};
12use datafusion::execution::SendableRecordBatchStream;
13use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
14use derive_getters::Getters;
15use derive_with::With;
16use itertools::Itertools;
17use rusqlite::types::ValueRef;
18use rusqlite::{Column, Row, Rows};
19use std::collections::HashMap;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23#[derive(Debug, Clone, With, Getters)]
24pub struct SqliteConnectionOptions {
25 pub path: PathBuf,
26 pub stream_chunk_size: usize,
27}
28
29impl SqliteConnectionOptions {
30 pub fn new(path: PathBuf) -> Self {
31 Self {
32 path,
33 stream_chunk_size: 2048,
34 }
35 }
36}
37
38#[derive(Debug)]
39pub struct SqlitePool {
40 pool: tokio_rusqlite::Connection,
41}
42
43pub async fn connect_sqlite(options: &SqliteConnectionOptions) -> DFResult<SqlitePool> {
44 let pool = tokio_rusqlite::Connection::open(&options.path)
45 .await
46 .map_err(|e| {
47 DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
48 })?;
49 Ok(SqlitePool { pool })
50}
51
52#[async_trait::async_trait]
53impl Pool for SqlitePool {
54 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
55 let conn = self.pool.clone();
56 Ok(Arc::new(SqliteConnection { conn }))
57 }
58}
59
60#[derive(Debug)]
61pub struct SqliteConnection {
62 conn: tokio_rusqlite::Connection,
63}
64
65#[async_trait::async_trait]
66impl Connection for SqliteConnection {
67 async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
68 let sql = RemoteDbType::Sqlite.query_limit_1(sql)?;
69 self.conn
70 .call(move |conn| {
71 let mut stmt = conn.prepare(&sql)?;
72 let columns: Vec<OwnedColumn> =
73 stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
74 let rows = stmt.query([])?;
75
76 let remote_schema = Arc::new(
77 build_remote_schema(columns.as_slice(), rows)
78 .map_err(|e| tokio_rusqlite::Error::Other(Box::new(e)))?,
79 );
80 Ok(remote_schema)
81 })
82 .await
83 .map_err(|e| DataFusionError::Execution(format!("Failed to infer schema: {e:?}")))
84 }
85
86 async fn query(
87 &self,
88 conn_options: &ConnectionOptions,
89 sql: &str,
90 table_schema: SchemaRef,
91 projection: Option<&Vec<usize>>,
92 unparsed_filters: &[String],
93 limit: Option<usize>,
94 ) -> DFResult<SendableRecordBatchStream> {
95 let projected_schema = project_schema(&table_schema, projection)?;
96 let sql = RemoteDbType::Sqlite.try_rewrite_query(sql, unparsed_filters, limit)?;
97 let conn = self.conn.clone();
98 let projection = projection.cloned();
99 let limit = conn_options.stream_chunk_size();
100 let stream = async_stream::stream! {
101 let mut offset = 0;
102 loop {
103 let sql = format!("SELECT * FROM ({sql}) LIMIT {limit} OFFSET {offset}");
104 let sql_clone = sql.clone();
105 let conn = conn.clone();
106 let projection = projection.clone();
107 let table_schema = table_schema.clone();
108 let (batch, is_empty) = conn
109 .call(move |conn| {
110 let mut stmt = conn.prepare(&sql)?;
111 let columns: Vec<OwnedColumn> =
112 stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
113 let rows = stmt.query([])?;
114
115 rows_to_batch(rows, &table_schema, columns, projection.as_ref())
116 .map_err(|e| tokio_rusqlite::Error::Other(e.into()))
117 })
118 .await
119 .map_err(|e| {
120 DataFusionError::Execution(format!(
121 "Failed to execute query {sql_clone} on sqlite: {e:?}"
122 ))
123 })?;
124 if is_empty {
125 break;
126 }
127 yield Ok(batch);
128 offset += limit;
129 }
130 };
131 Ok(Box::pin(RecordBatchStreamAdapter::new(
132 projected_schema,
133 stream,
134 )))
135 }
136}
137
138#[derive(Debug)]
139struct OwnedColumn {
140 name: String,
141 decl_type: Option<String>,
142}
143
144fn sqlite_col_to_owned_col(sqlite_col: &Column) -> OwnedColumn {
145 OwnedColumn {
146 name: sqlite_col.name().to_string(),
147 decl_type: sqlite_col.decl_type().map(|x| x.to_string()),
148 }
149}
150
151fn decl_type_to_remote_type(decl_type: &str) -> DFResult<SqliteType> {
152 if ["tinyint", "smallint", "int", "integer", "bigint"].contains(&decl_type) {
153 return Ok(SqliteType::Integer);
154 }
155 if ["real", "float", "double"].contains(&decl_type) {
156 return Ok(SqliteType::Real);
157 }
158 if decl_type.starts_with("real") {
159 return Ok(SqliteType::Real);
160 }
161 if ["text", "varchar", "char", "string"].contains(&decl_type) {
162 return Ok(SqliteType::Text);
163 }
164 if decl_type.starts_with("char")
165 || decl_type.starts_with("varchar")
166 || decl_type.starts_with("text")
167 {
168 return Ok(SqliteType::Text);
169 }
170 if ["binary", "varbinary", "tinyblob", "blob"].contains(&decl_type) {
171 return Ok(SqliteType::Blob);
172 }
173 if decl_type.starts_with("binary") || decl_type.starts_with("varbinary") {
174 return Ok(SqliteType::Blob);
175 }
176 Err(DataFusionError::NotImplemented(format!(
177 "Unsupported sqlite decl type: {decl_type}",
178 )))
179}
180
181fn build_remote_schema(columns: &[OwnedColumn], mut rows: Rows) -> DFResult<RemoteSchema> {
182 let mut remote_field_map = HashMap::with_capacity(columns.len());
183 let mut unknown_cols = vec![];
184 for (col_idx, col) in columns.iter().enumerate() {
185 if let Some(decl_type) = &col.decl_type {
186 let remote_type =
187 RemoteType::Sqlite(decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?);
188 remote_field_map.insert(col_idx, RemoteField::new(&col.name, remote_type, true));
189 } else {
190 unknown_cols.push(col_idx);
191 }
192 }
193
194 if !unknown_cols.is_empty() {
195 while let Some(row) = rows.next().map_err(|e| {
196 DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
197 })? {
198 let mut to_be_removed = vec![];
199 for col_idx in unknown_cols.iter() {
200 let value_ref = row.get_ref(*col_idx).map_err(|e| {
201 DataFusionError::Execution(format!(
202 "Failed to get value ref for column {col_idx}: {e:?}"
203 ))
204 })?;
205 match value_ref {
206 ValueRef::Null => {}
207 ValueRef::Integer(_) => {
208 remote_field_map.insert(
209 *col_idx,
210 RemoteField::new(
211 columns[*col_idx].name.clone(),
212 RemoteType::Sqlite(SqliteType::Integer),
213 true,
214 ),
215 );
216 to_be_removed.push(*col_idx);
217 }
218 ValueRef::Real(_) => {
219 remote_field_map.insert(
220 *col_idx,
221 RemoteField::new(
222 columns[*col_idx].name.clone(),
223 RemoteType::Sqlite(SqliteType::Real),
224 true,
225 ),
226 );
227 to_be_removed.push(*col_idx);
228 }
229 ValueRef::Text(_) => {
230 remote_field_map.insert(
231 *col_idx,
232 RemoteField::new(
233 columns[*col_idx].name.clone(),
234 RemoteType::Sqlite(SqliteType::Text),
235 true,
236 ),
237 );
238 to_be_removed.push(*col_idx);
239 }
240 ValueRef::Blob(_) => {
241 remote_field_map.insert(
242 *col_idx,
243 RemoteField::new(
244 columns[*col_idx].name.clone(),
245 RemoteType::Sqlite(SqliteType::Blob),
246 true,
247 ),
248 );
249 to_be_removed.push(*col_idx);
250 }
251 }
252 }
253 for col_idx in to_be_removed.iter() {
254 unknown_cols.retain(|&x| x != *col_idx);
255 }
256 if unknown_cols.is_empty() {
257 break;
258 }
259 }
260 }
261
262 if !unknown_cols.is_empty() {
263 return Err(DataFusionError::NotImplemented(format!(
264 "Failed to infer sqlite decl type for columns: {unknown_cols:?}"
265 )));
266 }
267 let remote_fields = remote_field_map
268 .into_iter()
269 .sorted_by_key(|entry| entry.0)
270 .map(|entry| entry.1)
271 .collect::<Vec<_>>();
272 Ok(RemoteSchema::new(remote_fields))
273}
274
275fn rows_to_batch(
276 mut rows: Rows,
277 table_schema: &SchemaRef,
278 columns: Vec<OwnedColumn>,
279 projection: Option<&Vec<usize>>,
280) -> DFResult<(RecordBatch, bool)> {
281 let projected_schema = project_schema(table_schema, projection)?;
282 let mut array_builders = vec![];
283 for field in table_schema.fields() {
284 let builder = make_builder(field.data_type(), 1000);
285 array_builders.push(builder);
286 }
287
288 let mut is_empty = true;
289 while let Some(row) = rows.next().map_err(|e| {
290 DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
291 })? {
292 is_empty = false;
293 append_rows_to_array_builders(
294 row,
295 table_schema,
296 &columns,
297 projection,
298 array_builders.as_mut_slice(),
299 )?;
300 }
301
302 let projected_columns = array_builders
303 .into_iter()
304 .enumerate()
305 .filter(|(idx, _)| projections_contains(projection, *idx))
306 .map(|(_, mut builder)| builder.finish())
307 .collect::<Vec<ArrayRef>>();
308 Ok((
309 RecordBatch::try_new(projected_schema, projected_columns)?,
310 is_empty,
311 ))
312}
313
314macro_rules! handle_primitive_type {
315 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
316 let builder = $builder
317 .as_any_mut()
318 .downcast_mut::<$builder_ty>()
319 .unwrap_or_else(|| {
320 panic!(
321 "Failed to downcast builder to {} for {:?} and {:?}",
322 stringify!($builder_ty),
323 $field,
324 $col
325 )
326 });
327
328 let v: Option<$value_ty> = $row.get($index).map_err(|e| {
329 DataFusionError::Execution(format!(
330 "Failed to get optional {} value for {:?} and {:?}: {e:?}",
331 stringify!($value_ty),
332 $field,
333 $col
334 ))
335 })?;
336
337 match v {
338 Some(v) => builder.append_value(v),
339 None => builder.append_null(),
340 }
341 }};
342}
343
344fn append_rows_to_array_builders(
345 row: &Row,
346 table_schema: &SchemaRef,
347 columns: &[OwnedColumn],
348 projection: Option<&Vec<usize>>,
349 array_builders: &mut [Box<dyn ArrayBuilder>],
350) -> DFResult<()> {
351 for (idx, field) in table_schema.fields.iter().enumerate() {
352 if !projections_contains(projection, idx) {
353 continue;
354 }
355 let builder = &mut array_builders[idx];
356 let col = columns.get(idx);
357 match field.data_type() {
358 DataType::Null => {
359 let builder = builder
360 .as_any_mut()
361 .downcast_mut::<NullBuilder>()
362 .expect("Failed to downcast builder to NullBuilder");
363 builder.append_null();
364 }
365 DataType::Int32 => {
366 handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx);
367 }
368 DataType::Int64 => {
369 handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx);
370 }
371 DataType::Float64 => {
372 handle_primitive_type!(builder, field, col, Float64Builder, f64, row, idx);
373 }
374 DataType::Utf8 => {
375 handle_primitive_type!(builder, field, col, StringBuilder, String, row, idx);
376 }
377 DataType::Binary => {
378 handle_primitive_type!(builder, field, col, BinaryBuilder, Vec<u8>, row, idx);
379 }
380 _ => {
381 return Err(DataFusionError::NotImplemented(format!(
382 "Unsupported data type {:?} for col: {:?}",
383 field.data_type(),
384 col
385 )));
386 }
387 }
388 }
389 Ok(())
390}