use crate::{CypherResult, Error, Result};
use crate::query_builder::CypherQuery;
#[cfg(not(feature = "bundled-extension"))]
use std::path::PathBuf;
use std::path::Path;
pub struct Connection {
conn: rusqlite::Connection,
}
impl Connection {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let conn = rusqlite::Connection::open(path)?;
Self::from_rusqlite(conn)
}
pub fn open_in_memory() -> Result<Self> {
let conn = rusqlite::Connection::open_in_memory()?;
Self::from_rusqlite(conn)
}
#[cfg(feature = "bundled-extension")]
pub fn from_rusqlite(conn: rusqlite::Connection) -> Result<Self> {
crate::platform::load_bundled_extension(&conn)?;
Ok(Connection { conn })
}
#[cfg(not(feature = "bundled-extension"))]
pub fn from_rusqlite(conn: rusqlite::Connection) -> Result<Self> {
let extension_path = find_extension()?;
load_extension(&conn, &extension_path)?;
Ok(Connection { conn })
}
#[cfg(not(feature = "bundled-extension"))]
pub fn open_with_extension<P: AsRef<Path>, E: AsRef<std::path::Path>>(
path: P,
extension_path: E,
) -> Result<Self> {
let conn = rusqlite::Connection::open(path)?;
load_extension(&conn, extension_path.as_ref())?;
Ok(Connection { conn })
}
pub fn cypher(&self, query: &str) -> Result<CypherResult> {
let result: Option<String> = self
.conn
.query_row("SELECT cypher(?1)", [query], |row| row.get(0))?;
match result {
Some(json_str) => {
if json_str.starts_with("Error") || json_str.starts_with("{\"error\"") {
return Err(parse_structured_error(&json_str));
}
CypherResult::from_json(&json_str)
}
None => Ok(CypherResult::empty()),
}
}
#[deprecated(since = "0.4.0", note = "Use cypher_builder() instead")]
pub fn cypher_with_params(&self, query: &str, params: &serde_json::Value) -> Result<CypherResult> {
self.execute_cypher_with_params(query, params)
}
pub(crate) fn execute_cypher_with_params(&self, query: &str, params: &serde_json::Value) -> Result<CypherResult> {
let params_json = serde_json::to_string(params)
.map_err(|e| Error::Cypher(format!("Failed to serialize params: {}", e)))?;
let result: Option<String> = self
.conn
.query_row("SELECT cypher(?1, ?2)", rusqlite::params![query, params_json], |row| row.get(0))?;
match result {
Some(json_str) => {
if json_str.starts_with("Error") || json_str.starts_with("{\"error\"") {
return Err(parse_structured_error(&json_str));
}
CypherResult::from_json(&json_str)
}
None => Ok(CypherResult::empty()),
}
}
pub fn cypher_builder<'a>(&'a self, query: &'a str) -> CypherQuery<'a> {
CypherQuery::new(self, query)
}
pub fn execute(&self, sql: &str) -> Result<usize> {
Ok(self.conn.execute(sql, [])?)
}
pub fn sqlite_connection(&self) -> &rusqlite::Connection {
&self.conn
}
}
#[cfg(not(feature = "bundled-extension"))]
fn find_extension() -> Result<PathBuf> {
let ext_name = if cfg!(target_os = "macos") {
"graphqlite.dylib"
} else if cfg!(target_os = "windows") {
"graphqlite.dll"
} else {
"graphqlite.so"
};
let search_paths: Vec<PathBuf> = vec![
std::env::var("GRAPHQLITE_EXTENSION_PATH")
.ok()
.map(PathBuf::from)
.unwrap_or_default(),
PathBuf::from("build").join(ext_name),
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.parent()
.unwrap_or(Path::new("."))
.parent()
.unwrap_or(Path::new("."))
.join("build")
.join(ext_name),
PathBuf::from("/usr/local/lib").join(ext_name),
PathBuf::from("/usr/lib").join(ext_name),
];
for path in search_paths {
if path.exists() {
return Ok(path);
}
}
Err(Error::ExtensionNotFound(format!(
"Could not find {}. Build with 'make extension' or set GRAPHQLITE_EXTENSION_PATH",
ext_name
)))
}
#[cfg(not(feature = "bundled-extension"))]
fn load_extension(conn: &rusqlite::Connection, path: &std::path::Path) -> Result<()> {
let load_path = path.with_extension("");
unsafe {
conn.load_extension_enable()?;
conn.load_extension(&load_path, None)?;
conn.load_extension_disable()?;
}
let test: String = conn.query_row("SELECT graphqlite_test()", [], |row| row.get(0))?;
if !test.to_lowercase().contains("successfully") {
return Err(Error::ExtensionNotFound(
"Extension loaded but verification failed".to_string(),
));
}
Ok(())
}
fn parse_structured_error(s: &str) -> Error {
if let Ok(v) = serde_json::from_str::<serde_json::Value>(s) {
if let Some(msg) = v.get("error").and_then(|e| e.as_str()) {
return Error::Cypher(msg.to_string());
}
}
Error::Cypher(s.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(not(feature = "bundled-extension"))]
fn get_test_extension_path() -> Option<std::path::PathBuf> {
let paths = [
std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.parent()
.unwrap()
.parent()
.unwrap()
.join("build/graphqlite.dylib"),
std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.parent()
.unwrap()
.parent()
.unwrap()
.join("build/graphqlite.so"),
];
paths.into_iter().find(|p| p.exists())
}
#[test]
#[cfg(not(feature = "bundled-extension"))]
fn test_find_extension() {
if get_test_extension_path().is_none() {
return;
}
assert!(find_extension().is_ok());
}
#[test]
#[cfg(feature = "bundled-extension")]
fn test_bundled_connection() {
let conn = Connection::open_in_memory();
assert!(conn.is_ok(), "Failed to open connection: {:?}", conn.err());
}
}