mecomp_daemon/
config.rs

1//! Handles the configuration of the daemon.
2//!
3//! this module is responsible for parsing the Config.toml file, parsing cli arguments, and
4//! setting up the logger.
5
6use config::{Config, ConfigError, Environment, File};
7use one_or_many::OneOrMany;
8use serde::Deserialize;
9
10use std::{path::PathBuf, str::FromStr};
11
12use mecomp_storage::util::MetadataConflictResolution;
13
14pub static DEFAULT_CONFIG: &str = include_str!("../Mecomp.toml");
15
16#[derive(Clone, Debug, Deserialize, Default, PartialEq, Eq)]
17pub struct Settings {
18    /// General Daemon Settings
19    #[serde(default)]
20    pub daemon: DaemonSettings,
21    /// Parameters for the reclustering algorithm.
22    #[serde(default)]
23    pub reclustering: ReclusterSettings,
24}
25
26impl Settings {
27    /// Load settings from the config file, environment variables, and CLI arguments.
28    ///
29    /// The config file is located at the path specified by the `--config` flag.
30    ///
31    /// The environment variables are prefixed with `MECOMP_`.
32    ///
33    /// # Arguments
34    ///
35    /// * `flags` - The parsed CLI arguments.
36    ///
37    /// # Errors
38    ///
39    /// This function will return an error if the config file is not found or if the config file is
40    /// invalid.
41    pub fn init(
42        config: PathBuf,
43        port: Option<u16>,
44        log_level: Option<log::LevelFilter>,
45    ) -> Result<Self, ConfigError> {
46        let s = Config::builder()
47            .add_source(File::from(config))
48            .add_source(Environment::with_prefix("MECOMP"))
49            .build()?;
50
51        let mut settings: Self = s.try_deserialize()?;
52
53        for path in &mut settings.daemon.library_paths {
54            *path = shellexpand::tilde(&path.to_string_lossy())
55                .into_owned()
56                .into();
57        }
58
59        if let Some(port) = port {
60            settings.daemon.rpc_port = port;
61        }
62
63        if let Some(log_level) = log_level {
64            settings.daemon.log_level = log_level;
65        }
66
67        Ok(settings)
68    }
69}
70
71#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
72pub struct DaemonSettings {
73    /// The port to listen on for RPC requests.
74    /// Default is 6600.
75    #[serde(default = "default_port")]
76    pub rpc_port: u16,
77    /// The root paths of the music library.
78    #[serde(default = "default_library_paths")]
79    pub library_paths: Box<[PathBuf]>,
80    /// Separators for artist names in song metadata.
81    /// For example, "Foo, Bar, Baz" would be split into \["Foo", "Bar", "Baz"\]. if the separator is ", ".
82    /// If the separator is not found, the entire string is considered as a single artist.
83    /// If unset, will not split artists.
84    ///
85    /// Users can provide one or many separators, and must provide them as either a single string or an array of strings.
86    ///
87    /// ```toml
88    /// [daemon]
89    /// artist_separator = " & "
90    /// artist_separator = [" & ", "; "]
91    ///
92    ///
93    /// ```
94    #[serde(default, deserialize_with = "de_artist_separator")]
95    pub artist_separator: OneOrMany<String>,
96    #[serde(default)]
97    pub genre_separator: Option<String>,
98    /// how conflicting metadata should be resolved
99    /// "overwrite" - overwrite the metadata with new metadata
100    /// "skip" - skip the file (keep old metadata)
101    #[serde(default)]
102    pub conflict_resolution: MetadataConflictResolution,
103    /// What level of logging to use.
104    /// Default is "info".
105    #[serde(default = "default_log_level")]
106    #[serde(deserialize_with = "de_log_level")]
107    pub log_level: log::LevelFilter,
108}
109
110fn de_artist_separator<'de, D>(deserializer: D) -> Result<OneOrMany<String>, D::Error>
111where
112    D: serde::Deserializer<'de>,
113{
114    let v = OneOrMany::<String>::deserialize(deserializer)?
115        .into_iter()
116        .filter(|s| !s.is_empty())
117        .collect::<OneOrMany<String>>();
118    if v.is_empty() {
119        Ok(OneOrMany::None)
120    } else {
121        Ok(v)
122    }
123}
124
125fn de_log_level<'de, D>(deserializer: D) -> Result<log::LevelFilter, D::Error>
126where
127    D: serde::Deserializer<'de>,
128{
129    let s = String::deserialize(deserializer)?;
130    Ok(log::LevelFilter::from_str(&s).unwrap_or_else(|_| default_log_level()))
131}
132
133const fn default_port() -> u16 {
134    6600
135}
136
137fn default_library_paths() -> Box<[PathBuf]> {
138    vec![shellexpand::tilde("~/Music/").into_owned().into()].into_boxed_slice()
139}
140
141const fn default_log_level() -> log::LevelFilter {
142    log::LevelFilter::Info
143}
144
145impl Default for DaemonSettings {
146    fn default() -> Self {
147        Self {
148            rpc_port: default_port(),
149            library_paths: default_library_paths(),
150            artist_separator: OneOrMany::None,
151            genre_separator: None,
152            conflict_resolution: MetadataConflictResolution::Overwrite,
153            log_level: default_log_level(),
154        }
155    }
156}
157
158#[derive(Clone, Copy, Debug, Deserialize, Default, PartialEq, Eq)]
159#[serde(rename_all = "lowercase")]
160pub enum ClusterAlgorithm {
161    KMeans,
162    #[default]
163    GMM,
164}
165
166#[cfg(feature = "analysis")]
167impl From<ClusterAlgorithm> for mecomp_analysis::clustering::ClusteringMethod {
168    fn from(algo: ClusterAlgorithm) -> Self {
169        match algo {
170            ClusterAlgorithm::KMeans => Self::KMeans,
171            ClusterAlgorithm::GMM => Self::GaussianMixtureModel,
172        }
173    }
174}
175
176#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq)]
177pub struct ReclusterSettings {
178    /// The number of reference datasets to use for the gap statistic.
179    /// (which is used to determine the optimal number of clusters)
180    /// 50 will give a decent estimate but for the best results use more,
181    /// 500 will give a very good estimate but be very slow.
182    /// We default to 250 in release mode.
183    #[serde(default = "default_gap_statistic_reference_datasets")]
184    pub gap_statistic_reference_datasets: usize,
185    /// The maximum number of clusters to create.
186    /// This is the upper bound on the number of clusters that can be created.
187    /// Increase if you're getting a "could not find optimal k" error.
188    /// Default is 24.
189    #[serde(default = "default_max_clusters")]
190    pub max_clusters: usize,
191    /// The clustering algorithm to use.
192    /// Either "kmeans" or "gmm".
193    #[serde(default)]
194    pub algorithm: ClusterAlgorithm,
195}
196
197const fn default_gap_statistic_reference_datasets() -> usize {
198    50
199}
200
201const fn default_max_clusters() -> usize {
202    #[cfg(debug_assertions)]
203    return 16;
204    #[cfg(not(debug_assertions))]
205    return 24;
206}
207
208impl Default for ReclusterSettings {
209    fn default() -> Self {
210        Self {
211            gap_statistic_reference_datasets: default_gap_statistic_reference_datasets(),
212            max_clusters: default_max_clusters(),
213            algorithm: ClusterAlgorithm::default(),
214        }
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    use pretty_assertions::assert_eq;
223    use rstest::rstest;
224
225    #[derive(Debug, PartialEq, Eq, Deserialize)]
226    #[serde(transparent)]
227    struct ArtistSeparatorTest {
228        #[serde(deserialize_with = "de_artist_separator")]
229        artist_separator: OneOrMany<String>,
230    }
231
232    #[rstest]
233    #[case(Vec::<String>::new())]
234    #[case("")]
235    fn test_de_artist_separator_empty<'de, D>(#[case] input: D)
236    where
237        D: serde::de::IntoDeserializer<'de>,
238    {
239        let deserializer = input.into_deserializer();
240        let result: Result<OneOrMany<String>, _> = de_artist_separator(deserializer);
241        assert!(result.is_ok());
242        assert!(result.unwrap().is_empty());
243    }
244
245    #[rstest]
246    #[case(vec![" & "], String::from(" & ").into())]
247    #[case(" & ", String::from(" & ").into())]
248    #[case(vec![" & ", "; "], vec![String::from(" & "), String::from("; ")].into())]
249    #[case(vec!["", " & ", "", "; "], vec![String::from(" & "), String::from("; ")].into())]
250    fn test_de_artist_separator<'de, D>(#[case] input: D, #[case] expected: OneOrMany<String>)
251    where
252        D: serde::de::IntoDeserializer<'de>,
253    {
254        let deserializer = input.into_deserializer();
255        let result: Result<OneOrMany<String>, _> = de_artist_separator(deserializer);
256        assert!(result.is_ok());
257        assert_eq!(result.unwrap(), expected);
258    }
259
260    #[test]
261    fn test_init_config() {
262        let temp_dir = tempfile::tempdir().unwrap();
263        let config_path = temp_dir.path().join("config.toml");
264        std::fs::write(
265            &config_path,
266            r#"            
267[daemon]
268rpc_port = 6600
269library_paths = ["/Music"]
270artist_separator = ["; "]
271genre_separator = ", "
272conflict_resolution = "overwrite"
273log_level = "debug"
274
275[reclustering]
276gap_statistic_reference_datasets = 50
277max_clusters = 24
278algorithm = "gmm"
279            "#,
280        )
281        .unwrap();
282
283        let expected = Settings {
284            daemon: DaemonSettings {
285                rpc_port: 6600,
286                library_paths: ["/Music".into()].into(),
287                artist_separator: vec!["; ".into()].into(),
288                genre_separator: Some(", ".into()),
289                conflict_resolution: MetadataConflictResolution::Overwrite,
290                log_level: log::LevelFilter::Debug,
291            },
292            reclustering: ReclusterSettings {
293                gap_statistic_reference_datasets: 50,
294                max_clusters: 24,
295                algorithm: ClusterAlgorithm::GMM,
296            },
297        };
298
299        let settings = Settings::init(config_path, None, None).unwrap();
300
301        assert_eq!(settings, expected);
302    }
303
304    #[test]
305    fn test_default_config_works() {
306        let temp_dir = tempfile::tempdir().unwrap();
307        let config_path = temp_dir.path().join("config.toml");
308        std::fs::write(&config_path, DEFAULT_CONFIG).unwrap();
309
310        let settings = Settings::init(config_path, None, None);
311
312        assert!(settings.is_ok(), "Error: {:?}", settings.err());
313    }
314}