datafusion_remote_table/connection/
sqlite.rs1use crate::connection::{RemoteDbType, projections_contains};
2use crate::{
3 Connection, ConnectionOptions, DFResult, Pool, RemoteField, RemoteSchema, RemoteSchemaRef,
4 RemoteType, SqliteType,
5};
6use datafusion::arrow::array::{
7 ArrayBuilder, ArrayRef, BinaryBuilder, Float64Builder, Int64Builder, NullBuilder, RecordBatch,
8 StringBuilder, make_builder,
9};
10use datafusion::arrow::datatypes::{DataType, SchemaRef};
11use datafusion::common::{DataFusionError, project_schema};
12use datafusion::execution::SendableRecordBatchStream;
13use datafusion::physical_plan::memory::MemoryStream;
14use datafusion::prelude::Expr;
15use itertools::Itertools;
16use regex::Regex;
17use rusqlite::types::ValueRef;
18use rusqlite::{Column, Row, Rows};
19use std::collections::HashMap;
20use std::path::PathBuf;
21use std::sync::{Arc, LazyLock};
22
23#[derive(Debug)]
24pub struct SqlitePool {
25 pool: tokio_rusqlite::Connection,
26}
27
28pub async fn connect_sqlite(path: &PathBuf) -> DFResult<SqlitePool> {
29 let pool = tokio_rusqlite::Connection::open(path).await.map_err(|e| {
30 DataFusionError::Execution(format!("Failed to open sqlite connection: {e:?}"))
31 })?;
32 Ok(SqlitePool { pool })
33}
34
35#[async_trait::async_trait]
36impl Pool for SqlitePool {
37 async fn get(&self) -> DFResult<Arc<dyn Connection>> {
38 let conn = self.pool.clone();
39 Ok(Arc::new(SqliteConnection { conn }))
40 }
41}
42
43#[derive(Debug)]
44pub struct SqliteConnection {
45 conn: tokio_rusqlite::Connection,
46}
47
48#[async_trait::async_trait]
49impl Connection for SqliteConnection {
50 async fn infer_schema(&self, sql: &str) -> DFResult<(RemoteSchemaRef, SchemaRef)> {
51 let sql = sql.to_string();
52 self.conn
53 .call(move |conn| {
54 let mut stmt = conn.prepare(&sql)?;
55 let columns: Vec<OwnedColumn> =
56 stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
57 let rows = stmt.query([])?;
58
59 let remote_schema = Arc::new(
60 build_remote_schema(columns.as_slice(), rows)
61 .map_err(|e| tokio_rusqlite::Error::Other(Box::new(e)))?,
62 );
63 let arrow_schema = Arc::new(remote_schema.to_arrow_schema());
64 Ok((remote_schema, arrow_schema))
65 })
66 .await
67 .map_err(|e| DataFusionError::Execution(format!("Failed to infer schema: {e:?}")))
68 }
69
70 async fn query(
71 &self,
72 _conn_options: &ConnectionOptions,
73 sql: &str,
74 table_schema: SchemaRef,
75 projection: Option<&Vec<usize>>,
76 filters: &[Expr],
77 limit: Option<usize>,
78 ) -> DFResult<SendableRecordBatchStream> {
79 let projected_schema = project_schema(&table_schema, projection)?;
80 let sql = RemoteDbType::Sqlite
81 .try_rewrite_query(sql, filters, limit)
82 .unwrap_or_else(|| sql.to_string());
83 let sql_clone = sql.clone();
84 let projection = projection.cloned();
85 let batch = self
86 .conn
87 .call(move |conn| {
88 let mut stmt = conn.prepare(&sql)?;
89 let columns: Vec<OwnedColumn> =
90 stmt.columns().iter().map(sqlite_col_to_owned_col).collect();
91 let rows = stmt.query([])?;
92
93 let batch = rows_to_batch(rows, &table_schema, columns, projection.as_ref())
94 .map_err(|e| tokio_rusqlite::Error::Other(e.into()))?;
95 Ok(batch)
96 })
97 .await
98 .map_err(|e| {
99 DataFusionError::Execution(format!(
100 "Failed to execute query {sql_clone} on sqlite: {e:?}"
101 ))
102 })?;
103
104 let memory_stream = MemoryStream::try_new(vec![batch], projected_schema, None)?;
105 Ok(Box::pin(memory_stream))
106 }
107}
108
109#[derive(Debug)]
110struct OwnedColumn {
111 name: String,
112 decl_type: Option<String>,
113}
114
115fn sqlite_col_to_owned_col(sqlite_col: &Column) -> OwnedColumn {
116 OwnedColumn {
117 name: sqlite_col.name().to_string(),
118 decl_type: sqlite_col.decl_type().map(|x| x.to_string()),
119 }
120}
121
122static CHAR_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^(?i)char\(\d+\)$").unwrap());
123static VARCHAR_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^(?i)varchar\(\d+\)$").unwrap());
124static TEXT_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^(?i)text\(\d+\)$").unwrap());
125static BINARY_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^(?i)binary\(\d+\)$").unwrap());
126static VARBINARY_RE: LazyLock<Regex> =
127 LazyLock::new(|| Regex::new(r"^(?i)varbinary\(\d+\)$").unwrap());
128
129fn decl_type_to_remote_type(decl_type: &str) -> DFResult<RemoteType> {
130 if "null".eq(decl_type) {
131 return Ok(RemoteType::Sqlite(SqliteType::Null));
132 }
133 if ["tinyint", "smallint", "int", "integer", "bigint"].contains(&decl_type) {
134 return Ok(RemoteType::Sqlite(SqliteType::Integer));
135 }
136 if ["real", "float", "double"].contains(&decl_type) {
137 return Ok(RemoteType::Sqlite(SqliteType::Real));
138 }
139 if ["text", "varchar", "char", "string"].contains(&decl_type) {
140 return Ok(RemoteType::Sqlite(SqliteType::Text));
141 }
142 if CHAR_RE.is_match(decl_type) || VARCHAR_RE.is_match(decl_type) || TEXT_RE.is_match(decl_type)
143 {
144 return Ok(RemoteType::Sqlite(SqliteType::Text));
145 }
146 if ["binary", "varbinary", "tinyblob", "blob"].contains(&decl_type) {
147 return Ok(RemoteType::Sqlite(SqliteType::Blob));
148 }
149 if BINARY_RE.is_match(decl_type) || VARBINARY_RE.is_match(decl_type) {
150 return Ok(RemoteType::Sqlite(SqliteType::Blob));
151 }
152 Err(DataFusionError::NotImplemented(format!(
153 "Unsupported sqlite decl type: {decl_type}",
154 )))
155}
156
157fn build_remote_schema(columns: &[OwnedColumn], mut rows: Rows) -> DFResult<RemoteSchema> {
158 let mut remote_field_map = HashMap::with_capacity(columns.len());
159 let mut unknown_cols = vec![];
160 for (col_idx, col) in columns.iter().enumerate() {
161 if let Some(decl_type) = &col.decl_type {
162 let remote_type = decl_type_to_remote_type(&decl_type.to_ascii_lowercase())?;
163 remote_field_map.insert(col_idx, RemoteField::new(&col.name, remote_type, true));
164 } else {
165 unknown_cols.push(col_idx);
166 }
167 }
168
169 if !unknown_cols.is_empty() {
170 while let Some(row) = rows.next().map_err(|e| {
171 DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
172 })? {
173 let mut to_be_removed = vec![];
174 for col_idx in unknown_cols.iter() {
175 let value_ref = row.get_ref(*col_idx).map_err(|e| {
176 DataFusionError::Execution(format!(
177 "Failed to get value ref for column {col_idx}: {e:?}"
178 ))
179 })?;
180 match value_ref {
181 ValueRef::Null => {}
182 ValueRef::Integer(_) => {
183 remote_field_map.insert(
184 *col_idx,
185 RemoteField::new(
186 columns[*col_idx].name.clone(),
187 RemoteType::Sqlite(SqliteType::Integer),
188 true,
189 ),
190 );
191 to_be_removed.push(*col_idx);
192 }
193 ValueRef::Real(_) => {
194 remote_field_map.insert(
195 *col_idx,
196 RemoteField::new(
197 columns[*col_idx].name.clone(),
198 RemoteType::Sqlite(SqliteType::Real),
199 true,
200 ),
201 );
202 to_be_removed.push(*col_idx);
203 }
204 ValueRef::Text(_) => {
205 remote_field_map.insert(
206 *col_idx,
207 RemoteField::new(
208 columns[*col_idx].name.clone(),
209 RemoteType::Sqlite(SqliteType::Text),
210 true,
211 ),
212 );
213 to_be_removed.push(*col_idx);
214 }
215 ValueRef::Blob(_) => {
216 remote_field_map.insert(
217 *col_idx,
218 RemoteField::new(
219 columns[*col_idx].name.clone(),
220 RemoteType::Sqlite(SqliteType::Blob),
221 true,
222 ),
223 );
224 to_be_removed.push(*col_idx);
225 }
226 }
227 }
228 for col_idx in to_be_removed.iter() {
229 unknown_cols.retain(|&x| x != *col_idx);
230 }
231 if unknown_cols.is_empty() {
232 break;
233 }
234 }
235 }
236
237 if !unknown_cols.is_empty() {
238 return Err(DataFusionError::NotImplemented(format!(
239 "Failed to infer sqlite decl type for columns: {unknown_cols:?}"
240 )));
241 }
242 let remote_fields = remote_field_map
243 .into_iter()
244 .sorted_by_key(|entry| entry.0)
245 .map(|entry| entry.1)
246 .collect::<Vec<_>>();
247 Ok(RemoteSchema::new(remote_fields))
248}
249
250fn rows_to_batch(
251 mut rows: Rows,
252 table_schema: &SchemaRef,
253 columns: Vec<OwnedColumn>,
254 projection: Option<&Vec<usize>>,
255) -> DFResult<RecordBatch> {
256 let projected_schema = project_schema(table_schema, projection)?;
257 let mut array_builders = vec![];
258 for field in table_schema.fields() {
259 let builder = make_builder(field.data_type(), 1000);
260 array_builders.push(builder);
261 }
262
263 while let Some(row) = rows.next().map_err(|e| {
264 DataFusionError::Execution(format!("Failed to get next row from sqlite: {e:?}"))
265 })? {
266 append_rows_to_array_builders(
267 row,
268 table_schema,
269 &columns,
270 projection,
271 array_builders.as_mut_slice(),
272 )?;
273 }
274
275 let projected_columns = array_builders
276 .into_iter()
277 .enumerate()
278 .filter(|(idx, _)| projections_contains(projection, *idx))
279 .map(|(_, mut builder)| builder.finish())
280 .collect::<Vec<ArrayRef>>();
281 Ok(RecordBatch::try_new(projected_schema, projected_columns)?)
282}
283
284macro_rules! handle_primitive_type {
285 ($builder:expr, $field:expr, $col:expr, $builder_ty:ty, $value_ty:ty, $row:expr, $index:expr) => {{
286 let builder = $builder
287 .as_any_mut()
288 .downcast_mut::<$builder_ty>()
289 .unwrap_or_else(|| {
290 panic!(
291 "Failed to downcast builder to {} for {:?} and {:?}",
292 stringify!($builder_ty),
293 $field,
294 $col
295 )
296 });
297
298 let v: Option<$value_ty> = $row.get($index).map_err(|e| {
299 DataFusionError::Execution(format!(
300 "Failed to get optional {} value for {:?} and {:?}: {e:?}",
301 stringify!($value_ty),
302 $field,
303 $col
304 ))
305 })?;
306
307 match v {
308 Some(v) => builder.append_value(v),
309 None => builder.append_null(),
310 }
311 }};
312}
313
314fn append_rows_to_array_builders(
315 row: &Row,
316 table_schema: &SchemaRef,
317 columns: &[OwnedColumn],
318 projection: Option<&Vec<usize>>,
319 array_builders: &mut [Box<dyn ArrayBuilder>],
320) -> DFResult<()> {
321 for (idx, field) in table_schema.fields.iter().enumerate() {
322 if !projections_contains(projection, idx) {
323 continue;
324 }
325 let builder = &mut array_builders[idx];
326 let col = columns.get(idx);
327 match field.data_type() {
328 DataType::Null => {
329 let builder = builder
330 .as_any_mut()
331 .downcast_mut::<NullBuilder>()
332 .expect("Failed to downcast builder to NullBuilder");
333 builder.append_null();
334 }
335 DataType::Int64 => {
336 handle_primitive_type!(builder, field, col, Int64Builder, i64, row, idx);
337 }
338 DataType::Float64 => {
339 handle_primitive_type!(builder, field, col, Float64Builder, f64, row, idx);
340 }
341 DataType::Utf8 => {
342 handle_primitive_type!(builder, field, col, StringBuilder, String, row, idx);
343 }
344 DataType::Binary => {
345 handle_primitive_type!(builder, field, col, BinaryBuilder, Vec<u8>, row, idx);
346 }
347 _ => {
348 return Err(DataFusionError::NotImplemented(format!(
349 "Unsupported data type {:?} for col: {:?}",
350 field.data_type(),
351 col
352 )));
353 }
354 }
355 }
356 Ok(())
357}