datafusion_table_providers/sql/db_connection_pool/dbconnection/
sqliteconn.rs1use std::any::Any;
2
3use crate::sql::arrow_sql_gen::sqlite::rows_to_arrow;
4use crate::util::schema::SchemaValidator;
5use crate::UnsupportedTypeAction;
6use arrow::datatypes::SchemaRef;
7use arrow_schema::DataType;
8use async_trait::async_trait;
9use datafusion::execution::SendableRecordBatchStream;
10use datafusion::physical_plan::memory::MemoryStream;
11use datafusion::sql::TableReference;
12use rusqlite::ToSql;
13use snafu::prelude::*;
14use tokio_rusqlite::Connection;
15
16use super::AsyncDbConnection;
17use super::DbConnection;
18use super::Result;
19
20#[derive(Debug, Snafu)]
21pub enum Error {
22 #[snafu(display("ConnectionError {source}"))]
23 ConnectionError { source: tokio_rusqlite::Error },
24
25 #[snafu(display("Unable to query: {source}"))]
26 QueryError { source: rusqlite::Error },
27
28 #[snafu(display("Failed to convert query result to Arrow: {source}"))]
29 ConversionError {
30 source: crate::sql::arrow_sql_gen::sqlite::Error,
31 },
32}
33
34pub struct SqliteConnection {
35 pub conn: Connection,
36}
37
38impl SchemaValidator for SqliteConnection {
39 type Error = super::Error;
40
41 fn is_data_type_supported(data_type: &DataType) -> bool {
42 match data_type {
43 DataType::Dictionary(_, _) | DataType::Interval(_) | DataType::Map(_, _) => false,
44 DataType::List(inner_field)
45 | DataType::FixedSizeList(inner_field, _)
46 | DataType::LargeList(inner_field) => {
47 match inner_field.data_type() {
48 dt if dt.is_primitive() => true,
49 DataType::Utf8
50 | DataType::Binary
51 | DataType::Utf8View
52 | DataType::BinaryView
53 | DataType::Boolean => true,
54 _ => false, }
56 }
57 DataType::Struct(inner_fields) => inner_fields
58 .iter()
59 .all(|field| Self::is_data_type_supported(field.data_type())),
60 _ => true,
61 }
62 }
63
64 fn unsupported_type_error(data_type: &DataType, field_name: &str) -> Self::Error {
65 super::Error::UnsupportedDataType {
66 data_type: data_type.to_string(),
67 field_name: field_name.to_string(),
68 }
69 }
70}
71
72impl DbConnection<Connection, &'static (dyn ToSql + Sync)> for SqliteConnection {
73 fn as_any(&self) -> &dyn Any {
74 self
75 }
76
77 fn as_any_mut(&mut self) -> &mut dyn Any {
78 self
79 }
80
81 fn as_async(&self) -> Option<&dyn AsyncDbConnection<Connection, &'static (dyn ToSql + Sync)>> {
82 Some(self)
83 }
84}
85
86#[async_trait]
87impl AsyncDbConnection<Connection, &'static (dyn ToSql + Sync)> for SqliteConnection {
88 fn new(conn: Connection) -> Self {
89 SqliteConnection { conn }
90 }
91
92 async fn tables(&self, _schema: &str) -> Result<Vec<String>, super::Error> {
93 let tables = self
94 .conn
95 .call(move |conn| {
96 let mut stmt = conn.prepare(
97 "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'",
98 )?;
99 let rows = stmt.query_map([], |row| row.get::<_, String>(0))?;
100 let tables: Result<Vec<_>, rusqlite::Error> = rows.collect();
101 Ok(tables?)
102 })
103 .await
104 .boxed()
105 .context(super::UnableToGetTablesSnafu)?;
106
107 Ok(tables)
108 }
109
110 async fn schemas(&self) -> Result<Vec<String>, super::Error> {
111 Ok(vec!["main".to_string()])
112 }
113
114 async fn get_schema(
115 &self,
116 table_reference: &TableReference,
117 ) -> Result<SchemaRef, super::Error> {
118 let table_reference = table_reference.to_quoted_string();
119 let schema: SchemaRef = self
120 .conn
121 .call(move |conn| {
122 let mut stmt = conn.prepare(&format!("SELECT * FROM {table_reference} LIMIT 1"))?;
123 let column_count = stmt.column_count();
124 let rows = stmt.query([])?;
125 let rec = rows_to_arrow(rows, column_count, None)
126 .context(ConversionSnafu)
127 .map_err(to_tokio_rusqlite_error)?;
128 let schema = rec.schema();
129 Ok(schema)
130 })
131 .await
132 .boxed()
133 .context(super::UnableToGetSchemaSnafu)?;
134
135 Self::handle_unsupported_schema(&schema, UnsupportedTypeAction::Error)
136 }
137
138 async fn query_arrow(
139 &self,
140 sql: &str,
141 params: &[&'static (dyn ToSql + Sync)],
142 projected_schema: Option<SchemaRef>,
143 ) -> Result<SendableRecordBatchStream> {
144 let sql = sql.to_string();
145 let params = params.to_vec();
146
147 let rec = self
148 .conn
149 .call(move |conn| {
150 let mut stmt = conn.prepare(sql.as_str())?;
151 for (i, param) in params.iter().enumerate() {
152 stmt.raw_bind_parameter(i + 1, param)?;
153 }
154 let column_count = stmt.column_count();
155 let rows = stmt.raw_query();
156
157 let rec = rows_to_arrow(rows, column_count, projected_schema)
158 .context(ConversionSnafu)
159 .map_err(to_tokio_rusqlite_error)?;
160 Ok(rec)
161 })
162 .await
163 .context(ConnectionSnafu)?;
164
165 let schema = rec.schema();
166 let recs = if rec.num_rows() > 0 {
167 vec![rec]
168 } else {
169 vec![]
170 };
171 Ok(Box::pin(MemoryStream::try_new(recs, schema, None)?))
172 }
173
174 async fn execute(&self, sql: &str, params: &[&'static (dyn ToSql + Sync)]) -> Result<u64> {
175 let sql = sql.to_string();
176 let params = params.to_vec();
177
178 let rows_modified = self
179 .conn
180 .call(move |conn| {
181 let mut stmt = conn.prepare(sql.as_str())?;
182 for (i, param) in params.iter().enumerate() {
183 stmt.raw_bind_parameter(i + 1, param)?;
184 }
185 let rows_modified = stmt.raw_execute()?;
186 Ok(rows_modified)
187 })
188 .await
189 .context(ConnectionSnafu)?;
190 Ok(rows_modified as u64)
191 }
192}
193
194fn to_tokio_rusqlite_error(e: impl Into<Error>) -> tokio_rusqlite::Error {
195 tokio_rusqlite::Error::Other(Box::new(e.into()))
196}