1use crate::{BlissError, BlissResult, Song, NUMBER_FEATURES};
11use extended_isolation_forest::{Forest, ForestOptions};
12use ndarray::{Array, Array1, Array2, Axis};
13use ndarray_stats::QuantileExt;
14use noisy_float::prelude::*;
15use std::collections::HashMap;
16
17pub trait DistanceMetricBuilder {
26 fn build<'a>(&'a self, vectors: &[Array1<f32>]) -> Box<dyn DistanceMetric + 'a>;
28}
29
30pub trait DistanceMetric {
32 fn distance(&self, vector: &Array1<f32>) -> f32;
34}
35
36pub struct FunctionDistanceMetric<'a, F: Fn(&Array1<f32>, &Array1<f32>) -> f32> {
38 func: &'a F,
39 state: Vec<Array1<f32>>,
40}
41
42impl<F> DistanceMetricBuilder for F
43where
44 F: Fn(&Array1<f32>, &Array1<f32>) -> f32 + 'static,
45{
46 fn build<'a>(&'a self, vectors: &[Array1<f32>]) -> Box<dyn DistanceMetric + 'a> {
47 Box::new(FunctionDistanceMetric {
48 func: self,
49 state: vectors.iter().map(|s| s.to_owned()).collect(),
50 })
51 }
52}
53
54impl<F: Fn(&Array1<f32>, &Array1<f32>) -> f32 + 'static> DistanceMetric
55 for FunctionDistanceMetric<'_, F>
56{
57 fn distance(&self, vector: &Array1<f32>) -> f32 {
58 self.state.iter().map(|v| (self.func)(v, vector)).sum()
59 }
60}
61
62pub fn euclidean_distance(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
66 let m = Array::eye(a.len());
70 (a - b).dot(&m).dot(&(a - b)).sqrt()
71}
72
73pub fn cosine_distance(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
77 let similarity = a.dot(b) / (a.dot(a).sqrt() * b.dot(b).sqrt());
78 1. - similarity
79}
80
81pub fn mahalanobis_distance_builder(m: Array2<f32>) -> impl Fn(&Array1<f32>, &Array1<f32>) -> f32 {
130 move |a: &Array1<f32>, b: &Array1<f32>| mahalanobis_distance(a, b, &m)
131}
132
133pub fn mahalanobis_distance(a: &Array1<f32>, b: &Array1<f32>, m: &Array2<f32>) -> f32 {
141 (a - b).dot(m).dot(&(a - b)).sqrt()
142}
143
144fn feature_array1_to_array(f: &Array1<f32>) -> [f32; NUMBER_FEATURES] {
145 f.as_slice()
146 .expect("Couldn't convert feature vector to slice")
147 .try_into()
148 .expect("Couldn't convert slice to array")
149}
150
151impl DistanceMetricBuilder for ForestOptions {
152 fn build(&self, vectors: &[Array1<f32>]) -> Box<dyn DistanceMetric> {
153 let a = &*vectors
154 .iter()
155 .map(feature_array1_to_array)
156 .collect::<Vec<_>>();
157
158 if self.sample_size > vectors.len() {
159 let mut opts = self.clone();
160 opts.sample_size = self.sample_size.min(vectors.len());
161 Box::new(Forest::from_slice(a, &opts).unwrap())
162 } else {
163 Box::new(Forest::from_slice(a, self).unwrap())
164 }
165 }
166}
167
168impl DistanceMetric for Forest<f32, NUMBER_FEATURES> {
169 fn distance(&self, vector: &Array1<f32>) -> f32 {
170 self.score(&feature_array1_to_array(vector)) as f32
171 }
172}
173
174pub fn closest_to_songs<'a, T: AsRef<Song> + Clone + 'a>(
178 initial_songs: &[T],
179 candidate_songs: &[T],
180 metric_builder: &'a dyn DistanceMetricBuilder,
181) -> impl Iterator<Item = T> + 'a {
182 let initial_songs = initial_songs
183 .iter()
184 .map(|c| c.as_ref().analysis.as_arr1())
185 .collect::<Vec<_>>();
186 let metric = metric_builder.build(&initial_songs);
187 let mut candidate_songs = candidate_songs.to_vec();
188 candidate_songs
189 .sort_by_cached_key(|song| n32(metric.distance(&song.as_ref().analysis.as_arr1())));
190 candidate_songs.into_iter()
191}
192
193struct SongToSongIterator<'a, T: AsRef<Song> + Clone> {
194 pool: Vec<T>,
195 vectors: Vec<Array1<f32>>,
196 metric_builder: &'a dyn DistanceMetricBuilder,
197}
198
199impl<T: AsRef<Song> + Clone> Iterator for SongToSongIterator<'_, T> {
200 type Item = T;
201
202 fn next(&mut self) -> Option<T> {
203 if self.pool.is_empty() {
204 return None;
205 }
206 let metric = self.metric_builder.build(&self.vectors);
207 let distances: Array1<f32> = Array::from_shape_fn(self.pool.len(), |j| {
208 metric.distance(&self.pool[j].as_ref().analysis.as_arr1())
209 });
210 let idx = distances.argmin().unwrap();
211 self.vectors.clear();
215 let song = self.pool.remove(idx);
216 self.vectors.push(song.as_ref().analysis.as_arr1());
217 Some(song)
218 }
219}
220
221pub fn song_to_song<'a, T: AsRef<Song> + Clone + 'a>(
231 initial_songs: &[T],
232 candidate_songs: &[T],
233 metric_builder: &'a dyn DistanceMetricBuilder,
234) -> impl Iterator<Item = T> + 'a {
235 let vectors = initial_songs
236 .iter()
237 .map(|s| s.as_ref().analysis.as_arr1())
238 .collect::<Vec<_>>();
239 let pool = candidate_songs.to_vec();
240 let iterator = SongToSongIterator {
241 vectors,
242 metric_builder,
243 pool,
244 };
245 iterator.into_iter()
246}
247
248pub fn dedup_playlist<'a, T: AsRef<Song>>(
260 playlist: impl Iterator<Item = T> + 'a,
261 distance_threshold: Option<f32>,
262) -> impl Iterator<Item = T> + 'a {
263 dedup_playlist_custom_distance(playlist, distance_threshold, &euclidean_distance)
264}
265
266pub fn dedup_playlist_custom_distance<'a, T: AsRef<Song>>(
280 playlist: impl Iterator<Item = T> + 'a,
281 distance_threshold: Option<f32>,
282 metric_builder: &'a dyn DistanceMetricBuilder,
283) -> impl Iterator<Item = T> + 'a {
284 let mut peekable = playlist.peekable();
285 let final_iterator = std::iter::from_fn(move || {
286 if let Some(s1) = peekable.next() {
287 loop {
288 if let Some(s2) = peekable.peek() {
289 let s1_ref = s1.as_ref();
290 let s2_ref = s2.as_ref();
291 let vector = [s1_ref.analysis.as_arr1()];
292 let metric = metric_builder.build(&vector);
293 let is_same = n32(metric.distance(&s2_ref.analysis.as_arr1()))
294 < distance_threshold.unwrap_or(0.05)
295 || (s1_ref.title.is_some()
296 && s2_ref.title.is_some()
297 && s1_ref.artist.is_some()
298 && s2_ref.artist.is_some()
299 && s1_ref.title == s2_ref.title
300 && s1_ref.artist == s2_ref.artist);
301 if is_same {
302 peekable.next();
303 continue;
304 } else {
305 return Some(s1);
306 }
307 }
308 return Some(s1);
309 }
310 }
311 None
312 });
313 final_iterator
314}
315
316pub fn closest_album_to_group<T: AsRef<Song> + Clone>(
337 group: Vec<T>,
338 pool: Vec<T>,
339) -> BlissResult<Vec<T>> {
340 let mut albums_analysis: HashMap<&str, Array2<f32>> = HashMap::new();
341 let mut albums = Vec::new();
342
343 let pool = pool
345 .into_iter()
346 .filter(|s| !group.iter().any(|gs| gs.as_ref() == s.as_ref()))
347 .collect::<Vec<_>>();
348 for song in &pool {
349 if let Some(album) = &song.as_ref().album {
350 if let Some(analysis) = albums_analysis.get_mut(album as &str) {
351 analysis
352 .push_row(song.as_ref().analysis.as_arr1().view())
353 .map_err(|e| {
354 BlissError::ProviderError(format!("while computing distances: {e}"))
355 })?;
356 } else {
357 let mut array = Array::zeros((1, song.as_ref().analysis.as_arr1().len()));
358 array.assign(&song.as_ref().analysis.as_arr1());
359 albums_analysis.insert(album, array);
360 }
361 }
362 }
363 let number_features = group[0].as_ref().analysis.as_vec().len();
364 let mut group_analysis = Array::zeros((group.len(), number_features));
365 for (song, mut column) in group.iter().zip(group_analysis.axis_iter_mut(Axis(0))) {
366 column.assign(&song.as_ref().analysis.as_arr1());
367 }
368 let first_analysis = group_analysis
369 .mean_axis(Axis(0))
370 .ok_or_else(|| BlissError::ProviderError(String::from("Mean of empty slice")))?;
371 for (album, analysis) in albums_analysis.iter() {
372 let mean_analysis = analysis
373 .mean_axis(Axis(0))
374 .ok_or_else(|| BlissError::ProviderError(String::from("Mean of empty slice")))?;
375 let album = album.to_owned();
376 albums.push((album, mean_analysis.to_owned()));
377 }
378
379 albums.sort_by_key(|(_, analysis)| n32(euclidean_distance(&first_analysis, analysis)));
380 let mut playlist = group;
381 for (album, _) in albums {
382 let mut al = pool
383 .iter()
384 .filter(|s| s.as_ref().album.as_deref() == Some(album))
385 .cloned()
386 .collect::<Vec<T>>();
387 al.sort_by(|s1, s2| {
388 let track_number1 = s1.as_ref().track_number.to_owned();
389 let track_number2 = s2.as_ref().track_number.to_owned();
390 let disc_number1 = s1.as_ref().disc_number.to_owned();
391 let disc_number2 = s2.as_ref().disc_number.to_owned();
392 (disc_number1, track_number1).cmp(&(disc_number2, track_number2))
393 });
394 playlist.extend(al);
395 }
396 Ok(playlist)
397}
398
399#[cfg(test)]
400mod test {
401 use super::*;
402 use crate::{Analysis, FeaturesVersion};
403 use ndarray::arr1;
404 use std::path::Path;
405
406 #[derive(Debug, Clone, PartialEq)]
407 struct CustomSong {
408 something: bool,
409 bliss_song: Song,
410 }
411
412 impl AsRef<Song> for CustomSong {
413 fn as_ref(&self) -> &Song {
414 &self.bliss_song
415 }
416 }
417
418 #[test]
419 fn test_dedup_playlist_custom_distance() {
420 let first_song = Song {
421 path: Path::new("path-to-first").to_path_buf(),
422 analysis: Analysis::new(
423 vec![
424 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
425 1., 1., 1.,
426 ],
427 FeaturesVersion::LATEST,
428 )
429 .unwrap(),
430 ..Default::default()
431 };
432 let first_song_dupe = Song {
433 path: Path::new("path-to-dupe").to_path_buf(),
434 analysis: Analysis::new(
435 vec![
436 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
437 1., 1., 1.,
438 ],
439 FeaturesVersion::LATEST,
440 )
441 .unwrap(),
442 ..Default::default()
443 };
444
445 let second_song = Song {
446 path: Path::new("path-to-second").to_path_buf(),
447 analysis: Analysis::new(
448 vec![
449 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1.9, 1., 1.,
450 1., 1., 1., 1.,
451 ],
452 FeaturesVersion::LATEST,
453 )
454 .unwrap(),
455 title: Some(String::from("dupe-title")),
456 artist: Some(String::from("dupe-artist")),
457 ..Default::default()
458 };
459 let third_song = Song {
460 path: Path::new("path-to-third").to_path_buf(),
461 title: Some(String::from("dupe-title")),
462 artist: Some(String::from("dupe-artist")),
463 analysis: Analysis::new(
464 vec![
465 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.5, 1., 1.,
466 1., 1., 1., 1.,
467 ],
468 FeaturesVersion::LATEST,
469 )
470 .unwrap(),
471 ..Default::default()
472 };
473 let fourth_song = Song {
474 path: Path::new("path-to-fourth").to_path_buf(),
475 artist: Some(String::from("no-dupe-artist")),
476 title: Some(String::from("dupe-title")),
477 analysis: Analysis::new(
478 vec![
479 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 0., 1., 1., 1.,
480 1., 1., 1.,
481 ],
482 FeaturesVersion::LATEST,
483 )
484 .unwrap(),
485 ..Default::default()
486 };
487 let fifth_song = Song {
488 path: Path::new("path-to-fourth").to_path_buf(),
489 analysis: Analysis::new(
490 vec![
491 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 0.001, 1., 1.,
492 1., 1., 1., 1.,
493 ],
494 FeaturesVersion::LATEST,
495 )
496 .unwrap(),
497 ..Default::default()
498 };
499
500 let playlist = vec![
501 first_song.to_owned(),
502 first_song_dupe.to_owned(),
503 second_song.to_owned(),
504 third_song.to_owned(),
505 fourth_song.to_owned(),
506 fifth_song.to_owned(),
507 ];
508 let playlist =
509 dedup_playlist_custom_distance(playlist.into_iter(), None, &euclidean_distance)
510 .collect::<Vec<_>>();
511 assert_eq!(
512 playlist,
513 vec![
514 first_song.to_owned(),
515 second_song.to_owned(),
516 fourth_song.to_owned(),
517 ],
518 );
519 let playlist = vec![
520 first_song.to_owned(),
521 first_song_dupe.to_owned(),
522 second_song.to_owned(),
523 third_song.to_owned(),
524 fourth_song.to_owned(),
525 fifth_song.to_owned(),
526 ];
527 let playlist =
528 dedup_playlist_custom_distance(playlist.into_iter(), Some(20.), &euclidean_distance)
529 .collect::<Vec<_>>();
530 assert_eq!(playlist, vec![first_song.to_owned()]);
531 let playlist = vec![
532 first_song.to_owned(),
533 first_song_dupe.to_owned(),
534 second_song.to_owned(),
535 third_song.to_owned(),
536 fourth_song.to_owned(),
537 fifth_song.to_owned(),
538 ];
539 let playlist = dedup_playlist(playlist.into_iter(), Some(20.)).collect::<Vec<_>>();
540 assert_eq!(playlist, vec![first_song.to_owned()]);
541 let playlist = vec![
542 first_song.to_owned(),
543 first_song_dupe.to_owned(),
544 second_song.to_owned(),
545 third_song.to_owned(),
546 fourth_song.to_owned(),
547 fifth_song.to_owned(),
548 ];
549 let playlist = dedup_playlist(playlist.into_iter(), None).collect::<Vec<_>>();
550 assert_eq!(
551 playlist,
552 vec![
553 first_song.to_owned(),
554 second_song.to_owned(),
555 fourth_song.to_owned(),
556 ]
557 );
558
559 let first_song = CustomSong {
560 bliss_song: first_song,
561 something: true,
562 };
563 let second_song = CustomSong {
564 bliss_song: second_song,
565 something: true,
566 };
567 let first_song_dupe = CustomSong {
568 bliss_song: first_song_dupe,
569 something: true,
570 };
571 let third_song = CustomSong {
572 bliss_song: third_song,
573 something: true,
574 };
575 let fourth_song = CustomSong {
576 bliss_song: fourth_song,
577 something: true,
578 };
579
580 let fifth_song = CustomSong {
581 bliss_song: fifth_song,
582 something: true,
583 };
584
585 let playlist = vec![
586 first_song.to_owned(),
587 first_song_dupe.to_owned(),
588 second_song.to_owned(),
589 third_song.to_owned(),
590 fourth_song.to_owned(),
591 fifth_song.to_owned(),
592 ];
593 let playlist =
594 dedup_playlist_custom_distance(playlist.into_iter(), None, &euclidean_distance)
595 .collect::<Vec<_>>();
596 assert_eq!(
597 playlist,
598 vec![
599 first_song.to_owned(),
600 second_song.to_owned(),
601 fourth_song.to_owned(),
602 ],
603 );
604 let playlist = vec![
605 first_song.to_owned(),
606 first_song_dupe.to_owned(),
607 second_song.to_owned(),
608 third_song.to_owned(),
609 fourth_song.to_owned(),
610 fifth_song.to_owned(),
611 ];
612 let playlist =
613 dedup_playlist_custom_distance(playlist.into_iter(), Some(20.), &cosine_distance)
614 .collect::<Vec<_>>();
615 assert_eq!(playlist, vec![first_song.to_owned()]);
616 let playlist = vec![
617 first_song.to_owned(),
618 first_song_dupe.to_owned(),
619 second_song.to_owned(),
620 third_song.to_owned(),
621 fourth_song.to_owned(),
622 fifth_song.to_owned(),
623 ];
624 let playlist = dedup_playlist(playlist.into_iter(), Some(20.)).collect::<Vec<_>>();
625 assert_eq!(playlist, vec![first_song.to_owned()]);
626 let playlist = vec![
627 first_song.to_owned(),
628 first_song_dupe.to_owned(),
629 second_song.to_owned(),
630 third_song.to_owned(),
631 fourth_song.to_owned(),
632 fifth_song.to_owned(),
633 ];
634 let playlist = dedup_playlist(playlist.into_iter(), None).collect::<Vec<_>>();
635 assert_eq!(
636 playlist,
637 vec![
638 first_song.to_owned(),
639 second_song.to_owned(),
640 fourth_song.to_owned(),
641 ]
642 );
643 }
644
645 #[test]
646 fn test_song_to_song() {
647 let first_song = Song {
648 path: Path::new("path-to-first").to_path_buf(),
649 analysis: Analysis::new(
650 vec![
651 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
652 1., 1., 1.,
653 ],
654 FeaturesVersion::LATEST,
655 )
656 .unwrap(),
657 ..Default::default()
658 };
659 let first_song_dupe = Song {
660 path: Path::new("path-to-dupe").to_path_buf(),
661 analysis: Analysis::new(
662 vec![
663 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
664 1., 1., 1.,
665 ],
666 FeaturesVersion::LATEST,
667 )
668 .unwrap(),
669 ..Default::default()
670 };
671
672 let second_song = Song {
673 path: Path::new("path-to-second").to_path_buf(),
674 analysis: Analysis::new(
675 vec![
676 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1.9, 1., 1.,
677 1., 1., 1., 1.,
678 ],
679 FeaturesVersion::LATEST,
680 )
681 .unwrap(),
682 ..Default::default()
683 };
684 let third_song = Song {
685 path: Path::new("path-to-third").to_path_buf(),
686 analysis: Analysis::new(
687 vec![
688 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.5, 1., 1.,
689 1., 1., 1., 1.,
690 ],
691 FeaturesVersion::LATEST,
692 )
693 .unwrap(),
694 ..Default::default()
695 };
696 let fourth_song = Song {
697 path: Path::new("path-to-fourth").to_path_buf(),
698 analysis: Analysis::new(
699 vec![
700 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 0., 1., 1., 1.,
701 1., 1., 1.,
702 ],
703 FeaturesVersion::LATEST,
704 )
705 .unwrap(),
706 ..Default::default()
707 };
708 let mut songs = vec![
709 &first_song,
710 &third_song,
711 &first_song_dupe,
712 &second_song,
713 &fourth_song,
714 ];
715 let songs =
716 song_to_song(&[&first_song], &mut songs, &euclidean_distance).collect::<Vec<_>>();
717 assert_eq!(
718 songs,
719 vec![
720 &first_song,
721 &first_song_dupe,
722 &second_song,
723 &third_song,
724 &fourth_song,
725 ],
726 );
727
728 let first_song = CustomSong {
729 bliss_song: first_song,
730 something: true,
731 };
732 let second_song = CustomSong {
733 bliss_song: second_song,
734 something: true,
735 };
736 let first_song_dupe = CustomSong {
737 bliss_song: first_song_dupe,
738 something: true,
739 };
740 let third_song = CustomSong {
741 bliss_song: third_song,
742 something: true,
743 };
744 let fourth_song = CustomSong {
745 bliss_song: fourth_song,
746 something: true,
747 };
748
749 let mut songs: Vec<&CustomSong> = vec![
750 &first_song,
751 &first_song_dupe,
752 &third_song,
753 &fourth_song,
754 &second_song,
755 ];
756
757 let songs =
758 song_to_song(&[&first_song], &mut songs, &euclidean_distance).collect::<Vec<_>>();
759
760 assert_eq!(
761 songs,
762 vec![
763 &first_song,
764 &first_song_dupe,
765 &second_song,
766 &third_song,
767 &fourth_song,
768 ],
769 );
770 }
771
772 #[test]
773 fn test_sort_closest_to_songs() {
774 let first_song = Song {
775 path: Path::new("path-to-first").to_path_buf(),
776 analysis: Analysis::new(
777 vec![
778 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
779 1., 1., 1.,
780 ],
781 FeaturesVersion::LATEST,
782 )
783 .unwrap(),
784 ..Default::default()
785 };
786 let first_song_dupe = Song {
787 path: Path::new("path-to-dupe").to_path_buf(),
788 analysis: Analysis::new(
789 vec![
790 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
791 1., 1., 1.,
792 ],
793 FeaturesVersion::LATEST,
794 )
795 .unwrap(),
796 ..Default::default()
797 };
798
799 let second_song = Song {
800 path: Path::new("path-to-second").to_path_buf(),
801 analysis: Analysis::new(
802 vec![
803 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1.9, 1., 1.,
804 1., 1., 1., 1.,
805 ],
806 FeaturesVersion::LATEST,
807 )
808 .unwrap(),
809 ..Default::default()
810 };
811 let third_song = Song {
812 path: Path::new("path-to-third").to_path_buf(),
813 analysis: Analysis::new(
814 vec![
815 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.5, 1., 1.,
816 1., 1., 1., 1.,
817 ],
818 FeaturesVersion::LATEST,
819 )
820 .unwrap(),
821 ..Default::default()
822 };
823 let fourth_song = Song {
824 path: Path::new("path-to-fourth").to_path_buf(),
825 analysis: Analysis::new(
826 vec![
827 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 0., 1., 1., 1.,
828 1., 1., 1.,
829 ],
830 FeaturesVersion::LATEST,
831 )
832 .unwrap(),
833 ..Default::default()
834 };
835 let fifth_song = Song {
836 path: Path::new("path-to-fifth").to_path_buf(),
837 analysis: Analysis::new(
838 vec![
839 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 0., 1., 1., 1.,
840 1., 1., 1.,
841 ],
842 FeaturesVersion::LATEST,
843 )
844 .unwrap(),
845 ..Default::default()
846 };
847
848 let songs = [
849 &fifth_song,
850 &fourth_song,
851 &first_song,
852 &first_song_dupe,
853 &second_song,
854 &third_song,
855 ];
856 let playlist: Vec<_> =
857 closest_to_songs(&[&first_song], &songs, &euclidean_distance).collect();
858 assert_eq!(
859 playlist,
860 [
861 &first_song,
862 &first_song_dupe,
863 &second_song,
864 &fifth_song,
865 &fourth_song,
866 &third_song
867 ],
868 );
869
870 let first_song = CustomSong {
871 bliss_song: first_song,
872 something: true,
873 };
874 let second_song = CustomSong {
875 bliss_song: second_song,
876 something: true,
877 };
878 let first_song_dupe = CustomSong {
879 bliss_song: first_song_dupe,
880 something: true,
881 };
882 let third_song = CustomSong {
883 bliss_song: third_song,
884 something: true,
885 };
886 let fourth_song = CustomSong {
887 bliss_song: fourth_song,
888 something: true,
889 };
890
891 let fifth_song = CustomSong {
892 bliss_song: fifth_song,
893 something: true,
894 };
895
896 let mut songs = [
897 &second_song,
898 &first_song,
899 &fourth_song,
900 &first_song_dupe,
901 &third_song,
902 &fifth_song,
903 ];
904
905 let playlist: Vec<_> =
906 closest_to_songs(&[&first_song], &mut songs, &euclidean_distance).collect();
907
908 assert_eq!(
909 playlist,
910 [
911 &first_song,
912 &first_song_dupe,
913 &second_song,
914 &fourth_song,
915 &fifth_song,
916 &third_song
917 ],
918 );
919 }
920
921 #[test]
922 fn test_mahalanobis_distance() {
923 let a = arr1(&[
924 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,
925 ]);
926 let b = arr1(&[
927 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
928 ]);
929 let m = Array2::eye(FeaturesVersion::Version1.feature_count())
930 * arr1(&[
931 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
932 ]);
933
934 let distance = mahalanobis_distance_builder(m);
935 assert_eq!(distance(&a, &b), 1.);
936 }
937
938 #[test]
939 fn test_mahalanobis_distance_with_songs() {
940 let first_song = Song {
941 path: Path::new("path-to-first").to_path_buf(),
942 analysis: Analysis::new(
943 vec![
944 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
945 1., 1., 1.,
946 ],
947 FeaturesVersion::LATEST,
948 )
949 .unwrap(),
950 ..Default::default()
951 };
952 let second_song = Song {
953 path: Path::new("path-to-second").to_path_buf(),
954 analysis: Analysis::new(
955 vec![
956 1.5, 5., 6., 5., 6., 6., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
957 1., 1., 1., 1.,
958 ],
959 FeaturesVersion::LATEST,
960 )
961 .unwrap(),
962 ..Default::default()
963 };
964 let third_song = Song {
965 path: Path::new("path-to-third").to_path_buf(),
966 analysis: Analysis::new(
967 vec![
968 5., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
969 1., 1., 1.,
970 ],
971 FeaturesVersion::LATEST,
972 )
973 .unwrap(),
974 ..Default::default()
975 };
976 let m = Array2::eye(NUMBER_FEATURES)
977 * arr1(&[
978 1.0, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
979 0., 0., 0.,
980 ]);
981 let distance = mahalanobis_distance_builder(m);
982
983 let playlist = closest_to_songs(
984 &[first_song.clone()],
985 &[third_song.clone(), second_song.clone()],
986 &distance,
987 )
988 .collect::<Vec<_>>();
989 assert_eq!(playlist, vec![second_song, third_song,]);
990 }
991
992 #[test]
993 fn test_euclidean_distance() {
994 let a = arr1(&[
995 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,
996 ]);
997 let b = arr1(&[
998 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
999 ]);
1000 assert_eq!(euclidean_distance(&a, &b), 4.242640687119285);
1001
1002 let a = arr1(&[0.5; 20]);
1003 let b = arr1(&[0.5; 20]);
1004 assert_eq!(euclidean_distance(&a, &b), 0.);
1005 assert_eq!(euclidean_distance(&a, &b), 0.);
1006 }
1007
1008 #[test]
1009 fn test_cosine_distance() {
1010 let a = arr1(&[
1011 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,
1012 ]);
1013 let b = arr1(&[
1014 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
1015 ]);
1016 assert_eq!(cosine_distance(&a, &b), 0.7705842661294382);
1017
1018 let a = arr1(&[0.5; 20]);
1019 let b = arr1(&[0.5; 20]);
1020 assert_eq!(cosine_distance(&a, &b), 0.);
1021 assert_eq!(cosine_distance(&a, &b), 0.);
1022 }
1023
1024 #[test]
1025 fn test_closest_to_group() {
1026 for version in vec![FeaturesVersion::Version1, FeaturesVersion::Version2] {
1027 let first_song = Song {
1028 path: Path::new("path-to-first").to_path_buf(),
1029 analysis: Analysis::new(vec![0.; version.feature_count()], version).unwrap(),
1030 album: Some(String::from("Album")),
1031 artist: Some(String::from("Artist")),
1032 track_number: Some(1),
1033 disc_number: Some(1),
1034 ..Default::default()
1035 };
1036 let second_song = Song {
1037 path: Path::new("path-to-third").to_path_buf(),
1038 analysis: Analysis::new(vec![10.; version.feature_count()].to_vec(), version)
1039 .unwrap(),
1040 album: Some(String::from("Album")),
1041 artist: Some(String::from("Another Artist")),
1042 track_number: Some(2),
1043 disc_number: Some(1),
1044 ..Default::default()
1045 };
1046
1047 let first_song_other_album_disc_1 = Song {
1048 path: Path::new("path-to-second-2").to_path_buf(),
1049 analysis: Analysis::new(vec![0.15; version.feature_count()].to_vec(), version)
1050 .unwrap(),
1051 album: Some(String::from("Another Album")),
1052 artist: Some(String::from("Artist")),
1053 track_number: Some(1),
1054 disc_number: Some(1),
1055 ..Default::default()
1056 };
1057 let second_song_other_album_disc_1 = Song {
1058 path: Path::new("path-to-second").to_path_buf(),
1059 analysis: Analysis::new(vec![0.1; version.feature_count()].to_vec(), version)
1060 .unwrap(),
1061 album: Some(String::from("Another Album")),
1062 artist: Some(String::from("Artist")),
1063 track_number: Some(2),
1064 disc_number: Some(1),
1065 ..Default::default()
1066 };
1067 let first_song_other_album_disc_2 = Song {
1068 path: Path::new("path-to-fourth").to_path_buf(),
1069 analysis: Analysis::new(vec![20.; version.feature_count()].to_vec(), version)
1070 .unwrap(),
1071 album: Some(String::from("Another Album")),
1072 artist: Some(String::from("Another Artist")),
1073 track_number: Some(1),
1074 disc_number: Some(2),
1075 ..Default::default()
1076 };
1077 let second_song_other_album_disc_2 = Song {
1078 path: Path::new("path-to-fourth").to_path_buf(),
1079 analysis: Analysis::new(vec![20.; version.feature_count()].to_vec(), version)
1080 .unwrap(),
1081 album: Some(String::from("Another Album")),
1082 artist: Some(String::from("Another Artist")),
1083 track_number: Some(4),
1084 disc_number: Some(2),
1085 ..Default::default()
1086 };
1087
1088 let song_no_album = Song {
1089 path: Path::new("path-to-fifth").to_path_buf(),
1090 analysis: Analysis::new(vec![40.; version.feature_count()].to_vec(), version)
1091 .unwrap(),
1092 artist: Some(String::from("Third Artist")),
1093 album: None,
1094 ..Default::default()
1095 };
1096
1097 let pool = vec![
1098 first_song.to_owned(),
1099 second_song_other_album_disc_1.to_owned(),
1100 second_song_other_album_disc_2.to_owned(),
1101 second_song.to_owned(),
1102 first_song_other_album_disc_2.to_owned(),
1103 first_song_other_album_disc_1.to_owned(),
1104 song_no_album.to_owned(),
1105 ];
1106 let group = vec![first_song.to_owned(), second_song.to_owned()];
1107 assert_eq!(
1108 vec![
1109 first_song.to_owned(),
1110 second_song.to_owned(),
1111 first_song_other_album_disc_1.to_owned(),
1112 second_song_other_album_disc_1.to_owned(),
1113 first_song_other_album_disc_2.to_owned(),
1114 second_song_other_album_disc_2.to_owned(),
1115 ],
1116 closest_album_to_group(group, pool.to_owned()).unwrap(),
1117 );
1118
1119 let first_song = CustomSong {
1120 bliss_song: first_song,
1121 something: true,
1122 };
1123 let second_song = CustomSong {
1124 bliss_song: second_song,
1125 something: true,
1126 };
1127
1128 let first_song_other_album_disc_1 = CustomSong {
1129 bliss_song: first_song_other_album_disc_1,
1130 something: true,
1131 };
1132 let second_song_other_album_disc_1 = CustomSong {
1133 bliss_song: second_song_other_album_disc_1,
1134 something: true,
1135 };
1136 let first_song_other_album_disc_2 = CustomSong {
1137 bliss_song: first_song_other_album_disc_2,
1138 something: true,
1139 };
1140 let second_song_other_album_disc_2 = CustomSong {
1141 bliss_song: second_song_other_album_disc_2,
1142 something: true,
1143 };
1144 let song_no_album = CustomSong {
1145 bliss_song: song_no_album,
1146 something: true,
1147 };
1148
1149 let pool = vec![
1150 first_song.to_owned(),
1151 second_song_other_album_disc_2.to_owned(),
1152 second_song_other_album_disc_1.to_owned(),
1153 second_song.to_owned(),
1154 first_song_other_album_disc_2.to_owned(),
1155 first_song_other_album_disc_1.to_owned(),
1156 song_no_album.to_owned(),
1157 ];
1158 let group = vec![first_song.to_owned(), second_song.to_owned()];
1159 assert_eq!(
1160 vec![
1161 first_song.to_owned(),
1162 second_song.to_owned(),
1163 first_song_other_album_disc_1.to_owned(),
1164 second_song_other_album_disc_1.to_owned(),
1165 first_song_other_album_disc_2.to_owned(),
1166 second_song_other_album_disc_2.to_owned(),
1167 ],
1168 closest_album_to_group(group, pool.to_owned()).unwrap(),
1169 );
1170 }
1171 }
1172
1173 #[test]
1175 fn test_forest_options() {
1176 let mozart_piano_19 = [
1179 Song {
1180 path: Path::new("path-to-first").to_path_buf(),
1181 analysis: Analysis::new(
1182 vec![
1183 0.5522649,
1184 -0.8664422,
1185 -0.81236243,
1186 -0.9475107,
1187 -0.76129013,
1188 -0.90520144,
1189 -0.8474938,
1190 -0.8924977,
1191 0.4956385,
1192 0.5076021,
1193 -0.5037869,
1194 -0.61038315,
1195 -0.47157913,
1196 -0.48194122,
1197 -0.36397678,
1198 -0.6443357,
1199 -0.9713509,
1200 -0.9781786,
1201 -0.98285836,
1202 -0.983834,
1203 -0.983834,
1204 -0.983834,
1205 -0.983834,
1206 ],
1207 FeaturesVersion::LATEST,
1208 )
1209 .unwrap(),
1210 ..Default::default()
1211 },
1212 Song {
1213 path: Path::new("path-to-second").to_path_buf(),
1214 analysis: Analysis::new(
1215 vec![
1216 0.28091776,
1217 -0.86352056,
1218 -0.8175835,
1219 -0.9497457,
1220 -0.77833027,
1221 -0.91656536,
1222 -0.8477104,
1223 -0.889485,
1224 0.41879785,
1225 0.45311546,
1226 -0.6252063,
1227 -0.6838323,
1228 -0.5326821,
1229 -0.63320035,
1230 -0.5573063,
1231 -0.7433087,
1232 -0.9815542,
1233 -0.98570454,
1234 -0.98824924,
1235 -0.9903612,
1236 -0.9903612,
1237 -0.9903612,
1238 -0.9903612,
1239 ],
1240 FeaturesVersion::LATEST,
1241 )
1242 .unwrap(),
1243 ..Default::default()
1244 },
1245 Song {
1246 path: Path::new("path-to-third").to_path_buf(),
1247 analysis: Analysis::new(
1248 vec![
1249 0.5978223,
1250 -0.84076107,
1251 -0.7841455,
1252 -0.886415,
1253 -0.72486377,
1254 -0.8015111,
1255 -0.79157853,
1256 -0.7739525,
1257 0.517207,
1258 0.535398,
1259 -0.30007458,
1260 -0.3972137,
1261 -0.41319674,
1262 -0.40709,
1263 -0.32283908,
1264 -0.5261506,
1265 -0.9656949,
1266 -0.9715169,
1267 -0.97524375,
1268 -0.9756616,
1269 -0.9756616,
1270 -0.9756616,
1271 -0.9756616,
1272 ],
1273 FeaturesVersion::LATEST,
1274 )
1275 .unwrap(),
1276 ..Default::default()
1277 },
1278 ];
1279
1280 let kind_of_blue = [
1281 Song {
1282 path: Path::new("path-to-fourth").to_path_buf(),
1283 analysis: Analysis::new(
1284 vec![
1285 0.35871255,
1286 -0.8679545,
1287 -0.6833263,
1288 -0.87800264,
1289 -0.7235142,
1290 -0.73546195,
1291 -0.48577756,
1292 -0.7732977,
1293 0.51237035,
1294 0.5379869,
1295 -0.00649637,
1296 -0.534671,
1297 -0.5743973,
1298 -0.5706258,
1299 -0.43162197,
1300 -0.6356183,
1301 -0.97918683,
1302 -0.98091763,
1303 -0.9845511,
1304 -0.98359185,
1305 -0.98359185,
1306 -0.98359185,
1307 -0.98359185,
1308 ],
1309 FeaturesVersion::LATEST,
1310 )
1311 .unwrap(),
1312 ..Default::default()
1313 },
1314 Song {
1315 path: Path::new("path-to-fifth").to_path_buf(),
1316 analysis: Analysis::new(
1317 vec![
1318 0.2806753,
1319 -0.85013694,
1320 -0.66921043,
1321 -0.8938313,
1322 -0.6848732,
1323 -0.75377,
1324 -0.48747814,
1325 -0.793482,
1326 0.44880342,
1327 0.461563,
1328 -0.115760505,
1329 -0.535959,
1330 -0.5749081,
1331 -0.55055845,
1332 -0.37976396,
1333 -0.538705,
1334 -0.97972554,
1335 -0.97890633,
1336 -0.98290455,
1337 -0.98231846,
1338 -0.98231846,
1339 -0.98231846,
1340 -0.98231846,
1341 ],
1342 FeaturesVersion::LATEST,
1343 )
1344 .unwrap(),
1345 ..Default::default()
1346 },
1347 Song {
1348 path: Path::new("path-to-sixth").to_path_buf(),
1349 analysis: Analysis::new(
1350 vec![
1351 0.1545173,
1352 -0.8991263,
1353 -0.79770947,
1354 -0.87425447,
1355 -0.77811325,
1356 -0.71051484,
1357 -0.7369138,
1358 -0.8515074,
1359 0.387398,
1360 0.42035806,
1361 -0.30229717,
1362 -0.624056,
1363 -0.6458885,
1364 -0.66208386,
1365 -0.5866134,
1366 -0.7613628,
1367 -0.98656195,
1368 -0.98821944,
1369 -0.99072844,
1370 -0.98729765,
1371 -0.98729765,
1372 -0.98729765,
1373 -0.98729765,
1374 ],
1375 FeaturesVersion::LATEST,
1376 )
1377 .unwrap(),
1378 ..Default::default()
1379 },
1380 Song {
1381 path: Path::new("path-to-seventh").to_path_buf(),
1382 analysis: Analysis::new(
1383 vec![
1384 0.3853314,
1385 -0.8475499,
1386 -0.64330614,
1387 -0.85917395,
1388 -0.6624141,
1389 -0.6356613,
1390 -0.40988427,
1391 -0.7480691,
1392 0.45981812,
1393 0.47096932,
1394 -0.19245929,
1395 -0.5228787,
1396 -0.42246288,
1397 -0.52656835,
1398 -0.45702273,
1399 -0.569838,
1400 -0.97620565,
1401 -0.97741324,
1402 -0.97741324,
1403 -0.97741324,
1404 -0.97741324,
1405 -0.9776932,
1406 -0.98088175,
1407 ],
1408 FeaturesVersion::LATEST,
1409 )
1410 .unwrap(),
1411 ..Default::default()
1412 },
1413 Song {
1414 path: Path::new("path-to-eight").to_path_buf(),
1415 analysis: Analysis::new(
1416 vec![
1417 0.18926656,
1418 -0.86667925,
1419 -0.7294189,
1420 -0.856192,
1421 -0.7180501,
1422 -0.66697484,
1423 -0.6093149,
1424 -0.82118326,
1425 0.3888924,
1426 0.42430043,
1427 -0.4414854,
1428 -0.6957753,
1429 -0.7092425,
1430 -0.68237424,
1431 -0.55543846,
1432 -0.77678657,
1433 -0.98610276,
1434 -0.98707336,
1435 -0.99165493,
1436 -0.99011236,
1437 -0.99011236,
1438 -0.99011236,
1439 -0.99011236,
1440 ],
1441 FeaturesVersion::LATEST,
1442 )
1443 .unwrap(),
1444 ..Default::default()
1445 },
1446 ];
1447
1448 let mozart_piano_23 = [
1449 Song {
1450 path: Path::new("path-to-ninth").to_path_buf(),
1451 analysis: Analysis::new(
1452 vec![
1453 0.38328362,
1454 -0.8752751,
1455 -0.8165319,
1456 -0.948534,
1457 -0.77668643,
1458 -0.9051969,
1459 -0.8473458,
1460 -0.88643366,
1461 0.49641085,
1462 0.5132351,
1463 -0.41367024,
1464 -0.5279201,
1465 -0.46787983,
1466 -0.49218357,
1467 -0.42164963,
1468 -0.6597451,
1469 -0.97317076,
1470 -0.9800342,
1471 -0.9832096,
1472 -0.98385316,
1473 -0.98385316,
1474 -0.98385316,
1475 -0.98385316,
1476 ],
1477 FeaturesVersion::LATEST,
1478 )
1479 .unwrap(),
1480 ..Default::default()
1481 },
1482 Song {
1483 path: Path::new("path-to-tenth").to_path_buf(),
1484 analysis: Analysis::new(
1485 vec![
1486 0.4301988,
1487 -0.89864063,
1488 -0.84993315,
1489 -0.9518692,
1490 -0.8329567,
1491 -0.9293889,
1492 -0.8605237,
1493 -0.8901016,
1494 0.35011983,
1495 0.3822446,
1496 -0.6384951,
1497 -0.7537949,
1498 -0.5867439,
1499 -0.57371,
1500 -0.5662942,
1501 -0.76130676,
1502 -0.9845436,
1503 -0.9833387,
1504 -0.9902381,
1505 -0.9905396,
1506 -0.9905396,
1507 -0.9905396,
1508 -0.9905396,
1509 ],
1510 FeaturesVersion::LATEST,
1511 )
1512 .unwrap(),
1513 ..Default::default()
1514 },
1515 Song {
1516 path: Path::new("path-to-eleventh").to_path_buf(),
1517 analysis: Analysis::new(
1518 vec![
1519 0.42334664,
1520 -0.8632808,
1521 -0.80268145,
1522 -0.91918564,
1523 -0.7522441,
1524 -0.8721291,
1525 -0.81877685,
1526 -0.8166921,
1527 0.53626525,
1528 0.540933,
1529 -0.34771818,
1530 -0.45362264,
1531 -0.35523874,
1532 -0.4072432,
1533 -0.25506926,
1534 -0.553644,
1535 -0.9624399,
1536 -0.9706371,
1537 -0.9753268,
1538 -0.9764576,
1539 -0.9764576,
1540 -0.9764576,
1541 -0.9764576,
1542 ],
1543 FeaturesVersion::LATEST,
1544 )
1545 .unwrap(),
1546 ..Default::default()
1547 },
1548 ];
1549
1550 let mut songs: Vec<&Song> = mozart_piano_19
1551 .iter()
1552 .chain(kind_of_blue.iter())
1553 .chain(mozart_piano_23.iter())
1554 .collect();
1555
1556 let opts = ForestOptions {
1559 n_trees: 1000,
1560 sample_size: 200,
1561 max_tree_depth: None,
1562 extension_level: 10,
1563 };
1564 let playlist: Vec<_> = closest_to_songs(
1565 &mozart_piano_19.iter().collect::<Vec<&Song>>(),
1566 &mut songs,
1567 &opts,
1568 )
1569 .collect();
1570 for e in &kind_of_blue {
1571 assert!(playlist[playlist.len() - 5..].contains(&e));
1572 }
1573 }
1574}