mecomp_storage/db/
mod.rs

1#[cfg(feature = "db")]
2pub mod crud;
3#[cfg(feature = "db")]
4pub mod health;
5#[cfg(feature = "db")]
6pub(crate) mod queries;
7pub mod schemas;
8
9#[cfg(feature = "db")]
10use surrealdb::{Surreal, engine::local::Db, opt::Config};
11
12#[cfg(feature = "db")]
13#[cfg(not(tarpaulin_include))]
14static DB_DIR: once_cell::sync::OnceCell<std::path::PathBuf> = once_cell::sync::OnceCell::new();
15#[cfg(feature = "db")]
16#[cfg(not(tarpaulin_include))]
17static TEMP_DB_DIR: once_cell::sync::Lazy<tempfile::TempDir> = once_cell::sync::Lazy::new(|| {
18    tempfile::tempdir().expect("Failed to create temporary directory")
19});
20
21/// NOTE: if you change this, you must go through the schemas and update the index analyzer names
22pub const FULL_TEXT_SEARCH_ANALYZER_NAME: &str = "custom_analyzer";
23
24/// Set the path to the database.
25///
26/// # Errors
27///
28/// This function will return an error if the path cannot be set.
29#[cfg(feature = "db")]
30#[allow(clippy::missing_inline_in_public_items)]
31pub fn set_database_path(path: std::path::PathBuf) -> Result<(), crate::errors::Error> {
32    DB_DIR
33        .set(path)
34        .map_err(crate::errors::Error::DbPathSetError)?;
35    log::info!("Primed database path");
36    Ok(())
37}
38
39/// Initialize the database with the necessary tables.
40///
41/// # Errors
42///
43/// This function will return an error if the database cannot be initialized.
44#[cfg(feature = "db")]
45#[allow(clippy::missing_inline_in_public_items)]
46pub async fn init_database() -> surrealdb::Result<Surreal<Db>> {
47    let config = Config::new().strict();
48    let db_path = DB_DIR
49    .get().cloned()
50    .unwrap_or_else(|| {
51        log::warn!("DB_DIR not set, defaulting to a temporary directory `{}`, this is likely a bug because `set_database_path` should be called before `init_database`", TEMP_DB_DIR.path().display());
52        TEMP_DB_DIR.path()
53        .to_path_buf()
54    });
55    let db = Surreal::new((db_path, config)).await?;
56
57    db.query("DEFINE NAMESPACE IF NOT EXISTS mecomp").await?;
58    db.use_ns("mecomp").await?;
59    db.query("DEFINE DATABASE IF NOT EXISTS music").await?;
60    db.use_db("music").await?;
61
62    register_custom_analyzer(&db).await?;
63    surrealqlx::register_tables!(
64        &db,
65        schemas::album::Album,
66        schemas::artist::Artist,
67        schemas::song::Song,
68        schemas::collection::Collection,
69        schemas::playlist::Playlist,
70        schemas::dynamic::DynamicPlaylist
71    )?;
72    #[cfg(feature = "analysis")]
73    surrealqlx::register_tables!(&db, schemas::analysis::Analysis)?;
74
75    queries::relations::define_relation_tables(&db).await?;
76
77    Ok(db)
78}
79
80#[cfg(feature = "db")]
81pub(crate) async fn register_custom_analyzer<C>(db: &Surreal<C>) -> surrealdb::Result<()>
82where
83    C: surrealdb::Connection,
84{
85    use queries::define_analyzer;
86    use surrealdb::sql::Tokenizer;
87
88    db.query(define_analyzer(
89        FULL_TEXT_SEARCH_ANALYZER_NAME,
90        Some(Tokenizer::Class),
91        &[
92            "ascii",
93            "lowercase",
94            "edgengram(1, 10)",
95            "snowball(English)",
96        ],
97    ))
98    .await?;
99
100    Ok(())
101}
102
103#[cfg(test)]
104mod test {
105    use super::schemas::{
106        album::Album, artist::Artist, collection::Collection, dynamic::DynamicPlaylist,
107        playlist::Playlist, song::Song,
108    };
109    use super::*;
110
111    use surrealdb::engine::local::Mem;
112    use surrealqlx::traits::Table;
113
114    #[tokio::test]
115    async fn test_register_tables() -> anyhow::Result<()> {
116        let config = Config::new().strict();
117        // use an in-memory db for testing
118        let db = Surreal::new::<Mem>(config).await?;
119
120        db.query("DEFINE NAMESPACE IF NOT EXISTS test").await?;
121        db.use_ns("test").await?;
122        db.query("DEFINE DATABASE IF NOT EXISTS test").await?;
123        db.use_db("test").await?;
124
125        // register the custom analyzer
126        register_custom_analyzer(&db).await?;
127
128        // first we init all the table to ensure that the queries made by the macro work without error
129        <Album as Table>::init_table(&db).await?;
130        <Artist as Table>::init_table(&db).await?;
131        <Song as Table>::init_table(&db).await?;
132        <Collection as Table>::init_table(&db).await?;
133        <Playlist as Table>::init_table(&db).await?;
134        <DynamicPlaylist as Table>::init_table(&db).await?;
135
136        // then we init the relation tables
137        queries::relations::define_relation_tables(&db).await?;
138
139        // then we try initializing one of the tables again to ensure that initialization won't mess with existing tables/data
140        <Album as Table>::init_table(&db).await?;
141
142        Ok(())
143    }
144}
145
146#[cfg(test)]
147mod minimal_reproduction {
148    //! This module contains minimal reproductions of issues from MECOMPs past.
149    //! They exist to ensure that the issues are indeed fixed.
150    use serde::{Deserialize, Serialize};
151    use surrealdb::{RecordId, Surreal, engine::local::Mem, method::Stats};
152
153    use crate::db::queries::generic::{Count, count};
154
155    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
156    struct User {
157        id: RecordId,
158        name: String,
159        age: usize,
160        favorite_numbers: [usize; 7],
161    }
162
163    static SCHEMA_SQL: &str = r"
164    BEGIN;
165    DEFINE TABLE users SCHEMAFULL;
166    COMMIT;
167    BEGIN;
168    DEFINE FIELD id ON users TYPE record;
169    DEFINE FIELD name ON users TYPE string;
170    DEFINE FIELD age ON users TYPE int;
171    DEFINE FIELD favorite_numbers ON users TYPE array<int>;
172    COMMIT;
173    BEGIN;
174    DEFINE INDEX users_name_unique_index ON users FIELDS name UNIQUE;
175    DEFINE INDEX users_age_normal_index ON users FIELDS age;
176    DEFINE INDEX users_favorite_numbers_vector_index ON users FIELDS favorite_numbers MTREE DIMENSION 7;
177    ";
178    const NUMBER_OF_USERS: usize = 100;
179
180    #[tokio::test]
181    async fn minimal_reproduction() {
182        let db = Surreal::new::<Mem>(()).await.unwrap();
183        db.use_ns("test").use_db("test").await.unwrap();
184
185        db.query(SCHEMA_SQL).await.unwrap();
186
187        let cnt: Option<Count> = db
188            // new syntax
189            .query(count("users"))
190            .await
191            .unwrap()
192            .take(0)
193            .unwrap();
194
195        assert_eq!(cnt, Some(Count::new(0)));
196
197        let john_id = RecordId::from(("users", "0"));
198        let john = User {
199            id: john_id.clone(),
200            name: "John".to_string(),
201            age: 42,
202            favorite_numbers: [1, 2, 3, 4, 5, 6, 7],
203        };
204
205        let sally_id = RecordId::from(("users", "1"));
206        let sally = User {
207            id: sally_id.clone(),
208            name: "Sally".to_string(),
209            age: 24,
210            favorite_numbers: [8, 9, 10, 11, 12, 13, 14],
211        };
212
213        let result: Option<User> = db
214            .create(john_id.clone())
215            .content(john.clone())
216            .await
217            .unwrap();
218
219        assert_eq!(result, Some(john.clone()));
220
221        let result: Option<User> = db
222            .create(sally_id.clone())
223            .content(sally.clone())
224            .await
225            .unwrap();
226
227        assert_eq!(result, Some(sally.clone()));
228
229        let result: Option<User> = db.select(john_id).await.unwrap();
230
231        assert_eq!(result, Some(john.clone()));
232
233        // create like 100 more users
234        for i in 2..NUMBER_OF_USERS {
235            let user_id = RecordId::from(("users", i.to_string()));
236            let user = User {
237                id: user_id.clone(),
238                name: format!("User {i}"),
239                age: i,
240                favorite_numbers: [i; 7],
241            };
242            let _: Option<User> = db.create(user_id.clone()).content(user).await.unwrap();
243        }
244
245        let mut resp_new = db
246            // new syntax
247            .query("SELECT count() FROM users GROUP ALL")
248            .with_stats()
249            .await
250            .unwrap();
251        dbg!(&resp_new);
252        let res = resp_new.take(0).unwrap();
253        let cnt: Option<Count> = res.1.unwrap();
254        assert_eq!(cnt, Some(Count::new(NUMBER_OF_USERS)));
255        let stats_new: Stats = res.0;
256
257        let mut resp_old = db
258            // old syntax
259            .query("RETURN array::len((SELECT * FROM users))")
260            .with_stats()
261            .await
262            .unwrap();
263        dbg!(&resp_old);
264        let res = resp_old.take(0).unwrap();
265        let cnt: Option<usize> = res.1.unwrap();
266        assert_eq!(cnt, Some(NUMBER_OF_USERS));
267        let stats_old: Stats = res.0;
268
269        // just a check to ensure the new syntax is faster
270        assert!(stats_new.execution_time.unwrap() < stats_old.execution_time.unwrap());
271
272        let result: Vec<User> = db.delete("users").await.unwrap();
273
274        assert_eq!(result.len(), NUMBER_OF_USERS);
275        assert!(result.contains(&john), "Result does not contain 'john'");
276        assert!(result.contains(&sally), "Result does not contain 'sally'");
277    }
278}