Skip to main content

mecomp_core/
config.rs

1//! Handles the configuration of the daemon.
2//!
3//! this module is responsible for parsing the Config.toml file, parsing cli arguments, and
4//! setting up the logger.
5
6use config::{Config, ConfigError, Environment, File};
7use one_or_many::OneOrMany;
8use serde::Deserialize;
9
10use std::{num::NonZeroUsize, path::PathBuf, str::FromStr};
11
12use mecomp_storage::util::MetadataConflictResolution;
13
14pub static DEFAULT_CONFIG: &str = include_str!("../Mecomp.toml");
15
16#[derive(Clone, Debug, Deserialize, Default, PartialEq, Eq)]
17pub struct Settings {
18    /// General Daemon Settings
19    #[serde(default)]
20    pub daemon: DaemonSettings,
21    /// Settings for song analysis
22    #[serde(default)]
23    pub analysis: AnalysisSettings,
24    /// Parameters for the reclustering algorithm.
25    #[serde(default)]
26    pub reclustering: ReclusterSettings,
27    /// Settings for the TUI
28    #[serde(default)]
29    pub tui: TuiSettings,
30}
31
32impl Settings {
33    /// Load settings from the config file, environment variables, and CLI arguments.
34    ///
35    /// The config file is located at the path specified by the `--config` flag.
36    ///
37    /// The environment variables are prefixed with `MECOMP_`.
38    ///
39    /// # Arguments
40    ///
41    /// * `flags` - The parsed CLI arguments.
42    ///
43    /// # Errors
44    ///
45    /// This function will return an error if the config file is not found or if the config file is
46    /// invalid.
47    #[inline]
48    pub fn init(
49        config: PathBuf,
50        port: Option<u16>,
51        log_level: Option<log::LevelFilter>,
52    ) -> Result<Self, ConfigError> {
53        let s = Config::builder()
54            .add_source(File::from(config))
55            .add_source(Environment::with_prefix("MECOMP"))
56            .build()?;
57
58        let mut settings: Self = s.try_deserialize()?;
59
60        for path in &mut settings.daemon.library_paths {
61            *path = shellexpand::tilde(&path.to_string_lossy())
62                .into_owned()
63                .into();
64        }
65
66        if let Some(port) = port {
67            settings.daemon.rpc_port = port;
68        }
69
70        if let Some(log_level) = log_level {
71            settings.daemon.log_level = log_level;
72        }
73
74        Ok(settings)
75    }
76
77    /// Get the (default) path to the config file.
78    /// If the config file does not exist at this path, it will be created with the default config.
79    ///
80    /// See [`crate::get_config_dir`] for more information about where this default path is located.
81    ///
82    /// # Errors
83    ///
84    /// This function will return an error if the system config directory (e.g., `~/.config` on linux) could not be found, or if the config file was missing and could not be created.
85    #[inline]
86    pub fn get_config_path() -> Result<PathBuf, std::io::Error> {
87        match crate::get_config_dir() {
88            Ok(config_dir) => {
89                // if the config directory does not exist, create it
90                if !config_dir.exists() {
91                    std::fs::create_dir_all(&config_dir)?;
92                }
93                let config_file = config_dir.join("Mecomp.toml");
94
95                if !config_file.exists() {
96                    std::fs::write(&config_file, DEFAULT_CONFIG)?;
97                }
98
99                Ok(config_file)
100            }
101            Err(e) => {
102                eprintln!("Error: {e}");
103                Err(std::io::Error::new(
104                    std::io::ErrorKind::NotFound,
105                    "Unable to find the config directory for mecomp.",
106                ))
107            }
108        }
109    }
110}
111
112#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
113pub struct DaemonSettings {
114    /// The port to listen on for RPC requests.
115    /// Default is 6600.
116    #[serde(default = "default_port")]
117    pub rpc_port: u16,
118    /// The root paths of the music library.
119    #[serde(default = "default_library_paths")]
120    pub library_paths: Box<[PathBuf]>,
121    /// Separators for artist names in song metadata.
122    /// For example, "Foo, Bar, Baz" would be split into \["Foo", "Bar", "Baz"\]. if the separator is ", ".
123    /// If the separator is not found, the entire string is considered as a single artist.
124    /// If unset, will not split artists.
125    ///
126    /// Users can provide one or many separators, and must provide them as either a single string or an array of strings.
127    ///
128    /// ```toml
129    /// [daemon]
130    /// artist_separator = " & "
131    /// artist_separator = [" & ", "; "]
132    /// ...
133    /// ```
134    #[serde(default, deserialize_with = "de_artist_separator")]
135    pub artist_separator: OneOrMany<String>,
136    /// Exceptions for artist name separation, for example:
137    /// "Foo & Bar; Baz" would be split into \["Foo", "Bar", "Baz"\] if the separators are set to "&" and "; ".
138    ///
139    /// However, if the following exception is set:
140    /// ```toml
141    /// [daemon]
142    /// protected_artist_names = ["Foo & Bar"]
143    /// ```
144    /// Then the artist "Foo & Bar; Baz" would be split into \["Foo & Bar", "Baz"\].
145    ///
146    /// Note that the exception applies to the entire "name", so:
147    /// ```toml
148    /// [daemon]
149    /// protected_artist_names = ["Foo & Bar"]
150    /// ```
151    /// would split "Foo & Bar" into \["Foo & Bar"\],
152    /// but "Foo & Bar Baz" would still be split into \["Foo", "Bar Baz"\].
153    #[serde(default)]
154    pub protected_artist_names: OneOrMany<String>,
155    #[serde(default)]
156    pub genre_separator: Option<String>,
157    /// how conflicting metadata should be resolved
158    /// "overwrite" - overwrite the metadata with new metadata
159    /// "skip" - skip the file (keep old metadata)
160    #[serde(default)]
161    pub conflict_resolution: MetadataConflictResolution,
162    /// What level of logging to use.
163    /// Default is "info".
164    #[serde(default = "default_log_level")]
165    #[serde(deserialize_with = "de_log_level")]
166    pub log_level: log::LevelFilter,
167}
168
169fn de_artist_separator<'de, D>(deserializer: D) -> Result<OneOrMany<String>, D::Error>
170where
171    D: serde::Deserializer<'de>,
172{
173    let v = OneOrMany::<String>::deserialize(deserializer)?
174        .into_iter()
175        .filter(|s| !s.is_empty())
176        .collect::<OneOrMany<String>>();
177    if v.is_empty() {
178        Ok(OneOrMany::None)
179    } else {
180        Ok(v)
181    }
182}
183
184fn de_log_level<'de, D>(deserializer: D) -> Result<log::LevelFilter, D::Error>
185where
186    D: serde::Deserializer<'de>,
187{
188    let s = String::deserialize(deserializer)?;
189    Ok(log::LevelFilter::from_str(&s).unwrap_or_else(|_| default_log_level()))
190}
191
192const fn default_port() -> u16 {
193    6600
194}
195
196fn default_library_paths() -> Box<[PathBuf]> {
197    vec![shellexpand::tilde("~/Music/").into_owned().into()].into_boxed_slice()
198}
199
200const fn default_log_level() -> log::LevelFilter {
201    log::LevelFilter::Info
202}
203
204impl Default for DaemonSettings {
205    #[inline]
206    fn default() -> Self {
207        Self {
208            rpc_port: default_port(),
209            library_paths: default_library_paths(),
210            artist_separator: OneOrMany::None,
211            protected_artist_names: OneOrMany::None,
212            genre_separator: None,
213            conflict_resolution: MetadataConflictResolution::Overwrite,
214            log_level: default_log_level(),
215        }
216    }
217}
218
219#[derive(Clone, Copy, Debug, Default, Deserialize, PartialEq, Eq)]
220#[serde(rename_all = "lowercase")]
221pub enum AnalysisKind {
222    #[default]
223    Features,
224    Embedding,
225}
226
227#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
228pub struct AnalysisSettings {
229    /// The kind of analysis to perform, either "features" or "embedding".
230    /// "features" will compute traditional audio features (tempo, key, etc.)
231    /// "embedding" will compute neural audio embedding using a pre-trained model.
232    /// Default is "features".
233    ///
234    /// Note that regardless of this setting, both features and embedding will be computed during analysis.
235    /// This only determines the kind used for clustering, radio, and other such tasks
236    #[serde(default)]
237    pub kind: AnalysisKind,
238    /// The number of threads to use for analysis.
239    /// Default is the number of logical CPUs on the system.
240    ///
241    /// Note that:
242    /// - increasing this number may increase memory usage significantly during analysis.
243    /// - setting this number to more than the number of logical CPUs will have no effect (saturates at number of logical CPUs).
244    /// - leave this unset to use the default.
245    #[serde(default)]
246    pub num_threads: Option<NonZeroUsize>,
247    /// You can optionally override the model used for generating audio embeddings.
248    /// Requirements:
249    /// - The model must be in the ONNX format with opset version 16 or higher.
250    /// - The model should expect mono audio samples at a sample rate of 22,050 Hz.
251    /// - The input tensor must be name "audio" and have shape [B, N] where N a dynamic length corresponding to the number of audio samples in the song, and B is the batch size.
252    /// - The output tensor must be name "embedding" and have shape [B, 32] corresponding to a 32-dimensional embedding vector. B is the batch size.
253    ///
254    /// If unset, or a non-existent/invalid path, the built-in model (which is bundled into the daemon binary) will be used.
255    #[serde(default)]
256    pub model_path: Option<PathBuf>,
257}
258
259impl Default for AnalysisSettings {
260    #[inline]
261    fn default() -> Self {
262        Self {
263            kind: AnalysisKind::default(),
264            num_threads: None,
265            model_path: None,
266        }
267    }
268}
269
270#[derive(Clone, Copy, Debug, Deserialize, Default, PartialEq, Eq)]
271#[serde(rename_all = "lowercase")]
272pub enum ClusterAlgorithm {
273    KMeans,
274    #[default]
275    GMM,
276}
277
278#[cfg(feature = "analysis")]
279impl From<ClusterAlgorithm> for mecomp_analysis::clustering::ClusteringMethod {
280    #[inline]
281    fn from(algo: ClusterAlgorithm) -> Self {
282        match algo {
283            ClusterAlgorithm::KMeans => Self::KMeans,
284            ClusterAlgorithm::GMM => Self::GaussianMixtureModel,
285        }
286    }
287}
288
289#[derive(Clone, Copy, Debug, Default, Deserialize, PartialEq, Eq)]
290#[serde(rename_all = "lowercase")]
291pub enum ProjectionMethod {
292    #[default]
293    None,
294    TSne,
295    Pca,
296}
297
298#[cfg(feature = "analysis")]
299impl From<ProjectionMethod> for mecomp_analysis::clustering::ProjectionMethod {
300    #[inline]
301    fn from(proj: ProjectionMethod) -> Self {
302        match proj {
303            ProjectionMethod::None => Self::None,
304            ProjectionMethod::TSne => Self::TSne,
305            ProjectionMethod::Pca => Self::Pca,
306        }
307    }
308}
309
310#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq)]
311pub struct ReclusterSettings {
312    /// The number of reference datasets to use for the gap statistic.
313    /// (which is used to determine the optimal number of clusters)
314    /// 50 will give a decent estimate but for the best results use more,
315    /// 500 will give a very good estimate but be very slow.
316    /// We default to 250 in release mode.
317    #[serde(default = "default_gap_statistic_reference_datasets")]
318    pub gap_statistic_reference_datasets: u32,
319    /// The maximum number of clusters to create.
320    /// This is the upper bound on the number of clusters that can be created.
321    /// Increase if you're getting a "could not find optimal k" error.
322    /// Default is 24.
323    #[serde(default = "default_max_clusters")]
324    pub max_clusters: usize,
325    /// The clustering algorithm to use.
326    /// Either "kmeans" or "gmm".
327    #[serde(default)]
328    pub algorithm: ClusterAlgorithm,
329    /// The projection method to preprocess the data with before clustering.
330    /// Either "tsne", "pca", or "none".
331    /// Default is "none".
332    #[serde(default)]
333    pub projection_method: ProjectionMethod,
334}
335
336const fn default_gap_statistic_reference_datasets() -> u32 {
337    50
338}
339
340const fn default_max_clusters() -> usize {
341    #[cfg(debug_assertions)]
342    return 16;
343    #[cfg(not(debug_assertions))]
344    return 24;
345}
346
347impl Default for ReclusterSettings {
348    #[inline]
349    fn default() -> Self {
350        Self {
351            gap_statistic_reference_datasets: default_gap_statistic_reference_datasets(),
352            max_clusters: default_max_clusters(),
353            algorithm: ClusterAlgorithm::default(),
354            projection_method: ProjectionMethod::default(),
355        }
356    }
357}
358
359#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
360pub struct TuiSettings {
361    /// How many songs should be queried for when starting a radio.
362    /// Default is 20.
363    #[serde(default = "default_radio_count")]
364    pub radio_count: u32,
365    /// The color scheme to use for the TUI.
366    /// Each color is either:
367    /// - a hex string in the format `#RRGGBB`.
368    ///   example: `#FFFFFF` for white.
369    /// - a material design color name in format "<COLOR>_<SHADE>".
370    ///   so "pink", `red-900`,  `light-blue_500`, `red900`, etc. are all invalid.
371    ///   but `PINK_900`, `RED_900`, `LIGHT_BLUE_500` are valid.
372    ///   - Exceptions are `WHITE` and `BLACK`, which are always valid.
373    #[serde(default)]
374    pub colors: TuiColorScheme,
375}
376
377const fn default_radio_count() -> u32 {
378    20
379}
380
381impl Default for TuiSettings {
382    #[inline]
383    fn default() -> Self {
384        Self {
385            radio_count: default_radio_count(),
386            colors: TuiColorScheme::default(),
387        }
388    }
389}
390
391#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Default)]
392pub struct TuiColorScheme {
393    /// app border colors
394    pub app_border: Option<String>,
395    pub app_border_text: Option<String>,
396    /// border colors
397    pub border_unfocused: Option<String>,
398    pub border_focused: Option<String>,
399    /// popup border color
400    pub popup_border: Option<String>,
401    /// text colors
402    pub text_normal: Option<String>,
403    pub text_highlight: Option<String>,
404    pub text_highlight_alt: Option<String>,
405    /// gauge colors, such as song progress bar
406    pub gauge_filled: Option<String>,
407    pub gauge_unfilled: Option<String>,
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413
414    use pretty_assertions::assert_eq;
415    use rstest::rstest;
416
417    #[derive(Debug, PartialEq, Eq, Deserialize)]
418    #[allow(dead_code)]
419    #[serde(transparent)]
420    struct ArtistSeparatorTest {
421        #[serde(deserialize_with = "de_artist_separator")]
422        artist_separator: OneOrMany<String>,
423    }
424
425    #[rstest]
426    #[case(Vec::<String>::new())]
427    #[case("")]
428    fn test_de_artist_separator_empty<'de, D>(#[case] input: D)
429    where
430        D: serde::de::IntoDeserializer<'de>,
431    {
432        let deserializer = input.into_deserializer();
433        let result: Result<OneOrMany<String>, _> = de_artist_separator(deserializer);
434        assert!(result.is_ok());
435        assert!(result.unwrap().is_empty());
436    }
437
438    #[rstest]
439    #[case(vec![" & "], String::from(" & ").into())]
440    #[case(" & ", String::from(" & ").into())]
441    #[case(vec![" & ", "; "], vec![String::from(" & "), String::from("; ")].into())]
442    #[case(vec!["", " & ", "", "; "], vec![String::from(" & "), String::from("; ")].into())]
443    fn test_de_artist_separator<'de, D>(#[case] input: D, #[case] expected: OneOrMany<String>)
444    where
445        D: serde::de::IntoDeserializer<'de>,
446    {
447        let deserializer = input.into_deserializer();
448        let result: Result<OneOrMany<String>, _> = de_artist_separator(deserializer);
449        assert!(result.is_ok());
450        assert_eq!(result.unwrap(), expected);
451    }
452
453    #[test]
454    fn test_init_config() {
455        let temp_dir = tempfile::tempdir().unwrap();
456        let config_path = temp_dir.path().join("config.toml");
457        std::fs::write(
458            &config_path,
459            r#"            
460[daemon]
461rpc_port = 6600
462library_paths = ["/Music"]
463artist_separator = ["; "]
464genre_separator = ", "
465conflict_resolution = "overwrite"
466log_level = "debug"
467
468[reclustering]
469gap_statistic_reference_datasets = 50
470max_clusters = 24
471algorithm = "gmm"
472
473[tui]
474radio_count = 21
475[tui.colors]
476app_border = "PINK_900"
477app_border_text = "PINK_300"
478border_unfocused = "RED_900"
479border_focused = "RED_200"
480popup_border = "LIGHT_BLUE_500"
481text_normal = "WHITE"
482text_highlight = "RED_600"
483text_highlight_alt = "RED_200"
484gauge_filled = "WHITE"
485gauge_unfilled = "BLACK"
486            "#,
487        )
488        .unwrap();
489
490        let expected = Settings {
491            daemon: DaemonSettings {
492                rpc_port: 6600,
493                library_paths: ["/Music".into()].into(),
494                artist_separator: vec!["; ".into()].into(),
495                protected_artist_names: OneOrMany::None,
496                genre_separator: Some(", ".into()),
497                conflict_resolution: MetadataConflictResolution::Overwrite,
498                log_level: log::LevelFilter::Debug,
499            },
500            analysis: AnalysisSettings::default(),
501            reclustering: ReclusterSettings {
502                gap_statistic_reference_datasets: 50,
503                max_clusters: 24,
504                algorithm: ClusterAlgorithm::GMM,
505                projection_method: ProjectionMethod::None,
506            },
507            tui: TuiSettings {
508                radio_count: 21,
509                colors: TuiColorScheme {
510                    app_border: Some("PINK_900".into()),
511                    app_border_text: Some("PINK_300".into()),
512                    border_unfocused: Some("RED_900".into()),
513                    border_focused: Some("RED_200".into()),
514                    popup_border: Some("LIGHT_BLUE_500".into()),
515                    text_normal: Some("WHITE".into()),
516                    text_highlight: Some("RED_600".into()),
517                    text_highlight_alt: Some("RED_200".into()),
518                    gauge_filled: Some("WHITE".into()),
519                    gauge_unfilled: Some("BLACK".into()),
520                },
521            },
522        };
523
524        let settings = Settings::init(config_path, None, None).unwrap();
525
526        assert_eq!(settings, expected);
527    }
528
529    #[test]
530    fn test_tui_colors_unset() {
531        let temp_dir = tempfile::tempdir().unwrap();
532        let config_path = temp_dir.path().join("config.toml");
533        std::fs::write(
534            &config_path,
535            r#"            
536[daemon]
537rpc_port = 6600
538library_paths = ["/Music"]
539artist_separator = ["; "]
540protected_artist_names = ["Foo & Bar"]
541genre_separator = ", "
542conflict_resolution = "overwrite"
543log_level = "debug"
544
545[reclustering]
546gap_statistic_reference_datasets = 50
547max_clusters = 24
548algorithm = "gmm"
549
550[tui]
551radio_count = 21
552            "#,
553        )
554        .unwrap();
555
556        let expected = Settings {
557            daemon: DaemonSettings {
558                rpc_port: 6600,
559                library_paths: ["/Music".into()].into(),
560                artist_separator: vec!["; ".into()].into(),
561                protected_artist_names: "Foo & Bar".to_string().into(),
562                genre_separator: Some(", ".into()),
563                conflict_resolution: MetadataConflictResolution::Overwrite,
564                log_level: log::LevelFilter::Debug,
565            },
566            analysis: AnalysisSettings::default(),
567            reclustering: ReclusterSettings {
568                gap_statistic_reference_datasets: 50,
569                max_clusters: 24,
570                algorithm: ClusterAlgorithm::GMM,
571                projection_method: ProjectionMethod::None,
572            },
573            tui: TuiSettings {
574                radio_count: 21,
575                colors: TuiColorScheme::default(),
576            },
577        };
578
579        let settings = Settings::init(config_path, None, None).unwrap();
580
581        assert_eq!(settings, expected);
582    }
583
584    #[test]
585    fn test_artist_names_to_not_split() {
586        let temp_dir = tempfile::tempdir().unwrap();
587        let config_path = temp_dir.path().join("config.toml");
588        std::fs::write(
589            &config_path,
590            r#"            
591[daemon]
592rpc_port = 6600
593library_paths = ["/Music"]
594artist_separator = ["; "]
595protected_artist_names = ["Foo & Bar"]
596genre_separator = ", "
597conflict_resolution = "overwrite"
598log_level = "debug"
599
600[reclustering]
601gap_statistic_reference_datasets = 50
602max_clusters = 24
603algorithm = "gmm"
604
605[tui]
606radio_count = 21
607            "#,
608        )
609        .unwrap();
610
611        let expected = Settings {
612            daemon: DaemonSettings {
613                rpc_port: 6600,
614                library_paths: ["/Music".into()].into(),
615                artist_separator: vec!["; ".into()].into(),
616                protected_artist_names: "Foo & Bar".to_string().into(),
617                genre_separator: Some(", ".into()),
618                conflict_resolution: MetadataConflictResolution::Overwrite,
619                log_level: log::LevelFilter::Debug,
620            },
621            analysis: AnalysisSettings::default(),
622            reclustering: ReclusterSettings {
623                gap_statistic_reference_datasets: 50,
624                max_clusters: 24,
625                algorithm: ClusterAlgorithm::GMM,
626                projection_method: ProjectionMethod::None,
627            },
628            tui: TuiSettings {
629                radio_count: 21,
630                colors: TuiColorScheme::default(),
631            },
632        };
633
634        let settings = Settings::init(config_path, None, None).unwrap();
635
636        assert_eq!(settings, expected);
637    }
638
639    #[test]
640    fn test_default_config_works() {
641        let temp_dir = tempfile::tempdir().unwrap();
642        let config_path = temp_dir.path().join("config.toml");
643        std::fs::write(&config_path, DEFAULT_CONFIG).unwrap();
644
645        let settings = Settings::init(config_path, None, None);
646
647        assert!(settings.is_ok(), "Error: {:?}", settings.err());
648    }
649}