Skip to main content

mecomp_storage/db/
mod.rs

1#[cfg(feature = "db")]
2pub mod crud;
3#[cfg(feature = "db")]
4pub mod health;
5#[cfg(feature = "db")]
6pub(crate) mod queries;
7pub mod schemas;
8
9#[cfg(feature = "db")]
10use surrealdb::{Surreal, engine::local::Db, opt::Config};
11
12#[cfg(feature = "db")]
13#[cfg(not(tarpaulin_include))]
14static DB_DIR: once_cell::sync::OnceCell<std::path::PathBuf> = once_cell::sync::OnceCell::new();
15#[cfg(feature = "db")]
16#[cfg(not(tarpaulin_include))]
17static TEMP_DB_DIR: once_cell::sync::Lazy<tempfile::TempDir> = once_cell::sync::Lazy::new(|| {
18    tempfile::tempdir().expect("Failed to create temporary directory")
19});
20
21/// Set the path to the database.
22///
23/// # Errors
24///
25/// This function will return an error if the path cannot be set.
26#[cfg(feature = "db")]
27#[allow(clippy::missing_inline_in_public_items)]
28pub fn set_database_path(path: std::path::PathBuf) -> Result<(), crate::errors::Error> {
29    DB_DIR
30        .set(path)
31        .map_err(crate::errors::Error::DbPathSetError)?;
32    log::info!("Primed database path");
33    Ok(())
34}
35
36/// Initialize the database with the necessary tables.
37///
38/// # Errors
39///
40/// This function will return an error if the database cannot be initialized.
41#[cfg(feature = "db")]
42#[allow(clippy::missing_inline_in_public_items)]
43pub async fn init_database() -> surrealqlx::Result<Surreal<Db>> {
44    use surrealqlx::surrql;
45
46    // Get the database path, or use a temporary directory if not set
47    let db_path = DB_DIR
48    .get().cloned()
49    .unwrap_or_else(|| {
50        log::warn!("DB_DIR not set, defaulting to a temporary directory `{}`, this is likely a bug because `set_database_path` should be called before `init_database`", TEMP_DB_DIR.path().display());
51        TEMP_DB_DIR.path()
52        .to_path_buf()
53    });
54
55    let temp_dir = db_path.join("temp");
56    let temp_dir = if temp_dir.exists() {
57        Some(temp_dir)
58    } else {
59        // if we can't create the temp directory, just don't use one
60        std::fs::create_dir_all(&temp_dir).ok().and(Some(temp_dir))
61    };
62    let config = Config::new().strict().temporary_directory(temp_dir);
63    let db = Surreal::new((db_path, config)).await?;
64
65    db.query(surrql!("DEFINE NAMESPACE IF NOT EXISTS mecomp"))
66        .await?;
67    db.use_ns("mecomp").await?;
68    db.query(surrql!("DEFINE DATABASE IF NOT EXISTS music"))
69        .await?;
70    db.use_db("music").await?;
71
72    register_custom_analyzer(&db).await?;
73    surrealqlx::register_tables!(
74        &db,
75        schemas::album::Album,
76        schemas::artist::Artist,
77        schemas::song::Song,
78        schemas::collection::Collection,
79        schemas::playlist::Playlist,
80        schemas::dynamic::DynamicPlaylist
81    )?;
82    #[cfg(feature = "analysis")]
83    surrealqlx::register_tables!(&db, schemas::analysis::Analysis)?;
84
85    queries::relations::define_relation_tables(&db).await?;
86
87    Ok(db)
88}
89
90#[cfg(feature = "db")]
91pub(crate) async fn register_custom_analyzer<C>(db: &Surreal<C>) -> surrealqlx::Result<()>
92where
93    C: surrealdb::Connection,
94{
95    use surrealqlx::{
96        migrations::{M, Migrations},
97        surrql,
98    };
99
100    // NOTE: if you change this, you must go through the schemas and update the index analyzer names
101    let analyzer_definition = surrql!(
102        "DEFINE ANALYZER IF NOT EXISTS custom_analyzer
103         TOKENIZERS class
104         FILTERS ascii,lowercase,edgengram(1,10),snowball(English);"
105    );
106
107    let migrations = Migrations::new(
108        "custom_analyzer",
109        vec![
110            M::up(analyzer_definition).down(surrql!("REMOVE ANALYZER IF EXISTS custom_analyzer;")),
111        ],
112    );
113
114    migrations.to_latest(db).await?;
115
116    Ok(())
117}
118
119#[cfg(test)]
120mod test {
121    use super::schemas::{
122        album::Album, artist::Artist, collection::Collection, dynamic::DynamicPlaylist,
123        playlist::Playlist, song::Song,
124    };
125    use super::*;
126
127    use surrealdb::engine::local::Mem;
128    use surrealqlx::traits::Table;
129
130    #[tokio::test]
131    async fn test_register_tables() -> anyhow::Result<()> {
132        let config = Config::new().strict();
133        // use an in-memory db for testing
134        let db = Surreal::new::<Mem>(config).await?;
135
136        db.query("DEFINE NAMESPACE IF NOT EXISTS test").await?;
137        db.use_ns("test").await?;
138        db.query("DEFINE DATABASE IF NOT EXISTS test").await?;
139        db.use_db("test").await?;
140
141        // register the custom analyzer
142        register_custom_analyzer(&db).await?;
143
144        // first we init all the table to ensure that the queries made by the macro work without error
145        <Album as Table>::init_table(&db).await?;
146        <Artist as Table>::init_table(&db).await?;
147        <Song as Table>::init_table(&db).await?;
148        <Collection as Table>::init_table(&db).await?;
149        <Playlist as Table>::init_table(&db).await?;
150        <DynamicPlaylist as Table>::init_table(&db).await?;
151
152        // then we init the relation tables
153        queries::relations::define_relation_tables(&db).await?;
154
155        // then we try initializing one of the tables again to ensure that initialization won't mess with existing tables/data
156        <Album as Table>::init_table(&db).await?;
157
158        Ok(())
159    }
160}
161
162#[cfg(test)]
163mod minimal_reproduction {
164    //! This module contains minimal reproductions of issues from MECOMPs past.
165    //! They exist to ensure that the issues are indeed fixed.
166    use serde::{Deserialize, Serialize};
167    use surrealdb::{RecordId, Surreal, engine::local::Mem, method::Stats};
168    use surrealqlx::surrql;
169
170    use crate::db::queries::generic::{Count, count};
171
172    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
173    struct User {
174        id: RecordId,
175        name: String,
176        age: u64,
177        favorite_numbers: [u64; 7],
178    }
179
180    static SCHEMA_SQL: &str = r"
181    BEGIN;
182    DEFINE TABLE users SCHEMAFULL;
183    COMMIT;
184    BEGIN;
185    DEFINE FIELD id ON users TYPE record;
186    DEFINE FIELD name ON users TYPE string;
187    DEFINE FIELD age ON users TYPE int;
188    DEFINE FIELD favorite_numbers ON users TYPE array<int>;
189    COMMIT;
190    BEGIN;
191    DEFINE INDEX users_name_unique_index ON users FIELDS name UNIQUE;
192    DEFINE INDEX users_age_normal_index ON users FIELDS age;
193    DEFINE INDEX users_favorite_numbers_vector_index ON users FIELDS favorite_numbers MTREE DIMENSION 7;
194    ";
195    const NUMBER_OF_USERS: u64 = 100;
196
197    #[tokio::test]
198    async fn minimal_reproduction() {
199        let db = Surreal::new::<Mem>(()).await.unwrap();
200        db.use_ns("test").use_db("test").await.unwrap();
201
202        db.query(SCHEMA_SQL).await.unwrap();
203
204        let cnt: Option<Count> = db
205            // new syntax
206            .query(count())
207            .bind(("table", "users"))
208            .await
209            .unwrap()
210            .take(0)
211            .unwrap();
212
213        assert_eq!(cnt, Some(Count::new(0)));
214
215        let john_id = RecordId::from(("users", "0"));
216        let john = User {
217            id: john_id.clone(),
218            name: "John".to_string(),
219            age: 42,
220            favorite_numbers: [1, 2, 3, 4, 5, 6, 7],
221        };
222
223        let sally_id = RecordId::from(("users", "1"));
224        let sally = User {
225            id: sally_id.clone(),
226            name: "Sally".to_string(),
227            age: 24,
228            favorite_numbers: [8, 9, 10, 11, 12, 13, 14],
229        };
230
231        let result: Option<User> = db
232            .create(john_id.clone())
233            .content(john.clone())
234            .await
235            .unwrap();
236
237        assert_eq!(result, Some(john.clone()));
238
239        let result: Option<User> = db
240            .create(sally_id.clone())
241            .content(sally.clone())
242            .await
243            .unwrap();
244
245        assert_eq!(result, Some(sally.clone()));
246
247        let result: Option<User> = db.select(john_id).await.unwrap();
248
249        assert_eq!(result, Some(john.clone()));
250
251        // create like 100 more users
252        for i in 2..NUMBER_OF_USERS {
253            let user_id = RecordId::from(("users", i.to_string()));
254            let user = User {
255                id: user_id.clone(),
256                name: format!("User {i}"),
257                age: i,
258                favorite_numbers: [i; 7],
259            };
260            let _: Option<User> = db.create(user_id.clone()).content(user).await.unwrap();
261        }
262
263        let mut resp_new = db
264            // new syntax
265            .query(surrql!("SELECT count() FROM users GROUP ALL"))
266            .with_stats()
267            .await
268            .unwrap();
269        dbg!(&resp_new);
270        let res = resp_new.take(0).unwrap();
271        let cnt: Option<Count> = res.1.unwrap();
272        assert_eq!(cnt, Some(Count::new(NUMBER_OF_USERS)));
273        let stats_new: Stats = res.0;
274
275        let mut resp_old = db
276            // old syntax
277            .query(surrql!("RETURN array::len((SELECT * FROM users))"))
278            .with_stats()
279            .await
280            .unwrap();
281        dbg!(&resp_old);
282        let res = resp_old.take(0).unwrap();
283        let cnt: Option<u64> = res.1.unwrap();
284        assert_eq!(cnt, Some(NUMBER_OF_USERS));
285        let stats_old: Stats = res.0;
286
287        // just a check to ensure the new syntax is faster
288        assert!(stats_new.execution_time.unwrap() < stats_old.execution_time.unwrap());
289
290        let result: Vec<User> = db.delete("users").await.unwrap();
291
292        assert_eq!(result.len() as u64, NUMBER_OF_USERS);
293        assert!(result.contains(&john), "Result does not contain 'john'");
294        assert!(result.contains(&sally), "Result does not contain 'sally'");
295    }
296}