use std::sync::{Arc, Weak};
use crate::core::{Error, Result};
use crate::executor::context::ExecutionContext;
use crate::executor::query_cache::CachedPlanRef;
use crate::parser::Parser;
use super::database::{Database, DatabaseInnerHandle, FromValue};
use super::params::Params;
use super::rows::Rows;
#[derive(Clone)]
pub struct Statement {
db_weak: Weak<DatabaseInnerHandle>,
sql: String,
plan: Option<CachedPlanRef>,
}
impl Statement {
pub(crate) fn new(
db_weak: Weak<DatabaseInnerHandle>,
sql: String,
db: &Database,
) -> Result<Self> {
let plan = {
let executor = db
.executor()
.lock()
.map_err(|_| Error::LockAcquisitionFailed("executor".to_string()))?;
if let Some(cached) = executor.query_cache().get(&sql) {
Some(cached)
} else {
let mut parser = Parser::new(&sql);
let mut program = parser
.parse_program()
.map_err(|e| Error::parse(e.to_string()))?;
if program.statements.len() == 1 {
let stmt = program.statements.pop().unwrap();
if matches!(stmt, crate::parser::ast::Statement::Expression(_)) {
return Err(Error::parse(format!(
"invalid SQL: unrecognised statement: {}",
sql
)));
}
let (has_params, param_count) = crate::executor::count_parameters(&stmt);
let stmt_arc = Arc::new(stmt);
let cached =
executor
.query_cache()
.put(&sql, stmt_arc, has_params, param_count);
Some(cached)
} else {
for s in &program.statements {
if matches!(s, crate::parser::ast::Statement::Expression(_)) {
return Err(Error::parse(format!(
"invalid SQL: unrecognised statement: {}",
sql
)));
}
}
None
}
}
};
Ok(Self { db_weak, sql, plan })
}
#[inline]
fn get_db(&self) -> Result<Database> {
self.db_weak
.upgrade()
.map(Database::from_inner)
.ok_or_else(|| Error::internal("Database was dropped"))
}
pub fn execute<P: Params>(&self, params: P) -> Result<i64> {
let db = self.get_db()?;
if let Some(plan) = &self.plan {
let executor = db
.executor()
.lock()
.map_err(|_| Error::LockAcquisitionFailed("executor".to_string()))?;
let ctx = ExecutionContext::with_params(params.into_params());
let result = executor.execute_with_cached_plan(plan, &ctx)?;
Ok(result.rows_affected())
} else {
db.execute(&self.sql, params)
}
}
pub fn query<P: Params>(&self, params: P) -> Result<Rows> {
let db = self.get_db()?;
if let Some(plan) = &self.plan {
let executor = db
.executor()
.lock()
.map_err(|_| Error::LockAcquisitionFailed("executor".to_string()))?;
let ctx = ExecutionContext::with_params(params.into_params());
let result = executor.execute_with_cached_plan(plan, &ctx)?;
Ok(Rows::new(result))
} else {
db.query(&self.sql, params)
}
}
pub fn query_one<T: FromValue, P: Params>(&self, params: P) -> Result<T> {
let row = self.query(params)?.next().ok_or(Error::NoRowsReturned)??;
row.get(0)
}
pub fn query_opt<T: FromValue, P: Params>(&self, params: P) -> Result<Option<T>> {
match self.query(params)?.next() {
None => Ok(None),
Some(Err(e)) => Err(e),
Some(Ok(row)) => Ok(Some(row.get(0)?)),
}
}
pub fn sql(&self) -> &str {
&self.sql
}
#[cfg(feature = "ffi")]
pub(crate) fn ast_statement(&self) -> Option<&Arc<crate::parser::ast::Statement>> {
self.plan.as_ref().map(|p| &p.statement)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prepared_statement_execute() {
let db = Database::open_in_memory().unwrap();
db.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)", ())
.unwrap();
let stmt = db.prepare("INSERT INTO users VALUES ($1, $2)").unwrap();
stmt.execute((1, "Alice")).unwrap();
stmt.execute((2, "Bob")).unwrap();
stmt.execute((3, "Charlie")).unwrap();
let count: i64 = db.query_one("SELECT COUNT(*) FROM users", ()).unwrap();
assert_eq!(count, 3);
}
#[test]
fn test_prepared_statement_query() {
let db = Database::open_in_memory().unwrap();
db.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)", ())
.unwrap();
db.execute(
"INSERT INTO users VALUES ($1, $2), ($3, $4), ($5, $6)",
(1, "Alice", 2, "Bob", 3, "Charlie"),
)
.unwrap();
let stmt = db.prepare("SELECT name FROM users WHERE id = $1").unwrap();
let name: String = stmt.query_one((1,)).unwrap();
assert_eq!(name, "Alice");
let name: String = stmt.query_one((2,)).unwrap();
assert_eq!(name, "Bob");
let name: String = stmt.query_one((3,)).unwrap();
assert_eq!(name, "Charlie");
}
#[test]
fn test_prepared_statement_query_opt() {
let db = Database::open_in_memory().unwrap();
db.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)", ())
.unwrap();
db.execute("INSERT INTO users VALUES ($1, $2)", (1, "Alice"))
.unwrap();
let stmt = db.prepare("SELECT name FROM users WHERE id = $1").unwrap();
let name: Option<String> = stmt.query_opt((1,)).unwrap();
assert_eq!(name, Some("Alice".to_string()));
let name: Option<String> = stmt.query_opt((999,)).unwrap();
assert_eq!(name, None);
}
#[test]
fn test_prepared_statement_sql() {
let db = Database::open_in_memory().unwrap();
let stmt = db.prepare("SELECT 1").unwrap();
assert_eq!(stmt.sql(), "SELECT 1");
}
}