1use std::cmp::Ordering;
3use std::os::raw::{c_char, c_int, c_void};
4use std::panic::{catch_unwind, UnwindSafe};
5use std::ptr;
6use std::slice;
7
8use crate::ffi;
9use crate::{str_to_cstring, Connection, InnerConnection, Result};
10
11unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
13 drop(Box::from_raw(p as *mut T));
14}
15
16impl Connection {
17 #[inline]
19 pub fn create_collation<'c, C>(&'c self, collation_name: &str, x_compare: C) -> Result<()>
20 where
21 C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'c,
22 {
23 self.db
24 .borrow_mut()
25 .create_collation(collation_name, x_compare)
26 }
27
28 #[inline]
30 pub fn collation_needed(
31 &self,
32 x_coll_needed: fn(&Connection, &str) -> Result<()>,
33 ) -> Result<()> {
34 self.db.borrow_mut().collation_needed(x_coll_needed)
35 }
36
37 #[inline]
39 pub fn remove_collation(&self, collation_name: &str) -> Result<()> {
40 self.db.borrow_mut().remove_collation(collation_name)
41 }
42}
43
44impl InnerConnection {
45 fn create_collation<'c, C>(&'c mut self, collation_name: &str, x_compare: C) -> Result<()>
46 where
47 C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'c,
48 {
49 unsafe extern "C" fn call_boxed_closure<C>(
50 arg1: *mut c_void,
51 arg2: c_int,
52 arg3: *const c_void,
53 arg4: c_int,
54 arg5: *const c_void,
55 ) -> c_int
56 where
57 C: Fn(&str, &str) -> Ordering,
58 {
59 let r = catch_unwind(|| {
60 let boxed_f: *mut C = arg1 as *mut C;
61 assert!(!boxed_f.is_null(), "Internal error - null function pointer");
62 let s1 = {
63 let c_slice = slice::from_raw_parts(arg3 as *const u8, arg2 as usize);
64 String::from_utf8_lossy(c_slice)
65 };
66 let s2 = {
67 let c_slice = slice::from_raw_parts(arg5 as *const u8, arg4 as usize);
68 String::from_utf8_lossy(c_slice)
69 };
70 (*boxed_f)(s1.as_ref(), s2.as_ref())
71 });
72 let t = match r {
73 Err(_) => {
74 return -1; }
76 Ok(r) => r,
77 };
78
79 match t {
80 Ordering::Less => -1,
81 Ordering::Equal => 0,
82 Ordering::Greater => 1,
83 }
84 }
85
86 let boxed_f: *mut C = Box::into_raw(Box::new(x_compare));
87 let c_name = str_to_cstring(collation_name)?;
88 let flags = ffi::SQLITE_UTF8;
89 let r = unsafe {
90 ffi::sqlite3_create_collation_v2(
91 self.db(),
92 c_name.as_ptr(),
93 flags,
94 boxed_f as *mut c_void,
95 Some(call_boxed_closure::<C>),
96 Some(free_boxed_value::<C>),
97 )
98 };
99 let res = self.decode_result(r);
100 if res.is_err() {
102 drop(unsafe { Box::from_raw(boxed_f) });
103 }
104 res
105 }
106
107 fn collation_needed(
108 &mut self,
109 x_coll_needed: fn(&Connection, &str) -> Result<()>,
110 ) -> Result<()> {
111 use std::mem;
112 unsafe extern "C" fn collation_needed_callback(
113 arg1: *mut c_void,
114 arg2: *mut ffi::sqlite3,
115 e_text_rep: c_int,
116 arg3: *const c_char,
117 ) {
118 use std::ffi::CStr;
119 use std::str;
120
121 if e_text_rep != ffi::SQLITE_UTF8 {
122 return;
124 }
125
126 let callback: fn(&Connection, &str) -> Result<()> = mem::transmute(arg1);
127 let res = catch_unwind(|| {
128 let conn = Connection::from_handle(arg2).unwrap();
129 let collation_name = {
130 let c_slice = CStr::from_ptr(arg3).to_bytes();
131 str::from_utf8(c_slice).expect("illegal coallation sequence name")
132 };
133 callback(&conn, collation_name)
134 });
135 if res.is_err() {
136 return; }
138 }
139
140 let r = unsafe {
141 ffi::sqlite3_collation_needed(
142 self.db(),
143 x_coll_needed as *mut c_void,
144 Some(collation_needed_callback),
145 )
146 };
147 self.decode_result(r)
148 }
149
150 #[inline]
151 fn remove_collation(&mut self, collation_name: &str) -> Result<()> {
152 let c_name = str_to_cstring(collation_name)?;
153 let r = unsafe {
154 ffi::sqlite3_create_collation_v2(
155 self.db(),
156 c_name.as_ptr(),
157 ffi::SQLITE_UTF8,
158 ptr::null_mut(),
159 None,
160 None,
161 )
162 };
163 self.decode_result(r)
164 }
165}
166
167#[cfg(test)]
168mod test {
169 use crate::{Connection, Result};
170 use fallible_streaming_iterator::FallibleStreamingIterator;
171 use std::cmp::Ordering;
172 use unicase::UniCase;
173
174 fn unicase_compare(s1: &str, s2: &str) -> Ordering {
175 UniCase::new(s1).cmp(&UniCase::new(s2))
176 }
177
178 #[test]
179 fn test_unicase() -> Result<()> {
180 let db = Connection::open_in_memory()?;
181
182 db.create_collation("unicase", unicase_compare)?;
183
184 collate(db)
185 }
186
187 fn collate(db: Connection) -> Result<()> {
188 db.execute_batch(
189 "CREATE TABLE foo (bar);
190 INSERT INTO foo (bar) VALUES ('Maße');
191 INSERT INTO foo (bar) VALUES ('MASSE');",
192 )?;
193 let mut stmt = db.prepare("SELECT DISTINCT bar COLLATE unicase FROM foo ORDER BY 1")?;
194 let rows = stmt.query([])?;
195 assert_eq!(rows.count()?, 1);
196 Ok(())
197 }
198
199 fn collation_needed(db: &Connection, collation_name: &str) -> Result<()> {
200 if "unicase" == collation_name {
201 db.create_collation(collation_name, unicase_compare)
202 } else {
203 Ok(())
204 }
205 }
206
207 #[test]
208 fn test_collation_needed() -> Result<()> {
209 let db = Connection::open_in_memory()?;
210 db.collation_needed(collation_needed)?;
211 collate(db)
212 }
213}