rusqlite/
collation.rs

1//! Add, remove, or modify a collation
2use std::cmp::Ordering;
3use std::os::raw::{c_char, c_int, c_void};
4use std::panic::{catch_unwind, UnwindSafe};
5use std::ptr;
6use std::slice;
7
8use crate::ffi;
9use crate::{str_to_cstring, Connection, InnerConnection, Result};
10
11// FIXME copy/paste from function.rs
12unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
13    drop(Box::from_raw(p.cast::<T>()));
14}
15
16impl Connection {
17    /// Add or modify a collation.
18    #[inline]
19    pub fn create_collation<C>(&self, collation_name: &str, x_compare: C) -> Result<()>
20    where
21        C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static,
22    {
23        self.db
24            .borrow_mut()
25            .create_collation(collation_name, x_compare)
26    }
27
28    /// Collation needed callback
29    #[inline]
30    pub fn collation_needed(
31        &self,
32        x_coll_needed: fn(&Connection, &str) -> Result<()>,
33    ) -> Result<()> {
34        self.db.borrow_mut().collation_needed(x_coll_needed)
35    }
36
37    /// Remove collation.
38    #[inline]
39    pub fn remove_collation(&self, collation_name: &str) -> Result<()> {
40        self.db.borrow_mut().remove_collation(collation_name)
41    }
42}
43
44impl InnerConnection {
45    fn create_collation<C>(&mut self, collation_name: &str, x_compare: C) -> Result<()>
46    where
47        C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static,
48    {
49        unsafe extern "C" fn call_boxed_closure<C>(
50            arg1: *mut c_void,
51            arg2: c_int,
52            arg3: *const c_void,
53            arg4: c_int,
54            arg5: *const c_void,
55        ) -> c_int
56        where
57            C: Fn(&str, &str) -> Ordering,
58        {
59            let r = catch_unwind(|| {
60                let boxed_f: *mut C = arg1.cast::<C>();
61                assert!(!boxed_f.is_null(), "Internal error - null function pointer");
62                let s1 = {
63                    let c_slice = slice::from_raw_parts(arg3.cast::<u8>(), arg2 as usize);
64                    String::from_utf8_lossy(c_slice)
65                };
66                let s2 = {
67                    let c_slice = slice::from_raw_parts(arg5.cast::<u8>(), arg4 as usize);
68                    String::from_utf8_lossy(c_slice)
69                };
70                (*boxed_f)(s1.as_ref(), s2.as_ref())
71            });
72            let t = match r {
73                Err(_) => {
74                    return -1; // FIXME How ?
75                }
76                Ok(r) => r,
77            };
78
79            match t {
80                Ordering::Less => -1,
81                Ordering::Equal => 0,
82                Ordering::Greater => 1,
83            }
84        }
85
86        let boxed_f: *mut C = Box::into_raw(Box::new(x_compare));
87        let c_name = str_to_cstring(collation_name)?;
88        let flags = ffi::SQLITE_UTF8;
89        let r = unsafe {
90            ffi::sqlite3_create_collation_v2(
91                self.db(),
92                c_name.as_ptr(),
93                flags,
94                boxed_f.cast::<c_void>(),
95                Some(call_boxed_closure::<C>),
96                Some(free_boxed_value::<C>),
97            )
98        };
99        let res = self.decode_result(r);
100        // The xDestroy callback is not called if the sqlite3_create_collation_v2()
101        // function fails.
102        if res.is_err() {
103            drop(unsafe { Box::from_raw(boxed_f) });
104        }
105        res
106    }
107
108    fn collation_needed(
109        &mut self,
110        x_coll_needed: fn(&Connection, &str) -> Result<()>,
111    ) -> Result<()> {
112        use std::mem;
113        #[allow(clippy::needless_return)]
114        unsafe extern "C" fn collation_needed_callback(
115            arg1: *mut c_void,
116            arg2: *mut ffi::sqlite3,
117            e_text_rep: c_int,
118            arg3: *const c_char,
119        ) {
120            use std::ffi::CStr;
121            use std::str;
122
123            if e_text_rep != ffi::SQLITE_UTF8 {
124                // TODO: validate
125                return;
126            }
127
128            let callback: fn(&Connection, &str) -> Result<()> = mem::transmute(arg1);
129            let res = catch_unwind(|| {
130                let conn = Connection::from_handle(arg2).unwrap();
131                let collation_name = {
132                    let c_slice = CStr::from_ptr(arg3).to_bytes();
133                    str::from_utf8(c_slice).expect("illegal collation sequence name")
134                };
135                callback(&conn, collation_name)
136            });
137            if res.is_err() {
138                return; // FIXME How ?
139            }
140        }
141
142        let r = unsafe {
143            ffi::sqlite3_collation_needed(
144                self.db(),
145                x_coll_needed as *mut c_void,
146                Some(collation_needed_callback),
147            )
148        };
149        self.decode_result(r)
150    }
151
152    #[inline]
153    fn remove_collation(&mut self, collation_name: &str) -> Result<()> {
154        let c_name = str_to_cstring(collation_name)?;
155        let r = unsafe {
156            ffi::sqlite3_create_collation_v2(
157                self.db(),
158                c_name.as_ptr(),
159                ffi::SQLITE_UTF8,
160                ptr::null_mut(),
161                None,
162                None,
163            )
164        };
165        self.decode_result(r)
166    }
167}
168
169#[cfg(test)]
170mod test {
171    use crate::{Connection, Result};
172    use fallible_streaming_iterator::FallibleStreamingIterator;
173    use std::cmp::Ordering;
174    use unicase::UniCase;
175
176    fn unicase_compare(s1: &str, s2: &str) -> Ordering {
177        UniCase::new(s1).cmp(&UniCase::new(s2))
178    }
179
180    #[test]
181    fn test_unicase() -> Result<()> {
182        let db = Connection::open_in_memory()?;
183
184        db.create_collation("unicase", unicase_compare)?;
185
186        collate(db)
187    }
188
189    fn collate(db: Connection) -> Result<()> {
190        db.execute_batch(
191            "CREATE TABLE foo (bar);
192             INSERT INTO foo (bar) VALUES ('Maße');
193             INSERT INTO foo (bar) VALUES ('MASSE');",
194        )?;
195        let mut stmt = db.prepare("SELECT DISTINCT bar COLLATE unicase FROM foo ORDER BY 1")?;
196        let rows = stmt.query([])?;
197        assert_eq!(rows.count()?, 1);
198        Ok(())
199    }
200
201    fn collation_needed(db: &Connection, collation_name: &str) -> Result<()> {
202        if "unicase" == collation_name {
203            db.create_collation(collation_name, unicase_compare)
204        } else {
205            Ok(())
206        }
207    }
208
209    #[test]
210    fn test_collation_needed() -> Result<()> {
211        let db = Connection::open_in_memory()?;
212        db.collation_needed(collation_needed)?;
213        collate(db)
214    }
215}