use crate::database::Database;
use crate::error::Error;
use crate::ffi::ffi;
use crate::query_result::QueryResult;
use crate::value::Value;
use cxx::UniquePtr;
use std::cell::UnsafeCell;
use std::convert::TryInto;
pub struct PreparedStatement {
pub(crate) statement: UniquePtr<ffi::PreparedStatement>,
}
impl PreparedStatement {
pub fn get_statement_type(&self) -> ffi::StatementType {
ffi::prepared_statement_get_statement_type(&self.statement)
}
}
pub struct Connection<'a> {
conn: UnsafeCell<UniquePtr<ffi::Connection<'a>>>,
}
unsafe impl Send for Connection<'_> {}
unsafe impl Sync for Connection<'_> {}
impl<'a> Connection<'a> {
pub fn new(database: &'a Database) -> Result<Self, Error> {
let db = unsafe { (*database.db.get()).pin_mut() };
Ok(Connection {
conn: UnsafeCell::new(ffi::database_connect(db)?),
})
}
pub fn set_max_num_threads_for_exec(&mut self, num_threads: u64) {
ffi::connection_set_max_num_thread_for_exec(self.conn.get_mut().pin_mut(), num_threads);
}
pub fn get_max_num_threads_for_exec(&self) -> u64 {
ffi::connection_get_max_num_thread_for_exec(unsafe { (*self.conn.get()).pin_mut() })
}
pub fn prepare(&self, query: &str) -> Result<PreparedStatement, Error> {
let statement = ffi::connection_prepare(
unsafe { (*self.conn.get()).pin_mut() },
ffi::StringView::new(query),
)?;
if ffi::prepared_statement_is_success(&statement) {
Ok(PreparedStatement { statement })
} else {
Err(Error::FailedPreparedStatement(
ffi::prepared_statement_error_message(&statement),
))
}
}
pub fn query(&self, query: &str) -> Result<QueryResult<'a>, Error> {
let conn = unsafe { (*self.conn.get()).pin_mut() };
let result = ffi::connection_query(conn, ffi::StringView::new(query))?;
if ffi::query_result_is_success(&result) {
Ok(QueryResult { result })
} else {
Err(Error::FailedQuery(ffi::query_result_get_error_message(
&result,
)))
}
}
pub fn execute(
&self,
prepared_statement: &mut PreparedStatement,
params: Vec<(&str, Value)>,
) -> Result<QueryResult<'a>, Error> {
let mut cxx_params = ffi::new_params();
for (key, value) in params {
let ffi_value: cxx::UniquePtr<ffi::Value> = value.try_into()?;
ffi::query_params_insert(cxx_params.pin_mut(), key, ffi_value);
}
let conn = unsafe { (*self.conn.get()).pin_mut() };
let result =
ffi::connection_execute(conn, prepared_statement.statement.pin_mut(), cxx_params)?;
if ffi::query_result_is_success(&result) {
Ok(QueryResult { result })
} else {
Err(Error::FailedQuery(ffi::query_result_get_error_message(
&result,
)))
}
}
pub fn interrupt(&self) -> Result<(), Error> {
let conn = unsafe { (*self.conn.get()).pin_mut() };
Ok(ffi::connection_interrupt(conn)?)
}
pub fn set_query_timeout(&self, timeout_ms: u64) {
let conn = unsafe { (*self.conn.get()).pin_mut() };
ffi::connection_set_query_timeout(conn, timeout_ms);
}
}
#[cfg(test)]
mod tests {
use crate::database::SYSTEM_CONFIG_FOR_TESTS;
use crate::{Connection, Database, Value};
use anyhow::{Error, Result};
#[test]
fn test_connection_threads() -> Result<()> {
let temp_dir = tempfile::tempdir()?;
let db = Database::new(temp_dir.path().join("test"), SYSTEM_CONFIG_FOR_TESTS)?;
let mut conn = Connection::new(&db)?;
conn.set_max_num_threads_for_exec(5);
assert_eq!(conn.get_max_num_threads_for_exec(), 5);
temp_dir.close()?;
Ok(())
}
#[test]
fn test_invalid_query() -> Result<()> {
let temp_dir = tempfile::tempdir()?;
let db = Database::new(temp_dir.path().join("test"), SYSTEM_CONFIG_FOR_TESTS)?;
let conn = Connection::new(&db)?;
conn.query("CREATE NODE TABLE Person(name STRING, age INT64, PRIMARY KEY(name));")?;
conn.query("CREATE (:Person {name: 'Alice', age: 25});")?;
conn.query("CREATE (:Person {name: 'Bob', age: 30});")?;
let result: Error = conn
.query("MATCH (a:Person RETURN a.name AS NAME, a.age AS AGE;")
.expect_err("Invalid syntax in query should produce an error")
.into();
assert_eq!(
result.to_string(),
"Query execution failed: Parser exception: \
Invalid input <MATCH (a:Person RETURN>: expected rule oC_SingleQuery (line: 1, offset: 16)
\"MATCH (a:Person RETURN a.name AS NAME, a.age AS AGE;\"
^^^^^^"
);
Ok(())
}
#[test]
fn test_multiple_statement_query() -> Result<()> {
let temp_dir = tempfile::tempdir()?;
let db = Database::new(temp_dir.path().join("test"), SYSTEM_CONFIG_FOR_TESTS)?;
let conn = Connection::new(&db)?;
conn.query("CREATE NODE TABLE Person(name STRING, age INT64, PRIMARY KEY(name));")?;
conn.query(
"CREATE (:Person {name: 'Alice', age: 25});
CREATE (:Person {name: 'Bob', age: 30});",
)?;
Ok(())
}
#[test]
fn test_query_result() -> Result<()> {
let temp_dir = tempfile::tempdir()?;
let db = Database::new(temp_dir.path().join("test"), SYSTEM_CONFIG_FOR_TESTS)?;
let conn = Connection::new(&db)?;
conn.query("CREATE NODE TABLE Person(name STRING, age INT16, PRIMARY KEY(name));")?;
conn.query("CREATE (:Person {name: 'Alice', age: 25});")?;
for result in conn.query("MATCH (a:Person) RETURN a.name AS NAME, a.age AS AGE;")? {
assert_eq!(result.len(), 2);
assert_eq!(result[0], Value::String("Alice".to_string()));
assert_eq!(result[1], Value::Int16(25));
}
temp_dir.close()?;
Ok(())
}
#[test]
fn test_params() -> Result<()> {
let temp_dir = tempfile::tempdir()?;
let db = Database::new(temp_dir.path().join("test"), SYSTEM_CONFIG_FOR_TESTS)?;
let conn = Connection::new(&db)?;
conn.query("CREATE NODE TABLE Person(name STRING, age INT16, PRIMARY KEY(name));")?;
conn.query("CREATE (:Person {name: 'Alice', age: 25});")?;
conn.query("CREATE (:Person {name: 'Bob', age: 30});")?;
let mut statement = conn.prepare("MATCH (a:Person) WHERE a.age = $age RETURN a.name;")?;
for result in conn.execute(&mut statement, vec![("age", Value::Int16(25))])? {
assert_eq!(result.len(), 1);
assert_eq!(result[0], Value::String("Alice".to_string()));
}
temp_dir.close()?;
Ok(())
}
#[test]
fn test_multithreaded_single_conn() -> Result<()> {
let temp_dir = tempfile::tempdir()?;
let db = Database::new(temp_dir.path().join("test"), SYSTEM_CONFIG_FOR_TESTS)?;
let conn = Connection::new(&db)?;
conn.query("CREATE NODE TABLE Person(name STRING, age INT32, PRIMARY KEY(name));")?;
conn.query("CREATE (:Person {name: 'Alice', age: 25});")?;
conn.query("CREATE (:Person {name: 'Bob', age: 30});")?;
let (alice, bob) = std::thread::scope(|s| -> Result<(Vec<Value>, Vec<Value>)> {
let alice_thread = s.spawn(|| -> Result<Vec<Value>> {
let mut result = conn.query("MATCH (a:Person) WHERE a.name = \"Alice\" RETURN a.name AS NAME, a.age AS AGE;")?;
Ok(result.next().unwrap())
});
let bob_thread = s.spawn(|| -> Result<Vec<Value>> {
let mut result = conn.query(
"MATCH (a:Person) WHERE a.name = \"Bob\" RETURN a.name AS NAME, a.age AS AGE;",
)?;
Ok(result.next().unwrap())
});
Ok((alice_thread.join().unwrap()?, bob_thread.join().unwrap()?))
})?;
assert_eq!(alice, vec!["Alice".into(), 25.into()]);
assert_eq!(bob, vec!["Bob".into(), 30.into()]);
temp_dir.close()?;
Ok(())
}
#[test]
fn test_multithreaded_multiple_conn() -> Result<()> {
let temp_dir = tempfile::tempdir()?;
let db = Database::new(temp_dir.path().join("test"), SYSTEM_CONFIG_FOR_TESTS)?;
let conn = Connection::new(&db)?;
conn.query("CREATE NODE TABLE Person(name STRING, age INT32, PRIMARY KEY(name));")?;
conn.query("CREATE (:Person {name: 'Alice', age: 25});")?;
conn.query("CREATE (:Person {name: 'Bob', age: 30});")?;
let (alice, bob) = std::thread::scope(|s| -> Result<(Vec<Value>, Vec<Value>)> {
let alice_thread = s.spawn(|| -> Result<Vec<Value>> {
let conn = Connection::new(&db)?;
let mut result = conn.query("MATCH (a:Person) WHERE a.name = \"Alice\" RETURN a.name AS NAME, a.age AS AGE;")?;
Ok(result.next().unwrap())
});
let bob_thread = s.spawn(|| -> Result<Vec<Value>> {
let conn = Connection::new(&db)?;
let mut result = conn.query(
"MATCH (a:Person) WHERE a.name = \"Bob\" RETURN a.name AS NAME, a.age AS AGE;",
)?;
Ok(result.next().unwrap())
});
Ok((alice_thread.join().unwrap()?, bob_thread.join().unwrap()?))
})?;
assert_eq!(alice, vec!["Alice".into(), 25.into()]);
assert_eq!(bob, vec!["Bob".into(), 30.into()]);
temp_dir.close()?;
Ok(())
}
macro_rules! extension_tests {
($($name:ident,)*) => {
$(
#[test]
#[cfg(feature = "extension_tests")]
fn $name() -> Result<()> {
let temp_dir = tempfile::tempdir()?;
let db = Database::new(temp_dir.path().join("testdb"), SYSTEM_CONFIG_FOR_TESTS)?;
let conn = Connection::new(&db)?;
let directory: String = if cfg!(windows) {
std::env::var("LBUG_LOCAL_EXTENSIONS")?.replace("\\", "/")
} else {
std::env::var("LBUG_LOCAL_EXTENSIONS")?
};
let name = stringify!($name);
conn.query(&format!("LOAD EXTENSION '{directory}/{name}/build/lib{name}.lbug_extension'"))?;
Ok(())
}
)*
}
}
extension_tests! {
fts, duckdb, httpfs, postgres, sqlite, json, delta, iceberg, vector,
}
}