Skip to main content

rusqlite/
collation.rs

1//! `feature = "collation"` Add, remove, or modify a collation
2use std::cmp::Ordering;
3use std::os::raw::{c_char, c_int, c_void};
4use std::panic::{catch_unwind, UnwindSafe};
5use std::ptr;
6use std::slice;
7
8use crate::ffi;
9use crate::{str_to_cstring, Connection, InnerConnection, Result};
10
11// FIXME copy/paste from function.rs
12unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
13    drop(Box::from_raw(p as *mut T));
14}
15
16impl Connection {
17    /// `feature = "collation"` Add or modify a collation.
18    #[inline]
19    pub fn create_collation<'c, C>(&'c self, collation_name: &str, x_compare: C) -> Result<()>
20    where
21        C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'c,
22    {
23        self.db
24            .borrow_mut()
25            .create_collation(collation_name, x_compare)
26    }
27
28    /// `feature = "collation"` Collation needed callback
29    #[inline]
30    pub fn collation_needed(
31        &self,
32        x_coll_needed: fn(&Connection, &str) -> Result<()>,
33    ) -> Result<()> {
34        self.db.borrow_mut().collation_needed(x_coll_needed)
35    }
36
37    /// `feature = "collation"` Remove collation.
38    #[inline]
39    pub fn remove_collation(&self, collation_name: &str) -> Result<()> {
40        self.db.borrow_mut().remove_collation(collation_name)
41    }
42}
43
44impl InnerConnection {
45    fn create_collation<'c, C>(&'c mut self, collation_name: &str, x_compare: C) -> Result<()>
46    where
47        C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'c,
48    {
49        unsafe extern "C" fn call_boxed_closure<C>(
50            arg1: *mut c_void,
51            arg2: c_int,
52            arg3: *const c_void,
53            arg4: c_int,
54            arg5: *const c_void,
55        ) -> c_int
56        where
57            C: Fn(&str, &str) -> Ordering,
58        {
59            let r = catch_unwind(|| {
60                let boxed_f: *mut C = arg1 as *mut C;
61                assert!(!boxed_f.is_null(), "Internal error - null function pointer");
62                let s1 = {
63                    let c_slice = slice::from_raw_parts(arg3 as *const u8, arg2 as usize);
64                    String::from_utf8_lossy(c_slice)
65                };
66                let s2 = {
67                    let c_slice = slice::from_raw_parts(arg5 as *const u8, arg4 as usize);
68                    String::from_utf8_lossy(c_slice)
69                };
70                (*boxed_f)(s1.as_ref(), s2.as_ref())
71            });
72            let t = match r {
73                Err(_) => {
74                    return -1; // FIXME How ?
75                }
76                Ok(r) => r,
77            };
78
79            match t {
80                Ordering::Less => -1,
81                Ordering::Equal => 0,
82                Ordering::Greater => 1,
83            }
84        }
85
86        let boxed_f: *mut C = Box::into_raw(Box::new(x_compare));
87        let c_name = str_to_cstring(collation_name)?;
88        let flags = ffi::SQLITE_UTF8;
89        let r = unsafe {
90            ffi::sqlite3_create_collation_v2(
91                self.db(),
92                c_name.as_ptr(),
93                flags,
94                boxed_f as *mut c_void,
95                Some(call_boxed_closure::<C>),
96                Some(free_boxed_value::<C>),
97            )
98        };
99        let res = self.decode_result(r);
100        // The xDestroy callback is not called if the sqlite3_create_collation_v2() function fails.
101        if res.is_err() {
102            drop(unsafe { Box::from_raw(boxed_f) });
103        }
104        res
105    }
106
107    fn collation_needed(
108        &mut self,
109        x_coll_needed: fn(&Connection, &str) -> Result<()>,
110    ) -> Result<()> {
111        use std::mem;
112        unsafe extern "C" fn collation_needed_callback(
113            arg1: *mut c_void,
114            arg2: *mut ffi::sqlite3,
115            e_text_rep: c_int,
116            arg3: *const c_char,
117        ) {
118            use std::ffi::CStr;
119            use std::str;
120
121            if e_text_rep != ffi::SQLITE_UTF8 {
122                // TODO: validate
123                return;
124            }
125
126            let callback: fn(&Connection, &str) -> Result<()> = mem::transmute(arg1);
127            let res = catch_unwind(|| {
128                let conn = Connection::from_handle(arg2).unwrap();
129                let collation_name = {
130                    let c_slice = CStr::from_ptr(arg3).to_bytes();
131                    str::from_utf8(c_slice).expect("illegal coallation sequence name")
132                };
133                callback(&conn, collation_name)
134            });
135            if res.is_err() {
136                return; // FIXME How ?
137            }
138        }
139
140        let r = unsafe {
141            ffi::sqlite3_collation_needed(
142                self.db(),
143                x_coll_needed as *mut c_void,
144                Some(collation_needed_callback),
145            )
146        };
147        self.decode_result(r)
148    }
149
150    #[inline]
151    fn remove_collation(&mut self, collation_name: &str) -> Result<()> {
152        let c_name = str_to_cstring(collation_name)?;
153        let r = unsafe {
154            ffi::sqlite3_create_collation_v2(
155                self.db(),
156                c_name.as_ptr(),
157                ffi::SQLITE_UTF8,
158                ptr::null_mut(),
159                None,
160                None,
161            )
162        };
163        self.decode_result(r)
164    }
165}
166
167#[cfg(test)]
168mod test {
169    use crate::{Connection, Result};
170    use fallible_streaming_iterator::FallibleStreamingIterator;
171    use std::cmp::Ordering;
172    use unicase::UniCase;
173
174    fn unicase_compare(s1: &str, s2: &str) -> Ordering {
175        UniCase::new(s1).cmp(&UniCase::new(s2))
176    }
177
178    #[test]
179    fn test_unicase() -> Result<()> {
180        let db = Connection::open_in_memory()?;
181
182        db.create_collation("unicase", unicase_compare)?;
183
184        collate(db)
185    }
186
187    fn collate(db: Connection) -> Result<()> {
188        db.execute_batch(
189            "CREATE TABLE foo (bar);
190             INSERT INTO foo (bar) VALUES ('Maße');
191             INSERT INTO foo (bar) VALUES ('MASSE');",
192        )?;
193        let mut stmt = db.prepare("SELECT DISTINCT bar COLLATE unicase FROM foo ORDER BY 1")?;
194        let rows = stmt.query([])?;
195        assert_eq!(rows.count()?, 1);
196        Ok(())
197    }
198
199    fn collation_needed(db: &Connection, collation_name: &str) -> Result<()> {
200        if "unicase" == collation_name {
201            db.create_collation(collation_name, unicase_compare)
202        } else {
203            Ok(())
204        }
205    }
206
207    #[test]
208    fn test_collation_needed() -> Result<()> {
209        let db = Connection::open_in_memory()?;
210        db.collation_needed(collation_needed)?;
211        collate(db)
212    }
213}