use lazy_static::lazy_static;
use std::borrow::BorrowMut;
use std::collections::{HashMap, HashSet};
use std::ffi::{c_char, c_void, CStr, CString};
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use arrow::{
datatypes::SchemaRef as ArrowSchemaRef, ffi::FFI_ArrowSchema, ffi_stream::FFI_ArrowArrayStream,
record_batch::RecordBatch, record_batch::RecordBatchReader,
};
use cxx::{CxxString, CxxVector};
use duckdb::{ffi as cffi, Connection};
use sqlparser::ast as sqlast;
use crate::ast::Ident;
use crate::compile::sql::{create_table_as, select_star_from};
use crate::compile::ConnectionString;
use crate::runtime::functions::{Format, LazyFileRelation};
use crate::runtime::{
self,
error::{rt_unimplemented, Result},
normalize::Normalizer,
sql::{LazySQLParam, SQLEngine, SQLEnginePool},
};
use crate::runtime::{SQLEmbedded, SQLEngineType, SQLParam};
use crate::types::Type;
use crate::types::{arrow::ArrowRecordBatchRelation, Relation, Value};
#[cxx::bridge]
pub mod cppffi {
extern "Rust" {
unsafe fn rust_build_array_stream(
data: *mut u32,
fields: &CxxVector<CxxString>,
dest: *mut u32,
);
}
unsafe extern "C++" {
include!("queryscript/include/duckdb-extra.hpp");
type ArrowArrayStreamWrapper;
type Value;
unsafe fn get_create_stream_fn() -> *mut u32;
unsafe fn duckdb_create_pointer(value: *mut u32) -> *mut Value;
unsafe fn init_arrow_scan(connection_ptr: *mut u32);
}
}
pub struct DuckDBNormalizer {
params: HashMap<String, String>,
}
static mut NEXT_DUCKDB_PLACEHOLDER: AtomicUsize = AtomicUsize::new(0);
impl DuckDBNormalizer {
pub fn new(scalar_params: &[Ident], relations: &HashSet<String>) -> DuckDBNormalizer {
let mut params: HashMap<String, String> = scalar_params
.iter()
.enumerate()
.map(|(i, s)| (s.to_string(), format!("${}", i + 1)))
.collect();
for relation in relations {
params.insert(
relation.to_string(),
format!("__qs_duck_{}", unsafe {
NEXT_DUCKDB_PLACEHOLDER.fetch_add(1, Ordering::SeqCst)
}),
);
}
DuckDBNormalizer { params }
}
}
impl Normalizer for DuckDBNormalizer {
fn quote_style(&self) -> Option<char> {
Some('"')
}
fn param(&self, key: &str) -> Option<&str> {
self.params.get(key).map(|s| s.as_str())
}
}
#[derive(Debug, Clone)]
pub enum ReplacementRelation {
Arrow {
data: Arc<dyn Relation>,
schema: ArrowSchemaRef,
},
File {
file_path: String,
format: Format,
},
}
type RelationMap = HashMap<String, ReplacementRelation>;
struct LocalRelations(HashSet<String>, Arc<Mutex<RelationMap>>);
impl LocalRelations {
pub fn new(relations: Arc<Mutex<RelationMap>>) -> LocalRelations {
LocalRelations(HashSet::new(), relations)
}
}
impl std::ops::Deref for LocalRelations {
type Target = HashSet<String>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::ops::DerefMut for LocalRelations {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl std::ops::Drop for LocalRelations {
fn drop(&mut self) {
let mut relations = self.1.lock().unwrap();
for relation in self.0.iter() {
relations.remove(relation);
}
}
}
#[derive(Debug)]
pub struct DuckDBEngine {
conn: ExclusiveConnection,
}
impl DuckDBEngine {
fn new_conn(url: Option<Arc<crate::compile::ConnectionString>>) -> Result<Box<dyn SQLEngine>> {
use std::collections::hash_map::Entry;
let mut conns = DUCKDB_CONNS.lock().unwrap();
let wrapper = match conns.entry(url) {
Entry::Occupied(e) => e.into_mut(),
Entry::Vacant(e) => {
let conn_wrapper = ConnectionSingleton::new(e.key().clone())?;
e.insert(conn_wrapper)
}
};
let base_conn = wrapper.try_clone()?;
Ok(Box::new(DuckDBEngine::build(base_conn)))
}
fn build(conn: ExclusiveConnection) -> DuckDBEngine {
DuckDBEngine { conn }
}
fn eval_in_place(
&mut self,
query: &sqlast::Statement,
file_loads: HashMap<Ident, ReplacementRelation>,
params: HashMap<Ident, SQLParam>,
) -> Result<Arc<dyn Relation>> {
let conn_state = self.conn.get_state();
let mut scalar_params = Vec::new();
let mut relation_params = LocalRelations::new(conn_state.relations.clone());
for key in file_loads.keys() {
relation_params.insert(key.to_string());
}
for (key, param) in params.iter() {
match ¶m.value {
Value::Relation(_) => {
relation_params.insert(key.to_string());
}
Value::Fn(_) => {
return rt_unimplemented!("Function parameters");
}
_ => {
scalar_params.push(key.clone());
}
}
}
scalar_params.sort();
let normalizer = DuckDBNormalizer::new(&scalar_params, &relation_params);
let query = normalizer.normalize(&query).as_result()?;
{
let relations = &mut conn_state.relations.lock()?;
for (key, param) in file_loads.into_iter() {
relations.insert(
normalizer.params.get(&key.to_string()).unwrap().clone(),
param,
);
}
for (key, param) in params.iter() {
match ¶m.value {
Value::Relation(r) => {
relations.insert(
normalizer.params.get(&key.to_string()).unwrap().clone(),
ReplacementRelation::Arrow {
data: r.clone(),
schema: Arc::new((¶m.type_).try_into()?),
},
);
}
_ => {}
}
}
}
let query_string = format!("{}", query);
let duckdb_params: Vec<&dyn duckdb::ToSql> = scalar_params
.iter()
.map(|k| ¶ms.get(k).unwrap().value as &dyn duckdb::ToSql)
.collect();
let mut stmt = conn_state.conn.prepare(&query_string)?;
let query_result = stmt.query_arrow(duckdb_params.as_slice())?;
Ok(ArrowRecordBatchRelation::from_duckdb(query_result))
}
}
impl ArrowRecordBatchRelation {
pub fn from_duckdb(query_result: duckdb::Arrow) -> Arc<dyn Relation> {
ArrowRecordBatchRelation::new(
query_result.get_schema(),
Arc::new(query_result.collect::<Vec<RecordBatch>>()),
)
}
}
#[async_trait::async_trait]
impl SQLEngine for DuckDBEngine {
async fn query(
&mut self,
query: &sqlast::Statement,
params: HashMap<Ident, LazySQLParam>,
) -> Result<Arc<dyn Relation>> {
let mut unresolved_params = HashMap::new();
let mut file_loads = HashMap::new();
for (name, param) in params.into_iter() {
if let Some(file_load) = param.value.as_any().downcast_ref::<LazyFileRelation>() {
file_loads.insert(
name,
ReplacementRelation::File {
file_path: file_load
.file_path
.clone()
.into_os_string()
.into_string()
.unwrap(),
format: file_load.format.clone(),
},
);
} else {
unresolved_params.insert(name, param);
}
}
let resolved_params = runtime::resolve_params(unresolved_params).await?;
runtime::expensive(|| self.eval_in_place(query, file_loads, resolved_params))
}
async fn exec(
&mut self,
stmt: &sqlast::Statement,
params: HashMap<Ident, LazySQLParam>,
) -> Result<()> {
self.query(stmt, params).await?;
Ok(())
}
async fn load(
&mut self,
table: &sqlast::ObjectName,
value: Value,
type_: Type,
temporary: bool,
) -> Result<()> {
let param_name = "__qs_load";
let query = create_table_as(
table.clone(),
select_star_from(&Ident::from(param_name)),
temporary,
);
let params = vec![(
param_name.into(),
LazySQLParam {
name: param_name.into(),
value: value.into(),
type_,
},
)]
.into_iter()
.collect();
self.query(&query, params).await?;
Ok(())
}
async fn table_exists(&mut self, table: &sqlast::ObjectName) -> Result<bool> {
let conn_state = self.conn.get_state();
let mut stmt = conn_state
.conn
.prepare("SELECT 1 FROM information_schema.tables WHERE table_name = ?")?;
let result = stmt
.query_map(&[&table.to_string()], |row| row.get::<usize, u32>(0))
.unwrap();
Ok(result.count() > 0)
}
fn engine_type(&self) -> SQLEngineType {
SQLEngineType::DuckDB
}
}
fn initialize_duckdb_connection(conn_state: &mut ConnectionState) {
unsafe {
let conn: &duckdb_repr::Connection = std::mem::transmute(&mut conn_state.conn);
let db_wrapper = conn.db.borrow();
cppffi::init_arrow_scan(db_wrapper.con as *mut u32);
cffi::duckdb_add_replacement_scan(
db_wrapper.db,
Some(replacement_scan_callback),
&mut conn_state.relations as *mut _ as *mut c_void,
None,
);
}
}
#[derive(Debug)]
struct ConnectionState {
conn: Connection,
relations: Arc<Mutex<RelationMap>>,
}
#[derive(Debug)]
struct ExclusiveConnection(Pin<Box<ConnectionState>>);
impl ExclusiveConnection {
fn new(conn: Connection) -> ExclusiveConnection {
let mut conn = Box::pin(ConnectionState {
conn,
relations: Arc::new(Mutex::new(HashMap::new())),
});
initialize_duckdb_connection(&mut conn);
ExclusiveConnection(conn)
}
fn try_clone(&mut self) -> Result<ExclusiveConnection> {
let state = self.0.borrow_mut();
Ok(ExclusiveConnection(Box::pin(ConnectionState {
conn: state.conn.try_clone()?,
relations: state.relations.clone(),
})))
}
fn get_state(&mut self) -> &mut ConnectionState {
self.0.borrow_mut()
}
}
unsafe impl Send for ExclusiveConnection {}
unsafe impl Sync for ExclusiveConnection {}
#[derive(Debug)]
struct ConnectionSingleton(ExclusiveConnection);
impl ConnectionSingleton {
fn new(url: Option<Arc<crate::compile::ConnectionString>>) -> Result<ConnectionSingleton> {
let conn = match url {
Some(url) => Connection::open(url.get_url().path()),
None => Connection::open_in_memory(),
}?;
Ok(Self(ExclusiveConnection::new(conn)))
}
fn try_clone(&mut self) -> Result<ExclusiveConnection> {
self.0.try_clone()
}
}
lazy_static! {
static ref DUCKDB_CONNS: Mutex<HashMap<Option<Arc<crate::compile::ConnectionString>>, ConnectionSingleton>> =
Mutex::new(HashMap::new());
}
#[async_trait::async_trait]
impl SQLEnginePool for DuckDBEngine {
async fn new(url: Arc<ConnectionString>) -> Result<Box<dyn SQLEngine>> {
DuckDBEngine::new_conn(Some(url))
}
async fn create(url: Arc<ConnectionString>) -> Result<()> {
let _ = Self::new(url).await?;
Ok(())
}
}
impl SQLEmbedded for DuckDBEngine {
fn new_embedded() -> Result<Box<dyn SQLEngine>> {
DuckDBEngine::new_conn(None)
}
}
unsafe fn cast_relation_data(data: *mut u32) -> &'static ReplacementRelation {
&*(data as *const ReplacementRelation)
}
#[no_mangle]
pub unsafe extern "C" fn replacement_scan_callback(
info: cffi::duckdb_replacement_scan_info,
table_name: *const c_char,
relation_map: *mut c_void,
) {
let c_str: &CStr = unsafe { CStr::from_ptr(table_name) };
let table_str: &str = match c_str.to_str() {
Ok(s) => s,
Err(_e) => return,
};
let relations = unsafe { &mut *(relation_map as *mut Arc<Mutex<RelationMap>>) };
let relations = relations.lock().unwrap();
let relation = match relations.get(table_str) {
Some(relation) => relation,
None => return,
};
match &relation {
ReplacementRelation::Arrow { .. } => {
let fn_name = CString::new("arrow_scan_qs").unwrap();
cffi::duckdb_replacement_scan_set_function_name(info, fn_name.as_ptr());
unsafe {
let get_data_fn = cppffi::get_create_stream_fn();
let mut data_ptr = cppffi::duckdb_create_pointer(relation as *const _ as *mut u32)
as cffi::duckdb_value;
let mut get_data_ptr =
cppffi::duckdb_create_pointer(get_data_fn) as cffi::duckdb_value;
let mut get_schema_ptr =
cppffi::duckdb_create_pointer(get_schema as *mut u32) as cffi::duckdb_value;
cffi::duckdb_replacement_scan_add_parameter(info, data_ptr);
cffi::duckdb_replacement_scan_add_parameter(info, get_data_ptr);
cffi::duckdb_replacement_scan_add_parameter(info, get_schema_ptr);
cffi::duckdb_destroy_value(&mut data_ptr);
cffi::duckdb_destroy_value(&mut get_data_ptr);
cffi::duckdb_destroy_value(&mut get_schema_ptr);
}
}
ReplacementRelation::File { file_path, format } => {
let file_name = CString::new(file_path.as_str()).unwrap();
let fn_name = CString::new(match format {
Format::Csv => "read_csv_auto",
Format::Json => "read_json_auto",
Format::Parquet => "read_parquet",
})
.unwrap();
cffi::duckdb_replacement_scan_set_function_name(info, fn_name.as_ptr());
let mut fname_ptr = cffi::duckdb_create_varchar(file_name.as_ptr());
cffi::duckdb_replacement_scan_add_parameter(info, fname_ptr);
cffi::duckdb_destroy_value(&mut fname_ptr);
}
};
}
#[no_mangle]
pub extern "C" fn get_schema(data: *mut u32, schema_ptr: *mut u32) {
let relation = unsafe { cast_relation_data(data) };
let schema = match &relation {
ReplacementRelation::Arrow { schema, .. } => schema,
_ => unreachable!(
"should only be in get_schema() if we have data for a replacement relation"
),
};
let schema_c = FFI_ArrowSchema::try_from(schema.as_ref());
let schema_c = match schema_c {
Ok(s) => s,
Err(e) => {
panic!("Failed to convert to arrow FFI schema: {:?}", e);
}
};
let dest_schema = unsafe { &mut *(schema_ptr as *mut FFI_ArrowSchema) };
let _old = std::mem::replace(dest_schema, schema_c);
}
fn rust_build_array_stream(data: *mut u32, fields: &CxxVector<CxxString>, dest: *mut u32) {
let relation = unsafe { cast_relation_data(data) };
let (schema, data )= match &relation{
ReplacementRelation::Arrow { schema, data } => (schema.clone(), data.clone()),
_ => unreachable!("should only be in rust_build_array_stream() if we have data for a replacement relation"),
};
let mut batch_reader = VecRecordBatchReader::new(schema, data);
let schema = batch_reader.schema();
let field_map = schema
.all_fields()
.iter()
.enumerate()
.map(|(i, f)| (f.name().clone(), i))
.collect::<std::collections::BTreeMap<String, usize>>();
let indices: Vec<usize> = fields
.iter()
.map(|f| *field_map.get(f.to_str().unwrap()).unwrap())
.collect();
batch_reader.set_projection(indices);
let record_batch = Box::new(batch_reader) as Box<dyn arrow::record_batch::RecordBatchReader>;
let record_batch_c = *Box::new(FFI_ArrowArrayStream::new(record_batch));
let dest_record_batch = unsafe { &mut *(dest as *mut FFI_ArrowArrayStream) };
let _old = std::mem::replace(dest_record_batch, record_batch_c);
}
struct VecRecordBatchReader {
schema: ArrowSchemaRef,
data: Arc<dyn Relation>,
idx: usize,
projection: Option<Vec<usize>>,
}
impl VecRecordBatchReader {
pub fn new(schema: ArrowSchemaRef, data: Arc<dyn Relation>) -> VecRecordBatchReader {
VecRecordBatchReader {
schema,
data,
idx: 0,
projection: None,
}
}
pub fn set_projection(&mut self, indices: Vec<usize>) {
self.projection = if indices.len() > 0 {
Some(indices)
} else {
None
};
}
}
impl Iterator for VecRecordBatchReader {
type Item = Result<arrow::record_batch::RecordBatch, arrow::error::ArrowError>;
fn next(
&mut self,
) -> Option<Result<arrow::record_batch::RecordBatch, arrow::error::ArrowError>> {
let rbs = &self.data;
if self.idx >= rbs.num_batches() {
None
} else {
let batch = rbs.batch(self.idx).as_arrow_recordbatch();
self.idx += 1;
Some(match &self.projection {
Some(fields) => batch.project(&fields),
None => Ok(batch.clone()),
})
}
}
}
impl arrow::record_batch::RecordBatchReader for VecRecordBatchReader {
fn schema(&self) -> arrow::datatypes::SchemaRef {
self.schema.clone()
}
}
#[allow(unused)]
mod duckdb_repr {
use arrow::datatypes::SchemaRef;
use duckdb::ffi;
use hashlink::LruCache;
use std::cell::RefCell;
use std::sync::Arc;
pub struct RawStatement {
ptr: ffi::duckdb_prepared_statement,
result: Option<ffi::duckdb_arrow>,
schema: Option<SchemaRef>,
statement_cache_key: Option<Arc<str>>,
}
pub struct StatementCache(RefCell<LruCache<Arc<str>, RawStatement>>);
pub struct InnerConnection {
pub db: ffi::duckdb_database,
pub con: ffi::duckdb_connection,
owned: bool,
}
pub struct Connection {
pub db: RefCell<InnerConnection>,
cache: StatementCache,
path: Option<std::path::PathBuf>,
}
}
#[test]
fn test_duckdb_init() {
ExclusiveConnection::new(Connection::open_in_memory().unwrap());
}
#[test]
fn test_duckdb_concurrency() {
fn run_query(conn: &mut Connection, query: &str) -> Result<Arc<dyn Relation>> {
let mut stmt = conn.prepare(query)?;
let query_result = stmt.query_arrow([])?;
let ret = ArrowRecordBatchRelation::from_duckdb(query_result);
Ok(ret)
}
let _ = std::fs::remove_file("/tmp/test_duckdb_concurrency.duckdb");
let url = crate::compile::ConnectionString::maybe_parse(
None,
"duckdb:///tmp/test_duckdb_concurrency.duckdb",
&crate::ast::SourceLocation::Unknown,
)
.unwrap()
.unwrap();
let mut conn = Connection::open(url.get_url().path()).unwrap();
run_query(&mut conn, "DROP TABLE IF EXISTS t").unwrap();
run_query(&mut conn, "CREATE TABLE t AS SELECT 1 AS a").unwrap();
let mut conn1 = Connection::open(url.get_url().path()).unwrap();
let mut conn2 = conn1.try_clone().unwrap();
run_query(&mut conn2, "CREATE OR REPLACE VIEW x AS SELECT * FROM t").unwrap();
run_query(&mut conn2, "SELECT * FROM x").unwrap();
run_query(&mut conn1, "SELECT * FROM t").unwrap();
run_query(&mut conn1, "SELECT * FROM x").unwrap();
}
#[test]
fn test_replacemnt_scan() {
fn run_query(conn: &mut Connection, query: &str) -> Result<Arc<dyn Relation>> {
let mut stmt = conn.prepare(query)?;
let query_result = stmt.query_arrow([])?;
let ret = ArrowRecordBatchRelation::from_duckdb(query_result);
Ok(ret)
}
let _ = std::fs::remove_file("/tmp/test_duckdb_replacement.duckdb");
let url = crate::compile::ConnectionString::maybe_parse(
None,
"duckdb:///tmp/test_duckdb_replacement.duckdb",
&crate::ast::SourceLocation::Unknown,
)
.unwrap()
.unwrap();
let mut conn1 = ExclusiveConnection::new(Connection::open(url.get_url().path()).unwrap());
let _ = run_query(&mut conn1.get_state().conn, "SELECT * FROM dne_1");
let mut conn2 = conn1.try_clone().unwrap();
conn1.try_clone().unwrap();
conn1.try_clone().unwrap();
let conn1 = conn1.get_state();
let conn2 = conn2.get_state();
let _ = run_query(&mut conn1.conn, "SELECT * FROM dne_1");
let _ = run_query(&mut conn2.conn, "SELECT * FROM dne_2");
let _ = run_query(&mut conn1.conn, "SELECT * FROM dne_1");
}
#[test]
fn test_table_not_exists() {
let mut conn1 = ExclusiveConnection::new(Connection::open_in_memory().unwrap());
let conn = &mut conn1.get_state().conn;
let foo = Into::<sqlast::ObjectName>::into(&Into::<Ident>::into("foo"));
let mut stmt = conn
.prepare("SELECT 1 FROM information_schema.tables WHERE table_name = ?")
.unwrap();
let result = stmt
.query_map(&[&foo.to_string()], |row| row.get::<usize, u32>(0))
.unwrap();
assert!(result.count() == 0)
}