1use one_or_many::OneOrMany;
4use surrealdb::{Connection, RecordId, Surreal};
5use tracing::instrument;
6
7use crate::{
8 db::{
9 queries::analysis::{
10 add_to_song, nearest_neighbors, nearest_neighbors_to_many, read_for_song, read_song,
11 read_songs_without_analysis,
12 },
13 schemas::{
14 analysis::{Analysis, AnalysisId, TABLE_NAME},
15 song::{Song, SongId},
16 },
17 },
18 errors::{Error, StorageResult},
19};
20
21impl Analysis {
22 #[instrument]
26 pub async fn create<C: Connection>(
27 db: &Surreal<C>,
28 song_id: SongId,
29 analysis: Self,
30 ) -> StorageResult<Option<Self>> {
31 if Self::read_for_song(db, song_id.clone()).await?.is_some() {
32 return Ok(None);
33 }
34
35 let result: Option<Self> = db
37 .create(RecordId::from_inner(analysis.id.clone()))
38 .content(analysis)
39 .await?;
40
41 if let Some(analysis) = result {
42 db.query(add_to_song())
44 .bind(("id", analysis.id.clone()))
45 .bind(("song", song_id))
46 .await?;
47
48 Ok(Some(analysis))
50 } else {
51 Ok(None)
52 }
53 }
54
55 #[instrument]
56 pub async fn read<C: Connection>(
57 db: &Surreal<C>,
58 id: AnalysisId,
59 ) -> StorageResult<Option<Self>> {
60 Ok(db.select(RecordId::from_inner(id)).await?)
61 }
62
63 #[instrument]
64 pub async fn read_all<C: Connection>(db: &Surreal<C>) -> StorageResult<Vec<Self>> {
65 Ok(db.select(TABLE_NAME).await?)
66 }
67
68 #[instrument]
72 pub async fn read_for_song<C: Connection>(
73 db: &Surreal<C>,
74 song_id: SongId,
75 ) -> StorageResult<Option<Self>> {
76 Ok(db
77 .query(read_for_song())
78 .bind(("song", song_id))
79 .await?
80 .take(0)?)
81 }
82
83 #[instrument]
89 pub async fn read_for_songs<C: Connection>(
90 db: &Surreal<C>,
91 song_ids: Vec<SongId>,
92 ) -> StorageResult<Vec<Option<Self>>> {
93 futures::future::try_join_all(song_ids.into_iter().map(|id| Self::read_for_song(db, id)))
94 .await
95 }
96
97 #[instrument]
99 pub async fn read_song<C: Connection>(db: &Surreal<C>, id: AnalysisId) -> StorageResult<Song> {
100 Option::<Song>::map_or_else(
101 db.query(read_song()).bind(("id", id)).await?.take(0)?,
102 || Err(Error::NotFound),
103 Ok,
104 )
105 }
106
107 #[instrument]
111 pub async fn read_songs<C: Connection>(
112 db: &Surreal<C>,
113 ids: OneOrMany<AnalysisId>,
114 ) -> StorageResult<OneOrMany<Song>> {
115 futures::future::try_join_all(ids.into_iter().map(|id| Self::read_song(db, id)))
116 .await
117 .map(OneOrMany::from)
118 }
119
120 #[instrument]
122 pub async fn read_songs_without_analysis<C: Connection>(
123 db: &Surreal<C>,
124 ) -> StorageResult<Vec<Song>> {
125 Ok(db.query(read_songs_without_analysis()).await?.take(0)?)
126 }
127
128 #[instrument]
130 pub async fn delete<C: Connection>(
131 db: &Surreal<C>,
132 id: AnalysisId,
133 ) -> StorageResult<Option<Self>> {
134 Ok(db.delete(RecordId::from_inner(id)).await?)
135 }
136
137 #[instrument]
139 pub async fn nearest_neighbors<C: Connection>(
140 db: &Surreal<C>,
141 id: AnalysisId,
142 n: u32,
143 ) -> StorageResult<Vec<Self>> {
144 let features = Self::read(db, id.clone())
145 .await?
146 .ok_or(Error::NotFound)?
147 .features;
148
149 Ok(db
150 .query(nearest_neighbors(n))
151 .bind(("id", id))
152 .bind(("target", features))
153 .await?
154 .take(0)?)
155 }
156
157 #[instrument]
161 pub async fn nearest_neighbors_to_many<C: Connection>(
162 db: &Surreal<C>,
163 ids: Vec<AnalysisId>,
164 n: u32,
165 ) -> StorageResult<Vec<Self>> {
166 let analyses =
168 futures::future::try_join_all(ids.iter().map(|id| Self::read(db, id.clone())))
169 .await?
170 .into_iter()
171 .map(|analysis| analysis.ok_or(Error::NotFound))
172 .collect::<Result<Vec<Self>, Error>>()?;
173
174 #[allow(clippy::cast_precision_loss)]
175 let num_analyses = analyses.len() as f64;
176
177 let avg_features = analyses.iter().fold(vec![0.; 20], |acc, analysis| {
178 acc.iter()
179 .zip(analysis.features.iter())
180 .map(|(a, b)| a + (b / num_analyses))
181 .collect::<Vec<_>>()
182 });
183
184 Ok(db
185 .query(nearest_neighbors_to_many(n))
186 .bind(("ids", ids))
187 .bind(("target", avg_features))
188 .await?
189 .take(0)?)
190 }
191}
192
193#[cfg(test)]
194mod test {
195 use super::*;
196 use crate::{
197 db::schemas::song::SongChangeSet,
198 test_utils::{arb_song_case, create_song_with_overrides, init_test_database},
199 };
200
201 use anyhow::Result;
202 use pretty_assertions::assert_eq;
203
204 #[tokio::test]
205 async fn test_create() -> Result<()> {
206 let db = init_test_database().await?;
207
208 let song =
209 create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
210
211 let analysis = Analysis {
212 id: Analysis::generate_id(),
213 features: [0.; 20],
214 };
215
216 let result = Analysis::create(&db, song.id.clone(), analysis.clone()).await?;
218 assert_eq!(result, Some(analysis.clone()));
219
220 let analysis = Analysis {
222 id: Analysis::generate_id(),
223 features: [1.; 20],
224 };
225 let result = Analysis::create(&db, song.id.clone(), analysis.clone()).await?;
226 assert_eq!(result, None);
227
228 Ok(())
229 }
230
231 #[tokio::test]
232 async fn test_read() -> Result<()> {
233 let db = init_test_database().await?;
234
235 let song =
236 create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
237
238 let analysis = Analysis {
239 id: Analysis::generate_id(),
240 features: [0.; 20],
241 };
242
243 let result = Analysis::create(&db, song.id.clone(), analysis.clone()).await?;
245 assert_eq!(result, Some(analysis.clone()));
246
247 let result = Analysis::read(&db, analysis.id.clone()).await?;
249 assert_eq!(result, Some(analysis));
250
251 Ok(())
252 }
253
254 #[tokio::test]
255 async fn test_read_all() -> Result<()> {
256 let db = init_test_database().await?;
257
258 let song =
259 create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
260
261 let analysis = Analysis {
262 id: Analysis::generate_id(),
263 features: [0.; 20],
264 };
265
266 let result = Analysis::create(&db, song.id.clone(), analysis.clone()).await?;
268 assert_eq!(result, Some(analysis.clone()));
269
270 let result = Analysis::read_all(&db).await?;
272 assert_eq!(result, vec![analysis]);
273
274 Ok(())
275 }
276
277 #[tokio::test]
278 async fn test_read_for_song() -> Result<()> {
279 let db = init_test_database().await?;
280
281 let song =
282 create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
283
284 let analysis = Analysis {
285 id: Analysis::generate_id(),
286 features: [0.; 20],
287 };
288
289 let result = Analysis::read_for_song(&db, song.id.clone()).await?;
291 assert_eq!(result, None);
292
293 let result = Analysis::create(&db, song.id.clone(), analysis.clone()).await?;
295 assert_eq!(result, Some(analysis.clone()));
296
297 let result = Analysis::read_for_song(&db, song.id.clone()).await?;
299 assert_eq!(result, Some(analysis));
300
301 Ok(())
302 }
303
304 #[tokio::test]
305 async fn test_read_for_songs() -> Result<()> {
306 let db = init_test_database().await?;
307
308 let song1 =
309 create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
310 let song2 =
311 create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
312 let song3 =
313 create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
314
315 let analysis1 = Analysis {
316 id: Analysis::generate_id(),
317 features: [0.; 20],
318 };
319 let analysis2 = Analysis {
320 id: Analysis::generate_id(),
321 features: [1.; 20],
322 };
323
324 let result = Analysis::create(&db, song1.id.clone(), analysis1.clone()).await?;
326 assert_eq!(result, Some(analysis1.clone()));
327 let result = Analysis::create(&db, song2.id.clone(), analysis2.clone()).await?;
328 assert_eq!(result, Some(analysis2.clone()));
329
330 let result = Analysis::read_for_songs(
332 &db,
333 vec![song1.id.clone(), song2.id.clone(), song3.id.clone()],
334 )
335 .await?;
336 assert_eq!(result, vec![Some(analysis1), Some(analysis2), None]);
337
338 Ok(())
339 }
340
341 #[tokio::test]
342 async fn test_read_song() -> Result<()> {
343 let db = init_test_database().await?;
344
345 let song =
346 create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
347
348 let analysis = Analysis {
349 id: Analysis::generate_id(),
350 features: [0.; 20],
351 };
352
353 let result = Analysis::create(&db, song.id.clone(), analysis.clone()).await?;
355 assert_eq!(result, Some(analysis.clone()));
356
357 let result = Analysis::read_song(&db, analysis.id.clone()).await?;
359 assert_eq!(result, song);
360
361 Ok(())
362 }
363
364 #[tokio::test]
365 async fn test_read_songs() -> Result<()> {
366 let db = init_test_database().await?;
367
368 let song1 =
369 create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
370 let song2 =
371 create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
372
373 let analysis1 = Analysis {
374 id: Analysis::generate_id(),
375 features: [0.; 20],
376 };
377 let analysis2 = Analysis {
378 id: Analysis::generate_id(),
379 features: [1.; 20],
380 };
381
382 let result = Analysis::create(&db, song1.id.clone(), analysis1.clone()).await?;
384 assert_eq!(result, Some(analysis1.clone()));
385 let result = Analysis::create(&db, song2.id.clone(), analysis2.clone()).await?;
386 assert_eq!(result, Some(analysis2.clone()));
387
388 let result = Analysis::read_songs(
390 &db,
391 OneOrMany::Many(vec![analysis1.id.clone(), analysis2.id.clone()]),
392 )
393 .await?;
394 assert_eq!(result, OneOrMany::Many(vec![song1, song2]));
395
396 Ok(())
397 }
398
399 #[tokio::test]
400 async fn test_read_songs_without_analysis() -> Result<()> {
401 let db = init_test_database().await?;
402
403 let song1 =
404 create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
405 let song2 =
406 create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
407
408 let result = Analysis::read_songs_without_analysis(&db).await?;
410 assert_eq!(result.len(), 2);
411 assert!(result.contains(&song1));
412 assert!(result.contains(&song2));
413
414 let analysis1 = Analysis {
415 id: Analysis::generate_id(),
416 features: [0.; 20],
417 };
418 let analysis2 = Analysis {
419 id: Analysis::generate_id(),
420 features: [0.; 20],
421 };
422
423 let result = Analysis::create(&db, song1.id.clone(), analysis1.clone()).await?;
425 assert_eq!(result, Some(analysis1.clone()));
426
427 let result = Analysis::read_songs_without_analysis(&db).await?;
429 assert_eq!(result, vec![song2.clone()]);
430
431 let result = Analysis::create(&db, song2.id.clone(), analysis2.clone()).await?;
433 assert_eq!(result, Some(analysis2.clone()));
434
435 let result = Analysis::read_songs_without_analysis(&db).await?;
437 assert_eq!(result, vec![]);
438
439 Ok(())
440 }
441
442 #[tokio::test]
443 async fn test_delete() -> Result<()> {
444 let db = init_test_database().await?;
445
446 let song =
447 create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
448
449 let analysis = Analysis {
450 id: Analysis::generate_id(),
451 features: [0.; 20],
452 };
453
454 let result = Analysis::create(&db, song.id.clone(), analysis.clone()).await?;
456 assert_eq!(result, Some(analysis.clone()));
457
458 let result = Analysis::delete(&db, analysis.id.clone()).await?;
460 assert_eq!(result, Some(analysis.clone()));
461
462 let result = Analysis::read(&db, analysis.id.clone()).await?;
464 assert_eq!(result, None);
465
466 let result = Analysis::read_for_song(&db, song.id.clone()).await?;
468 assert_eq!(result, None);
469
470 Ok(())
471 }
472
473 #[tokio::test]
474 async fn test_nearest_neighbors() -> Result<()> {
475 let db = init_test_database().await?;
476
477 let song1 =
478 create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
479 let song2 =
480 create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
481 let song3 =
482 create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
483
484 let analysis1 = Analysis {
485 id: Analysis::generate_id(),
486 features: [0.; 20],
487 };
488 let analysis2 = Analysis {
489 id: Analysis::generate_id(),
490 features: [0.; 20],
491 };
492 let analysis3 = Analysis {
493 id: Analysis::generate_id(),
494 features: [1.; 20],
495 };
496
497 let result1 = Analysis::create(&db, song1.id.clone(), analysis1.clone()).await?;
499 assert_eq!(result1, Some(analysis1.clone()));
500 let result2 = Analysis::create(&db, song2.id.clone(), analysis2.clone()).await?;
501 assert_eq!(result2, Some(analysis2.clone()));
502 let result3 = Analysis::create(&db, song3.id.clone(), analysis3.clone()).await?;
503 assert_eq!(result3, Some(analysis3.clone()));
504
505 let result = Analysis::nearest_neighbors(&db, analysis1.id, 1).await?;
507 assert_eq!(result, vec![analysis2.clone()]);
508
509 Ok(())
510 }
511
512 #[tokio::test]
513 async fn test_analysis_deleted_when_song_deleted() -> Result<()> {
514 let db = init_test_database().await?;
515
516 let song =
517 create_song_with_overrides(&db, arb_song_case()(), SongChangeSet::default()).await?;
518
519 let analysis = Analysis {
520 id: Analysis::generate_id(),
521 features: [0.; 20],
522 };
523
524 let result = Analysis::create(&db, song.id.clone(), analysis.clone()).await?;
526 assert_eq!(result, Some(analysis.clone()));
527
528 let result = Song::delete(&db, song.id.clone()).await?;
530 assert_eq!(result, Some(song.clone()));
531
532 let result = Song::read(&db, song.id.clone()).await?;
534 assert_eq!(result, None);
535
536 let result = Analysis::read(&db, analysis.id.clone()).await?;
538 assert_eq!(result, None);
539
540 let result = Analysis::read_for_song(&db, song.id.clone()).await?;
542 assert_eq!(result, None);
543
544 let result = Analysis::read_songs_without_analysis(&db).await?;
546 assert_eq!(result, vec![]);
547
548 let result = Analysis::read_song(&db, analysis.id.clone()).await;
550 assert!(matches!(result, Err(Error::NotFound)));
551
552 Ok(())
553 }
554}