1use std::{
2 ops::{Range, RangeInclusive},
3 path::PathBuf,
4 str::FromStr,
5 sync::Arc,
6 time::Duration,
7};
8
9use anyhow::Result;
10use lofty::{config::WriteOptions, file::TaggedFileExt, prelude::*, probe::Probe, tag::Accessor};
11use one_or_many::OneOrMany;
12use rand::{seq::IteratorRandom, Rng};
13#[cfg(feature = "db")]
14use surrealdb::{
15 engine::local::{Db, Mem},
16 sql::Id,
17 Connection, Surreal,
18};
19
20#[cfg(feature = "analysis")]
21use crate::db::schemas::analysis::Analysis;
22#[cfg(not(feature = "db"))]
23use crate::db::schemas::Id;
24use crate::db::schemas::{
25 album::Album,
26 artist::Artist,
27 collection::Collection,
28 playlist::Playlist,
29 song::{Song, SongChangeSet, SongMetadata},
30};
31
32pub const ARTIST_NAME_SEPARATOR: &str = ", ";
33
34#[cfg(feature = "db")]
41pub async fn init_test_database() -> surrealdb::Result<Surreal<Db>> {
42 use crate::db::schemas::dynamic::DynamicPlaylist;
43
44 let db = Surreal::new::<Mem>(()).await?;
45 db.use_ns("test").use_db("test").await?;
46
47 crate::db::register_custom_analyzer(&db).await?;
48 surrealqlx::register_tables!(
49 &db,
50 Album,
51 Artist,
52 Song,
53 Collection,
54 Playlist,
55 DynamicPlaylist
56 )?;
57 #[cfg(feature = "analysis")]
58 surrealqlx::register_tables!(&db, Analysis)?;
59
60 Ok(db)
61}
62
63#[cfg(feature = "db")]
90pub async fn init_test_database_with_state<SCF>(
91 song_count: std::num::NonZero<usize>,
92 mut song_case_func: SCF,
93 dynamic: Option<crate::db::schemas::dynamic::DynamicPlaylist>,
94 tempdir: &tempfile::TempDir,
95) -> Arc<Surreal<Db>>
96where
97 SCF: FnMut(usize) -> (SongCase, bool, bool) + Send + Sync,
98{
99 use anyhow::Context;
100
101 use crate::db::schemas::dynamic::DynamicPlaylist;
102
103 let db = Arc::new(init_test_database().await.unwrap());
104
105 let playlist = Playlist {
107 id: Playlist::generate_id(),
108 name: "Playlist 0".into(),
109 runtime: Duration::from_secs(0),
110 song_count: 0,
111 };
112 let playlist = Playlist::create(&db, playlist).await.unwrap().unwrap();
113
114 let collection = Collection {
115 id: Collection::generate_id(),
116 name: "Collection 0".into(),
117 runtime: Duration::from_secs(0),
118 song_count: 0,
119 };
120 let collection = Collection::create(&db, collection).await.unwrap().unwrap();
121
122 if let Some(dynamic) = dynamic {
123 let _ = DynamicPlaylist::create(&db, dynamic)
124 .await
125 .unwrap()
126 .unwrap();
127 }
128
129 for i in 0..(song_count.get()) {
131 let (song_case, add_to_playlist, add_to_collection) = song_case_func(i);
132
133 let metadata = create_song_metadata(tempdir, song_case.clone())
134 .context(format!(
135 "failed to create metadata for song case {song_case:?}"
136 ))
137 .unwrap();
138
139 let song = Song::try_load_into_db(&db, metadata)
140 .await
141 .context(format!(
142 "Failed to load into db the song case: {song_case:?}"
143 ))
144 .unwrap();
145
146 if add_to_playlist {
147 Playlist::add_songs(&db, playlist.id.clone(), vec![song.id.clone()])
148 .await
149 .unwrap();
150 }
151 if add_to_collection {
152 Collection::add_songs(&db, collection.id.clone(), vec![song.id.clone()])
153 .await
154 .unwrap();
155 }
156 }
157
158 db
159}
160
161#[cfg(feature = "db")]
173pub async fn create_song_with_overrides<C: Connection>(
174 db: &Surreal<C>,
175 SongCase {
176 song,
177 artists,
178 album_artists,
179 album,
180 genre,
181 }: SongCase,
182 overrides: SongChangeSet,
183) -> Result<Song> {
184 let id = Song::generate_id();
185 let song = Song {
186 id: id.clone(),
187 title: Arc::from(format!("Song {song}").as_str()),
188 artist: artists
189 .iter()
190 .map(|a| format!("Artist {a}"))
191 .map(Arc::from)
192 .collect::<Vec<_>>()
193 .into(),
194 album_artist: album_artists
195 .iter()
196 .map(|a| format!("Artist {a}"))
197 .map(Arc::from)
198 .collect::<Vec<_>>()
199 .into(),
200 album: Arc::from(format!("Album {album}").as_str()),
201 genre: OneOrMany::One(Arc::from(format!("Genre {genre}").as_str())),
202 runtime: Duration::from_secs(120),
203 track: None,
204 disc: None,
205 release_year: None,
206 extension: Arc::from("mp3"),
207 path: PathBuf::from_str(&format!("{}.mp3", id.id))?,
208 };
209
210 Song::create(db, song.clone()).await?;
211 if overrides != SongChangeSet::default() {
212 Song::update(db, song.id.clone(), overrides).await?;
213 }
214 let song = Song::read(db, song.id).await?.expect("Song should exist");
215 Ok(song)
216}
217
218pub fn create_song_metadata(
227 tempdir: &tempfile::TempDir,
228 SongCase {
229 song,
230 artists,
231 album_artists,
232 album,
233 genre,
234 }: SongCase,
235) -> Result<SongMetadata> {
236 let base_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
238 .join("../assets/music.mp3")
239 .canonicalize()?;
240
241 let mut tagged_file = Probe::open(&base_path)?.read()?;
242 let tag = match tagged_file.primary_tag_mut() {
243 Some(primary_tag) => primary_tag,
244 None => tagged_file
248 .first_tag_mut()
249 .ok_or_else(|| anyhow::anyhow!("ERROR: No tags found"))?,
250 };
251
252 tag.insert_text(
253 ItemKey::AlbumArtist,
254 album_artists
255 .iter()
256 .map(|a| format!("Artist {a}"))
257 .collect::<Vec<_>>()
258 .join(ARTIST_NAME_SEPARATOR),
259 );
260
261 tag.remove_artist();
262 tag.set_artist(
263 artists
264 .iter()
265 .map(|a| format!("Artist {a}"))
266 .collect::<Vec<_>>()
267 .join(ARTIST_NAME_SEPARATOR),
268 );
269
270 tag.remove_album();
271 tag.set_album(format!("Album {album}"));
272
273 tag.remove_title();
274 tag.set_title(format!("Song {song}"));
275
276 tag.remove_genre();
277 tag.set_genre(format!("Genre {genre}"));
278
279 let new_path = tempdir.path().join(format!("song_{}.mp3", Id::ulid()));
280 std::fs::copy(&base_path, &new_path)?;
282 tag.save_to_path(&new_path, WriteOptions::default())?;
284
285 Ok(SongMetadata::load_from_path(
287 new_path,
288 &OneOrMany::One(ARTIST_NAME_SEPARATOR.to_string()),
289 None,
290 )?)
291}
292
293#[derive(Debug, Clone)]
294pub struct SongCase {
295 pub song: u8,
296 pub artists: Vec<u8>,
297 pub album_artists: Vec<u8>,
298 pub album: u8,
299 pub genre: u8,
300}
301
302impl SongCase {
303 #[must_use]
304 pub const fn new(
305 song: u8,
306 artists: Vec<u8>,
307 album_artists: Vec<u8>,
308 album: u8,
309 genre: u8,
310 ) -> Self {
311 Self {
312 song,
313 artists,
314 album_artists,
315 album,
316 genre,
317 }
318 }
319}
320
321pub fn arb_song_case() -> impl Fn() -> SongCase {
322 || {
323 let artist_item_strategy = move || {
324 (0..=10u8)
325 .choose(&mut rand::thread_rng())
326 .unwrap_or_default()
327 };
328 let rng = &mut rand::thread_rng();
329 let artists = arb_vec(&artist_item_strategy, 1..=10)()
330 .into_iter()
331 .collect::<std::collections::HashSet<_>>()
332 .into_iter()
333 .collect::<Vec<_>>();
334 let album_artists = arb_vec(&artist_item_strategy, 1..=10)()
335 .into_iter()
336 .collect::<std::collections::HashSet<_>>()
337 .into_iter()
338 .collect::<Vec<_>>();
339 let song = (0..=10u8).choose(rng).unwrap_or_default();
340 let album = (0..=10u8).choose(rng).unwrap_or_default();
341 let genre = (0..=10u8).choose(rng).unwrap_or_default();
342
343 SongCase::new(song, artists, album_artists, album, genre)
344 }
345}
346
347pub fn arb_vec<T>(
348 item_strategy: &impl Fn() -> T,
349 range: RangeInclusive<usize>,
350) -> impl Fn() -> Vec<T> + '_
351where
352 T: Clone + std::fmt::Debug + Sized,
353{
354 move || {
355 let size = range
356 .clone()
357 .choose(&mut rand::thread_rng())
358 .unwrap_or_default();
359 std::iter::repeat_with(item_strategy).take(size).collect()
360 }
361}
362
363pub enum IndexMode {
364 InBounds,
365 OutOfBounds,
366}
367
368pub fn arb_vec_and_index<T>(
369 item_strategy: &impl Fn() -> T,
370 range: RangeInclusive<usize>,
371 index_mode: IndexMode,
372) -> impl Fn() -> (Vec<T>, usize) + '_
373where
374 T: Clone + std::fmt::Debug + Sized,
375{
376 move || {
377 let vec = arb_vec(item_strategy, range.clone())();
378 let index = match index_mode {
379 IndexMode::InBounds => 0..vec.len(),
380 #[allow(clippy::range_plus_one)]
381 IndexMode::OutOfBounds => vec.len()..(vec.len() + vec.len() / 2 + 1),
382 }
383 .choose(&mut rand::thread_rng())
384 .unwrap_or_default();
385 (vec, index)
386 }
387}
388
389pub enum RangeStartMode {
390 Standard,
391 Zero,
392 OutOfBounds,
393}
394
395pub enum RangeEndMode {
396 Start,
397 Standard,
398 OutOfBounds,
399}
400
401pub enum RangeIndexMode {
402 InBounds,
403 InRange,
404 AfterRangeInBounds,
405 OutOfBounds,
406 BeforeRange,
407}
408
409pub fn arb_vec_and_range_and_index<T>(
413 item_strategy: &impl Fn() -> T,
414 range: RangeInclusive<usize>,
415 range_start_mode: RangeStartMode,
416 range_end_mode: RangeEndMode,
417 index_mode: RangeIndexMode,
418) -> impl Fn() -> (Vec<T>, Range<usize>, Option<usize>) + '_
419where
420 T: Clone + std::fmt::Debug + Sized,
421{
422 move || {
423 let rng = &mut rand::thread_rng();
424 let vec = arb_vec(item_strategy, range.clone())();
425 let start = match range_start_mode {
426 RangeStartMode::Standard => 0..vec.len(),
427 #[allow(clippy::range_plus_one)]
428 RangeStartMode::OutOfBounds => vec.len()..(vec.len() + vec.len() / 2 + 1),
429 RangeStartMode::Zero => 0..1,
430 }
431 .choose(rng)
432 .unwrap_or_default();
433 let end = match range_end_mode {
434 RangeEndMode::Standard => start..vec.len(),
435 #[allow(clippy::range_plus_one)]
436 RangeEndMode::OutOfBounds => vec.len()..(vec.len() + vec.len() / 2 + 1).max(start),
437 #[allow(clippy::range_plus_one)]
438 RangeEndMode::Start => start..(start + 1),
439 }
440 .choose(rng)
441 .unwrap_or_default();
442
443 let index = match index_mode {
444 RangeIndexMode::InBounds => 0..vec.len(),
445 RangeIndexMode::InRange => start..end,
446 RangeIndexMode::AfterRangeInBounds => end..vec.len(),
447 #[allow(clippy::range_plus_one)]
448 RangeIndexMode::OutOfBounds => vec.len()..(vec.len() + vec.len() / 2 + 1),
449 RangeIndexMode::BeforeRange => 0..start,
450 }
451 .choose(rng);
452
453 (vec, start..end, index)
454 }
455}
456
457pub fn arb_analysis_features() -> impl Fn() -> [f64; 20] {
458 move || {
459 let rng = &mut rand::thread_rng();
460 let mut features = [0.0; 20];
461 for feature in &mut features {
462 *feature = rng.gen_range(-1.0..1.0);
463 }
464 features
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471 use pretty_assertions::assert_eq;
472
473 #[tokio::test]
474 async fn test_create_song() {
475 let db = init_test_database().await.unwrap();
476 let song_case = SongCase::new(0, vec![0], vec![0], 0, 0);
478
479 let result = create_song_with_overrides(&db, song_case, SongChangeSet::default()).await;
481
482 if let Err(e) = result {
484 panic!("Error creating song: {e:?}");
485 }
486
487 let song = result.unwrap();
489
490 let song_from_db = Song::read(&db, song.id.clone()).await.unwrap().unwrap();
492
493 assert_eq!(song, song_from_db);
495 }
496}