mecomp_storage/db/crud/
analysis.rs

1//! CRUD operations for the analysis table
2
3use one_or_many::OneOrMany;
4use surrealdb::{Connection, RecordId, Surreal};
5use tracing::instrument;
6
7use crate::{
8    db::{
9        queries::analysis::{
10            add_to_song, nearest_neighbors, nearest_neighbors_to_many, read_for_song, read_song,
11            read_songs_without_analysis,
12        },
13        schemas::{
14            analysis::{Analysis, AnalysisId, TABLE_NAME},
15            song::{Song, SongId},
16        },
17    },
18    errors::{Error, StorageResult},
19};
20
21impl Analysis {
22    /// create a new analysis for the given song
23    ///
24    /// If an analysis already exists for the song, this will return None.
25    #[instrument]
26    pub async fn create<C: Connection>(
27        db: &Surreal<C>,
28        song_id: SongId,
29        analysis: Self,
30    ) -> StorageResult<Option<Self>> {
31        if Self::read_for_song(db, song_id.clone()).await?.is_some() {
32            return Ok(None);
33        }
34
35        // create the analysis
36        let result: Option<Self> = db
37            .create(RecordId::from_inner(analysis.id.clone()))
38            .content(analysis)
39            .await?;
40
41        if let Some(analysis) = result {
42            // relate the song to the analysis
43            db.query(add_to_song())
44                .bind(("id", analysis.id.clone()))
45                .bind(("song", song_id))
46                .await?;
47
48            // return the analysis
49            Ok(Some(analysis))
50        } else {
51            Ok(None)
52        }
53    }
54
55    #[instrument]
56    pub async fn read<C: Connection>(
57        db: &Surreal<C>,
58        id: AnalysisId,
59    ) -> StorageResult<Option<Self>> {
60        Ok(db.select(RecordId::from_inner(id)).await?)
61    }
62
63    #[instrument]
64    pub async fn read_all<C: Connection>(db: &Surreal<C>) -> StorageResult<Vec<Self>> {
65        Ok(db.select(TABLE_NAME).await?)
66    }
67
68    /// Read the analysis for a song
69    ///
70    /// If the song does not have an analysis, this will return None.
71    #[instrument]
72    pub async fn read_for_song<C: Connection>(
73        db: &Surreal<C>,
74        song_id: SongId,
75    ) -> StorageResult<Option<Self>> {
76        Ok(db
77            .query(read_for_song())
78            .bind(("song", song_id))
79            .await?
80            .take(0)?)
81    }
82
83    /// Read the analysis for OneOrMany song(s)
84    ///
85    /// Needed for clustering(?)
86    ///
87    /// We return a Vec<Option<Analysis>>, where None means the song doesn't have an analysis, so that it's up to the caller to handle songs without analyses.
88    #[instrument]
89    pub async fn read_for_songs<C: Connection>(
90        db: &Surreal<C>,
91        song_ids: Vec<SongId>,
92    ) -> StorageResult<Vec<Option<Self>>> {
93        futures::future::try_join_all(song_ids.into_iter().map(|id| Self::read_for_song(db, id)))
94            .await
95    }
96
97    /// Read the song for an analysis
98    #[instrument]
99    pub async fn read_song<C: Connection>(db: &Surreal<C>, id: AnalysisId) -> StorageResult<Song> {
100        Option::<Song>::map_or_else(
101            db.query(read_song()).bind(("id", id)).await?.take(0)?,
102            || Err(Error::NotFound),
103            Ok,
104        )
105    }
106
107    /// Read the song for OneOrMany analyses
108    ///
109    /// needed to convert a list of analyses (such as what we get from nearest_neighbors) into a list of songs
110    #[instrument]
111    pub async fn read_songs<C: Connection>(
112        db: &Surreal<C>,
113        ids: OneOrMany<AnalysisId>,
114    ) -> StorageResult<OneOrMany<Song>> {
115        futures::future::try_join_all(ids.into_iter().map(|id| Self::read_song(db, id)))
116            .await
117            .map(OneOrMany::from)
118    }
119
120    /// Get all the songs that don't have an analysis
121    #[instrument]
122    pub async fn read_songs_without_analysis<C: Connection>(
123        db: &Surreal<C>,
124    ) -> StorageResult<Vec<Song>> {
125        Ok(db.query(read_songs_without_analysis()).await?.take(0)?)
126    }
127
128    /// Delete an analysis
129    #[instrument]
130    pub async fn delete<C: Connection>(
131        db: &Surreal<C>,
132        id: AnalysisId,
133    ) -> StorageResult<Option<Self>> {
134        Ok(db.delete(RecordId::from_inner(id)).await?)
135    }
136
137    /// Find the `n` nearest neighbors to an analysis
138    #[instrument]
139    pub async fn nearest_neighbors<C: Connection>(
140        db: &Surreal<C>,
141        id: AnalysisId,
142        n: u32,
143    ) -> StorageResult<Vec<Self>> {
144        let features = Self::read(db, id.clone())
145            .await?
146            .ok_or(Error::NotFound)?
147            .features;
148
149        Ok(db
150            .query(nearest_neighbors(n))
151            .bind(("id", id))
152            .bind(("target", features))
153            .await?
154            .take(0)?)
155    }
156
157    /// Find the `n` nearest neighbors to a list of analyses
158    ///
159    /// The provided analyses should not be included in the results
160    #[instrument]
161    pub async fn nearest_neighbors_to_many<C: Connection>(
162        db: &Surreal<C>,
163        ids: Vec<AnalysisId>,
164        n: u32,
165    ) -> StorageResult<Vec<Self>> {
166        // find the average "features" of the given analyses
167        let analyses =
168            futures::future::try_join_all(ids.iter().map(|id| Self::read(db, id.clone())))
169                .await?
170                .into_iter()
171                .map(|analysis| analysis.ok_or(Error::NotFound))
172                .collect::<Result<Vec<Self>, Error>>()?;
173
174        #[allow(clippy::cast_precision_loss)]
175        let num_analyses = analyses.len() as f64;
176
177        let avg_features = analyses.iter().fold(vec![0.; 20], |acc, analysis| {
178            acc.iter()
179                .zip(analysis.features.iter())
180                .map(|(a, b)| a + (b / num_analyses))
181                .collect::<Vec<_>>()
182        });
183
184        Ok(db
185            .query(nearest_neighbors_to_many(n))
186            .bind(("ids", ids))
187            .bind(("target", avg_features))
188            .await?
189            .take(0)?)
190    }
191}
192
193#[cfg(test)]
194mod test {
195    use super::*;
196    use crate::{
197        db::schemas::song::SongChangeSet,
198        test_utils::{arb_song_case, create_song_with_overrides, init_test_database},
199    };
200
201    use anyhow::Result;
202    use pretty_assertions::assert_eq;
203
204    #[tokio::test]
205    async fn test_create() -> Result<()> {
206        let db = init_test_database().await?;
207
208        let song =
209            create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
210
211        let analysis = Analysis {
212            id: Analysis::generate_id(),
213            features: [0.; 20],
214        };
215
216        // create the analysis
217        let result = Analysis::create(&db, song.id.clone(), analysis.clone()).await?;
218        assert_eq!(result, Some(analysis.clone()));
219
220        // if we try to create another analysis for the same song, we get Ok(None)
221        let analysis = Analysis {
222            id: Analysis::generate_id(),
223            features: [1.; 20],
224        };
225        let result = Analysis::create(&db, song.id.clone(), analysis.clone()).await?;
226        assert_eq!(result, None);
227
228        Ok(())
229    }
230
231    #[tokio::test]
232    async fn test_read() -> Result<()> {
233        let db = init_test_database().await?;
234
235        let song =
236            create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
237
238        let analysis = Analysis {
239            id: Analysis::generate_id(),
240            features: [0.; 20],
241        };
242
243        // create the analysis
244        let result = Analysis::create(&db, song.id.clone(), analysis.clone()).await?;
245        assert_eq!(result, Some(analysis.clone()));
246
247        // read the analysis
248        let result = Analysis::read(&db, analysis.id.clone()).await?;
249        assert_eq!(result, Some(analysis));
250
251        Ok(())
252    }
253
254    #[tokio::test]
255    async fn test_read_all() -> Result<()> {
256        let db = init_test_database().await?;
257
258        let song =
259            create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
260
261        let analysis = Analysis {
262            id: Analysis::generate_id(),
263            features: [0.; 20],
264        };
265
266        // create the analysis
267        let result = Analysis::create(&db, song.id.clone(), analysis.clone()).await?;
268        assert_eq!(result, Some(analysis.clone()));
269
270        // read all the analyses
271        let result = Analysis::read_all(&db).await?;
272        assert_eq!(result, vec![analysis]);
273
274        Ok(())
275    }
276
277    #[tokio::test]
278    async fn test_read_for_song() -> Result<()> {
279        let db = init_test_database().await?;
280
281        let song =
282            create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
283
284        let analysis = Analysis {
285            id: Analysis::generate_id(),
286            features: [0.; 20],
287        };
288
289        // the song doesn't have an analysis yet
290        let result = Analysis::read_for_song(&db, song.id.clone()).await?;
291        assert_eq!(result, None);
292
293        // create the analysis
294        let result = Analysis::create(&db, song.id.clone(), analysis.clone()).await?;
295        assert_eq!(result, Some(analysis.clone()));
296
297        // read the analysis for the song
298        let result = Analysis::read_for_song(&db, song.id.clone()).await?;
299        assert_eq!(result, Some(analysis));
300
301        Ok(())
302    }
303
304    #[tokio::test]
305    async fn test_read_for_songs() -> Result<()> {
306        let db = init_test_database().await?;
307
308        let song1 =
309            create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
310        let song2 =
311            create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
312        let song3 =
313            create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
314
315        let analysis1 = Analysis {
316            id: Analysis::generate_id(),
317            features: [0.; 20],
318        };
319        let analysis2 = Analysis {
320            id: Analysis::generate_id(),
321            features: [1.; 20],
322        };
323
324        // create the analyses
325        let result = Analysis::create(&db, song1.id.clone(), analysis1.clone()).await?;
326        assert_eq!(result, Some(analysis1.clone()));
327        let result = Analysis::create(&db, song2.id.clone(), analysis2.clone()).await?;
328        assert_eq!(result, Some(analysis2.clone()));
329
330        // read the analyses for the songs
331        let result = Analysis::read_for_songs(
332            &db,
333            vec![song1.id.clone(), song2.id.clone(), song3.id.clone()],
334        )
335        .await?;
336        assert_eq!(result, vec![Some(analysis1), Some(analysis2), None]);
337
338        Ok(())
339    }
340
341    #[tokio::test]
342    async fn test_read_song() -> Result<()> {
343        let db = init_test_database().await?;
344
345        let song =
346            create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
347
348        let analysis = Analysis {
349            id: Analysis::generate_id(),
350            features: [0.; 20],
351        };
352
353        // create the analysis
354        let result = Analysis::create(&db, song.id.clone(), analysis.clone()).await?;
355        assert_eq!(result, Some(analysis.clone()));
356
357        // read the song for the analysis
358        let result = Analysis::read_song(&db, analysis.id.clone()).await?;
359        assert_eq!(result, song);
360
361        Ok(())
362    }
363
364    #[tokio::test]
365    async fn test_read_songs() -> Result<()> {
366        let db = init_test_database().await?;
367
368        let song1 =
369            create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
370        let song2 =
371            create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
372
373        let analysis1 = Analysis {
374            id: Analysis::generate_id(),
375            features: [0.; 20],
376        };
377        let analysis2 = Analysis {
378            id: Analysis::generate_id(),
379            features: [1.; 20],
380        };
381
382        // create the analyses
383        let result = Analysis::create(&db, song1.id.clone(), analysis1.clone()).await?;
384        assert_eq!(result, Some(analysis1.clone()));
385        let result = Analysis::create(&db, song2.id.clone(), analysis2.clone()).await?;
386        assert_eq!(result, Some(analysis2.clone()));
387
388        // read the songs for the analyses
389        let result = Analysis::read_songs(
390            &db,
391            OneOrMany::Many(vec![analysis1.id.clone(), analysis2.id.clone()]),
392        )
393        .await?;
394        assert_eq!(result, OneOrMany::Many(vec![song1, song2]));
395
396        Ok(())
397    }
398
399    #[tokio::test]
400    async fn test_read_songs_without_analysis() -> Result<()> {
401        let db = init_test_database().await?;
402
403        let song1 =
404            create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
405        let song2 =
406            create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
407
408        // read the songs without an analysis
409        let result = Analysis::read_songs_without_analysis(&db).await?;
410        assert_eq!(result.len(), 2);
411        assert!(result.contains(&song1));
412        assert!(result.contains(&song2));
413
414        let analysis1 = Analysis {
415            id: Analysis::generate_id(),
416            features: [0.; 20],
417        };
418        let analysis2 = Analysis {
419            id: Analysis::generate_id(),
420            features: [0.; 20],
421        };
422
423        // create the analysis
424        let result = Analysis::create(&db, song1.id.clone(), analysis1.clone()).await?;
425        assert_eq!(result, Some(analysis1.clone()));
426
427        // read the songs without an analysis
428        let result = Analysis::read_songs_without_analysis(&db).await?;
429        assert_eq!(result, vec![song2.clone()]);
430
431        // create the analysis
432        let result = Analysis::create(&db, song2.id.clone(), analysis2.clone()).await?;
433        assert_eq!(result, Some(analysis2.clone()));
434
435        // read the songs without an analysis
436        let result = Analysis::read_songs_without_analysis(&db).await?;
437        assert_eq!(result, vec![]);
438
439        Ok(())
440    }
441
442    #[tokio::test]
443    async fn test_delete() -> Result<()> {
444        let db = init_test_database().await?;
445
446        let song =
447            create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
448
449        let analysis = Analysis {
450            id: Analysis::generate_id(),
451            features: [0.; 20],
452        };
453
454        // create the analysis
455        let result = Analysis::create(&db, song.id.clone(), analysis.clone()).await?;
456        assert_eq!(result, Some(analysis.clone()));
457
458        // delete the analysis
459        let result = Analysis::delete(&db, analysis.id.clone()).await?;
460        assert_eq!(result, Some(analysis.clone()));
461
462        // if we try to read the analysis, we get None
463        let result = Analysis::read(&db, analysis.id.clone()).await?;
464        assert_eq!(result, None);
465
466        // if we try to read the analysis for the song, we get None
467        let result = Analysis::read_for_song(&db, song.id.clone()).await?;
468        assert_eq!(result, None);
469
470        Ok(())
471    }
472
473    #[tokio::test]
474    async fn test_nearest_neighbors() -> Result<()> {
475        let db = init_test_database().await?;
476
477        let song1 =
478            create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
479        let song2 =
480            create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
481        let song3 =
482            create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
483
484        let analysis1 = Analysis {
485            id: Analysis::generate_id(),
486            features: [0.; 20],
487        };
488        let analysis2 = Analysis {
489            id: Analysis::generate_id(),
490            features: [0.; 20],
491        };
492        let analysis3 = Analysis {
493            id: Analysis::generate_id(),
494            features: [1.; 20],
495        };
496
497        // create the analyses
498        let result1 = Analysis::create(&db, song1.id.clone(), analysis1.clone()).await?;
499        assert_eq!(result1, Some(analysis1.clone()));
500        let result2 = Analysis::create(&db, song2.id.clone(), analysis2.clone()).await?;
501        assert_eq!(result2, Some(analysis2.clone()));
502        let result3 = Analysis::create(&db, song3.id.clone(), analysis3.clone()).await?;
503        assert_eq!(result3, Some(analysis3.clone()));
504
505        // find the nearest neighbor to analysis1
506        let result = Analysis::nearest_neighbors(&db, analysis1.id, 1).await?;
507        assert_eq!(result, vec![analysis2.clone()]);
508
509        Ok(())
510    }
511
512    #[tokio::test]
513    async fn test_analysis_deleted_when_song_deleted() -> Result<()> {
514        let db = init_test_database().await?;
515
516        let song =
517            create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
518
519        let analysis = Analysis {
520            id: Analysis::generate_id(),
521            features: [0.; 20],
522        };
523
524        // create the analysis
525        let result = Analysis::create(&db, song.id.clone(), analysis.clone()).await?;
526        assert_eq!(result, Some(analysis.clone()));
527
528        // delete the song
529        let result = Song::delete(&db, song.id.clone()).await?;
530        assert_eq!(result, Some(song.clone()));
531
532        // if we try to read the song, we get None
533        let result = Song::read(&db, song.id.clone()).await?;
534        assert_eq!(result, None);
535
536        // if we try to read the analysis, we get None
537        let result = Analysis::read(&db, analysis.id.clone()).await?;
538        assert_eq!(result, None);
539
540        // if we try to read the analysis for the song, we get None
541        let result = Analysis::read_for_song(&db, song.id.clone()).await?;
542        assert_eq!(result, None);
543
544        // if we try to read the songs without an analysis, we get an empty list
545        let result = Analysis::read_songs_without_analysis(&db).await?;
546        assert_eq!(result, vec![]);
547
548        // if we try to read the song for the analysis, we get an error
549        let result = Analysis::read_song(&db, analysis.id.clone()).await;
550        assert!(matches!(result, Err(Error::NotFound)));
551
552        Ok(())
553    }
554}