1use std::marker::PhantomData;
2use std::path::Path;
3use std::str::FromStr;
4
5use crate::wal::{ffi::make_wal_manager, Wal, WalManager};
6
7#[cfg(not(feature = "rusqlite"))]
8type RawConnection = *mut crate::ffi::sqlite3;
9#[cfg(feature = "rusqlite")]
10type RawConnection = rusqlite::Connection;
11
12#[cfg(not(feature = "rusqlite"))]
13pub type OpenFlags = std::ffi::c_int;
14#[cfg(feature = "rusqlite")]
15pub type OpenFlags = rusqlite::OpenFlags;
16
17#[cfg(feature = "rusqlite")]
18type Error = rusqlite::Error;
19#[cfg(not(feature = "rusqlite"))]
20type Error = crate::Error;
21
22#[derive(Clone, Debug, Default)]
23pub enum Cipher {
24 #[default]
26 Aes256Cbc,
27}
28
29impl FromStr for Cipher {
30 type Err = libsql_ffi::Error;
31
32 fn from_str(s: &str) -> Result<Self, Self::Err> {
33 match s {
34 "aes256cbc" => Ok(Cipher::Aes256Cbc),
35 _ => Err(Self::Err::new(21)),
36 }
37 }
38}
39
40impl Cipher {
41 #[cfg(feature = "encryption")]
42 pub fn cipher_id(&self) -> i32 {
43 let name = match self {
44 Cipher::Aes256Cbc => "aes256cbc\0",
45 };
46 unsafe { sqlite3mc_cipher_index(name.as_ptr() as _) }
47 }
48}
49
50#[derive(Clone, Debug)]
51pub struct EncryptionConfig {
52 pub cipher: Cipher,
53 pub encryption_key: bytes::Bytes,
54}
55
56impl EncryptionConfig {
57 pub fn new(cipher: Cipher, encryption_key: bytes::Bytes) -> Self {
58 Self {
59 cipher,
60 encryption_key,
61 }
62 }
63
64 #[cfg(feature = "encryption")]
65 pub fn cipher_id(&self) -> i32 {
66 self.cipher.cipher_id()
67 }
68}
69
70#[derive(Debug)]
71pub struct Connection<W> {
72 conn: RawConnection,
73 _pth: PhantomData<W>,
74}
75
76#[cfg(feature = "rusqlite")]
77impl<W> std::ops::Deref for Connection<W> {
78 type Target = rusqlite::Connection;
79
80 fn deref(&self) -> &Self::Target {
81 &self.conn
82 }
83}
84
85#[cfg(feature = "rusqlite")]
86impl<W> std::ops::DerefMut for Connection<W> {
87 fn deref_mut(&mut self) -> &mut Self::Target {
88 &mut self.conn
89 }
90}
91
92#[cfg(feature = "rusqlite")]
93impl Connection<crate::wal::Sqlite3Wal> {
94 pub fn test() -> Self {
96 let conn = rusqlite::Connection::open_in_memory().unwrap();
97 Self {
98 conn,
99 _pth: PhantomData,
100 }
101 }
102}
103
104#[cfg(feature = "encryption")]
105extern "C" {
106 fn sqlite3mc_cipher_index(cipher: *const std::ffi::c_void) -> std::ffi::c_int;
107 fn sqlite3mc_config(
108 db: *mut libsql_ffi::sqlite3,
109 cipher: *const std::ffi::c_void,
110 nKey: std::ffi::c_int,
111 ) -> std::ffi::c_int;
112 fn sqlite3_key(
113 db: *mut libsql_ffi::sqlite3,
114 pKey: *const std::ffi::c_void,
115 nKey: std::ffi::c_int,
116 ) -> std::ffi::c_int;
117 fn sqlite3_rekey(
118 db: *mut libsql_ffi::sqlite3,
119 pKey: *const std::ffi::c_void,
120 nKey: std::ffi::c_int,
121 ) -> std::ffi::c_int;
122
123 fn libsql_leak_pager(db: *mut libsql_ffi::sqlite3) -> *mut crate::ffi::Pager;
124 fn libsql_generate_initial_vector(seed: u32, iv: *mut u8);
125 fn libsql_generate_aes256_key(user_password: *const u8, password_length: u32, digest: *mut u8);
126}
127
128pub fn pghdr_creator(
129 data: &mut [u8; 4096],
130 _db: *mut libsql_ffi::sqlite3,
131) -> libsql_ffi::libsql_pghdr {
132 #[cfg(feature = "encryption")]
133 let pager = crate::connection::leak_pager(_db);
134 #[cfg(not(feature = "encryption"))]
135 let pager = std::ptr::null_mut();
136
137 libsql_ffi::libsql_pghdr {
138 pPage: std::ptr::null_mut(),
139 pData: data.as_mut_ptr() as _,
140 pExtra: std::ptr::null_mut(),
141 pCache: std::ptr::null_mut(),
142 pDirty: std::ptr::null_mut(),
143 pPager: pager,
144 pgno: 1,
145 pageHash: 0x02, flags: 0,
147 nRef: 0,
148 pDirtyNext: std::ptr::null_mut(),
149 pDirtyPrev: std::ptr::null_mut(),
150 }
151}
152
153#[cfg(feature = "encryption")]
154pub unsafe fn set_encryption_cipher(db: *mut libsql_ffi::sqlite3, cipher_id: i32) -> i32 {
157 unsafe { sqlite3mc_config(db, "default:cipher\0".as_ptr() as _, cipher_id) as i32 }
158}
159
160#[cfg(feature = "encryption")]
161pub unsafe fn set_encryption_key(db: *mut libsql_ffi::sqlite3, key: &[u8]) -> i32 {
164 unsafe { sqlite3_key(db, key.as_ptr() as _, key.len() as _) as i32 }
165}
166
167#[cfg(feature = "encryption")]
168pub fn reset_encryption_key(db: *mut libsql_ffi::sqlite3, key: &[u8]) -> i32 {
171 unsafe { sqlite3_rekey(db, key.as_ptr() as _, key.len() as _) as i32 }
172}
173
174#[cfg(feature = "encryption")]
175pub fn leak_pager(db: *mut libsql_ffi::sqlite3) -> *mut crate::ffi::Pager {
176 unsafe { libsql_leak_pager(db) }
177}
178
179#[cfg(feature = "encryption")]
180pub fn generate_initial_vector(seed: u32, iv: &mut [u8]) {
181 unsafe { libsql_generate_initial_vector(seed, iv.as_mut_ptr()) }
182}
183
184#[cfg(feature = "encryption")]
185pub fn generate_aes256_key(user_password: &[u8], digest: &mut [u8]) {
186 unsafe {
187 libsql_generate_aes256_key(
188 user_password.as_ptr(),
189 user_password.len() as u32,
190 digest.as_mut_ptr(),
191 )
192 }
193}
194
195pub const NO_AUTOCHECKPOINT: u32 = 0;
196
197impl<W: Wal> Connection<W> {
198 pub fn open<T>(
200 path: impl AsRef<Path>,
201 flags: OpenFlags,
202 wal_manager: T,
203 auto_checkpoint: u32,
204 encryption_config: Option<EncryptionConfig>,
205 ) -> Result<Self, Error>
206 where
207 T: WalManager<Wal = W>,
208 {
209 tracing::trace!(
210 "Opening a connection with regular WAL at {}",
211 path.as_ref().display()
212 );
213
214 #[cfg(feature = "rusqlite")]
215 let conn = {
216 let conn = if cfg!(feature = "unix-excl-vfs") {
217 rusqlite::Connection::open_with_flags_vfs_and_wal(
218 path,
219 flags,
220 "unix-excl",
221 make_wal_manager(wal_manager),
222 )
223 } else {
224 rusqlite::Connection::open_with_flags_and_wal(
225 path,
226 flags,
227 make_wal_manager(wal_manager),
228 )
229 }?;
230
231 if !cfg!(feature = "encryption") && encryption_config.is_some() {
232 return Err(Error::SqliteFailure(
233 rusqlite::ffi::Error::new(21),
234 Some("encryption feature is not enabled, the database will not be encrypted on disk"
235 .to_string()),
236 ));
237 }
238 #[cfg(feature = "encryption")]
239 if let Some(cfg) = encryption_config {
240 let cipher_id = cfg.cipher_id();
241 if unsafe { set_encryption_cipher(conn.handle(), cipher_id) } == -1 {
242 return Err(Error::SqliteFailure(
243 rusqlite::ffi::Error::new(21),
244 Some("failed to set encryption cipher".into()),
245 ));
246 };
247 if unsafe { set_encryption_key(conn.handle(), &cfg.encryption_key) }
248 != rusqlite::ffi::SQLITE_OK
249 {
250 return Err(Error::SqliteFailure(
251 rusqlite::ffi::Error::new(21),
252 Some("failed to set encryption key".into()),
253 ));
254 };
255 }
256
257 conn.pragma_update(None, "journal_mode", "WAL")?;
258 unsafe {
259 let rc =
260 rusqlite::ffi::sqlite3_wal_autocheckpoint(conn.handle(), auto_checkpoint as _);
261 if rc != 0 {
262 return Err(rusqlite::Error::SqliteFailure(
263 rusqlite::ffi::Error::new(rc),
264 Some("failed to set auto_checkpoint".into()),
265 ));
266 }
267 }
268
269 conn.busy_timeout(std::time::Duration::from_millis(100))?;
270
271 conn
272 };
273
274 #[cfg(not(feature = "rusqlite"))]
275 let conn = unsafe {
276 #[cfg(unix)]
277 let path = {
278 use std::os::unix::ffi::OsStrExt;
279 std::ffi::CString::new(path.as_ref().as_os_str().as_bytes()).map_err(|_| {
280 crate::error::Error::Bug(
281 "invalid database path containing an internal nul byte",
282 )
283 })?
284 };
285 #[cfg(not(unix))]
286 let path = path
287 .to_str()
288 .ok_or_else(|| crate::error::Error::Bug("database path is not valid unicode"))
289 .and_then(|x| {
290 std::ffi::CString::new(x).map_err(|_| {
291 crate::error::Error::Bug(
292 "invalid database path containing an internal nul byte",
293 )
294 })
295 })?;
296
297 let mut conn: *mut crate::ffi::sqlite3 = std::ptr::null_mut();
298 let vfs = if cfg!(feature = "unix-excl-vfs") {
302 "unix-excl\0".as_ptr() as *const _
303 } else {
304 std::ptr::null_mut()
305 };
306 let mut rc = libsql_ffi::libsql_open_v3(
307 path.as_ptr(),
308 &mut conn as *mut _,
309 flags,
310 vfs,
311 make_wal_manager(wal_manager),
312 );
313
314 if !cfg!(feature = "encryption") && encryption_config.is_some() {
315 return Err(Error::Bug(
316 "encryption feature is not enabled, the database will not be encrypted on disk",
317 ));
318 }
319 #[cfg(feature = "encryption")]
320 if let Some(cfg) = encryption_config {
321 let cipher_id = cfg.cipher_id();
322 if set_encryption_cipher(conn, cipher_id) == -1 {
323 return Err(Error::Bug("failed to set encryption cipher"));
324 }
325 if set_encryption_key(conn, &cfg.encryption_key) != libsql_ffi::SQLITE_OK {
326 return Err(Error::Bug("failed to set encryption key"));
327 }
328 }
329
330 if rc == 0 {
331 rc = libsql_ffi::sqlite3_wal_autocheckpoint(conn, auto_checkpoint as _);
332 }
333
334 if rc != 0 {
335 libsql_ffi::sqlite3_close(conn);
336 return Err(crate::Error::from(rc));
337 }
338
339 assert!(!conn.is_null());
340
341 crate::ffi::sqlite3_busy_timeout(conn, 5000);
342
343 conn
344 };
345
346 Ok(Connection {
347 conn,
348 _pth: PhantomData,
349 })
350 }
351
352 pub unsafe fn handle(&self) -> *mut libsql_ffi::sqlite3 {
357 #[cfg(feature = "rusqlite")]
358 {
359 self.conn.handle()
360 }
361 #[cfg(not(feature = "rusqlite"))]
362 {
363 self.conn
364 }
365 }
366
367 pub fn db_change_counter(&self) -> Result<u32, std::ffi::c_int> {
368 let mut file_ptr: *mut libsql_ffi::sqlite3_file = std::ptr::null_mut();
369 let rc = unsafe {
370 libsql_ffi::sqlite3_file_control(
371 self.handle(),
372 "main\0".as_ptr() as *const _,
373 libsql_ffi::SQLITE_FCNTL_FILE_POINTER,
374 &mut file_ptr as *mut _ as *mut _,
375 )
376 };
377 if rc != libsql_ffi::SQLITE_OK {
378 return Err(rc);
379 }
380 let counter = unsafe {
381 let mut counter: u32 = 0;
382 let file = &*file_ptr;
383 (*file.pMethods).xRead.unwrap()(file_ptr, &mut counter as *mut _ as *mut _, 4, 24);
384 u32::from_be(counter)
385 };
386 Ok(counter)
387 }
388}
389