datafusion_remote_table/connection/
sqlite.rs1use crate::connection::{RemoteDbType, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, Pool, RemoteField, RemoteSchema, RemoteSchemaRef,
4 RemoteType, SqliteType,
5};
6use datafusion::arrow::array::{
7 ArrayBuilder, ArrayRef, BinaryBuilder, Float64Builder, Int64Builder, NullBuilder, RecordBatch,
8 StringBuilder, make_builder,
9};
10use datafusion::arrow::datatypes::{DataType, SchemaRef};
11use datafusion::common::{DataFusionError, project_schema};
12use datafusion::execution::SendableRecordBatchStream;
13use datafusion::physical_plan::memory::MemoryStream;
14use datafusion::prelude::Expr;
15use itertools::Itertools;
16use rusqlite::types::ValueRef;
17use rusqlite::{Column, Row, Rows};
18use std::collections::HashMap;
19use std::path::PathBuf;
20use std::sync::Arc;
21
22#[derive(Debug)]
23pub struct SqlitePool {
24 pool: tokio_rusqlite::Connection,
25}
26
27pub async fn connect_sqlite(path: &PathBuf) -> DFResult<SqlitePool> {
28 let pool = tokio_rusqlite::Connection::open(path).await.map_err(|e| {
29 DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
30 })?;
31 Ok(SqlitePool { pool })
32}
33
34#[async_trait::async_trait]
35impl Pool for SqlitePool {
36 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
37 let conn = self.pool.clone();
38 Ok(Arc::new(SqliteConnection { conn }))
39 }
40}
41
42#[derive(Debug)]
43pub struct SqliteConnection {
44 conn: tokio_rusqlite::Connection,
45}
46
47#[async_trait::async_trait]
48impl Connection for SqliteConnection {
49 async fn infer_schema(&self, sql: &str) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
50 let sql = sql.to_string();
51 self.conn
52 .call(move |conn| {
53 let mut stmt = conn.prepare(&sql)?;
54 let columns: Vec<OwnedColumn> =
55 stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
56 let rows = stmt.query([])?;
57
58 let remote_schema = Arc::new(
59 build_remote_schema(columns.as_slice(), rows)
60 .map_err(|e| tokio_rusqlite::Error::Other(Box::new(e)))?,
61 );
62 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
63 Ok((remote_schema, arrow_schema))
64 })
65 .await
66 .map_err(|e| DataFusionError::Execution(format!("Failed to infer schema: {e:?}")))
67 }
68
69 async fn query(
70 &self,
71 _conn_options: &ConnectionOptions,
72 sql: &str,
73 table_schema: SchemaRef,
74 projection: Option<&Vec<usize>>,
75 filters: &[Expr],
76 limit: Option<usize>,
77 ) -> DFResult<SendableRecordBatchStream> {
78 let projected_schema = project_schema(&table_schema, projection)?;
79 let sql = RemoteDbType::Sqlite
80 .try_rewrite_query(sql, filters, limit)
81 .unwrap_or_else(|| sql.to_string());
82 let sql_clone = sql.clone();
83 let projection = projection.cloned();
84 let batch = self
85 .conn
86 .call(move |conn| {
87 let mut stmt = conn.prepare(&sql)?;
88 let columns: Vec<OwnedColumn> =
89 stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
90 let rows = stmt.query([])?;
91
92 let batch = rows_to_batch(rows, &table_schema, columns, projection.as_ref())
93 .map_err(|e| tokio_rusqlite::Error::Other(e.into()))?;
94 Ok(batch)
95 })
96 .await
97 .map_err(|e| {
98 DataFusionError::Execution(format!(
99 "Failed to execute query {sql_clone} on sqlite: {e:?}"
100 ))
101 })?;
102
103 let memory_stream = MemoryStream::try_new(vec![batch], projected_schema, None)?;
104 Ok(Box::pin(memory_stream))
105 }
106}
107
108#[derive(Debug)]
109struct OwnedColumn {
110 name: String,
111 decl_type: Option<String>,
112}
113
114fn sqlite_col_to_owned_col(sqlite_col: &Column) -> OwnedColumn {
115 OwnedColumn {
116 name: sqlite_col.name().to_string(),
117 decl_type: sqlite_col.decl_type().map(|x| x.to_string()),
118 }
119}
120
121fn decl_type_to_remote_type(decl_type: &str) -> DFResult<RemoteType> {
122 if "null".eq(decl_type) {
123 return Ok(RemoteType::Sqlite(SqliteType::Null));
124 }
125 if ["tinyint", "smallint", "int", "integer", "bigint"].contains(&decl_type) {
126 return Ok(RemoteType::Sqlite(SqliteType::Integer));
127 }
128 if ["real", "float", "double"].contains(&decl_type) {
129 return Ok(RemoteType::Sqlite(SqliteType::Real));
130 }
131 if decl_type.starts_with("real") {
132 return Ok(RemoteType::Sqlite(SqliteType::Real));
133 }
134 if ["text", "varchar", "char", "string"].contains(&decl_type) {
135 return Ok(RemoteType::Sqlite(SqliteType::Text));
136 }
137 if decl_type.starts_with("char")
138 || decl_type.starts_with("varchar")
139 || decl_type.starts_with("text")
140 {
141 return Ok(RemoteType::Sqlite(SqliteType::Text));
142 }
143 if ["binary", "varbinary", "tinyblob", "blob"].contains(&decl_type) {
144 return Ok(RemoteType::Sqlite(SqliteType::Blob));
145 }
146 if decl_type.starts_with("binary") || decl_type.starts_with("varbinary") {
147 return Ok(RemoteType::Sqlite(SqliteType::Blob));
148 }
149 Err(DataFusionError::NotImplemented(format!(
150 "Unsupported sqlite decl type: {decl_type}",
151 )))
152}
153
154fn build_remote_schema(columns: &[OwnedColumn], mut rows: Rows) -> DFResult<RemoteSchema> {
155 let mut remote_field_map = HashMap::with_capacity(columns.len());
156 let mut unknown_cols = vec![];
157 for (col_idx, col) in columns.iter().enumerate() {
158 if let Some(decl_type) = &col.decl_type {
159 let remote_type = decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?;
160 remote_field_map.insert(col_idx, RemoteField::new(&col.name, remote_type, true));
161 } else {
162 unknown_cols.push(col_idx);
163 }
164 }
165
166 if !unknown_cols.is_empty() {
167 while let Some(row) = rows.next().map_err(|e| {
168 DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
169 })? {
170 let mut to_be_removed = vec![];
171 for col_idx in unknown_cols.iter() {
172 let value_ref = row.get_ref(*col_idx).map_err(|e| {
173 DataFusionError::Execution(format!(
174 "Failed to get value ref for column {col_idx}: {e:?}"
175 ))
176 })?;
177 match value_ref {
178 ValueRef::Null => {}
179 ValueRef::Integer(_) => {
180 remote_field_map.insert(
181 *col_idx,
182 RemoteField::new(
183 columns[*col_idx].name.clone(),
184 RemoteType::Sqlite(SqliteType::Integer),
185 true,
186 ),
187 );
188 to_be_removed.push(*col_idx);
189 }
190 ValueRef::Real(_) => {
191 remote_field_map.insert(
192 *col_idx,
193 RemoteField::new(
194 columns[*col_idx].name.clone(),
195 RemoteType::Sqlite(SqliteType::Real),
196 true,
197 ),
198 );
199 to_be_removed.push(*col_idx);
200 }
201 ValueRef::Text(_) => {
202 remote_field_map.insert(
203 *col_idx,
204 RemoteField::new(
205 columns[*col_idx].name.clone(),
206 RemoteType::Sqlite(SqliteType::Text),
207 true,
208 ),
209 );
210 to_be_removed.push(*col_idx);
211 }
212 ValueRef::Blob(_) => {
213 remote_field_map.insert(
214 *col_idx,
215 RemoteField::new(
216 columns[*col_idx].name.clone(),
217 RemoteType::Sqlite(SqliteType::Blob),
218 true,
219 ),
220 );
221 to_be_removed.push(*col_idx);
222 }
223 }
224 }
225 for col_idx in to_be_removed.iter() {
226 unknown_cols.retain(|&x| x != *col_idx);
227 }
228 if unknown_cols.is_empty() {
229 break;
230 }
231 }
232 }
233
234 if !unknown_cols.is_empty() {
235 return Err(DataFusionError::NotImplemented(format!(
236 "Failed to infer sqlite decl type for columns: {unknown_cols:?}"
237 )));
238 }
239 let remote_fields = remote_field_map
240 .into_iter()
241 .sorted_by_key(|entry| entry.0)
242 .map(|entry| entry.1)
243 .collect::<Vec<_>>();
244 Ok(RemoteSchema::new(remote_fields))
245}
246
247fn rows_to_batch(
248 mut rows: Rows,
249 table_schema: &SchemaRef,
250 columns: Vec<OwnedColumn>,
251 projection: Option<&Vec<usize>>,
252) -> DFResult<RecordBatch> {
253 let projected_schema = project_schema(table_schema, projection)?;
254 let mut array_builders = vec![];
255 for field in table_schema.fields() {
256 let builder = make_builder(field.data_type(), 1000);
257 array_builders.push(builder);
258 }
259
260 while let Some(row) = rows.next().map_err(|e| {
261 DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
262 })? {
263 append_rows_to_array_builders(
264 row,
265 table_schema,
266 &columns,
267 projection,
268 array_builders.as_mut_slice(),
269 )?;
270 }
271
272 let projected_columns = array_builders
273 .into_iter()
274 .enumerate()
275 .filter(|(idx, _)| projections_contains(projection, *idx))
276 .map(|(_, mut builder)| builder.finish())
277 .collect::<Vec<ArrayRef>>();
278 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
279}
280
281macro_rules! handle_primitive_type {
282 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
283 let builder = $builder
284 .as_any_mut()
285 .downcast_mut::<$builder_ty>()
286 .unwrap_or_else(|| {
287 panic!(
288 "Failed to downcast builder to {} for {:?} and {:?}",
289 stringify!($builder_ty),
290 $field,
291 $col
292 )
293 });
294
295 let v: Option<$value_ty> = $row.get($index).map_err(|e| {
296 DataFusionError::Execution(format!(
297 "Failed to get optional {} value for {:?} and {:?}: {e:?}",
298 stringify!($value_ty),
299 $field,
300 $col
301 ))
302 })?;
303
304 match v {
305 Some(v) => builder.append_value(v),
306 None => builder.append_null(),
307 }
308 }};
309}
310
311fn append_rows_to_array_builders(
312 row: &Row,
313 table_schema: &SchemaRef,
314 columns: &[OwnedColumn],
315 projection: Option<&Vec<usize>>,
316 array_builders: &mut [Box<dyn ArrayBuilder>],
317) -> DFResult<()> {
318 for (idx, field) in table_schema.fields.iter().enumerate() {
319 if !projections_contains(projection, idx) {
320 continue;
321 }
322 let builder = &mut array_builders[idx];
323 let col = columns.get(idx);
324 match field.data_type() {
325 DataType::Null => {
326 let builder = builder
327 .as_any_mut()
328 .downcast_mut::<NullBuilder>()
329 .expect("Failed to downcast builder to NullBuilder");
330 builder.append_null();
331 }
332 DataType::Int64 => {
333 handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx);
334 }
335 DataType::Float64 => {
336 handle_primitive_type!(builder, field, col, Float64Builder, f64, row, idx);
337 }
338 DataType::Utf8 => {
339 handle_primitive_type!(builder, field, col, StringBuilder, String, row, idx);
340 }
341 DataType::Binary => {
342 handle_primitive_type!(builder, field, col, BinaryBuilder, Vec<u8>, row, idx);
343 }
344 _ => {
345 return Err(DataFusionError::NotImplemented(format!(
346 "Unsupported data type {:?} for col: {:?}",
347 field.data_type(),
348 col
349 )));
350 }
351 }
352 }
353 Ok(())
354}