1use crate::ffi as bindings;
26
27use crate::{
28 check_import_db, ImportDbError, MemChunksFile, OsCallback, SQLiteIoMethods, SQLiteVfs,
29 SQLiteVfsFile, VfsAppData, VfsError, VfsFile, VfsResult, VfsStore,
30};
31
32use alloc::boxed::Box;
33use alloc::string::String;
34use alloc::vec::Vec;
35use alloc::{format, vec};
36use core::cell::RefCell;
37use core::ffi::CStr;
38use core::marker::PhantomData;
39use core::time::Duration;
40use hashbrown::HashMap;
41
42const VFS_NAME: &CStr = c"memvfs";
43
44type Result<T> = core::result::Result<T, MemVfsError>;
45
46pub enum MemFile {
47 Main(MemChunksFile),
48 Temp(MemChunksFile),
49}
50
51impl MemFile {
52 fn new(flags: i32) -> Self {
53 if flags & bindings::SQLITE_OPEN_MAIN_DB == 0 {
54 Self::Temp(MemChunksFile::default())
55 } else {
56 Self::Main(MemChunksFile::waiting_for_write())
57 }
58 }
59
60 fn file(&self) -> &MemChunksFile {
61 let (MemFile::Main(file) | MemFile::Temp(file)) = self;
62 file
63 }
64
65 fn file_mut(&mut self) -> &mut MemChunksFile {
66 let (MemFile::Main(file) | MemFile::Temp(file)) = self;
67 file
68 }
69}
70
71impl VfsFile for MemFile {
72 fn read(&self, buf: &mut [u8], offset: usize) -> VfsResult<bool> {
73 self.file().read(buf, offset)
74 }
75
76 fn write(&mut self, buf: &[u8], offset: usize) -> VfsResult<()> {
77 self.file_mut().write(buf, offset)
78 }
79
80 fn truncate(&mut self, size: usize) -> VfsResult<()> {
81 self.file_mut().truncate(size)
82 }
83
84 fn flush(&mut self) -> VfsResult<()> {
85 self.file_mut().flush()
86 }
87
88 fn size(&self) -> VfsResult<usize> {
89 self.file().size()
90 }
91}
92
93type MemAppData = RefCell<HashMap<String, MemFile>>;
94
95#[derive(Copy, Clone, Default)]
96struct MemStore;
97
98impl VfsStore<MemFile, MemAppData> for MemStore {
99 fn add_file(vfs: *mut bindings::sqlite3_vfs, file: &str, flags: i32) -> VfsResult<()> {
100 let app_data = unsafe { Self::app_data(vfs) };
101 app_data
102 .borrow_mut()
103 .insert(file.into(), MemFile::new(flags));
104 Ok(())
105 }
106
107 fn contains_file(vfs: *mut bindings::sqlite3_vfs, file: &str) -> VfsResult<bool> {
108 let app_data = unsafe { Self::app_data(vfs) };
109 Ok(app_data.borrow().contains_key(file))
110 }
111
112 fn delete_file(vfs: *mut bindings::sqlite3_vfs, file: &str) -> VfsResult<()> {
113 let app_data = unsafe { Self::app_data(vfs) };
114 if app_data.borrow_mut().remove(file).is_none() {
115 return Err(VfsError::new(
116 bindings::SQLITE_IOERR_DELETE,
117 format!("{file} not found"),
118 ));
119 }
120 Ok(())
121 }
122
123 fn with_file<F: Fn(&MemFile) -> VfsResult<i32>>(
124 vfs_file: &SQLiteVfsFile,
125 f: F,
126 ) -> VfsResult<i32> {
127 let name = unsafe { vfs_file.name() };
128 let app_data = unsafe { Self::app_data(vfs_file.vfs) };
129 match app_data.borrow().get(name) {
130 Some(file) => f(file),
131 None => Err(VfsError::new(
132 bindings::SQLITE_IOERR,
133 format!("{name} not found"),
134 )),
135 }
136 }
137
138 fn with_file_mut<F: Fn(&mut MemFile) -> VfsResult<i32>>(
139 vfs_file: &SQLiteVfsFile,
140 f: F,
141 ) -> VfsResult<i32> {
142 let name = unsafe { vfs_file.name() };
143 let app_data = unsafe { Self::app_data(vfs_file.vfs) };
144 match app_data.borrow_mut().get_mut(name) {
145 Some(file) => f(file),
146 None => Err(VfsError::new(
147 bindings::SQLITE_IOERR,
148 format!("{name} not found"),
149 )),
150 }
151 }
152}
153
154#[derive(Clone, Copy, Default)]
155struct MemIoMethods;
156
157impl SQLiteIoMethods for MemIoMethods {
158 type File = MemFile;
159 type AppData = MemAppData;
160 type Store = MemStore;
161
162 const VERSION: ::core::ffi::c_int = 1;
163}
164
165#[derive(Clone, Copy, Default)]
166struct MemVfs<C>(PhantomData<C>);
167
168impl<C> SQLiteVfs<MemIoMethods> for MemVfs<C>
169where
170 C: OsCallback,
171{
172 const VERSION: ::core::ffi::c_int = 1;
173
174 fn sleep(dur: Duration) {
175 C::sleep(dur);
176 }
177
178 fn random(buf: &mut [u8]) {
179 C::random(buf);
180 }
181
182 fn epoch_timestamp_in_ms() -> i64 {
183 C::epoch_timestamp_in_ms()
184 }
185}
186
187#[derive(thiserror::Error, Debug)]
188pub enum MemVfsError {
189 #[error(transparent)]
190 ImportDb(#[from] ImportDbError),
191 #[error("Generic error: {0}")]
192 Generic(String),
193}
194
195pub struct MemVfsUtil<C>(&'static VfsAppData<MemAppData>, PhantomData<C>);
197
198impl<C> Default for MemVfsUtil<C>
199where
200 C: OsCallback,
201{
202 fn default() -> Self {
203 MemVfsUtil::new()
204 }
205}
206
207impl<C> MemVfsUtil<C>
208where
209 C: OsCallback,
210{
211 pub fn new() -> Self {
213 MemVfsUtil(unsafe { install::<C>() }, PhantomData)
215 }
216}
217
218impl<C> MemVfsUtil<C>
219where
220 C: OsCallback,
221{
222 fn import_db_unchecked_impl(
223 &self,
224 filename: &str,
225 bytes: &[u8],
226 page_size: usize,
227 clear_wal: bool,
228 ) -> Result<()> {
229 if self.exists(filename) {
230 return Err(MemVfsError::Generic(format!(
231 "{filename} file already exists"
232 )));
233 }
234
235 self.0.borrow_mut().insert(filename.into(), {
236 let mut file = MemFile::Main(MemChunksFile::new(page_size));
237 file.write(bytes, 0).unwrap();
238 if clear_wal {
239 file.write(&[1, 1], 18).unwrap();
241 }
242 file
243 });
244
245 Ok(())
246 }
247
248 pub fn import_db(&self, filename: &str, bytes: &[u8]) -> Result<()> {
256 let page_size = check_import_db(bytes)?;
257 self.import_db_unchecked_impl(filename, bytes, page_size, true)
258 }
259
260 pub fn import_db_unchecked(
262 &self,
263 filename: &str,
264 bytes: &[u8],
265 page_size: usize,
266 ) -> Result<()> {
267 self.import_db_unchecked_impl(filename, bytes, page_size, false)
268 }
269
270 pub fn export_db(&self, filename: &str) -> Result<Vec<u8>> {
272 let name2file = self.0.borrow();
273
274 if let Some(file) = name2file.get(filename) {
275 let file_size = file.size().unwrap();
276 let mut ret = vec![0; file_size];
277 file.read(&mut ret, 0).unwrap();
278 Ok(ret)
279 } else {
280 Err(MemVfsError::Generic(
281 "The file to be exported does not exist".into(),
282 ))
283 }
284 }
285
286 pub fn delete_db(&self, filename: &str) {
288 self.0.borrow_mut().remove(filename);
289 }
290
291 pub fn clear_all(&self) {
293 core::mem::take(&mut *self.0.borrow_mut());
294 }
295
296 pub fn exists(&self, filename: &str) -> bool {
298 self.0.borrow().contains_key(filename)
299 }
300
301 pub fn list(&self) -> Vec<String> {
303 self.0.borrow().keys().cloned().collect()
304 }
305
306 pub fn count(&self) -> usize {
308 self.0.borrow().len()
309 }
310}
311
312pub unsafe fn install<C: OsCallback>() -> &'static VfsAppData<MemAppData> {
320 let vfs = bindings::sqlite3_vfs_find(VFS_NAME.as_ptr());
321
322 let vfs = if vfs.is_null() {
323 let vfs = Box::leak(Box::new(MemVfs::<C>::vfs(
324 VFS_NAME.as_ptr(),
325 VfsAppData::new(MemAppData::default()).leak(),
326 )));
327 assert_eq!(
328 bindings::sqlite3_vfs_register(vfs, 1),
329 bindings::SQLITE_OK,
330 "failed to register memvfs"
331 );
332 vfs as *mut bindings::sqlite3_vfs
333 } else {
334 vfs
335 };
336
337 MemStore::app_data(vfs)
338}
339
340pub unsafe fn uninstall() {
349 let vfs = bindings::sqlite3_vfs_find(VFS_NAME.as_ptr());
350
351 if !vfs.is_null() {
352 assert_eq!(
353 bindings::sqlite3_vfs_unregister(vfs),
354 bindings::SQLITE_OK,
355 "failed to unregister memvfs"
356 );
357 drop(VfsAppData::<MemAppData>::from_raw(
358 (*vfs).pAppData as *mut _,
359 ));
360 drop(Box::from_raw(vfs));
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use crate::{
367 memvfs::{MemAppData, MemFile, MemStore},
368 test_suite::test_vfs_store,
369 VfsAppData,
370 };
371
372 #[test]
373 fn test_memory_vfs_store() {
374 test_vfs_store::<MemAppData, MemFile, MemStore>(VfsAppData::new(MemAppData::default()))
375 .unwrap();
376 }
377}