sqlite_vfs_http/
vfs.rs

1use super::*;
2use rand::{thread_rng, Rng};
3use sqlite_vfs::{OpenKind, OpenOptions, Vfs};
4use std::{
5    io::{Error, ErrorKind},
6    time::Duration,
7};
8
9pub struct HttpVfs {
10    pub(crate) client: Option<Client>,
11    pub(crate) block_size: usize,
12    pub(crate) download_threshold: usize,
13}
14
15impl Vfs for HttpVfs {
16    type Handle = Connection;
17
18    fn open(&self, db: &str, opts: OpenOptions) -> Result<Self::Handle, Error> {
19        if opts.kind != OpenKind::MainDb {
20            return Err(Error::new(
21                ErrorKind::ReadOnlyFilesystem,
22                "only main database supported",
23            ));
24        }
25
26        Ok(Connection::new(
27            db,
28            self.client.clone(),
29            self.block_size,
30            self.download_threshold,
31        )?)
32    }
33
34    fn delete(&self, _db: &str) -> Result<(), Error> {
35        Err(Error::new(
36            ErrorKind::ReadOnlyFilesystem,
37            "delete operation is not supported",
38        ))
39    }
40
41    fn exists(&self, _db: &str) -> Result<bool, Error> {
42        Ok(false)
43    }
44
45    fn temporary_name(&self) -> String {
46        String::from("main.db")
47    }
48
49    fn random(&self, buffer: &mut [i8]) {
50        Rng::fill(&mut thread_rng(), buffer);
51    }
52
53    fn sleep(&self, duration: Duration) -> Duration {
54        std::thread::sleep(duration);
55        duration
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use std::future::Future;
62
63    use super::*;
64    use rusqlite::{Connection, OpenFlags};
65    use tokio::time::sleep;
66
67    const QUERY_SQLITE_MASTER: &str = "SELECT count(1) FROM sqlite_master WHERE type = 'table'";
68    const QUERY_TEST: &str = "SELECT name FROM test";
69
70    mod server {
71        use rocket::{custom, figment::Figment, get, routes, Config, Shutdown, State};
72        use rocket_seek_stream::SeekStream;
73        use rusqlite::Connection;
74        use std::{collections::HashMap, fs::read, io::Cursor, thread::JoinHandle};
75        use tempfile::tempdir;
76        use tokio::runtime::Runtime;
77
78        fn init_database() -> HashMap<i64, Vec<u8>> {
79            let schemas = [
80                vec![
81                    "PRAGMA journal_mode = MEMORY;",
82                    "CREATE TABLE test1 (id INTEGER PRIMARY KEY, name TEXT);",
83                    "CREATE TABLE test2 (id INTEGER PRIMARY KEY, name TEXT);",
84                ],
85                vec![
86                    "PRAGMA journal_mode = MEMORY;",
87                    "CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT);",
88                    "INSERT INTO test (name) VALUES ('Alice');",
89                    "INSERT INTO test (name) VALUES ('Bob');",
90                ],
91            ];
92            let mut database = HashMap::new();
93
94            let temp = tempdir().unwrap();
95
96            for (i, schema) in schemas.into_iter().enumerate() {
97                let path = temp.path().join(format!("{i}.db"));
98                let conn = Connection::open(&path).unwrap();
99                conn.execute_batch(&schema.join("\n")).unwrap();
100                conn.close().unwrap();
101                database.insert(i as i64, read(&path).unwrap());
102            }
103
104            database
105        }
106
107        #[get("/<id>")]
108        pub async fn database(
109            db: &State<HashMap<i64, Vec<u8>>>,
110            id: i64,
111        ) -> Option<SeekStream<'static>> {
112            if let Some(buffer) = db.get(&id) {
113                let cursor = Cursor::new(buffer.clone());
114                Some(SeekStream::with_opts(cursor, buffer.len() as u64, None))
115            } else {
116                None
117            }
118        }
119
120        #[get("/shutdown")]
121        pub async fn shutdown(shutdown: Shutdown) -> &'static str {
122            shutdown.notify();
123            "Shutting down..."
124        }
125
126        pub fn launch() -> JoinHandle<Result<(), rocket::Error>> {
127            std::thread::spawn(|| {
128                let rt = Runtime::new().unwrap();
129                rt.block_on(async {
130                    custom(Figment::from(Config::default()).merge(("port", 4096)))
131                        .manage(init_database())
132                        .mount("/", routes![database, shutdown])
133                        .launch()
134                        .await?;
135
136                    Ok(())
137                })
138            })
139        }
140    }
141
142    async fn init_server<C, F>(future: C) -> anyhow::Result<()>
143    where
144        C: FnOnce(String) -> F,
145        F: Future<Output = anyhow::Result<()>>,
146    {
147        let base = "http://127.0.0.1:4096";
148        let server = server::launch();
149
150        // wait for server to start
151        loop {
152            let resp = reqwest::get(base).await;
153            if let Ok(resp) = resp {
154                if resp.status() == 404 {
155                    break;
156                }
157            }
158            sleep(Duration::from_millis(100)).await;
159        }
160
161        future(base.into()).await?;
162
163        reqwest::get(format!("{base}/shutdown").as_str()).await?;
164        server.join().unwrap()?;
165
166        Ok(())
167    }
168
169    #[tokio::test]
170    async fn test_http_vfs() {
171        init_server(|base| async move {
172            vfs::register_http_vfs();
173
174            {
175                let conn = Connection::open_with_flags_and_vfs(
176                    format!("{base}/0"),
177                    OpenFlags::SQLITE_OPEN_READ_WRITE
178                        | OpenFlags::SQLITE_OPEN_CREATE
179                        | OpenFlags::SQLITE_OPEN_NO_MUTEX,
180                    HTTP_VFS,
181                )?;
182                assert_eq!(
183                    conn.query_row::<usize, _, _>(QUERY_SQLITE_MASTER, [], |row| row.get(0))?,
184                    2
185                );
186            }
187
188            {
189                let conn = Connection::open_with_flags_and_vfs(
190                    format!("{base}/1"),
191                    OpenFlags::SQLITE_OPEN_READ_WRITE
192                        | OpenFlags::SQLITE_OPEN_CREATE
193                        | OpenFlags::SQLITE_OPEN_NO_MUTEX,
194                    HTTP_VFS,
195                )?;
196                let mut stmt = conn.prepare(QUERY_TEST)?;
197                assert_eq!(
198                    stmt.query_map([], |row| row.get::<_, String>(0))?
199                        .collect::<Result<Vec<_>, _>>()?,
200                    vec!["Alice".to_string(), "Bob".to_string()]
201                );
202            }
203
204            Ok(())
205        })
206        .await
207        .unwrap();
208    }
209}