libsql_sys/
connection.rs

1use std::marker::PhantomData;
2use std::path::Path;
3use std::str::FromStr;
4
5use crate::wal::{ffi::make_wal_manager, Wal, WalManager};
6
7#[cfg(not(feature = "rusqlite"))]
8type RawConnection = *mut crate::ffi::sqlite3;
9#[cfg(feature = "rusqlite")]
10type RawConnection = rusqlite::Connection;
11
12#[cfg(not(feature = "rusqlite"))]
13pub type OpenFlags = std::ffi::c_int;
14#[cfg(feature = "rusqlite")]
15pub type OpenFlags = rusqlite::OpenFlags;
16
17#[cfg(feature = "rusqlite")]
18type Error = rusqlite::Error;
19#[cfg(not(feature = "rusqlite"))]
20type Error = crate::Error;
21
22#[derive(Clone, Debug, Default)]
23pub enum Cipher {
24    // AES 256 Bit CBC - No HMAC (wxSQLite3)
25    #[default]
26    Aes256Cbc,
27}
28
29impl FromStr for Cipher {
30    type Err = libsql_ffi::Error;
31
32    fn from_str(s: &str) -> Result<Self, Self::Err> {
33        match s {
34            "aes256cbc" => Ok(Cipher::Aes256Cbc),
35            _ => Err(Self::Err::new(21)),
36        }
37    }
38}
39
40impl Cipher {
41    #[cfg(feature = "encryption")]
42    pub fn cipher_id(&self) -> i32 {
43        let name = match self {
44            Cipher::Aes256Cbc => "aes256cbc\0",
45        };
46        unsafe { sqlite3mc_cipher_index(name.as_ptr() as _) }
47    }
48}
49
50#[derive(Clone, Debug)]
51pub struct EncryptionConfig {
52    pub cipher: Cipher,
53    pub encryption_key: bytes::Bytes,
54}
55
56impl EncryptionConfig {
57    pub fn new(cipher: Cipher, encryption_key: bytes::Bytes) -> Self {
58        Self {
59            cipher,
60            encryption_key,
61        }
62    }
63
64    #[cfg(feature = "encryption")]
65    pub fn cipher_id(&self) -> i32 {
66        self.cipher.cipher_id()
67    }
68}
69
70#[derive(Debug)]
71pub struct Connection<W> {
72    conn: RawConnection,
73    _pth: PhantomData<W>,
74}
75
76#[cfg(feature = "rusqlite")]
77impl<W> std::ops::Deref for Connection<W> {
78    type Target = rusqlite::Connection;
79
80    fn deref(&self) -> &Self::Target {
81        &self.conn
82    }
83}
84
85#[cfg(feature = "rusqlite")]
86impl<W> std::ops::DerefMut for Connection<W> {
87    fn deref_mut(&mut self) -> &mut Self::Target {
88        &mut self.conn
89    }
90}
91
92#[cfg(feature = "rusqlite")]
93impl Connection<crate::wal::Sqlite3Wal> {
94    /// returns a dummy, in-memory connection. For testing purposes only
95    pub fn test() -> Self {
96        let conn = rusqlite::Connection::open_in_memory().unwrap();
97        Self {
98            conn,
99            _pth: PhantomData,
100        }
101    }
102}
103
104#[cfg(feature = "encryption")]
105extern "C" {
106    fn sqlite3mc_cipher_index(cipher: *const std::ffi::c_void) -> std::ffi::c_int;
107    fn sqlite3mc_config(
108        db: *mut libsql_ffi::sqlite3,
109        cipher: *const std::ffi::c_void,
110        nKey: std::ffi::c_int,
111    ) -> std::ffi::c_int;
112    fn sqlite3_key(
113        db: *mut libsql_ffi::sqlite3,
114        pKey: *const std::ffi::c_void,
115        nKey: std::ffi::c_int,
116    ) -> std::ffi::c_int;
117    fn sqlite3_rekey(
118        db: *mut libsql_ffi::sqlite3,
119        pKey: *const std::ffi::c_void,
120        nKey: std::ffi::c_int,
121    ) -> std::ffi::c_int;
122
123    fn libsql_leak_pager(db: *mut libsql_ffi::sqlite3) -> *mut crate::ffi::Pager;
124    fn libsql_generate_initial_vector(seed: u32, iv: *mut u8);
125    fn libsql_generate_aes256_key(user_password: *const u8, password_length: u32, digest: *mut u8);
126}
127
128pub fn pghdr_creator(
129    data: &mut [u8; 4096],
130    _db: *mut libsql_ffi::sqlite3,
131) -> libsql_ffi::libsql_pghdr {
132    #[cfg(feature = "encryption")]
133    let pager = crate::connection::leak_pager(_db);
134    #[cfg(not(feature = "encryption"))]
135    let pager = std::ptr::null_mut();
136
137    libsql_ffi::libsql_pghdr {
138        pPage: std::ptr::null_mut(),
139        pData: data.as_mut_ptr() as _,
140        pExtra: std::ptr::null_mut(),
141        pCache: std::ptr::null_mut(),
142        pDirty: std::ptr::null_mut(),
143        pPager: pager,
144        pgno: 1,
145        pageHash: 0x02, // DIRTY
146        flags: 0,
147        nRef: 0,
148        pDirtyNext: std::ptr::null_mut(),
149        pDirtyPrev: std::ptr::null_mut(),
150    }
151}
152
153#[cfg(feature = "encryption")]
154/// # Safety
155/// db must point to a vaid sqlite database
156pub unsafe fn set_encryption_cipher(db: *mut libsql_ffi::sqlite3, cipher_id: i32) -> i32 {
157    unsafe { sqlite3mc_config(db, "default:cipher\0".as_ptr() as _, cipher_id) as i32 }
158}
159
160#[cfg(feature = "encryption")]
161/// # Safety
162/// db must point to a vaid sqlite database
163pub unsafe fn set_encryption_key(db: *mut libsql_ffi::sqlite3, key: &[u8]) -> i32 {
164    unsafe { sqlite3_key(db, key.as_ptr() as _, key.len() as _) as i32 }
165}
166
167#[cfg(feature = "encryption")]
168/// # Safety
169/// db must point to a valid sqlite database
170pub fn reset_encryption_key(db: *mut libsql_ffi::sqlite3, key: &[u8]) -> i32 {
171    unsafe { sqlite3_rekey(db, key.as_ptr() as _, key.len() as _) as i32 }
172}
173
174#[cfg(feature = "encryption")]
175pub fn leak_pager(db: *mut libsql_ffi::sqlite3) -> *mut crate::ffi::Pager {
176    unsafe { libsql_leak_pager(db) }
177}
178
179#[cfg(feature = "encryption")]
180pub fn generate_initial_vector(seed: u32, iv: &mut [u8]) {
181    unsafe { libsql_generate_initial_vector(seed, iv.as_mut_ptr()) }
182}
183
184#[cfg(feature = "encryption")]
185pub fn generate_aes256_key(user_password: &[u8], digest: &mut [u8]) {
186    unsafe {
187        libsql_generate_aes256_key(
188            user_password.as_ptr(),
189            user_password.len() as u32,
190            digest.as_mut_ptr(),
191        )
192    }
193}
194
195pub const NO_AUTOCHECKPOINT: u32 = 0;
196
197impl<W: Wal> Connection<W> {
198    /// Opens a database with the regular wal methods in the directory pointed to by path
199    pub fn open<T>(
200        path: impl AsRef<Path>,
201        flags: OpenFlags,
202        wal_manager: T,
203        auto_checkpoint: u32,
204        encryption_config: Option<EncryptionConfig>,
205    ) -> Result<Self, Error>
206    where
207        T: WalManager<Wal = W>,
208    {
209        tracing::trace!(
210            "Opening a connection with regular WAL at {}",
211            path.as_ref().display()
212        );
213
214        #[cfg(feature = "rusqlite")]
215        let conn = {
216            let conn = if cfg!(feature = "unix-excl-vfs") {
217                rusqlite::Connection::open_with_flags_vfs_and_wal(
218                    path,
219                    flags,
220                    "unix-excl",
221                    make_wal_manager(wal_manager),
222                )
223            } else {
224                rusqlite::Connection::open_with_flags_and_wal(
225                    path,
226                    flags,
227                    make_wal_manager(wal_manager),
228                )
229            }?;
230
231            if !cfg!(feature = "encryption") && encryption_config.is_some() {
232                return Err(Error::SqliteFailure(
233                    rusqlite::ffi::Error::new(21),
234                    Some("encryption feature is not enabled, the database will not be encrypted on disk"
235                        .to_string()),
236                ));
237            }
238            #[cfg(feature = "encryption")]
239            if let Some(cfg) = encryption_config {
240                let cipher_id = cfg.cipher_id();
241                if unsafe { set_encryption_cipher(conn.handle(), cipher_id) } == -1 {
242                    return Err(Error::SqliteFailure(
243                        rusqlite::ffi::Error::new(21),
244                        Some("failed to set encryption cipher".into()),
245                    ));
246                };
247                if unsafe { set_encryption_key(conn.handle(), &cfg.encryption_key) }
248                    != rusqlite::ffi::SQLITE_OK
249                {
250                    return Err(Error::SqliteFailure(
251                        rusqlite::ffi::Error::new(21),
252                        Some("failed to set encryption key".into()),
253                    ));
254                };
255            }
256
257            conn.pragma_update(None, "journal_mode", "WAL")?;
258            unsafe {
259                let rc =
260                    rusqlite::ffi::sqlite3_wal_autocheckpoint(conn.handle(), auto_checkpoint as _);
261                if rc != 0 {
262                    return Err(rusqlite::Error::SqliteFailure(
263                        rusqlite::ffi::Error::new(rc),
264                        Some("failed to set auto_checkpoint".into()),
265                    ));
266                }
267            }
268
269            conn.busy_timeout(std::time::Duration::from_millis(100))?;
270
271            conn
272        };
273
274        #[cfg(not(feature = "rusqlite"))]
275        let conn = unsafe {
276            #[cfg(unix)]
277            let path = {
278                use std::os::unix::ffi::OsStrExt;
279                std::ffi::CString::new(path.as_ref().as_os_str().as_bytes()).map_err(|_| {
280                    crate::error::Error::Bug(
281                        "invalid database path containing an internal nul byte",
282                    )
283                })?
284            };
285            #[cfg(not(unix))]
286            let path = path
287                .to_str()
288                .ok_or_else(|| crate::error::Error::Bug("database path is not valid unicode"))
289                .and_then(|x| {
290                    std::ffi::CString::new(x).map_err(|_| {
291                        crate::error::Error::Bug(
292                            "invalid database path containing an internal nul byte",
293                        )
294                    })
295                })?;
296
297            let mut conn: *mut crate::ffi::sqlite3 = std::ptr::null_mut();
298            // We pass a pointer to the WAL methods data to the database connection. This means
299            // that the reference must outlive the connection. This is guaranteed by the marker in
300            // the returned connection.
301            let vfs = if cfg!(feature = "unix-excl-vfs") {
302                "unix-excl\0".as_ptr() as *const _
303            } else {
304                std::ptr::null_mut()
305            };
306            let mut rc = libsql_ffi::libsql_open_v3(
307                path.as_ptr(),
308                &mut conn as *mut _,
309                flags,
310                vfs,
311                make_wal_manager(wal_manager),
312            );
313
314            if !cfg!(feature = "encryption") && encryption_config.is_some() {
315                return Err(Error::Bug(
316                    "encryption feature is not enabled, the database will not be encrypted on disk",
317                ));
318            }
319            #[cfg(feature = "encryption")]
320            if let Some(cfg) = encryption_config {
321                let cipher_id = cfg.cipher_id();
322                if set_encryption_cipher(conn, cipher_id) == -1 {
323                    return Err(Error::Bug("failed to set encryption cipher"));
324                }
325                if set_encryption_key(conn, &cfg.encryption_key) != libsql_ffi::SQLITE_OK {
326                    return Err(Error::Bug("failed to set encryption key"));
327                }
328            }
329
330            if rc == 0 {
331                rc = libsql_ffi::sqlite3_wal_autocheckpoint(conn, auto_checkpoint as _);
332            }
333
334            if rc != 0 {
335                libsql_ffi::sqlite3_close(conn);
336                return Err(crate::Error::from(rc));
337            }
338
339            assert!(!conn.is_null());
340
341            crate::ffi::sqlite3_busy_timeout(conn, 5000);
342
343            conn
344        };
345
346        Ok(Connection {
347            conn,
348            _pth: PhantomData,
349        })
350    }
351
352    /// Returns the raw sqlite handle
353    ///
354    /// # Safety
355    /// The caller is responsible for the returned pointer.
356    pub unsafe fn handle(&self) -> *mut libsql_ffi::sqlite3 {
357        #[cfg(feature = "rusqlite")]
358        {
359            self.conn.handle()
360        }
361        #[cfg(not(feature = "rusqlite"))]
362        {
363            self.conn
364        }
365    }
366
367    pub fn db_change_counter(&self) -> Result<u32, std::ffi::c_int> {
368        let mut file_ptr: *mut libsql_ffi::sqlite3_file = std::ptr::null_mut();
369        let rc = unsafe {
370            libsql_ffi::sqlite3_file_control(
371                self.handle(),
372                "main\0".as_ptr() as *const _,
373                libsql_ffi::SQLITE_FCNTL_FILE_POINTER,
374                &mut file_ptr as *mut _ as *mut _,
375            )
376        };
377        if rc != libsql_ffi::SQLITE_OK {
378            return Err(rc);
379        }
380        let counter = unsafe {
381            let mut counter: u32 = 0;
382            let file = &*file_ptr;
383            (*file.pMethods).xRead.unwrap()(file_ptr, &mut counter as *mut _ as *mut _, 4, 24);
384            u32::from_be(counter)
385        };
386        Ok(counter)
387    }
388}
389// pub struct Connection<'a> {
390//     pub conn: *mut crate::ffi::sqlite3,
391//     _pth: PhantomData<&'a mut ()>,
392// }
393//
394// /// The `Connection` struct is `Send` because `sqlite3` is thread-safe.
395// unsafe impl<'a> Send for Connection<'a> {}
396// unsafe impl<'a> Sync for Connection<'a> {}
397//
398// impl<'a> Connection<'a> {
399//     /// returns a dummy, in-memory connection. For testing purposes only
400//     pub fn test(_: &mut ()) -> Self {
401//         let mut conn: *mut crate::ffi::sqlite3 = std::ptr::null_mut();
402//         let rc = unsafe {
403//             crate::ffi::sqlite3_open(":memory:\0".as_ptr() as *const _, &mut conn as *mut _)
404//         };
405//         assert_eq!(rc, 0);
406//         assert!(!conn.is_null());
407//         Self {
408//             conn,
409//             _pth: PhantomData,
410//         }
411//     }
412//
413//     /// Opens a database with the regular wal methods, given a path to the database file.
414//     pub fn open<W: Wal>(
415//         path: impl AsRef<std::path::Path>,
416//         flags: c_int,
417//         // we technically _only_ need to know about W, but requiring a static ref to the wal_hook ensures that
418//         // it has been instantiated and lives for long enough
419//         _wal_hook: &'static WalMethodsHook<W>,
420//         hook_ctx: &'a mut W::Context,
421//     ) -> Result<Self, crate::Error> {
422//         let path = path.as_ref();
423//         tracing::trace!(
424//             "Opening a connection with regular WAL at {}",
425//             path.display()
426//         );
427//
428//         let conn_str = format!("file:{}?_journal_mode=WAL", path.display());
429//         let filename = CString::new(conn_str).unwrap();
430//         let mut conn: *mut crate::ffi::sqlite3 = std::ptr::null_mut();
431//
432//         unsafe {
433//             // We pass a pointer to the WAL methods data to the database connection. This means
434//             // that the reference must outlive the connection. This is guaranteed by the marker in
435//             // the returned connection.
436//             let rc = crate::ffi::libsql_open_v2(
437//                 filename.as_ptr(),
438//                 &mut conn as *mut _,
439//                 flags,
440//                 std::ptr::null_mut(),
441//                 W::name().as_ptr(),
442//                 hook_ctx as *mut _ as *mut _,
443//             );
444//
445//             if rc != 0 {
446//                 crate::ffi::sqlite3_close(conn);
447//                 return Err(crate::Error::LibError(rc));
448//             }
449//
450//             assert!(!conn.is_null());
451//         };
452//
453//         unsafe {
454//             crate::ffi::sqlite3_busy_timeout(conn, 5000);
455//         }
456//
457//         Ok(Connection {
458//             conn,
459//             _pth: PhantomData,
460//         })
461//     }
462//
463//     pub fn is_autocommit(&self) -> bool {
464//         unsafe { crate::ffi::sqlite3_get_autocommit(self.conn) != 0 }
465//     }
466// }
467//
468// impl Drop for Connection<'_> {
469//     fn drop(&mut self) {
470//         if self.conn.is_null() {
471//             tracing::debug!("Trying to close a null connection");
472//             return;
473//         }
474//         unsafe {
475//             crate::ffi::sqlite3_close(self.conn as *mut _);
476//         }
477//     }
478// }