Skip to main content

mecomp_storage/db/
mod.rs

1#[cfg(feature = "db")]
2pub mod crud;
3#[cfg(feature = "db")]
4pub mod health;
5#[cfg(feature = "db")]
6pub(crate) mod queries;
7pub mod schemas;
8
9#[cfg(feature = "db")]
10use surrealdb::{Surreal, engine::local::Db, opt::Config};
11
12#[cfg(feature = "db")]
13#[cfg(not(tarpaulin_include))]
14static DB_DIR: once_cell::sync::OnceCell<std::path::PathBuf> = once_cell::sync::OnceCell::new();
15#[cfg(feature = "db")]
16#[cfg(not(tarpaulin_include))]
17static TEMP_DB_DIR: once_cell::sync::Lazy<tempfile::TempDir> = once_cell::sync::Lazy::new(|| {
18    tempfile::tempdir().expect("Failed to create temporary directory")
19});
20
21/// Set the path to the database.
22///
23/// # Errors
24///
25/// This function will return an error if the path cannot be set.
26#[cfg(feature = "db")]
27#[allow(clippy::missing_inline_in_public_items)]
28pub fn set_database_path(path: std::path::PathBuf) -> Result<(), crate::errors::Error> {
29    DB_DIR
30        .set(path)
31        .map_err(crate::errors::Error::DbPathSetError)?;
32    log::info!("Primed database path");
33    Ok(())
34}
35
36/// Initialize the database with the necessary tables.
37///
38/// # Errors
39///
40/// This function will return an error if the database cannot be initialized.
41#[cfg(feature = "db")]
42#[allow(clippy::missing_inline_in_public_items)]
43pub async fn init_database() -> surrealqlx::Result<Surreal<Db>> {
44    use surrealqlx::surrql;
45
46    let config = Config::new().strict();
47    let db_path = DB_DIR
48    .get().cloned()
49    .unwrap_or_else(|| {
50        log::warn!("DB_DIR not set, defaulting to a temporary directory `{}`, this is likely a bug because `set_database_path` should be called before `init_database`", TEMP_DB_DIR.path().display());
51        TEMP_DB_DIR.path()
52        .to_path_buf()
53    });
54    let db = Surreal::new((db_path, config)).await?;
55
56    db.query(surrql!("DEFINE NAMESPACE IF NOT EXISTS mecomp"))
57        .await?;
58    db.use_ns("mecomp").await?;
59    db.query(surrql!("DEFINE DATABASE IF NOT EXISTS music"))
60        .await?;
61    db.use_db("music").await?;
62
63    register_custom_analyzer(&db).await?;
64    surrealqlx::register_tables!(
65        &db,
66        schemas::album::Album,
67        schemas::artist::Artist,
68        schemas::song::Song,
69        schemas::collection::Collection,
70        schemas::playlist::Playlist,
71        schemas::dynamic::DynamicPlaylist
72    )?;
73    #[cfg(feature = "analysis")]
74    surrealqlx::register_tables!(&db, schemas::analysis::Analysis)?;
75
76    queries::relations::define_relation_tables(&db).await?;
77
78    Ok(db)
79}
80
81#[cfg(feature = "db")]
82pub(crate) async fn register_custom_analyzer<C>(db: &Surreal<C>) -> surrealqlx::Result<()>
83where
84    C: surrealdb::Connection,
85{
86    use surrealqlx::{
87        migrations::{M, Migrations},
88        surrql,
89    };
90
91    // NOTE: if you change this, you must go through the schemas and update the index analyzer names
92    let analyzer_definition = surrql!(
93        "DEFINE ANALYZER IF NOT EXISTS custom_analyzer
94         TOKENIZERS class
95         FILTERS ascii,lowercase,edgengram(1,10),snowball(English);"
96    );
97
98    let migrations = Migrations::new(
99        "custom_analyzer",
100        vec![
101            M::up(analyzer_definition).down(surrql!("REMOVE ANALYZER IF EXISTS custom_analyzer;")),
102        ],
103    );
104
105    migrations.to_latest(db).await?;
106
107    Ok(())
108}
109
110#[cfg(test)]
111mod test {
112    use super::schemas::{
113        album::Album, artist::Artist, collection::Collection, dynamic::DynamicPlaylist,
114        playlist::Playlist, song::Song,
115    };
116    use super::*;
117
118    use surrealdb::engine::local::Mem;
119    use surrealqlx::traits::Table;
120
121    #[tokio::test]
122    async fn test_register_tables() -> anyhow::Result<()> {
123        let config = Config::new().strict();
124        // use an in-memory db for testing
125        let db = Surreal::new::<Mem>(config).await?;
126
127        db.query("DEFINE NAMESPACE IF NOT EXISTS test").await?;
128        db.use_ns("test").await?;
129        db.query("DEFINE DATABASE IF NOT EXISTS test").await?;
130        db.use_db("test").await?;
131
132        // register the custom analyzer
133        register_custom_analyzer(&db).await?;
134
135        // first we init all the table to ensure that the queries made by the macro work without error
136        <Album as Table>::init_table(&db).await?;
137        <Artist as Table>::init_table(&db).await?;
138        <Song as Table>::init_table(&db).await?;
139        <Collection as Table>::init_table(&db).await?;
140        <Playlist as Table>::init_table(&db).await?;
141        <DynamicPlaylist as Table>::init_table(&db).await?;
142
143        // then we init the relation tables
144        queries::relations::define_relation_tables(&db).await?;
145
146        // then we try initializing one of the tables again to ensure that initialization won't mess with existing tables/data
147        <Album as Table>::init_table(&db).await?;
148
149        Ok(())
150    }
151}
152
153#[cfg(test)]
154mod minimal_reproduction {
155    //! This module contains minimal reproductions of issues from MECOMPs past.
156    //! They exist to ensure that the issues are indeed fixed.
157    use serde::{Deserialize, Serialize};
158    use surrealdb::{RecordId, Surreal, engine::local::Mem, method::Stats};
159    use surrealqlx::surrql;
160
161    use crate::db::queries::generic::{Count, count};
162
163    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
164    struct User {
165        id: RecordId,
166        name: String,
167        age: u64,
168        favorite_numbers: [u64; 7],
169    }
170
171    static SCHEMA_SQL: &str = r"
172    BEGIN;
173    DEFINE TABLE users SCHEMAFULL;
174    COMMIT;
175    BEGIN;
176    DEFINE FIELD id ON users TYPE record;
177    DEFINE FIELD name ON users TYPE string;
178    DEFINE FIELD age ON users TYPE int;
179    DEFINE FIELD favorite_numbers ON users TYPE array<int>;
180    COMMIT;
181    BEGIN;
182    DEFINE INDEX users_name_unique_index ON users FIELDS name UNIQUE;
183    DEFINE INDEX users_age_normal_index ON users FIELDS age;
184    DEFINE INDEX users_favorite_numbers_vector_index ON users FIELDS favorite_numbers MTREE DIMENSION 7;
185    ";
186    const NUMBER_OF_USERS: u64 = 100;
187
188    #[tokio::test]
189    async fn minimal_reproduction() {
190        let db = Surreal::new::<Mem>(()).await.unwrap();
191        db.use_ns("test").use_db("test").await.unwrap();
192
193        db.query(SCHEMA_SQL).await.unwrap();
194
195        let cnt: Option<Count> = db
196            // new syntax
197            .query(count())
198            .bind(("table", "users"))
199            .await
200            .unwrap()
201            .take(0)
202            .unwrap();
203
204        assert_eq!(cnt, Some(Count::new(0)));
205
206        let john_id = RecordId::from(("users", "0"));
207        let john = User {
208            id: john_id.clone(),
209            name: "John".to_string(),
210            age: 42,
211            favorite_numbers: [1, 2, 3, 4, 5, 6, 7],
212        };
213
214        let sally_id = RecordId::from(("users", "1"));
215        let sally = User {
216            id: sally_id.clone(),
217            name: "Sally".to_string(),
218            age: 24,
219            favorite_numbers: [8, 9, 10, 11, 12, 13, 14],
220        };
221
222        let result: Option<User> = db
223            .create(john_id.clone())
224            .content(john.clone())
225            .await
226            .unwrap();
227
228        assert_eq!(result, Some(john.clone()));
229
230        let result: Option<User> = db
231            .create(sally_id.clone())
232            .content(sally.clone())
233            .await
234            .unwrap();
235
236        assert_eq!(result, Some(sally.clone()));
237
238        let result: Option<User> = db.select(john_id).await.unwrap();
239
240        assert_eq!(result, Some(john.clone()));
241
242        // create like 100 more users
243        for i in 2..NUMBER_OF_USERS {
244            let user_id = RecordId::from(("users", i.to_string()));
245            let user = User {
246                id: user_id.clone(),
247                name: format!("User {i}"),
248                age: i,
249                favorite_numbers: [i; 7],
250            };
251            let _: Option<User> = db.create(user_id.clone()).content(user).await.unwrap();
252        }
253
254        let mut resp_new = db
255            // new syntax
256            .query(surrql!("SELECT count() FROM users GROUP ALL"))
257            .with_stats()
258            .await
259            .unwrap();
260        dbg!(&resp_new);
261        let res = resp_new.take(0).unwrap();
262        let cnt: Option<Count> = res.1.unwrap();
263        assert_eq!(cnt, Some(Count::new(NUMBER_OF_USERS)));
264        let stats_new: Stats = res.0;
265
266        let mut resp_old = db
267            // old syntax
268            .query(surrql!("RETURN array::len((SELECT * FROM users))"))
269            .with_stats()
270            .await
271            .unwrap();
272        dbg!(&resp_old);
273        let res = resp_old.take(0).unwrap();
274        let cnt: Option<u64> = res.1.unwrap();
275        assert_eq!(cnt, Some(NUMBER_OF_USERS));
276        let stats_old: Stats = res.0;
277
278        // just a check to ensure the new syntax is faster
279        assert!(stats_new.execution_time.unwrap() < stats_old.execution_time.unwrap());
280
281        let result: Vec<User> = db.delete("users").await.unwrap();
282
283        assert_eq!(result.len() as u64, NUMBER_OF_USERS);
284        assert!(result.contains(&john), "Result does not contain 'john'");
285        assert!(result.contains(&sally), "Result does not contain 'sally'");
286    }
287}