datafusion_table_providers/sql/db_connection_pool/dbconnection/
sqliteconn.rs

1use std::any::Any;
2
3use crate::sql::arrow_sql_gen::sqlite::rows_to_arrow;
4use crate::util::schema::SchemaValidator;
5use crate::UnsupportedTypeAction;
6use arrow::datatypes::SchemaRef;
7use arrow_schema::DataType;
8use async_trait::async_trait;
9use datafusion::execution::SendableRecordBatchStream;
10use datafusion::physical_plan::memory::MemoryStream;
11use datafusion::sql::TableReference;
12use rusqlite::ToSql;
13use snafu::prelude::*;
14use tokio_rusqlite::Connection;
15
16use super::AsyncDbConnection;
17use super::DbConnection;
18use super::Result;
19
20#[derive(Debug, Snafu)]
21pub enum Error {
22    #[snafu(display("ConnectionError {source}"))]
23    ConnectionError { source: tokio_rusqlite::Error },
24
25    #[snafu(display("Unable to query: {source}"))]
26    QueryError { source: rusqlite::Error },
27
28    #[snafu(display("Failed to convert query result to Arrow: {source}"))]
29    ConversionError {
30        source: crate::sql::arrow_sql_gen::sqlite::Error,
31    },
32}
33
34pub struct SqliteConnection {
35    pub conn: Connection,
36}
37
38impl SchemaValidator for SqliteConnection {
39    type Error = super::Error;
40
41    fn is_data_type_supported(data_type: &DataType) -> bool {
42        match data_type {
43            DataType::Dictionary(_, _) | DataType::Interval(_) | DataType::Map(_, _) => false,
44            DataType::List(inner_field)
45            | DataType::FixedSizeList(inner_field, _)
46            | DataType::LargeList(inner_field) => {
47                match inner_field.data_type() {
48                    dt if dt.is_primitive() => true,
49                    DataType::Utf8
50                    | DataType::Binary
51                    | DataType::Utf8View
52                    | DataType::BinaryView
53                    | DataType::Boolean => true,
54                    _ => false, // nested lists don't support anything else yet
55                }
56            }
57            DataType::Struct(inner_fields) => inner_fields
58                .iter()
59                .all(|field| Self::is_data_type_supported(field.data_type())),
60            _ => true,
61        }
62    }
63
64    fn unsupported_type_error(data_type: &DataType, field_name: &str) -> Self::Error {
65        super::Error::UnsupportedDataType {
66            data_type: data_type.to_string(),
67            field_name: field_name.to_string(),
68        }
69    }
70}
71
72impl DbConnection<Connection, &'static (dyn ToSql + Sync)> for SqliteConnection {
73    fn as_any(&self) -> &dyn Any {
74        self
75    }
76
77    fn as_any_mut(&mut self) -> &mut dyn Any {
78        self
79    }
80
81    fn as_async(&self) -> Option<&dyn AsyncDbConnection<Connection, &'static (dyn ToSql + Sync)>> {
82        Some(self)
83    }
84}
85
86#[async_trait]
87impl AsyncDbConnection<Connection, &'static (dyn ToSql + Sync)> for SqliteConnection {
88    fn new(conn: Connection) -> Self {
89        SqliteConnection { conn }
90    }
91
92    async fn tables(&self, _schema: &str) -> Result<Vec<String>, super::Error> {
93        let tables = self
94            .conn
95            .call(move |conn| {
96                let mut stmt = conn.prepare(
97                    "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'",
98                )?;
99                let rows = stmt.query_map([], |row| row.get::<_, String>(0))?;
100                let tables: Result<Vec<_>, rusqlite::Error> = rows.collect();
101                Ok(tables?)
102            })
103            .await
104            .boxed()
105            .context(super::UnableToGetTablesSnafu)?;
106
107        Ok(tables)
108    }
109
110    async fn schemas(&self) -> Result<Vec<String>, super::Error> {
111        Ok(vec!["main".to_string()])
112    }
113
114    async fn get_schema(
115        &self,
116        table_reference: &TableReference,
117    ) -> Result<SchemaRef, super::Error> {
118        let table_reference = table_reference.to_quoted_string();
119        let schema: SchemaRef = self
120            .conn
121            .call(move |conn| {
122                let mut stmt = conn.prepare(&format!("SELECT * FROM {table_reference} LIMIT 1"))?;
123                let column_count = stmt.column_count();
124                let rows = stmt.query([])?;
125                let rec = rows_to_arrow(rows, column_count, None)
126                    .context(ConversionSnafu)
127                    .map_err(to_tokio_rusqlite_error)?;
128                let schema = rec.schema();
129                Ok(schema)
130            })
131            .await
132            .boxed()
133            .context(super::UnableToGetSchemaSnafu)?;
134
135        Self::handle_unsupported_schema(&schema, UnsupportedTypeAction::Error)
136    }
137
138    async fn query_arrow(
139        &self,
140        sql: &str,
141        params: &[&'static (dyn ToSql + Sync)],
142        projected_schema: Option<SchemaRef>,
143    ) -> Result<SendableRecordBatchStream> {
144        let sql = sql.to_string();
145        let params = params.to_vec();
146
147        let rec = self
148            .conn
149            .call(move |conn| {
150                let mut stmt = conn.prepare(sql.as_str())?;
151                for (i, param) in params.iter().enumerate() {
152                    stmt.raw_bind_parameter(i + 1, param)?;
153                }
154                let column_count = stmt.column_count();
155                let rows = stmt.raw_query();
156
157                let rec = rows_to_arrow(rows, column_count, projected_schema)
158                    .context(ConversionSnafu)
159                    .map_err(to_tokio_rusqlite_error)?;
160                Ok(rec)
161            })
162            .await
163            .context(ConnectionSnafu)?;
164
165        let schema = rec.schema();
166        let recs = if rec.num_rows() > 0 {
167            vec![rec]
168        } else {
169            vec![]
170        };
171        Ok(Box::pin(MemoryStream::try_new(recs, schema, None)?))
172    }
173
174    async fn execute(&self, sql: &str, params: &[&'static (dyn ToSql + Sync)]) -> Result<u64> {
175        let sql = sql.to_string();
176        let params = params.to_vec();
177
178        let rows_modified = self
179            .conn
180            .call(move |conn| {
181                let mut stmt = conn.prepare(sql.as_str())?;
182                for (i, param) in params.iter().enumerate() {
183                    stmt.raw_bind_parameter(i + 1, param)?;
184                }
185                let rows_modified = stmt.raw_execute()?;
186                Ok(rows_modified)
187            })
188            .await
189            .context(ConnectionSnafu)?;
190        Ok(rows_modified as u64)
191    }
192}
193
194fn to_tokio_rusqlite_error(e: impl Into<Error>) -> tokio_rusqlite::Error {
195    tokio_rusqlite::Error::Other(Box::new(e.into()))
196}