mecomp_storage/
test_utils.rs

1use std::{
2    ops::{Range, RangeInclusive},
3    path::PathBuf,
4    str::FromStr,
5    sync::Arc,
6    time::Duration,
7};
8
9use anyhow::Result;
10use lofty::{config::WriteOptions, file::TaggedFileExt, prelude::*, probe::Probe, tag::Accessor};
11use one_or_many::OneOrMany;
12use rand::{seq::IteratorRandom, Rng};
13#[cfg(feature = "db")]
14use surrealdb::{
15    engine::local::{Db, Mem},
16    sql::Id,
17    Connection, Surreal,
18};
19
20#[cfg(feature = "analysis")]
21use crate::db::schemas::analysis::Analysis;
22#[cfg(not(feature = "db"))]
23use crate::db::schemas::Id;
24use crate::db::schemas::{
25    album::Album,
26    artist::Artist,
27    collection::Collection,
28    playlist::Playlist,
29    song::{Song, SongChangeSet, SongMetadata},
30};
31
32pub const ARTIST_NAME_SEPARATOR: &str = ", ";
33
34/// Initialize a test database with the same tables as the main database.
35/// This is useful for testing queries and mutations.
36///
37/// # Errors
38///
39/// This function will return an error if the database cannot be initialized.
40#[cfg(feature = "db")]
41pub async fn init_test_database() -> surrealdb::Result<Surreal<Db>> {
42    use crate::db::schemas::dynamic::DynamicPlaylist;
43
44    let db = Surreal::new::<Mem>(()).await?;
45    db.use_ns("test").use_db("test").await?;
46
47    crate::db::register_custom_analyzer(&db).await?;
48    surrealqlx::register_tables!(
49        &db,
50        Album,
51        Artist,
52        Song,
53        Collection,
54        Playlist,
55        DynamicPlaylist
56    )?;
57    #[cfg(feature = "analysis")]
58    surrealqlx::register_tables!(&db, Analysis)?;
59
60    Ok(db)
61}
62
63/// Initialize a test database with some basic state
64///
65/// # What will be created:
66///
67/// - a playlist named "Playlist 0"
68/// - a collection named "Collection 0"
69/// - optionally, a passed `DynamicPlaylist`
70/// - `song_count` arbitrary songs whose values are determined by the given `song_case_func`
71/// - a file in the given `TempDir` for each song
72///
73/// Can optionally also create a dynamic playlist with given information
74///
75/// You can pass functions to be used to create the songs and playlists
76///
77/// `song_case_func` signature
78/// `FnMut(usize) -> (SongCase, bool, bool)`
79/// - `i`: which song this is, 0..`song_count`
80/// - returns: `(the song_case to use when generating the song, whether the song should be added to the playlist, whether it should be added to the collection`
81///
82/// Note: will actually create files for the songs in the passed `TempDir`
83///
84/// # Panics
85///
86/// Panics if an error occurs during the above process, this is intended to only be used for testing
87/// so panicking when something goes wrong ensures that tests will fail and the backtrace will point
88/// to whatever line caused the panic in here.
89#[cfg(feature = "db")]
90pub async fn init_test_database_with_state<SCF>(
91    song_count: std::num::NonZero<usize>,
92    mut song_case_func: SCF,
93    dynamic: Option<crate::db::schemas::dynamic::DynamicPlaylist>,
94    tempdir: &tempfile::TempDir,
95) -> Arc<Surreal<Db>>
96where
97    SCF: FnMut(usize) -> (SongCase, bool, bool) + Send + Sync,
98{
99    use anyhow::Context;
100
101    use crate::db::schemas::dynamic::DynamicPlaylist;
102
103    let db = Arc::new(init_test_database().await.unwrap());
104
105    // create the playlist, collection, and optionally the dynamic playlist
106    let playlist = Playlist {
107        id: Playlist::generate_id(),
108        name: "Playlist 0".into(),
109        runtime: Duration::from_secs(0),
110        song_count: 0,
111    };
112    let playlist = Playlist::create(&db, playlist).await.unwrap().unwrap();
113
114    let collection = Collection {
115        id: Collection::generate_id(),
116        name: "Collection 0".into(),
117        runtime: Duration::from_secs(0),
118        song_count: 0,
119    };
120    let collection = Collection::create(&db, collection).await.unwrap().unwrap();
121
122    if let Some(dynamic) = dynamic {
123        let _ = DynamicPlaylist::create(&db, dynamic)
124            .await
125            .unwrap()
126            .unwrap();
127    }
128
129    // create the songs
130    for i in 0..(song_count.get()) {
131        let (song_case, add_to_playlist, add_to_collection) = song_case_func(i);
132
133        let metadata = create_song_metadata(tempdir, song_case.clone())
134            .context(format!(
135                "failed to create metadata for song case {song_case:?}"
136            ))
137            .unwrap();
138
139        let song = Song::try_load_into_db(&db, metadata)
140            .await
141            .context(format!(
142                "Failed to load into db the song case: {song_case:?}"
143            ))
144            .unwrap();
145
146        if add_to_playlist {
147            Playlist::add_songs(&db, playlist.id.clone(), vec![song.id.clone()])
148                .await
149                .unwrap();
150        }
151        if add_to_collection {
152            Collection::add_songs(&db, collection.id.clone(), vec![song.id.clone()])
153                .await
154                .unwrap();
155        }
156    }
157
158    db
159}
160
161/// Create a song with the given case, and optionally apply the given overrides.
162///
163/// The created song is shallow, meaning that the artists, album artists, and album are not created in the database.
164///
165/// # Errors
166///
167/// This function will return an error if the song cannot be created.
168///
169/// # Panics
170///
171/// Panics if the song can't be read from the database after creation.
172#[cfg(feature = "db")]
173pub async fn create_song_with_overrides<C: Connection>(
174    db: &Surreal<C>,
175    SongCase {
176        song,
177        artists,
178        album_artists,
179        album,
180        genre,
181    }: SongCase,
182    overrides: SongChangeSet,
183) -> Result<Song> {
184    let id = Song::generate_id();
185    let song = Song {
186        id: id.clone(),
187        title: Arc::from(format!("Song {song}").as_str()),
188        artist: artists
189            .iter()
190            .map(|a| format!("Artist {a}"))
191            .map(Arc::from)
192            .collect::<Vec<_>>()
193            .into(),
194        album_artist: album_artists
195            .iter()
196            .map(|a| format!("Artist {a}"))
197            .map(Arc::from)
198            .collect::<Vec<_>>()
199            .into(),
200        album: Arc::from(format!("Album {album}").as_str()),
201        genre: OneOrMany::One(Arc::from(format!("Genre {genre}").as_str())),
202        runtime: Duration::from_secs(120),
203        track: None,
204        disc: None,
205        release_year: None,
206        extension: Arc::from("mp3"),
207        path: PathBuf::from_str(&format!("{}.mp3", id.id))?,
208    };
209
210    Song::create(db, song.clone()).await?;
211    if overrides != SongChangeSet::default() {
212        Song::update(db, song.id.clone(), overrides).await?;
213    }
214    let song = Song::read(db, song.id).await?.expect("Song should exist");
215    Ok(song)
216}
217
218/// Creates a song file with the given case and overrides.
219/// The song file is created in a temporary directory.
220/// The song metadata is created from the song file.
221/// The song is not added to the database.
222///
223/// # Errors
224///
225/// This function will return an error if the song metadata cannot be created.
226pub fn create_song_metadata(
227    tempdir: &tempfile::TempDir,
228    SongCase {
229        song,
230        artists,
231        album_artists,
232        album,
233        genre,
234    }: SongCase,
235) -> Result<SongMetadata> {
236    // we have an example mp3 in `assets/`, we want to take that and create a new audio file with psuedorandom id3 tags
237    let base_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
238        .join("../assets/music.mp3")
239        .canonicalize()?;
240
241    let mut tagged_file = Probe::open(&base_path)?.read()?;
242    let tag = match tagged_file.primary_tag_mut() {
243        Some(primary_tag) => primary_tag,
244        // If the "primary" tag doesn't exist, we just grab the
245        // first tag we can find. Realistically, a tag reader would likely
246        // iterate through the tags to find a suitable one.
247        None => tagged_file
248            .first_tag_mut()
249            .ok_or_else(|| anyhow::anyhow!("ERROR: No tags found"))?,
250    };
251
252    tag.insert_text(
253        ItemKey::AlbumArtist,
254        album_artists
255            .iter()
256            .map(|a| format!("Artist {a}"))
257            .collect::<Vec<_>>()
258            .join(ARTIST_NAME_SEPARATOR),
259    );
260
261    tag.remove_artist();
262    tag.set_artist(
263        artists
264            .iter()
265            .map(|a| format!("Artist {a}"))
266            .collect::<Vec<_>>()
267            .join(ARTIST_NAME_SEPARATOR),
268    );
269
270    tag.remove_album();
271    tag.set_album(format!("Album {album}"));
272
273    tag.remove_title();
274    tag.set_title(format!("Song {song}"));
275
276    tag.remove_genre();
277    tag.set_genre(format!("Genre {genre}"));
278
279    let new_path = tempdir.path().join(format!("song_{}.mp3", Id::ulid()));
280    // copy the base file to the new path
281    std::fs::copy(&base_path, &new_path)?;
282    // write the new tags to the new file
283    tag.save_to_path(&new_path, WriteOptions::default())?;
284
285    // now, we need to load a SongMetadata from the new file
286    Ok(SongMetadata::load_from_path(
287        new_path,
288        &OneOrMany::One(ARTIST_NAME_SEPARATOR.to_string()),
289        None,
290    )?)
291}
292
293#[derive(Debug, Clone)]
294pub struct SongCase {
295    pub song: u8,
296    pub artists: Vec<u8>,
297    pub album_artists: Vec<u8>,
298    pub album: u8,
299    pub genre: u8,
300}
301
302impl SongCase {
303    #[must_use]
304    pub const fn new(
305        song: u8,
306        artists: Vec<u8>,
307        album_artists: Vec<u8>,
308        album: u8,
309        genre: u8,
310    ) -> Self {
311        Self {
312            song,
313            artists,
314            album_artists,
315            album,
316            genre,
317        }
318    }
319}
320
321pub fn arb_song_case() -> impl Fn() -> SongCase {
322    || {
323        let artist_item_strategy = move || {
324            (0..=10u8)
325                .choose(&mut rand::thread_rng())
326                .unwrap_or_default()
327        };
328        let rng = &mut rand::thread_rng();
329        let artists = arb_vec(&artist_item_strategy, 1..=10)()
330            .into_iter()
331            .collect::<std::collections::HashSet<_>>()
332            .into_iter()
333            .collect::<Vec<_>>();
334        let album_artists = arb_vec(&artist_item_strategy, 1..=10)()
335            .into_iter()
336            .collect::<std::collections::HashSet<_>>()
337            .into_iter()
338            .collect::<Vec<_>>();
339        let song = (0..=10u8).choose(rng).unwrap_or_default();
340        let album = (0..=10u8).choose(rng).unwrap_or_default();
341        let genre = (0..=10u8).choose(rng).unwrap_or_default();
342
343        SongCase::new(song, artists, album_artists, album, genre)
344    }
345}
346
347pub fn arb_vec<T>(
348    item_strategy: &impl Fn() -> T,
349    range: RangeInclusive<usize>,
350) -> impl Fn() -> Vec<T> + '_
351where
352    T: Clone + std::fmt::Debug + Sized,
353{
354    move || {
355        let size = range
356            .clone()
357            .choose(&mut rand::thread_rng())
358            .unwrap_or_default();
359        std::iter::repeat_with(item_strategy).take(size).collect()
360    }
361}
362
363pub enum IndexMode {
364    InBounds,
365    OutOfBounds,
366}
367
368pub fn arb_vec_and_index<T>(
369    item_strategy: &impl Fn() -> T,
370    range: RangeInclusive<usize>,
371    index_mode: IndexMode,
372) -> impl Fn() -> (Vec<T>, usize) + '_
373where
374    T: Clone + std::fmt::Debug + Sized,
375{
376    move || {
377        let vec = arb_vec(item_strategy, range.clone())();
378        let index = match index_mode {
379            IndexMode::InBounds => 0..vec.len(),
380            #[allow(clippy::range_plus_one)]
381            IndexMode::OutOfBounds => vec.len()..(vec.len() + vec.len() / 2 + 1),
382        }
383        .choose(&mut rand::thread_rng())
384        .unwrap_or_default();
385        (vec, index)
386    }
387}
388
389pub enum RangeStartMode {
390    Standard,
391    Zero,
392    OutOfBounds,
393}
394
395pub enum RangeEndMode {
396    Start,
397    Standard,
398    OutOfBounds,
399}
400
401pub enum RangeIndexMode {
402    InBounds,
403    InRange,
404    AfterRangeInBounds,
405    OutOfBounds,
406    BeforeRange,
407}
408
409// Returns a tuple of a Vec of T and a Range<usize>
410// where the start is a random index in the Vec
411// and the end is a random index in the Vec that is greater than or equal to the start
412pub fn arb_vec_and_range_and_index<T>(
413    item_strategy: &impl Fn() -> T,
414    range: RangeInclusive<usize>,
415    range_start_mode: RangeStartMode,
416    range_end_mode: RangeEndMode,
417    index_mode: RangeIndexMode,
418) -> impl Fn() -> (Vec<T>, Range<usize>, Option<usize>) + '_
419where
420    T: Clone + std::fmt::Debug + Sized,
421{
422    move || {
423        let rng = &mut rand::thread_rng();
424        let vec = arb_vec(item_strategy, range.clone())();
425        let start = match range_start_mode {
426            RangeStartMode::Standard => 0..vec.len(),
427            #[allow(clippy::range_plus_one)]
428            RangeStartMode::OutOfBounds => vec.len()..(vec.len() + vec.len() / 2 + 1),
429            RangeStartMode::Zero => 0..1,
430        }
431        .choose(rng)
432        .unwrap_or_default();
433        let end = match range_end_mode {
434            RangeEndMode::Standard => start..vec.len(),
435            #[allow(clippy::range_plus_one)]
436            RangeEndMode::OutOfBounds => vec.len()..(vec.len() + vec.len() / 2 + 1).max(start),
437            #[allow(clippy::range_plus_one)]
438            RangeEndMode::Start => start..(start + 1),
439        }
440        .choose(rng)
441        .unwrap_or_default();
442
443        let index = match index_mode {
444            RangeIndexMode::InBounds => 0..vec.len(),
445            RangeIndexMode::InRange => start..end,
446            RangeIndexMode::AfterRangeInBounds => end..vec.len(),
447            #[allow(clippy::range_plus_one)]
448            RangeIndexMode::OutOfBounds => vec.len()..(vec.len() + vec.len() / 2 + 1),
449            RangeIndexMode::BeforeRange => 0..start,
450        }
451        .choose(rng);
452
453        (vec, start..end, index)
454    }
455}
456
457pub fn arb_analysis_features() -> impl Fn() -> [f64; 20] {
458    move || {
459        let rng = &mut rand::thread_rng();
460        let mut features = [0.0; 20];
461        for feature in &mut features {
462            *feature = rng.gen_range(-1.0..1.0);
463        }
464        features
465    }
466}
467
468#[cfg(test)]
469mod tests {
470    use super::*;
471    use pretty_assertions::assert_eq;
472
473    #[tokio::test]
474    async fn test_create_song() {
475        let db = init_test_database().await.unwrap();
476        // Create a test case
477        let song_case = SongCase::new(0, vec![0], vec![0], 0, 0);
478
479        // Call the create_song function
480        let result = create_song_with_overrides(&db, song_case, SongChangeSet::default()).await;
481
482        // Assert that the result is Ok
483        if let Err(e) = result {
484            panic!("Error creating song: {e:?}");
485        }
486
487        // Get the Song from the result
488        let song = result.unwrap();
489
490        // Assert that we can get the song from the database
491        let song_from_db = Song::read(&db, song.id.clone()).await.unwrap().unwrap();
492
493        // Assert that the song from the database is the same as the song we created
494        assert_eq!(song, song_from_db);
495    }
496}