mecomp_storage/
test_utils.rs

1use std::{
2    ops::{Range, RangeInclusive},
3    path::PathBuf,
4    str::FromStr,
5    sync::Arc,
6    time::Duration,
7};
8
9use anyhow::Result;
10use lofty::{config::WriteOptions, file::TaggedFileExt, prelude::*, probe::Probe, tag::Accessor};
11use one_or_many::OneOrMany;
12use rand::{seq::IteratorRandom, Rng};
13#[cfg(feature = "db")]
14use surrealdb::{
15    engine::local::{Db, Mem},
16    sql::Id,
17    Connection, Surreal,
18};
19
20#[cfg(feature = "analysis")]
21use crate::db::schemas::analysis::Analysis;
22#[cfg(not(feature = "db"))]
23use crate::db::schemas::Id;
24use crate::db::schemas::{
25    album::Album,
26    artist::Artist,
27    collection::Collection,
28    playlist::Playlist,
29    song::{Song, SongChangeSet, SongMetadata},
30};
31
32pub const ARTIST_NAME_SEPARATOR: &str = ", ";
33
34/// Initialize a test database with the same tables as the main database.
35/// This is useful for testing queries and mutations.
36///
37/// # Errors
38///
39/// This function will return an error if the database cannot be initialized.
40#[cfg(feature = "db")]
41#[allow(clippy::missing_inline_in_public_items)]
42pub async fn init_test_database() -> surrealdb::Result<Surreal<Db>> {
43    use crate::db::schemas::dynamic::DynamicPlaylist;
44
45    let db = Surreal::new::<Mem>(()).await?;
46    db.use_ns("test").use_db("test").await?;
47
48    crate::db::register_custom_analyzer(&db).await?;
49    surrealqlx::register_tables!(
50        &db,
51        Album,
52        Artist,
53        Song,
54        Collection,
55        Playlist,
56        DynamicPlaylist
57    )?;
58    #[cfg(feature = "analysis")]
59    surrealqlx::register_tables!(&db, Analysis)?;
60
61    Ok(db)
62}
63
64/// Initialize a test database with some basic state
65///
66/// # What will be created:
67///
68/// - a playlist named "Playlist 0"
69/// - a collection named "Collection 0"
70/// - optionally, a passed `DynamicPlaylist`
71/// - `song_count` arbitrary songs whose values are determined by the given `song_case_func`
72/// - a file in the given `TempDir` for each song
73///
74/// Can optionally also create a dynamic playlist with given information
75///
76/// You can pass functions to be used to create the songs and playlists
77///
78/// `song_case_func` signature
79/// `FnMut(usize) -> (SongCase, bool, bool)`
80/// - `i`: which song this is, 0..`song_count`
81/// - returns: `(the song_case to use when generating the song, whether the song should be added to the playlist, whether it should be added to the collection`
82///
83/// Note: will actually create files for the songs in the passed `TempDir`
84///
85/// # Panics
86///
87/// Panics if an error occurs during the above process, this is intended to only be used for testing
88/// so panicking when something goes wrong ensures that tests will fail and the backtrace will point
89/// to whatever line caused the panic in here.
90#[cfg(feature = "db")]
91#[allow(clippy::missing_inline_in_public_items)]
92pub async fn init_test_database_with_state<SCF>(
93    song_count: std::num::NonZero<usize>,
94    mut song_case_func: SCF,
95    dynamic: Option<crate::db::schemas::dynamic::DynamicPlaylist>,
96    tempdir: &tempfile::TempDir,
97) -> Arc<Surreal<Db>>
98where
99    SCF: FnMut(usize) -> (SongCase, bool, bool) + Send + Sync,
100{
101    use anyhow::Context;
102
103    use crate::db::schemas::dynamic::DynamicPlaylist;
104
105    let db = Arc::new(init_test_database().await.unwrap());
106
107    // create the playlist, collection, and optionally the dynamic playlist
108    let playlist = Playlist {
109        id: Playlist::generate_id(),
110        name: "Playlist 0".into(),
111        runtime: Duration::from_secs(0),
112        song_count: 0,
113    };
114    let playlist = Playlist::create(&db, playlist).await.unwrap().unwrap();
115
116    let collection = Collection {
117        id: Collection::generate_id(),
118        name: "Collection 0".into(),
119        runtime: Duration::from_secs(0),
120        song_count: 0,
121    };
122    let collection = Collection::create(&db, collection).await.unwrap().unwrap();
123
124    if let Some(dynamic) = dynamic {
125        let _ = DynamicPlaylist::create(&db, dynamic)
126            .await
127            .unwrap()
128            .unwrap();
129    }
130
131    // create the songs
132    for i in 0..(song_count.get()) {
133        let (song_case, add_to_playlist, add_to_collection) = song_case_func(i);
134
135        let metadata = create_song_metadata(tempdir, song_case.clone())
136            .context(format!(
137                "failed to create metadata for song case {song_case:?}"
138            ))
139            .unwrap();
140
141        let song = Song::try_load_into_db(&db, metadata)
142            .await
143            .context(format!(
144                "Failed to load into db the song case: {song_case:?}"
145            ))
146            .unwrap();
147
148        if add_to_playlist {
149            Playlist::add_songs(&db, playlist.id.clone(), vec![song.id.clone()])
150                .await
151                .unwrap();
152        }
153        if add_to_collection {
154            Collection::add_songs(&db, collection.id.clone(), vec![song.id.clone()])
155                .await
156                .unwrap();
157        }
158    }
159
160    db
161}
162
163/// Create a song with the given case, and optionally apply the given overrides.
164///
165/// The created song is shallow, meaning that the artists, album artists, and album are not created in the database.
166///
167/// # Errors
168///
169/// This function will return an error if the song cannot be created.
170///
171/// # Panics
172///
173/// Panics if the song can't be read from the database after creation.
174#[cfg(feature = "db")]
175#[allow(clippy::missing_inline_in_public_items)]
176pub async fn create_song_with_overrides<C: Connection>(
177    db: &Surreal<C>,
178    SongCase {
179        song,
180        artists,
181        album_artists,
182        album,
183        genre,
184    }: SongCase,
185    overrides: SongChangeSet,
186) -> Result<Song> {
187    let id = Song::generate_id();
188    let song = Song {
189        id: id.clone(),
190        title: Arc::from(format!("Song {song}").as_str()),
191        artist: artists
192            .iter()
193            .map(|a| format!("Artist {a}"))
194            .map(Arc::from)
195            .collect::<Vec<_>>()
196            .into(),
197        album_artist: album_artists
198            .iter()
199            .map(|a| format!("Artist {a}"))
200            .map(Arc::from)
201            .collect::<Vec<_>>()
202            .into(),
203        album: Arc::from(format!("Album {album}").as_str()),
204        genre: OneOrMany::One(Arc::from(format!("Genre {genre}").as_str())),
205        runtime: Duration::from_secs(120),
206        track: None,
207        disc: None,
208        release_year: None,
209        extension: Arc::from("mp3"),
210        path: PathBuf::from_str(&format!("{}.mp3", id.id))?,
211    };
212
213    Song::create(db, song.clone()).await?;
214    if overrides != SongChangeSet::default() {
215        Song::update(db, song.id.clone(), overrides).await?;
216    }
217    let song = Song::read(db, song.id).await?.expect("Song should exist");
218    Ok(song)
219}
220
221/// Creates a song file with the given case and overrides.
222/// The song file is created in a temporary directory.
223/// The song metadata is created from the song file.
224/// The song is not added to the database.
225///
226/// # Errors
227///
228/// This function will return an error if the song metadata cannot be created.
229#[allow(clippy::missing_inline_in_public_items)]
230pub fn create_song_metadata(
231    tempdir: &tempfile::TempDir,
232    SongCase {
233        song,
234        artists,
235        album_artists,
236        album,
237        genre,
238    }: SongCase,
239) -> Result<SongMetadata> {
240    // we have an example mp3 in `assets/`, we want to take that and create a new audio file with psuedorandom id3 tags
241    let base_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
242        .join("../assets/music.mp3")
243        .canonicalize()?;
244
245    let mut tagged_file = Probe::open(&base_path)?.read()?;
246    let tag = match tagged_file.primary_tag_mut() {
247        Some(primary_tag) => primary_tag,
248        // If the "primary" tag doesn't exist, we just grab the
249        // first tag we can find. Realistically, a tag reader would likely
250        // iterate through the tags to find a suitable one.
251        None => tagged_file
252            .first_tag_mut()
253            .ok_or_else(|| anyhow::anyhow!("ERROR: No tags found"))?,
254    };
255
256    tag.insert_text(
257        ItemKey::AlbumArtist,
258        album_artists
259            .iter()
260            .map(|a| format!("Artist {a}"))
261            .collect::<Vec<_>>()
262            .join(ARTIST_NAME_SEPARATOR),
263    );
264
265    tag.remove_artist();
266    tag.set_artist(
267        artists
268            .iter()
269            .map(|a| format!("Artist {a}"))
270            .collect::<Vec<_>>()
271            .join(ARTIST_NAME_SEPARATOR),
272    );
273
274    tag.remove_album();
275    tag.set_album(format!("Album {album}"));
276
277    tag.remove_title();
278    tag.set_title(format!("Song {song}"));
279
280    tag.remove_genre();
281    tag.set_genre(format!("Genre {genre}"));
282
283    let new_path = tempdir.path().join(format!("song_{}.mp3", Id::ulid()));
284    // copy the base file to the new path
285    std::fs::copy(&base_path, &new_path)?;
286    // write the new tags to the new file
287    tag.save_to_path(&new_path, WriteOptions::default())?;
288
289    // now, we need to load a SongMetadata from the new file
290    Ok(SongMetadata::load_from_path(
291        new_path,
292        &OneOrMany::One(ARTIST_NAME_SEPARATOR.to_string()),
293        None,
294    )?)
295}
296
297#[derive(Debug, Clone)]
298pub struct SongCase {
299    pub song: u8,
300    pub artists: Vec<u8>,
301    pub album_artists: Vec<u8>,
302    pub album: u8,
303    pub genre: u8,
304}
305
306impl SongCase {
307    #[must_use]
308    #[inline]
309    pub const fn new(
310        song: u8,
311        artists: Vec<u8>,
312        album_artists: Vec<u8>,
313        album: u8,
314        genre: u8,
315    ) -> Self {
316        Self {
317            song,
318            artists,
319            album_artists,
320            album,
321            genre,
322        }
323    }
324}
325
326#[inline]
327pub const fn arb_song_case() -> impl Fn() -> SongCase {
328    || {
329        let artist_item_strategy = move || {
330            (0..=10u8)
331                .choose(&mut rand::thread_rng())
332                .unwrap_or_default()
333        };
334        let rng = &mut rand::thread_rng();
335        let artists = arb_vec(&artist_item_strategy, 1..=10)()
336            .into_iter()
337            .collect::<std::collections::HashSet<_>>()
338            .into_iter()
339            .collect::<Vec<_>>();
340        let album_artists = arb_vec(&artist_item_strategy, 1..=10)()
341            .into_iter()
342            .collect::<std::collections::HashSet<_>>()
343            .into_iter()
344            .collect::<Vec<_>>();
345        let song = (0..=10u8).choose(rng).unwrap_or_default();
346        let album = (0..=10u8).choose(rng).unwrap_or_default();
347        let genre = (0..=10u8).choose(rng).unwrap_or_default();
348
349        SongCase::new(song, artists, album_artists, album, genre)
350    }
351}
352
353#[inline]
354pub const fn arb_vec<T>(
355    item_strategy: &impl Fn() -> T,
356    range: RangeInclusive<usize>,
357) -> impl Fn() -> Vec<T> + '_
358where
359    T: Clone + std::fmt::Debug + Sized,
360{
361    move || {
362        let size = range
363            .clone()
364            .choose(&mut rand::thread_rng())
365            .unwrap_or_default();
366        std::iter::repeat_with(item_strategy).take(size).collect()
367    }
368}
369
370pub enum IndexMode {
371    InBounds,
372    OutOfBounds,
373}
374
375#[inline]
376pub const fn arb_vec_and_index<T>(
377    item_strategy: &impl Fn() -> T,
378    range: RangeInclusive<usize>,
379    index_mode: IndexMode,
380) -> impl Fn() -> (Vec<T>, usize) + '_
381where
382    T: Clone + std::fmt::Debug + Sized,
383{
384    move || {
385        let vec = arb_vec(item_strategy, range.clone())();
386        let index = match index_mode {
387            IndexMode::InBounds => 0..vec.len(),
388            #[allow(clippy::range_plus_one)]
389            IndexMode::OutOfBounds => vec.len()..(vec.len() + vec.len() / 2 + 1),
390        }
391        .choose(&mut rand::thread_rng())
392        .unwrap_or_default();
393        (vec, index)
394    }
395}
396
397pub enum RangeStartMode {
398    Standard,
399    Zero,
400    OutOfBounds,
401}
402
403pub enum RangeEndMode {
404    Start,
405    Standard,
406    OutOfBounds,
407}
408
409pub enum RangeIndexMode {
410    InBounds,
411    InRange,
412    AfterRangeInBounds,
413    OutOfBounds,
414    BeforeRange,
415}
416
417// Returns a tuple of a Vec of T and a Range<usize>
418// where the start is a random index in the Vec
419// and the end is a random index in the Vec that is greater than or equal to the start
420#[inline]
421pub const fn arb_vec_and_range_and_index<T>(
422    item_strategy: &impl Fn() -> T,
423    range: RangeInclusive<usize>,
424    range_start_mode: RangeStartMode,
425    range_end_mode: RangeEndMode,
426    index_mode: RangeIndexMode,
427) -> impl Fn() -> (Vec<T>, Range<usize>, Option<usize>) + '_
428where
429    T: Clone + std::fmt::Debug + Sized,
430{
431    move || {
432        let rng = &mut rand::thread_rng();
433        let vec = arb_vec(item_strategy, range.clone())();
434        let start = match range_start_mode {
435            RangeStartMode::Standard => 0..vec.len(),
436            #[allow(clippy::range_plus_one)]
437            RangeStartMode::OutOfBounds => vec.len()..(vec.len() + vec.len() / 2 + 1),
438            RangeStartMode::Zero => 0..1,
439        }
440        .choose(rng)
441        .unwrap_or_default();
442        let end = match range_end_mode {
443            RangeEndMode::Standard => start..vec.len(),
444            #[allow(clippy::range_plus_one)]
445            RangeEndMode::OutOfBounds => vec.len()..(vec.len() + vec.len() / 2 + 1).max(start),
446            #[allow(clippy::range_plus_one)]
447            RangeEndMode::Start => start..(start + 1),
448        }
449        .choose(rng)
450        .unwrap_or_default();
451
452        let index = match index_mode {
453            RangeIndexMode::InBounds => 0..vec.len(),
454            RangeIndexMode::InRange => start..end,
455            RangeIndexMode::AfterRangeInBounds => end..vec.len(),
456            #[allow(clippy::range_plus_one)]
457            RangeIndexMode::OutOfBounds => vec.len()..(vec.len() + vec.len() / 2 + 1),
458            RangeIndexMode::BeforeRange => 0..start,
459        }
460        .choose(rng);
461
462        (vec, start..end, index)
463    }
464}
465
466#[inline]
467pub const fn arb_analysis_features() -> impl Fn() -> [f64; 20] {
468    move || {
469        let rng = &mut rand::thread_rng();
470        let mut features = [0.0; 20];
471        for feature in &mut features {
472            *feature = rng.gen_range(-1.0..1.0);
473        }
474        features
475    }
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481    use pretty_assertions::assert_eq;
482
483    #[tokio::test]
484    async fn test_create_song() {
485        let db = init_test_database().await.unwrap();
486        // Create a test case
487        let song_case = SongCase::new(0, vec![0], vec![0], 0, 0);
488
489        // Call the create_song function
490        let result = create_song_with_overrides(&db, song_case, SongChangeSet::default()).await;
491
492        // Assert that the result is Ok
493        if let Err(e) = result {
494            panic!("Error creating song: {e:?}");
495        }
496
497        // Get the Song from the result
498        let song = result.unwrap();
499
500        // Assert that we can get the song from the database
501        let song_from_db = Song::read(&db, song.id.clone()).await.unwrap().unwrap();
502
503        // Assert that the song from the database is the same as the song we created
504        assert_eq!(song, song_from_db);
505    }
506}