mecomp_core/
config.rs

1//! Handles the configuration of the daemon.
2//!
3//! this module is responsible for parsing the Config.toml file, parsing cli arguments, and
4//! setting up the logger.
5
6use config::{Config, ConfigError, Environment, File};
7use one_or_many::OneOrMany;
8use serde::Deserialize;
9
10use std::{path::PathBuf, str::FromStr};
11
12use mecomp_storage::util::MetadataConflictResolution;
13
14pub static DEFAULT_CONFIG: &str = include_str!("../Mecomp.toml");
15
16#[derive(Clone, Debug, Deserialize, Default, PartialEq, Eq)]
17pub struct Settings {
18    /// General Daemon Settings
19    #[serde(default)]
20    pub daemon: DaemonSettings,
21    /// Parameters for the reclustering algorithm.
22    #[serde(default)]
23    pub reclustering: ReclusterSettings,
24    /// Settings for the TUI
25    #[serde(default)]
26    pub tui: TuiSettings,
27}
28
29impl Settings {
30    /// Load settings from the config file, environment variables, and CLI arguments.
31    ///
32    /// The config file is located at the path specified by the `--config` flag.
33    ///
34    /// The environment variables are prefixed with `MECOMP_`.
35    ///
36    /// # Arguments
37    ///
38    /// * `flags` - The parsed CLI arguments.
39    ///
40    /// # Errors
41    ///
42    /// This function will return an error if the config file is not found or if the config file is
43    /// invalid.
44    #[inline]
45    pub fn init(
46        config: PathBuf,
47        port: Option<u16>,
48        log_level: Option<log::LevelFilter>,
49    ) -> Result<Self, ConfigError> {
50        let s = Config::builder()
51            .add_source(File::from(config))
52            .add_source(Environment::with_prefix("MECOMP"))
53            .build()?;
54
55        let mut settings: Self = s.try_deserialize()?;
56
57        for path in &mut settings.daemon.library_paths {
58            *path = shellexpand::tilde(&path.to_string_lossy())
59                .into_owned()
60                .into();
61        }
62
63        if let Some(port) = port {
64            settings.daemon.rpc_port = port;
65        }
66
67        if let Some(log_level) = log_level {
68            settings.daemon.log_level = log_level;
69        }
70
71        Ok(settings)
72    }
73
74    /// Get the (default) path to the config file.
75    /// If the config file does not exist at this path, it will be created with the default config.
76    ///
77    /// See [`crate::get_config_dir`] for more information about where this default path is located.
78    ///
79    /// # Errors
80    ///
81    /// This function will return an error if the system config directory (e.g., `~/.config` on linux) could not be found, or if the config file was missing and could not be created.
82    #[inline]
83    pub fn get_config_path() -> Result<PathBuf, std::io::Error> {
84        match crate::get_config_dir() {
85            Ok(config_dir) => {
86                // if the config directory does not exist, create it
87                if !config_dir.exists() {
88                    std::fs::create_dir_all(&config_dir)?;
89                }
90                let config_file = config_dir.join("Mecomp.toml");
91
92                if !config_file.exists() {
93                    std::fs::write(&config_file, DEFAULT_CONFIG)?;
94                }
95
96                Ok(config_file)
97            }
98            Err(e) => {
99                eprintln!("Error: {e}");
100                Err(std::io::Error::new(
101                    std::io::ErrorKind::NotFound,
102                    "Unable to find the config directory for mecomp.",
103                ))
104            }
105        }
106    }
107}
108
109#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
110pub struct DaemonSettings {
111    /// The port to listen on for RPC requests.
112    /// Default is 6600.
113    #[serde(default = "default_port")]
114    pub rpc_port: u16,
115    /// The root paths of the music library.
116    #[serde(default = "default_library_paths")]
117    pub library_paths: Box<[PathBuf]>,
118    /// Separators for artist names in song metadata.
119    /// For example, "Foo, Bar, Baz" would be split into \["Foo", "Bar", "Baz"\]. if the separator is ", ".
120    /// If the separator is not found, the entire string is considered as a single artist.
121    /// If unset, will not split artists.
122    ///
123    /// Users can provide one or many separators, and must provide them as either a single string or an array of strings.
124    ///
125    /// ```toml
126    /// [daemon]
127    /// artist_separator = " & "
128    /// artist_separator = [" & ", "; "]
129    /// ...
130    /// ```
131    #[serde(default, deserialize_with = "de_artist_separator")]
132    pub artist_separator: OneOrMany<String>,
133    #[serde(default)]
134    pub genre_separator: Option<String>,
135    /// how conflicting metadata should be resolved
136    /// "overwrite" - overwrite the metadata with new metadata
137    /// "skip" - skip the file (keep old metadata)
138    #[serde(default)]
139    pub conflict_resolution: MetadataConflictResolution,
140    /// What level of logging to use.
141    /// Default is "info".
142    #[serde(default = "default_log_level")]
143    #[serde(deserialize_with = "de_log_level")]
144    pub log_level: log::LevelFilter,
145}
146
147fn de_artist_separator<'de, D>(deserializer: D) -> Result<OneOrMany<String>, D::Error>
148where
149    D: serde::Deserializer<'de>,
150{
151    let v = OneOrMany::<String>::deserialize(deserializer)?
152        .into_iter()
153        .filter(|s| !s.is_empty())
154        .collect::<OneOrMany<String>>();
155    if v.is_empty() {
156        Ok(OneOrMany::None)
157    } else {
158        Ok(v)
159    }
160}
161
162fn de_log_level<'de, D>(deserializer: D) -> Result<log::LevelFilter, D::Error>
163where
164    D: serde::Deserializer<'de>,
165{
166    let s = String::deserialize(deserializer)?;
167    Ok(log::LevelFilter::from_str(&s).unwrap_or_else(|_| default_log_level()))
168}
169
170const fn default_port() -> u16 {
171    6600
172}
173
174fn default_library_paths() -> Box<[PathBuf]> {
175    vec![shellexpand::tilde("~/Music/").into_owned().into()].into_boxed_slice()
176}
177
178const fn default_log_level() -> log::LevelFilter {
179    log::LevelFilter::Info
180}
181
182impl Default for DaemonSettings {
183    #[inline]
184    fn default() -> Self {
185        Self {
186            rpc_port: default_port(),
187            library_paths: default_library_paths(),
188            artist_separator: OneOrMany::None,
189            genre_separator: None,
190            conflict_resolution: MetadataConflictResolution::Overwrite,
191            log_level: default_log_level(),
192        }
193    }
194}
195
196#[derive(Clone, Copy, Debug, Deserialize, Default, PartialEq, Eq)]
197#[serde(rename_all = "lowercase")]
198pub enum ClusterAlgorithm {
199    KMeans,
200    #[default]
201    GMM,
202}
203
204impl From<ClusterAlgorithm> for mecomp_analysis::clustering::ClusteringMethod {
205    #[inline]
206    fn from(algo: ClusterAlgorithm) -> Self {
207        match algo {
208            ClusterAlgorithm::KMeans => Self::KMeans,
209            ClusterAlgorithm::GMM => Self::GaussianMixtureModel,
210        }
211    }
212}
213
214#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq)]
215pub struct ReclusterSettings {
216    /// The number of reference datasets to use for the gap statistic.
217    /// (which is used to determine the optimal number of clusters)
218    /// 50 will give a decent estimate but for the best results use more,
219    /// 500 will give a very good estimate but be very slow.
220    /// We default to 250 in release mode.
221    #[serde(default = "default_gap_statistic_reference_datasets")]
222    pub gap_statistic_reference_datasets: usize,
223    /// The maximum number of clusters to create.
224    /// This is the upper bound on the number of clusters that can be created.
225    /// Increase if you're getting a "could not find optimal k" error.
226    /// Default is 24.
227    #[serde(default = "default_max_clusters")]
228    pub max_clusters: usize,
229    /// The clustering algorithm to use.
230    /// Either "kmeans" or "gmm".
231    #[serde(default)]
232    pub algorithm: ClusterAlgorithm,
233}
234
235const fn default_gap_statistic_reference_datasets() -> usize {
236    50
237}
238
239const fn default_max_clusters() -> usize {
240    #[cfg(debug_assertions)]
241    return 16;
242    #[cfg(not(debug_assertions))]
243    return 24;
244}
245
246impl Default for ReclusterSettings {
247    #[inline]
248    fn default() -> Self {
249        Self {
250            gap_statistic_reference_datasets: default_gap_statistic_reference_datasets(),
251            max_clusters: default_max_clusters(),
252            algorithm: ClusterAlgorithm::default(),
253        }
254    }
255}
256
257#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
258pub struct TuiSettings {
259    /// How many songs should be queried for when starting a radio.
260    /// Default is 20.
261    #[serde(default = "default_radio_count")]
262    pub radio_count: u32,
263}
264
265const fn default_radio_count() -> u32 {
266    20
267}
268
269impl Default for TuiSettings {
270    #[inline]
271    fn default() -> Self {
272        Self {
273            radio_count: default_radio_count(),
274        }
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    use pretty_assertions::assert_eq;
283    use rstest::rstest;
284
285    #[derive(Debug, PartialEq, Eq, Deserialize)]
286    #[serde(transparent)]
287    struct ArtistSeparatorTest {
288        #[serde(deserialize_with = "de_artist_separator")]
289        artist_separator: OneOrMany<String>,
290    }
291
292    #[rstest]
293    #[case(Vec::<String>::new())]
294    #[case("")]
295    fn test_de_artist_separator_empty<'de, D>(#[case] input: D)
296    where
297        D: serde::de::IntoDeserializer<'de>,
298    {
299        let deserializer = input.into_deserializer();
300        let result: Result<OneOrMany<String>, _> = de_artist_separator(deserializer);
301        assert!(result.is_ok());
302        assert!(result.unwrap().is_empty());
303    }
304
305    #[rstest]
306    #[case(vec![" & "], String::from(" & ").into())]
307    #[case(" & ", String::from(" & ").into())]
308    #[case(vec![" & ", "; "], vec![String::from(" & "), String::from("; ")].into())]
309    #[case(vec!["", " & ", "", "; "], vec![String::from(" & "), String::from("; ")].into())]
310    fn test_de_artist_separator<'de, D>(#[case] input: D, #[case] expected: OneOrMany<String>)
311    where
312        D: serde::de::IntoDeserializer<'de>,
313    {
314        let deserializer = input.into_deserializer();
315        let result: Result<OneOrMany<String>, _> = de_artist_separator(deserializer);
316        assert!(result.is_ok());
317        assert_eq!(result.unwrap(), expected);
318    }
319
320    #[test]
321    fn test_init_config() {
322        let temp_dir = tempfile::tempdir().unwrap();
323        let config_path = temp_dir.path().join("config.toml");
324        std::fs::write(
325            &config_path,
326            r#"            
327[daemon]
328rpc_port = 6600
329library_paths = ["/Music"]
330artist_separator = ["; "]
331genre_separator = ", "
332conflict_resolution = "overwrite"
333log_level = "debug"
334
335[reclustering]
336gap_statistic_reference_datasets = 50
337max_clusters = 24
338algorithm = "gmm"
339
340[tui]
341radio_count = 21
342            "#,
343        )
344        .unwrap();
345
346        let expected = Settings {
347            daemon: DaemonSettings {
348                rpc_port: 6600,
349                library_paths: ["/Music".into()].into(),
350                artist_separator: vec!["; ".into()].into(),
351                genre_separator: Some(", ".into()),
352                conflict_resolution: MetadataConflictResolution::Overwrite,
353                log_level: log::LevelFilter::Debug,
354            },
355            reclustering: ReclusterSettings {
356                gap_statistic_reference_datasets: 50,
357                max_clusters: 24,
358                algorithm: ClusterAlgorithm::GMM,
359            },
360            tui: TuiSettings { radio_count: 21 },
361        };
362
363        let settings = Settings::init(config_path, None, None).unwrap();
364
365        assert_eq!(settings, expected);
366    }
367
368    #[test]
369    fn test_default_config_works() {
370        let temp_dir = tempfile::tempdir().unwrap();
371        let config_path = temp_dir.path().join("config.toml");
372        std::fs::write(&config_path, DEFAULT_CONFIG).unwrap();
373
374        let settings = Settings::init(config_path, None, None);
375
376        assert!(settings.is_ok(), "Error: {:?}", settings.err());
377    }
378}