rusqlite-ic 0.28.1

Ergonomic wrapper for SQLite
Documentation
//! Add, remove, or modify a collation
use std::cmp::Ordering;
use std::os::raw::{c_char, c_int, c_void};
use std::panic::{catch_unwind, UnwindSafe};
use std::ptr;
use std::slice;

use crate::ffi;
use crate::{str_to_cstring, Connection, InnerConnection, Result};

// FIXME copy/paste from function.rs
unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
    drop(Box::from_raw(p.cast::<T>()));
}

impl Connection {
    /// Add or modify a collation.
    #[inline]
    pub fn create_collation<C>(&self, collation_name: &str, x_compare: C) -> Result<()>
    where
        C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static,
    {
        self.db
            .borrow_mut()
            .create_collation(collation_name, x_compare)
    }

    /// Collation needed callback
    #[inline]
    pub fn collation_needed(
        &self,
        x_coll_needed: fn(&Connection, &str) -> Result<()>,
    ) -> Result<()> {
        self.db.borrow_mut().collation_needed(x_coll_needed)
    }

    /// Remove collation.
    #[inline]
    pub fn remove_collation(&self, collation_name: &str) -> Result<()> {
        self.db.borrow_mut().remove_collation(collation_name)
    }
}

impl InnerConnection {
    fn create_collation<C>(&mut self, collation_name: &str, x_compare: C) -> Result<()>
    where
        C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static,
    {
        unsafe extern "C" fn call_boxed_closure<C>(
            arg1: *mut c_void,
            arg2: c_int,
            arg3: *const c_void,
            arg4: c_int,
            arg5: *const c_void,
        ) -> c_int
        where
            C: Fn(&str, &str) -> Ordering,
        {
            let r = catch_unwind(|| {
                let boxed_f: *mut C = arg1.cast::<C>();
                assert!(!boxed_f.is_null(), "Internal error - null function pointer");
                let s1 = {
                    let c_slice = slice::from_raw_parts(arg3.cast::<u8>(), arg2 as usize);
                    String::from_utf8_lossy(c_slice)
                };
                let s2 = {
                    let c_slice = slice::from_raw_parts(arg5.cast::<u8>(), arg4 as usize);
                    String::from_utf8_lossy(c_slice)
                };
                (*boxed_f)(s1.as_ref(), s2.as_ref())
            });
            let t = match r {
                Err(_) => {
                    return -1; // FIXME How ?
                }
                Ok(r) => r,
            };

            match t {
                Ordering::Less => -1,
                Ordering::Equal => 0,
                Ordering::Greater => 1,
            }
        }

        let boxed_f: *mut C = Box::into_raw(Box::new(x_compare));
        let c_name = str_to_cstring(collation_name)?;
        let flags = ffi::SQLITE_UTF8;
        let r = unsafe {
            ffi::sqlite3_create_collation_v2(
                self.db(),
                c_name.as_ptr(),
                flags,
                boxed_f.cast::<c_void>(),
                Some(call_boxed_closure::<C>),
                Some(free_boxed_value::<C>),
            )
        };
        let res = self.decode_result(r);
        // The xDestroy callback is not called if the sqlite3_create_collation_v2()
        // function fails.
        if res.is_err() {
            drop(unsafe { Box::from_raw(boxed_f) });
        }
        res
    }

    fn collation_needed(
        &mut self,
        x_coll_needed: fn(&Connection, &str) -> Result<()>,
    ) -> Result<()> {
        use std::mem;
        #[allow(clippy::needless_return)]
        unsafe extern "C" fn collation_needed_callback(
            arg1: *mut c_void,
            arg2: *mut ffi::sqlite3,
            e_text_rep: c_int,
            arg3: *const c_char,
        ) {
            use std::ffi::CStr;
            use std::str;

            if e_text_rep != ffi::SQLITE_UTF8 {
                // TODO: validate
                return;
            }

            let callback: fn(&Connection, &str) -> Result<()> = mem::transmute(arg1);
            let res = catch_unwind(|| {
                let conn = Connection::from_handle(arg2).unwrap();
                let collation_name = {
                    let c_slice = CStr::from_ptr(arg3).to_bytes();
                    str::from_utf8(c_slice).expect("illegal collation sequence name")
                };
                callback(&conn, collation_name)
            });
            if res.is_err() {
                return; // FIXME How ?
            }
        }

        let r = unsafe {
            ffi::sqlite3_collation_needed(
                self.db(),
                x_coll_needed as *mut c_void,
                Some(collation_needed_callback),
            )
        };
        self.decode_result(r)
    }

    #[inline]
    fn remove_collation(&mut self, collation_name: &str) -> Result<()> {
        let c_name = str_to_cstring(collation_name)?;
        let r = unsafe {
            ffi::sqlite3_create_collation_v2(
                self.db(),
                c_name.as_ptr(),
                ffi::SQLITE_UTF8,
                ptr::null_mut(),
                None,
                None,
            )
        };
        self.decode_result(r)
    }
}

#[cfg(test)]
mod test {
    use crate::{Connection, Result};
    use fallible_streaming_iterator::FallibleStreamingIterator;
    use std::cmp::Ordering;
    use unicase::UniCase;

    fn unicase_compare(s1: &str, s2: &str) -> Ordering {
        UniCase::new(s1).cmp(&UniCase::new(s2))
    }

    #[test]
    fn test_unicase() -> Result<()> {
        let db = Connection::open_in_memory()?;

        db.create_collation("unicase", unicase_compare)?;

        collate(db)
    }

    fn collate(db: Connection) -> Result<()> {
        db.execute_batch(
            "CREATE TABLE foo (bar);
             INSERT INTO foo (bar) VALUES ('Maße');
             INSERT INTO foo (bar) VALUES ('MASSE');",
        )?;
        let mut stmt = db.prepare("SELECT DISTINCT bar COLLATE unicase FROM foo ORDER BY 1")?;
        let rows = stmt.query([])?;
        assert_eq!(rows.count()?, 1);
        Ok(())
    }

    fn collation_needed(db: &Connection, collation_name: &str) -> Result<()> {
        if "unicase" == collation_name {
            db.create_collation(collation_name, unicase_compare)
        } else {
            Ok(())
        }
    }

    #[test]
    fn test_collation_needed() -> Result<()> {
        let db = Connection::open_in_memory()?;
        db.collation_needed(collation_needed)?;
        collate(db)
    }
}