rbdc_sqlite/connection/
collation.rs

1use std::cmp::Ordering;
2use std::ffi::CString;
3use std::fmt::{self, Debug, Formatter};
4use std::os::raw::{c_int, c_void};
5use std::slice;
6use std::str::from_utf8_unchecked;
7use std::sync::Arc;
8
9use libsqlite3_sys::{sqlite3_create_collation_v2, SQLITE_OK, SQLITE_UTF8};
10use rbdc::err_protocol;
11
12use crate::connection::handle::ConnectionHandle;
13use crate::SqliteError;
14use rbdc::error::Error;
15
16#[derive(Clone)]
17pub struct Collation {
18    name: Arc<str>,
19    collate: Arc<dyn Fn(&str, &str) -> Ordering + Send + Sync + 'static>,
20    // SAFETY: these must match the concrete type of `collate`
21    call: unsafe extern "C" fn(
22        arg1: *mut c_void,
23        arg2: c_int,
24        arg3: *const c_void,
25        arg4: c_int,
26        arg5: *const c_void,
27    ) -> c_int,
28    free: unsafe extern "C" fn(*mut c_void),
29}
30
31impl Collation {
32    pub fn new<N, F>(name: N, collate: F) -> Self
33    where
34        N: Into<Arc<str>>,
35        F: Fn(&str, &str) -> Ordering + Send + Sync + 'static,
36    {
37        unsafe extern "C" fn drop_arc_value<T>(p: *mut c_void) {
38            drop(Arc::from_raw(p as *mut T));
39        }
40
41        Collation {
42            name: name.into(),
43            collate: Arc::new(collate),
44            call: call_boxed_closure::<F>,
45            free: drop_arc_value::<F>,
46        }
47    }
48
49    pub(crate) fn create(&self, handle: &mut ConnectionHandle) -> Result<(), Error> {
50        let raw_f = Arc::into_raw(Arc::clone(&self.collate));
51        let c_name = CString::new(&*self.name)
52            .map_err(|_| err_protocol!("invalid collation name: {:?}", self.name))?;
53        let flags = SQLITE_UTF8;
54        let r = unsafe {
55            sqlite3_create_collation_v2(
56                handle.as_ptr(),
57                c_name.as_ptr(),
58                flags,
59                raw_f as *mut c_void,
60                Some(self.call),
61                Some(self.free),
62            )
63        };
64
65        if r == SQLITE_OK {
66            Ok(())
67        } else {
68            // The xDestroy callback is not called if the sqlite3_create_collation_v2() function fails.
69            drop(unsafe { Arc::from_raw(raw_f) });
70            Err(Error::from(SqliteError::new(handle.as_ptr())))
71        }
72    }
73}
74
75impl Debug for Collation {
76    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
77        f.debug_struct("Collation")
78            .field("name", &self.name)
79            .finish_non_exhaustive()
80    }
81}
82
83pub(crate) fn create_collation<F>(
84    handle: &mut ConnectionHandle,
85    name: &str,
86    compare: F,
87) -> Result<(), Error>
88where
89    F: Fn(&str, &str) -> Ordering + Send + Sync + 'static,
90{
91    unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
92        drop(Box::from_raw(p as *mut T));
93    }
94
95    let boxed_f: *mut F = Box::into_raw(Box::new(compare));
96    let c_name =
97        CString::new(name).map_err(|_| err_protocol!("invalid collation name: {}", name))?;
98    let flags = SQLITE_UTF8;
99    let r = unsafe {
100        sqlite3_create_collation_v2(
101            handle.as_ptr(),
102            c_name.as_ptr(),
103            flags,
104            boxed_f as *mut c_void,
105            Some(call_boxed_closure::<F>),
106            Some(free_boxed_value::<F>),
107        )
108    };
109
110    if r == SQLITE_OK {
111        Ok(())
112    } else {
113        // The xDestroy callback is not called if the sqlite3_create_collation_v2() function fails.
114        drop(unsafe { Box::from_raw(boxed_f) });
115        Err(Error::from(SqliteError::new(handle.as_ptr())))
116    }
117}
118
119unsafe extern "C" fn call_boxed_closure<C>(
120    data: *mut c_void,
121    left_len: c_int,
122    left_ptr: *const c_void,
123    right_len: c_int,
124    right_ptr: *const c_void,
125) -> c_int
126where
127    C: Fn(&str, &str) -> Ordering,
128{
129    let boxed_f: *mut C = data as *mut C;
130    debug_assert!(!boxed_f.is_null());
131    let s1 = {
132        let c_slice = slice::from_raw_parts(left_ptr as *const u8, left_len as usize);
133        from_utf8_unchecked(c_slice)
134    };
135    let s2 = {
136        let c_slice = slice::from_raw_parts(right_ptr as *const u8, right_len as usize);
137        from_utf8_unchecked(c_slice)
138    };
139    let t = (*boxed_f)(s1, s2);
140
141    match t {
142        Ordering::Less => -1,
143        Ordering::Equal => 0,
144        Ordering::Greater => 1,
145    }
146}