1use crate::connection::{RemoteDbType, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, Pool, RemoteField, RemoteSchema, RemoteSchemaRef,
4 RemoteType, SqliteType,
5};
6use datafusion::arrow::array::{
7 ArrayBuilder, ArrayRef, BinaryBuilder, Float64Builder, Int32Builder, Int64Builder, NullBuilder,
8 RecordBatch, RecordBatchOptions, StringBuilder, make_builder,
9};
10use datafusion::arrow::datatypes::{DataType, SchemaRef};
11use datafusion::common::{DataFusionError, project_schema};
12use datafusion::execution::SendableRecordBatchStream;
13use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
14use derive_getters::Getters;
15use derive_with::With;
16use itertools::Itertools;
17use log::debug;
18use rusqlite::types::ValueRef;
19use rusqlite::{Column, Row, Rows};
20use std::collections::HashMap;
21use std::path::PathBuf;
22use std::sync::Arc;
23
24#[derive(Debug, Clone, With, Getters)]
25pub struct SqliteConnectionOptions {
26 pub path: PathBuf,
27 pub stream_chunk_size: usize,
28}
29
30impl SqliteConnectionOptions {
31 pub fn new(path: PathBuf) -> Self {
32 Self {
33 path,
34 stream_chunk_size: 2048,
35 }
36 }
37}
38
39#[derive(Debug)]
40pub struct SqlitePool {
41 pool: tokio_rusqlite::Connection,
42}
43
44pub async fn connect_sqlite(options: &SqliteConnectionOptions) -> DFResult<SqlitePool> {
45 let pool = tokio_rusqlite::Connection::open(&options.path)
46 .await
47 .map_err(|e| {
48 DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
49 })?;
50 Ok(SqlitePool { pool })
51}
52
53#[async_trait::async_trait]
54impl Pool for SqlitePool {
55 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
56 let conn = self.pool.clone();
57 Ok(Arc::new(SqliteConnection { conn }))
58 }
59}
60
61#[derive(Debug)]
62pub struct SqliteConnection {
63 conn: tokio_rusqlite::Connection,
64}
65
66#[async_trait::async_trait]
67impl Connection for SqliteConnection {
68 async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef> {
69 let sql = RemoteDbType::Sqlite.query_limit_1(sql);
70 self.conn
71 .call(move |conn| {
72 let mut stmt = conn.prepare(&sql)?;
73 let columns: Vec<OwnedColumn> =
74 stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
75 let rows = stmt.query([])?;
76
77 let remote_schema = Arc::new(
78 build_remote_schema(columns.as_slice(), rows)
79 .map_err(|e| tokio_rusqlite::Error::Other(Box::new(e)))?,
80 );
81 Ok(remote_schema)
82 })
83 .await
84 .map_err(|e| DataFusionError::Execution(format!("Failed to infer schema: {e:?}")))
85 }
86
87 async fn query(
88 &self,
89 conn_options: &ConnectionOptions,
90 sql: &str,
91 table_schema: SchemaRef,
92 projection: Option<&Vec<usize>>,
93 unparsed_filters: &[String],
94 limit: Option<usize>,
95 ) -> DFResult<SendableRecordBatchStream> {
96 let projected_schema = project_schema(&table_schema, projection)?;
97 let sql = RemoteDbType::Sqlite.rewrite_query(sql, unparsed_filters, limit);
98 debug!("[remote-table] executing sqlite query: {sql}");
99
100 let conn = self.conn.clone();
101 let projection = projection.cloned();
102 let limit = conn_options.stream_chunk_size();
103 let stream = async_stream::stream! {
104 let mut offset = 0;
105 loop {
106 let sql = format!("SELECT * FROM ({sql}) LIMIT {limit} OFFSET {offset}");
107 let sql_clone = sql.clone();
108 let conn = conn.clone();
109 let projection = projection.clone();
110 let table_schema = table_schema.clone();
111 let (batch, is_empty) = conn
112 .call(move |conn| {
113 let mut stmt = conn.prepare(&sql)?;
114 let columns: Vec<OwnedColumn> =
115 stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
116 let rows = stmt.query([])?;
117
118 rows_to_batch(rows, &table_schema, columns, projection.as_ref())
119 .map_err(|e| tokio_rusqlite::Error::Other(e.into()))
120 })
121 .await
122 .map_err(|e| {
123 DataFusionError::Execution(format!(
124 "Failed to execute query {sql_clone} on sqlite: {e:?}"
125 ))
126 })?;
127 if is_empty {
128 break;
129 }
130 yield Ok(batch);
131 offset += limit;
132 }
133 };
134 Ok(Box::pin(RecordBatchStreamAdapter::new(
135 projected_schema,
136 stream,
137 )))
138 }
139}
140
141#[derive(Debug)]
142struct OwnedColumn {
143 name: String,
144 decl_type: Option<String>,
145}
146
147fn sqlite_col_to_owned_col(sqlite_col: &Column) -> OwnedColumn {
148 OwnedColumn {
149 name: sqlite_col.name().to_string(),
150 decl_type: sqlite_col.decl_type().map(|x| x.to_string()),
151 }
152}
153
154fn decl_type_to_remote_type(decl_type: &str) -> DFResult<SqliteType> {
155 if ["tinyint", "smallint", "int", "integer", "bigint"].contains(&decl_type) {
156 return Ok(SqliteType::Integer);
157 }
158 if ["real", "float", "double"].contains(&decl_type) {
159 return Ok(SqliteType::Real);
160 }
161 if decl_type.starts_with("real") {
162 return Ok(SqliteType::Real);
163 }
164 if ["text", "varchar", "char", "string"].contains(&decl_type) {
165 return Ok(SqliteType::Text);
166 }
167 if decl_type.starts_with("char")
168 || decl_type.starts_with("varchar")
169 || decl_type.starts_with("text")
170 {
171 return Ok(SqliteType::Text);
172 }
173 if ["binary", "varbinary", "tinyblob", "blob"].contains(&decl_type) {
174 return Ok(SqliteType::Blob);
175 }
176 if decl_type.starts_with("binary") || decl_type.starts_with("varbinary") {
177 return Ok(SqliteType::Blob);
178 }
179 Err(DataFusionError::NotImplemented(format!(
180 "Unsupported sqlite decl type: {decl_type}",
181 )))
182}
183
184fn build_remote_schema(columns: &[OwnedColumn], mut rows: Rows) -> DFResult<RemoteSchema> {
185 let mut remote_field_map = HashMap::with_capacity(columns.len());
186 let mut unknown_cols = vec![];
187 for (col_idx, col) in columns.iter().enumerate() {
188 if let Some(decl_type) = &col.decl_type {
189 let remote_type =
190 RemoteType::Sqlite(decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?);
191 remote_field_map.insert(col_idx, RemoteField::new(&col.name, remote_type, true));
192 } else {
193 unknown_cols.push(col_idx);
194 }
195 }
196
197 if !unknown_cols.is_empty() {
198 while let Some(row) = rows.next().map_err(|e| {
199 DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
200 })? {
201 let mut to_be_removed = vec![];
202 for col_idx in unknown_cols.iter() {
203 let value_ref = row.get_ref(*col_idx).map_err(|e| {
204 DataFusionError::Execution(format!(
205 "Failed to get value ref for column {col_idx}: {e:?}"
206 ))
207 })?;
208 match value_ref {
209 ValueRef::Null => {}
210 ValueRef::Integer(_) => {
211 remote_field_map.insert(
212 *col_idx,
213 RemoteField::new(
214 columns[*col_idx].name.clone(),
215 RemoteType::Sqlite(SqliteType::Integer),
216 true,
217 ),
218 );
219 to_be_removed.push(*col_idx);
220 }
221 ValueRef::Real(_) => {
222 remote_field_map.insert(
223 *col_idx,
224 RemoteField::new(
225 columns[*col_idx].name.clone(),
226 RemoteType::Sqlite(SqliteType::Real),
227 true,
228 ),
229 );
230 to_be_removed.push(*col_idx);
231 }
232 ValueRef::Text(_) => {
233 remote_field_map.insert(
234 *col_idx,
235 RemoteField::new(
236 columns[*col_idx].name.clone(),
237 RemoteType::Sqlite(SqliteType::Text),
238 true,
239 ),
240 );
241 to_be_removed.push(*col_idx);
242 }
243 ValueRef::Blob(_) => {
244 remote_field_map.insert(
245 *col_idx,
246 RemoteField::new(
247 columns[*col_idx].name.clone(),
248 RemoteType::Sqlite(SqliteType::Blob),
249 true,
250 ),
251 );
252 to_be_removed.push(*col_idx);
253 }
254 }
255 }
256 for col_idx in to_be_removed.iter() {
257 unknown_cols.retain(|&x| x != *col_idx);
258 }
259 if unknown_cols.is_empty() {
260 break;
261 }
262 }
263 }
264
265 if !unknown_cols.is_empty() {
266 return Err(DataFusionError::NotImplemented(format!(
267 "Failed to infer sqlite decl type for columns: {unknown_cols:?}"
268 )));
269 }
270 let remote_fields = remote_field_map
271 .into_iter()
272 .sorted_by_key(|entry| entry.0)
273 .map(|entry| entry.1)
274 .collect::<Vec<_>>();
275 Ok(RemoteSchema::new(remote_fields))
276}
277
278fn rows_to_batch(
279 mut rows: Rows,
280 table_schema: &SchemaRef,
281 columns: Vec<OwnedColumn>,
282 projection: Option<&Vec<usize>>,
283) -> DFResult<(RecordBatch, bool)> {
284 let projected_schema = project_schema(table_schema, projection)?;
285 let mut array_builders = vec![];
286 for field in table_schema.fields() {
287 let builder = make_builder(field.data_type(), 1000);
288 array_builders.push(builder);
289 }
290
291 let mut is_empty = true;
292 let mut row_count = 0;
293 while let Some(row) = rows.next().map_err(|e| {
294 DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
295 })? {
296 is_empty = false;
297 row_count += 1;
298 append_rows_to_array_builders(
299 row,
300 table_schema,
301 &columns,
302 projection,
303 array_builders.as_mut_slice(),
304 )?;
305 }
306
307 let projected_columns = array_builders
308 .into_iter()
309 .enumerate()
310 .filter(|(idx, _)| projections_contains(projection, *idx))
311 .map(|(_, mut builder)| builder.finish())
312 .collect::<Vec<ArrayRef>>();
313 let options = RecordBatchOptions::new().with_row_count(Some(row_count));
314 Ok((
315 RecordBatch::try_new_with_options(projected_schema, projected_columns, &options)?,
316 is_empty,
317 ))
318}
319
320macro_rules! handle_primitive_type {
321 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
322 let builder = $builder
323 .as_any_mut()
324 .downcast_mut::<$builder_ty>()
325 .unwrap_or_else(|| {
326 panic!(
327 "Failed to downcast builder to {} for {:?} and {:?}",
328 stringify!($builder_ty),
329 $field,
330 $col
331 )
332 });
333
334 let v: Option<$value_ty> = $row.get($index).map_err(|e| {
335 DataFusionError::Execution(format!(
336 "Failed to get optional {} value for {:?} and {:?}: {e:?}",
337 stringify!($value_ty),
338 $field,
339 $col
340 ))
341 })?;
342
343 match v {
344 Some(v) => builder.append_value(v),
345 None => builder.append_null(),
346 }
347 }};
348}
349
350fn append_rows_to_array_builders(
351 row: &Row,
352 table_schema: &SchemaRef,
353 columns: &[OwnedColumn],
354 projection: Option<&Vec<usize>>,
355 array_builders: &mut [Box<dyn ArrayBuilder>],
356) -> DFResult<()> {
357 for (idx, field) in table_schema.fields.iter().enumerate() {
358 if !projections_contains(projection, idx) {
359 continue;
360 }
361 let builder = &mut array_builders[idx];
362 let col = columns.get(idx);
363 match field.data_type() {
364 DataType::Null => {
365 let builder = builder
366 .as_any_mut()
367 .downcast_mut::<NullBuilder>()
368 .expect("Failed to downcast builder to NullBuilder");
369 builder.append_null();
370 }
371 DataType::Int32 => {
372 handle_primitive_type!(builder, field, col, Int32Builder, i32, row, idx);
373 }
374 DataType::Int64 => {
375 handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx);
376 }
377 DataType::Float64 => {
378 handle_primitive_type!(builder, field, col, Float64Builder, f64, row, idx);
379 }
380 DataType::Utf8 => {
381 handle_primitive_type!(builder, field, col, StringBuilder, String, row, idx);
382 }
383 DataType::Binary => {
384 handle_primitive_type!(builder, field, col, BinaryBuilder, Vec<u8>, row, idx);
385 }
386 _ => {
387 return Err(DataFusionError::NotImplemented(format!(
388 "Unsupported data type {:?} for col: {:?}",
389 field.data_type(),
390 col
391 )));
392 }
393 }
394 }
395 Ok(())
396}