1use config::{Config, ConfigError, Environment, File};
7use one_or_many::OneOrMany;
8use serde::Deserialize;
9
10use std::{path::PathBuf, str::FromStr};
11
12use mecomp_storage::util::MetadataConflictResolution;
13
14pub static DEFAULT_CONFIG: &str = include_str!("../Mecomp.toml");
15
16#[derive(Clone, Debug, Deserialize, Default, PartialEq, Eq)]
17pub struct Settings {
18 #[serde(default)]
20 pub daemon: DaemonSettings,
21 #[serde(default)]
23 pub reclustering: ReclusterSettings,
24 #[serde(default)]
26 pub tui: TuiSettings,
27}
28
29impl Settings {
30 #[inline]
45 pub fn init(
46 config: PathBuf,
47 port: Option<u16>,
48 log_level: Option<log::LevelFilter>,
49 ) -> Result<Self, ConfigError> {
50 let s = Config::builder()
51 .add_source(File::from(config))
52 .add_source(Environment::with_prefix("MECOMP"))
53 .build()?;
54
55 let mut settings: Self = s.try_deserialize()?;
56
57 for path in &mut settings.daemon.library_paths {
58 *path = shellexpand::tilde(&path.to_string_lossy())
59 .into_owned()
60 .into();
61 }
62
63 if let Some(port) = port {
64 settings.daemon.rpc_port = port;
65 }
66
67 if let Some(log_level) = log_level {
68 settings.daemon.log_level = log_level;
69 }
70
71 Ok(settings)
72 }
73
74 #[inline]
83 pub fn get_config_path() -> Result<PathBuf, std::io::Error> {
84 match crate::get_config_dir() {
85 Ok(config_dir) => {
86 if !config_dir.exists() {
88 std::fs::create_dir_all(&config_dir)?;
89 }
90 let config_file = config_dir.join("Mecomp.toml");
91
92 if !config_file.exists() {
93 std::fs::write(&config_file, DEFAULT_CONFIG)?;
94 }
95
96 Ok(config_file)
97 }
98 Err(e) => {
99 eprintln!("Error: {e}");
100 Err(std::io::Error::new(
101 std::io::ErrorKind::NotFound,
102 "Unable to find the config directory for mecomp.",
103 ))
104 }
105 }
106 }
107}
108
109#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
110pub struct DaemonSettings {
111 #[serde(default = "default_port")]
114 pub rpc_port: u16,
115 #[serde(default = "default_library_paths")]
117 pub library_paths: Box<[PathBuf]>,
118 #[serde(default, deserialize_with = "de_artist_separator")]
132 pub artist_separator: OneOrMany<String>,
133 #[serde(default)]
151 pub protected_artist_names: OneOrMany<String>,
152 #[serde(default)]
153 pub genre_separator: Option<String>,
154 #[serde(default)]
158 pub conflict_resolution: MetadataConflictResolution,
159 #[serde(default = "default_log_level")]
162 #[serde(deserialize_with = "de_log_level")]
163 pub log_level: log::LevelFilter,
164}
165
166fn de_artist_separator<'de, D>(deserializer: D) -> Result<OneOrMany<String>, D::Error>
167where
168 D: serde::Deserializer<'de>,
169{
170 let v = OneOrMany::<String>::deserialize(deserializer)?
171 .into_iter()
172 .filter(|s| !s.is_empty())
173 .collect::<OneOrMany<String>>();
174 if v.is_empty() {
175 Ok(OneOrMany::None)
176 } else {
177 Ok(v)
178 }
179}
180
181fn de_log_level<'de, D>(deserializer: D) -> Result<log::LevelFilter, D::Error>
182where
183 D: serde::Deserializer<'de>,
184{
185 let s = String::deserialize(deserializer)?;
186 Ok(log::LevelFilter::from_str(&s).unwrap_or_else(|_| default_log_level()))
187}
188
189const fn default_port() -> u16 {
190 6600
191}
192
193fn default_library_paths() -> Box<[PathBuf]> {
194 vec![shellexpand::tilde("~/Music/").into_owned().into()].into_boxed_slice()
195}
196
197const fn default_log_level() -> log::LevelFilter {
198 log::LevelFilter::Info
199}
200
201impl Default for DaemonSettings {
202 #[inline]
203 fn default() -> Self {
204 Self {
205 rpc_port: default_port(),
206 library_paths: default_library_paths(),
207 artist_separator: OneOrMany::None,
208 protected_artist_names: OneOrMany::None,
209 genre_separator: None,
210 conflict_resolution: MetadataConflictResolution::Overwrite,
211 log_level: default_log_level(),
212 }
213 }
214}
215
216#[derive(Clone, Copy, Debug, Deserialize, Default, PartialEq, Eq)]
217#[serde(rename_all = "lowercase")]
218pub enum ClusterAlgorithm {
219 KMeans,
220 #[default]
221 GMM,
222}
223
224impl From<ClusterAlgorithm> for mecomp_analysis::clustering::ClusteringMethod {
225 #[inline]
226 fn from(algo: ClusterAlgorithm) -> Self {
227 match algo {
228 ClusterAlgorithm::KMeans => Self::KMeans,
229 ClusterAlgorithm::GMM => Self::GaussianMixtureModel,
230 }
231 }
232}
233
234#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq)]
235pub struct ReclusterSettings {
236 #[serde(default = "default_gap_statistic_reference_datasets")]
242 pub gap_statistic_reference_datasets: usize,
243 #[serde(default = "default_max_clusters")]
248 pub max_clusters: usize,
249 #[serde(default)]
252 pub algorithm: ClusterAlgorithm,
253}
254
255const fn default_gap_statistic_reference_datasets() -> usize {
256 50
257}
258
259const fn default_max_clusters() -> usize {
260 #[cfg(debug_assertions)]
261 return 16;
262 #[cfg(not(debug_assertions))]
263 return 24;
264}
265
266impl Default for ReclusterSettings {
267 #[inline]
268 fn default() -> Self {
269 Self {
270 gap_statistic_reference_datasets: default_gap_statistic_reference_datasets(),
271 max_clusters: default_max_clusters(),
272 algorithm: ClusterAlgorithm::default(),
273 }
274 }
275}
276
277#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
278pub struct TuiSettings {
279 #[serde(default = "default_radio_count")]
282 pub radio_count: u32,
283}
284
285const fn default_radio_count() -> u32 {
286 20
287}
288
289impl Default for TuiSettings {
290 #[inline]
291 fn default() -> Self {
292 Self {
293 radio_count: default_radio_count(),
294 }
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 use pretty_assertions::assert_eq;
303 use rstest::rstest;
304
305 #[derive(Debug, PartialEq, Eq, Deserialize)]
306 #[serde(transparent)]
307 struct ArtistSeparatorTest {
308 #[serde(deserialize_with = "de_artist_separator")]
309 artist_separator: OneOrMany<String>,
310 }
311
312 #[rstest]
313 #[case(Vec::<String>::new())]
314 #[case("")]
315 fn test_de_artist_separator_empty<'de, D>(#[case] input: D)
316 where
317 D: serde::de::IntoDeserializer<'de>,
318 {
319 let deserializer = input.into_deserializer();
320 let result: Result<OneOrMany<String>, _> = de_artist_separator(deserializer);
321 assert!(result.is_ok());
322 assert!(result.unwrap().is_empty());
323 }
324
325 #[rstest]
326 #[case(vec![" & "], String::from(" & ").into())]
327 #[case(" & ", String::from(" & ").into())]
328 #[case(vec![" & ", "; "], vec![String::from(" & "), String::from("; ")].into())]
329 #[case(vec!["", " & ", "", "; "], vec![String::from(" & "), String::from("; ")].into())]
330 fn test_de_artist_separator<'de, D>(#[case] input: D, #[case] expected: OneOrMany<String>)
331 where
332 D: serde::de::IntoDeserializer<'de>,
333 {
334 let deserializer = input.into_deserializer();
335 let result: Result<OneOrMany<String>, _> = de_artist_separator(deserializer);
336 assert!(result.is_ok());
337 assert_eq!(result.unwrap(), expected);
338 }
339
340 #[test]
341 fn test_init_config() {
342 let temp_dir = tempfile::tempdir().unwrap();
343 let config_path = temp_dir.path().join("config.toml");
344 std::fs::write(
345 &config_path,
346 r#"
347[daemon]
348rpc_port = 6600
349library_paths = ["/Music"]
350artist_separator = ["; "]
351genre_separator = ", "
352conflict_resolution = "overwrite"
353log_level = "debug"
354
355[reclustering]
356gap_statistic_reference_datasets = 50
357max_clusters = 24
358algorithm = "gmm"
359
360[tui]
361radio_count = 21
362 "#,
363 )
364 .unwrap();
365
366 let expected = Settings {
367 daemon: DaemonSettings {
368 rpc_port: 6600,
369 library_paths: ["/Music".into()].into(),
370 artist_separator: vec!["; ".into()].into(),
371 protected_artist_names: OneOrMany::None,
372 genre_separator: Some(", ".into()),
373 conflict_resolution: MetadataConflictResolution::Overwrite,
374 log_level: log::LevelFilter::Debug,
375 },
376 reclustering: ReclusterSettings {
377 gap_statistic_reference_datasets: 50,
378 max_clusters: 24,
379 algorithm: ClusterAlgorithm::GMM,
380 },
381 tui: TuiSettings { radio_count: 21 },
382 };
383
384 let settings = Settings::init(config_path, None, None).unwrap();
385
386 assert_eq!(settings, expected);
387 }
388
389 #[test]
390 fn test_artist_names_to_not_split() {
391 let temp_dir = tempfile::tempdir().unwrap();
392 let config_path = temp_dir.path().join("config.toml");
393 std::fs::write(
394 &config_path,
395 r#"
396[daemon]
397rpc_port = 6600
398library_paths = ["/Music"]
399artist_separator = ["; "]
400protected_artist_names = ["Foo & Bar"]
401genre_separator = ", "
402conflict_resolution = "overwrite"
403log_level = "debug"
404
405[reclustering]
406gap_statistic_reference_datasets = 50
407max_clusters = 24
408algorithm = "gmm"
409
410[tui]
411radio_count = 21
412 "#,
413 )
414 .unwrap();
415
416 let expected = Settings {
417 daemon: DaemonSettings {
418 rpc_port: 6600,
419 library_paths: ["/Music".into()].into(),
420 artist_separator: vec!["; ".into()].into(),
421 protected_artist_names: OneOrMany::One("Foo & Bar".into()),
422 genre_separator: Some(", ".into()),
423 conflict_resolution: MetadataConflictResolution::Overwrite,
424 log_level: log::LevelFilter::Debug,
425 },
426 reclustering: ReclusterSettings {
427 gap_statistic_reference_datasets: 50,
428 max_clusters: 24,
429 algorithm: ClusterAlgorithm::GMM,
430 },
431 tui: TuiSettings { radio_count: 21 },
432 };
433
434 let settings = Settings::init(config_path, None, None).unwrap();
435
436 assert_eq!(settings, expected);
437 }
438
439 #[test]
440 fn test_default_config_works() {
441 let temp_dir = tempfile::tempdir().unwrap();
442 let config_path = temp_dir.path().join("config.toml");
443 std::fs::write(&config_path, DEFAULT_CONFIG).unwrap();
444
445 let settings = Settings::init(config_path, None, None);
446
447 assert!(settings.is_ok(), "Error: {:?}", settings.err());
448 }
449}