use enum_primitive::FromPrimitive;
use libc::{c_int, c_char};
use std::ffi as std_ffi;
use std::mem;
use std::ptr;
use std::slice;
use std::str;
use std::ffi::CStr;
use std::rc::Rc;
use time::Duration;
use self::SqliteOk::SQLITE_OK;
use self::Step::{SQLITE_ROW, SQLITE_DONE};
pub use super::{
SqliteError,
SqliteErrorCode,
SqliteResult,
};
pub use super::ColumnType;
pub use super::ColumnType::SQLITE_NULL;
use ffi;
enum_from_primitive! {
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
#[allow(non_camel_case_types)]
#[allow(missing_docs)]
pub enum SqliteOk {
SQLITE_OK = 0
}
}
enum_from_primitive! {
#[derive(Debug, PartialEq, Eq)]
#[allow(non_camel_case_types)]
enum SqliteLogLevel {
SQLITE_NOTICE = 27,
SQLITE_WARNING = 28,
}
}
struct Database {
handle: *mut ffi::sqlite3,
}
impl Drop for Database {
fn drop(&mut self) {
let ok = unsafe { ffi::sqlite3_close(self.handle) };
assert_eq!(ok, SQLITE_OK as c_int);
}
}
pub struct DatabaseConnection {
db: Rc<Database>,
detailed: bool
}
pub trait Access {
fn open(self, db: *mut *mut ffi::sqlite3) -> c_int;
}
fn maybe<T>(choice: bool, x: T) -> Option<T> {
if choice { Some(x) } else { None }
}
use std::ffi::NulError;
impl From<NulError> for SqliteError {
fn from(_: NulError) -> SqliteError {
SqliteError{
kind: SqliteErrorCode::SQLITE_MISUSE,
desc: "Sql string contained an internal 0 byte",
detail: None
}
}
}
impl DatabaseConnection {
pub fn new<A: Access>(access: A) -> SqliteResult<DatabaseConnection> {
let mut db = ptr::null_mut();
let result = access.open(&mut db);
match decode_result(result, "sqlite3_open_v2", Some(db)) {
Ok(()) => Ok(DatabaseConnection {
db: Rc::new(Database { handle: db}),
detailed: true,
}),
Err(err) => {
unsafe { ffi::sqlite3_close(db) };
Err(err)
}
}
}
pub fn ignore_detail(&mut self) {
self.detailed = false;
}
pub fn in_memory() -> SqliteResult<DatabaseConnection> {
struct InMemory;
impl Access for InMemory {
fn open(self, db: *mut *mut ffi::sqlite3) -> c_int {
let c_memory = str_charstar(":memory:");
unsafe { ffi::sqlite3_open(c_memory.as_ptr(), db) }
}
}
DatabaseConnection::new(InMemory)
}
pub fn prepare<'db:'st, 'st>(&'db self, sql: &str) -> SqliteResult<PreparedStatement> {
match self.prepare_with_offset(sql) {
Ok((cur, _)) => Ok(cur),
Err(e) => Err(e)
}
}
pub fn prepare_with_offset<'db:'st, 'st>(&'db self, sql: &str)
-> SqliteResult<(PreparedStatement, usize)> {
let mut stmt = ptr::null_mut();
let mut tail = ptr::null();
let z_sql = str_charstar(sql);
let n_byte = sql.len() as c_int;
let r = unsafe { ffi::sqlite3_prepare_v2(self.db.handle, z_sql.as_ptr(), n_byte, &mut stmt, &mut tail) };
match decode_result(r, "sqlite3_prepare_v2", maybe(self.detailed, self.db.handle)) {
Ok(()) => {
let offset = tail as usize - z_sql.as_ptr() as usize;
Ok((PreparedStatement { stmt: stmt , db: self.db.clone(), detailed: self.detailed }, offset))
},
Err(code) => Err(code)
}
}
pub fn errmsg(&mut self) -> String {
DatabaseConnection::_errmsg(self.db.handle)
}
fn _errmsg(db: *mut ffi::sqlite3) -> String {
let errmsg = unsafe { ffi::sqlite3_errmsg(db) };
charstar_str(&(errmsg)).unwrap_or("").to_string()
}
pub fn exec(&mut self, sql: &str) -> SqliteResult<()> {
let c_sql = try!(std_ffi::CString::new(sql.as_bytes()));
let result = unsafe {
ffi::sqlite3_exec(self.db.handle, c_sql.as_ptr(), None,
ptr::null_mut(), ptr::null_mut())
};
decode_result(result, "sqlite3_exec", maybe(self.detailed, self.db.handle))
}
pub fn changes(&self) -> u64 {
let dbh = self.db.handle;
let count = unsafe { ffi::sqlite3_changes(dbh) };
count as u64
}
pub fn busy_timeout(&mut self, d: Duration) -> SqliteResult<()> {
let ms = d.num_milliseconds() as i32;
let result = unsafe { ffi::sqlite3_busy_timeout(self.db.handle, ms) };
decode_result(result, "sqlite3_busy_timeout", maybe(self.detailed, self.db.handle))
}
pub fn last_insert_rowid(&self) -> i64 {
unsafe { ffi::sqlite3_last_insert_rowid(self.db.handle) }
}
pub unsafe fn expose(&mut self) -> *mut ffi::sqlite3 {
self.db.handle
}
}
fn charstar_str<'a>(utf_bytes: &'a *const c_char) -> Option<&'a str> {
if *utf_bytes == ptr::null() {
return None;
}
let c_str = unsafe { CStr::from_ptr(*utf_bytes) };
Some( unsafe { str::from_utf8_unchecked(c_str.to_bytes()) } )
}
#[inline(always)]
pub fn str_charstar<'a>(s: &'a str) -> std_ffi::CString {
std_ffi::CString::new(s.as_bytes()).unwrap_or(std_ffi::CString::new("").unwrap())
}
pub struct PreparedStatement {
db: Rc<Database>,
stmt: *mut ffi::sqlite3_stmt,
detailed: bool,
}
impl Drop for PreparedStatement {
fn drop(&mut self) {
unsafe {
ffi::sqlite3_finalize(self.stmt);
}
}
}
pub type ParamIx = u16;
impl PreparedStatement {
pub fn execute(&mut self) -> ResultSet {
ResultSet { statement: self }
}
}
impl PreparedStatement {
pub fn ignore_detail(&mut self) {
self.detailed = false;
}
fn detail_db(&mut self) -> Option<*mut ffi::sqlite3> {
if self.detailed {
let db = unsafe { ffi::sqlite3_db_handle(self.stmt) };
Some(db)
} else {
None
}
}
fn get_detail(&mut self) -> Option<String> {
self.detail_db().map(|db| DatabaseConnection::_errmsg(db))
}
pub fn bind_null(&mut self, i: ParamIx) -> SqliteResult<()> {
let ix = i as c_int;
let r = unsafe { ffi::sqlite3_bind_null(self.stmt, ix ) };
decode_result(r, "sqlite3_bind_null", self.detail_db())
}
pub fn bind_int(&mut self, i: ParamIx, value: i32) -> SqliteResult<()> {
let ix = i as c_int;
let r = unsafe { ffi::sqlite3_bind_int(self.stmt, ix, value) };
decode_result(r, "sqlite3_bind_int", self.detail_db())
}
pub fn bind_int64(&mut self, i: ParamIx, value: i64) -> SqliteResult<()> {
let ix = i as c_int;
let r = unsafe { ffi::sqlite3_bind_int64(self.stmt, ix, value) };
decode_result(r, "sqlite3_bind_int64", self.detail_db())
}
pub fn bind_double(&mut self, i: ParamIx, value: f64) -> SqliteResult<()> {
let ix = i as c_int;
let r = unsafe { ffi::sqlite3_bind_double(self.stmt, ix, value) };
decode_result(r, "sqlite3_bind_double", self.detail_db())
}
pub fn bind_text(&mut self, i: ParamIx, value: &str) -> SqliteResult<()> {
let ix = i as c_int;
let transient = unsafe { mem::transmute(-1 as isize) };
let c_value = str_charstar(value);
let len = value.len() as c_int;
let r = unsafe { ffi::sqlite3_bind_text(self.stmt, ix, c_value.as_ptr(), len, transient) };
decode_result(r, "sqlite3_bind_text", self.detail_db())
}
pub fn bind_blob(&mut self, i: ParamIx, value: &[u8]) -> SqliteResult<()> {
let ix = i as c_int;
let transient = unsafe { mem::transmute(-1 as isize) };
let len = value.len() as c_int;
let val = unsafe { mem::transmute(value.as_ptr()) };
let r = unsafe { ffi::sqlite3_bind_blob(self.stmt, ix, val, len, transient) };
decode_result(r, "sqlite3_bind_blob", self.detail_db())
}
pub fn clear_bindings(&mut self) {
unsafe { ffi::sqlite3_clear_bindings(self.stmt) };
}
pub fn bind_parameter_count(&mut self) -> ParamIx {
let count = unsafe { ffi::sqlite3_bind_parameter_count(self.stmt) };
count as ParamIx
}
pub unsafe fn expose(&mut self) -> *mut ffi::sqlite3_stmt {
self.stmt
}
pub fn changes(&self) -> u64 {
let dbh = self.db.handle;
let count = unsafe { ffi::sqlite3_changes(dbh) };
count as u64
}
}
pub struct ResultSet<'res> {
statement: &'res mut PreparedStatement,
}
enum_from_primitive! {
#[derive(Debug, PartialEq, Eq)]
#[allow(non_camel_case_types)]
enum Step {
SQLITE_ROW = 100,
SQLITE_DONE = 101,
}
}
impl<'res> Drop for ResultSet<'res> {
fn drop(&mut self) {
unsafe { ffi::sqlite3_reset(self.statement.stmt) };
}
}
impl<'res:'row, 'row> ResultSet<'res> {
pub fn step(&'row mut self) -> SqliteResult<Option<ResultRow<'res, 'row>>> {
let result = unsafe { ffi::sqlite3_step(self.statement.stmt) };
match Step::from_i32(result) {
Some(SQLITE_ROW) => {
Ok(Some(ResultRow{ rows: self }))
},
Some(SQLITE_DONE) => Ok(None),
None => Err(error_result(result, "step", self.statement.get_detail()))
}
}
}
pub struct ResultRow<'res:'row, 'row> {
rows: &'row mut ResultSet<'res>
}
pub type ColIx = u32;
impl<'res, 'row> ResultRow<'res, 'row> {
pub fn column_count(&self) -> ColIx {
let stmt = self.rows.statement.stmt;
let result = unsafe { ffi::sqlite3_column_count(stmt) };
result as ColIx
}
pub fn with_column_name<T, F: Fn(&str) -> T>(&mut self, i: ColIx, default: T, f: F) -> T {
let stmt = self.rows.statement.stmt;
let n = i as c_int;
let result = unsafe { ffi::sqlite3_column_name(stmt, n) };
match charstar_str(&result) {
Some(name) => f(name),
None => default
}
}
pub fn column_type(&self, col: ColIx) -> ColumnType {
let stmt = self.rows.statement.stmt;
let i_col = col as c_int;
let result = unsafe { ffi::sqlite3_column_type(stmt, i_col) };
ColumnType::from_i32(result).unwrap_or(SQLITE_NULL)
}
pub fn column_int(&self, col: ColIx) -> i32 {
let stmt = self.rows.statement.stmt;
let i_col = col as c_int;
unsafe { ffi::sqlite3_column_int(stmt, i_col) }
}
pub fn column_int64(&self, col: ColIx) -> i64 {
let stmt = self.rows.statement.stmt;
let i_col = col as c_int;
unsafe { ffi::sqlite3_column_int64(stmt, i_col) }
}
pub fn column_double(&self, col: ColIx) -> f64 {
let stmt = self.rows.statement.stmt;
let i_col = col as c_int;
unsafe { ffi::sqlite3_column_double(stmt, i_col) }
}
pub fn column_text(&self, col: ColIx) -> Option<String> {
self.column_str(col).map(|s| s.to_string())
}
pub fn column_str<'a>(&'a self, col: ColIx) -> Option<&'a str> {
self.column_slice(col).and_then(|slice| str::from_utf8(slice).ok() )
}
pub fn column_blob(&self, col: ColIx) -> Option<Vec<u8>> {
self.column_slice(col).map(|bs| bs.to_vec())
}
pub fn column_slice<'a>(&'a self, col: ColIx) -> Option<&'a [u8]> {
let stmt = self.rows.statement.stmt;
let i_col = col as c_int;
let bs = unsafe { ffi::sqlite3_column_blob(stmt, i_col) } as *const ::libc::c_uchar;
if bs == ptr::null() {
return None;
}
let len = unsafe { ffi::sqlite3_column_bytes(stmt, i_col) } as usize;
Some( unsafe { slice::from_raw_parts(bs, len) } )
}
}
pub fn decode_result(
result: c_int,
desc: &'static str,
detail_db: Option<*mut ffi::sqlite3>,
) -> SqliteResult<()> {
if result == SQLITE_OK as c_int {
Ok(())
} else {
let detail = detail_db.map(|db| DatabaseConnection::_errmsg(db));
Err(error_result(result, desc, detail))
}
}
fn error_result(
result: c_int,
desc: &'static str,
detail: Option<String>
) -> SqliteError {
SqliteError {
kind: SqliteErrorCode::from_i32(result).unwrap(),
desc: desc,
detail: detail
}
}
#[cfg(test)]
mod test_opening {
use super::{DatabaseConnection, SqliteResult};
use time::Duration;
#[test]
fn db_construct_typechecks() {
assert!(DatabaseConnection::in_memory().is_ok())
}
#[test]
fn db_busy_timeout() {
fn go() -> SqliteResult<()> {
let mut db = try!(DatabaseConnection::in_memory());
db.busy_timeout(Duration::seconds(2))
}
go().unwrap();
}
}
#[cfg(test)]
mod tests {
use super::{DatabaseConnection, SqliteResult, ResultSet};
use std::str;
#[test]
fn stmt_new_types() {
fn go() -> SqliteResult<()> {
let db = try!(DatabaseConnection::in_memory());
let res = db.prepare("select 1 + 1").map( |_s| () );
res
}
go().unwrap();
}
fn with_query<T, F>(sql: &str, mut f: F) -> SqliteResult<T>
where F: FnMut(&mut ResultSet) -> T
{
let db = try!(DatabaseConnection::in_memory());
let mut s = try!(db.prepare(sql));
let mut rows = s.execute();
Ok(f(&mut rows))
}
#[test]
fn query_two_rows() {
fn go() -> SqliteResult<(u32, i32)> {
let mut count = 0;
let mut sum = 0i32;
with_query("select 1
union all
select 2", |rows| {
loop {
match rows.step() {
Ok(Some(ref mut row)) => {
count += 1;
sum += row.column_int(0);
},
_ => break
}
}
(count, sum)
})
}
assert_eq!(go(), Ok((2, 3)))
}
#[test]
fn query_null_string() {
with_query("select null", |rows| {
match rows.step() {
Ok(Some(ref mut row)) => {
assert_eq!(row.column_text(0), None);
}
_ => { panic!("Expected a row"); }
}
}).unwrap();
}
#[test]
fn detailed_errors() {
let go = || -> SqliteResult<()> {
let db = try!(DatabaseConnection::in_memory());
try!(db.prepare("select bogus"));
Ok( () )
};
let err = go().err().unwrap();
assert_eq!(err.detail(), Some("no such column: bogus".to_string()))
}
#[test]
fn no_alloc_errors_db() {
let go = || {
let mut db = try!(DatabaseConnection::in_memory());
db.ignore_detail();
try!(db.prepare("select bogus"));
Ok( () )
};
let x: SqliteResult<()> = go();
let err = x.err().unwrap();
assert_eq!(err.detail(), None)
}
#[test]
fn no_alloc_errors_stmt() {
let db = DatabaseConnection::in_memory().unwrap();
let mut stmt = db.prepare("select 1").unwrap();
stmt.ignore_detail();
let oops = stmt.bind_text(3, "abc");
assert_eq!(oops.err().unwrap().detail(), None)
}
#[test]
fn non_utf8_str() {
let mut stmt = DatabaseConnection::in_memory().unwrap().prepare("SELECT x'4546FF'").unwrap();
let mut rows = stmt.execute();
let row = rows.step().unwrap().unwrap();
assert_eq!(row.column_str(0), None);
assert!(str::from_utf8(&[0x45u8, 0x46, 0xff]).is_err());
}
}