1#[cfg(feature = "db")]
2pub mod crud;
3#[cfg(feature = "db")]
4pub mod health;
5#[cfg(feature = "db")]
6pub(crate) mod queries;
7pub mod schemas;
8
9#[cfg(feature = "db")]
10use surrealdb::{Surreal, engine::local::Db, opt::Config};
11
12#[cfg(feature = "db")]
13#[cfg(not(tarpaulin_include))]
14static DB_DIR: once_cell::sync::OnceCell<std::path::PathBuf> = once_cell::sync::OnceCell::new();
15#[cfg(feature = "db")]
16#[cfg(not(tarpaulin_include))]
17static TEMP_DB_DIR: once_cell::sync::Lazy<tempfile::TempDir> = once_cell::sync::Lazy::new(|| {
18 tempfile::tempdir().expect("Failed to create temporary directory")
19});
20
21#[cfg(feature = "db")]
27#[allow(clippy::missing_inline_in_public_items)]
28pub fn set_database_path(path: std::path::PathBuf) -> Result<(), crate::errors::Error> {
29 DB_DIR
30 .set(path)
31 .map_err(crate::errors::Error::DbPathSetError)?;
32 log::info!("Primed database path");
33 Ok(())
34}
35
36#[cfg(feature = "db")]
42#[allow(clippy::missing_inline_in_public_items)]
43pub async fn init_database() -> surrealqlx::Result<Surreal<Db>> {
44 use surrealqlx::surrql;
45
46 let db_path = DB_DIR
48 .get().cloned()
49 .unwrap_or_else(|| {
50 log::warn!("DB_DIR not set, defaulting to a temporary directory `{}`, this is likely a bug because `set_database_path` should be called before `init_database`", TEMP_DB_DIR.path().display());
51 TEMP_DB_DIR.path()
52 .to_path_buf()
53 });
54
55 let temp_dir = db_path.join("temp");
56 let temp_dir = if temp_dir.exists() {
57 Some(temp_dir)
58 } else {
59 std::fs::create_dir_all(&temp_dir).ok().and(Some(temp_dir))
61 };
62 let config = Config::new().strict().temporary_directory(temp_dir);
63 let db = Surreal::new((db_path, config)).await?;
64
65 db.query(surrql!("DEFINE NAMESPACE IF NOT EXISTS mecomp"))
66 .await?;
67 db.use_ns("mecomp").await?;
68 db.query(surrql!("DEFINE DATABASE IF NOT EXISTS music"))
69 .await?;
70 db.use_db("music").await?;
71
72 register_custom_analyzer(&db).await?;
73 surrealqlx::register_tables!(
74 &db,
75 schemas::album::Album,
76 schemas::artist::Artist,
77 schemas::song::Song,
78 schemas::collection::Collection,
79 schemas::playlist::Playlist,
80 schemas::dynamic::DynamicPlaylist
81 )?;
82 #[cfg(feature = "analysis")]
83 surrealqlx::register_tables!(&db, schemas::analysis::Analysis)?;
84
85 queries::relations::define_relation_tables(&db).await?;
86
87 Ok(db)
88}
89
90#[cfg(feature = "db")]
91pub(crate) async fn register_custom_analyzer<C>(db: &Surreal<C>) -> surrealqlx::Result<()>
92where
93 C: surrealdb::Connection,
94{
95 use surrealqlx::{
96 migrations::{M, Migrations},
97 surrql,
98 };
99
100 let analyzer_definition = surrql!(
102 "DEFINE ANALYZER IF NOT EXISTS custom_analyzer
103 TOKENIZERS class
104 FILTERS ascii,lowercase,edgengram(1,10),snowball(English);"
105 );
106
107 let migrations = Migrations::new(
108 "custom_analyzer",
109 vec![
110 M::up(analyzer_definition).down(surrql!("REMOVE ANALYZER IF EXISTS custom_analyzer;")),
111 ],
112 );
113
114 migrations.to_latest(db).await?;
115
116 Ok(())
117}
118
119#[cfg(test)]
120mod test {
121 use super::schemas::{
122 album::Album, artist::Artist, collection::Collection, dynamic::DynamicPlaylist,
123 playlist::Playlist, song::Song,
124 };
125 use super::*;
126
127 use surrealdb::engine::local::Mem;
128 use surrealqlx::traits::Table;
129
130 #[tokio::test]
131 async fn test_register_tables() -> anyhow::Result<()> {
132 let config = Config::new().strict();
133 let db = Surreal::new::<Mem>(config).await?;
135
136 db.query("DEFINE NAMESPACE IF NOT EXISTS test").await?;
137 db.use_ns("test").await?;
138 db.query("DEFINE DATABASE IF NOT EXISTS test").await?;
139 db.use_db("test").await?;
140
141 register_custom_analyzer(&db).await?;
143
144 <Album as Table>::init_table(&db).await?;
146 <Artist as Table>::init_table(&db).await?;
147 <Song as Table>::init_table(&db).await?;
148 <Collection as Table>::init_table(&db).await?;
149 <Playlist as Table>::init_table(&db).await?;
150 <DynamicPlaylist as Table>::init_table(&db).await?;
151
152 queries::relations::define_relation_tables(&db).await?;
154
155 <Album as Table>::init_table(&db).await?;
157
158 Ok(())
159 }
160}
161
162#[cfg(test)]
163mod minimal_reproduction {
164 use serde::{Deserialize, Serialize};
167 use surrealdb::{RecordId, Surreal, engine::local::Mem, method::Stats};
168 use surrealqlx::surrql;
169
170 use crate::db::queries::generic::{Count, count};
171
172 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
173 struct User {
174 id: RecordId,
175 name: String,
176 age: u64,
177 favorite_numbers: [u64; 7],
178 }
179
180 static SCHEMA_SQL: &str = r"
181 BEGIN;
182 DEFINE TABLE users SCHEMAFULL;
183 COMMIT;
184 BEGIN;
185 DEFINE FIELD id ON users TYPE record;
186 DEFINE FIELD name ON users TYPE string;
187 DEFINE FIELD age ON users TYPE int;
188 DEFINE FIELD favorite_numbers ON users TYPE array<int>;
189 COMMIT;
190 BEGIN;
191 DEFINE INDEX users_name_unique_index ON users FIELDS name UNIQUE;
192 DEFINE INDEX users_age_normal_index ON users FIELDS age;
193 DEFINE INDEX users_favorite_numbers_vector_index ON users FIELDS favorite_numbers MTREE DIMENSION 7;
194 ";
195 const NUMBER_OF_USERS: u64 = 100;
196
197 #[tokio::test]
198 async fn minimal_reproduction() {
199 let db = Surreal::new::<Mem>(()).await.unwrap();
200 db.use_ns("test").use_db("test").await.unwrap();
201
202 db.query(SCHEMA_SQL).await.unwrap();
203
204 let cnt: Option<Count> = db
205 .query(count())
207 .bind(("table", "users"))
208 .await
209 .unwrap()
210 .take(0)
211 .unwrap();
212
213 assert_eq!(cnt, Some(Count::new(0)));
214
215 let john_id = RecordId::from(("users", "0"));
216 let john = User {
217 id: john_id.clone(),
218 name: "John".to_string(),
219 age: 42,
220 favorite_numbers: [1, 2, 3, 4, 5, 6, 7],
221 };
222
223 let sally_id = RecordId::from(("users", "1"));
224 let sally = User {
225 id: sally_id.clone(),
226 name: "Sally".to_string(),
227 age: 24,
228 favorite_numbers: [8, 9, 10, 11, 12, 13, 14],
229 };
230
231 let result: Option<User> = db
232 .create(john_id.clone())
233 .content(john.clone())
234 .await
235 .unwrap();
236
237 assert_eq!(result, Some(john.clone()));
238
239 let result: Option<User> = db
240 .create(sally_id.clone())
241 .content(sally.clone())
242 .await
243 .unwrap();
244
245 assert_eq!(result, Some(sally.clone()));
246
247 let result: Option<User> = db.select(john_id).await.unwrap();
248
249 assert_eq!(result, Some(john.clone()));
250
251 for i in 2..NUMBER_OF_USERS {
253 let user_id = RecordId::from(("users", i.to_string()));
254 let user = User {
255 id: user_id.clone(),
256 name: format!("User {i}"),
257 age: i,
258 favorite_numbers: [i; 7],
259 };
260 let _: Option<User> = db.create(user_id.clone()).content(user).await.unwrap();
261 }
262
263 let mut resp_new = db
264 .query(surrql!("SELECT count() FROM users GROUP ALL"))
266 .with_stats()
267 .await
268 .unwrap();
269 dbg!(&resp_new);
270 let res = resp_new.take(0).unwrap();
271 let cnt: Option<Count> = res.1.unwrap();
272 assert_eq!(cnt, Some(Count::new(NUMBER_OF_USERS)));
273 let stats_new: Stats = res.0;
274
275 let mut resp_old = db
276 .query(surrql!("RETURN array::len((SELECT * FROM users))"))
278 .with_stats()
279 .await
280 .unwrap();
281 dbg!(&resp_old);
282 let res = resp_old.take(0).unwrap();
283 let cnt: Option<u64> = res.1.unwrap();
284 assert_eq!(cnt, Some(NUMBER_OF_USERS));
285 let stats_old: Stats = res.0;
286
287 assert!(stats_new.execution_time.unwrap() < stats_old.execution_time.unwrap());
289
290 let result: Vec<User> = db.delete("users").await.unwrap();
291
292 assert_eq!(result.len() as u64, NUMBER_OF_USERS);
293 assert!(result.contains(&john), "Result does not contain 'john'");
294 assert!(result.contains(&sally), "Result does not contain 'sally'");
295 }
296}