1use config::{Config, ConfigError, Environment, File};
7use one_or_many::OneOrMany;
8use serde::Deserialize;
9
10use std::{num::NonZeroUsize, path::PathBuf, str::FromStr};
11
12use mecomp_storage::util::MetadataConflictResolution;
13
14pub static DEFAULT_CONFIG: &str = include_str!("../Mecomp.toml");
15
16#[derive(Clone, Debug, Deserialize, Default, PartialEq, Eq)]
17pub struct Settings {
18 #[serde(default)]
20 pub daemon: DaemonSettings,
21 #[serde(default)]
23 pub analysis: AnalysisSettings,
24 #[serde(default)]
26 pub reclustering: ReclusterSettings,
27 #[serde(default)]
29 pub tui: TuiSettings,
30}
31
32impl Settings {
33 #[inline]
48 pub fn init(
49 config: PathBuf,
50 port: Option<u16>,
51 log_level: Option<log::LevelFilter>,
52 ) -> Result<Self, ConfigError> {
53 let s = Config::builder()
54 .add_source(File::from(config))
55 .add_source(Environment::with_prefix("MECOMP"))
56 .build()?;
57
58 let mut settings: Self = s.try_deserialize()?;
59
60 for path in &mut settings.daemon.library_paths {
61 *path = shellexpand::tilde(&path.to_string_lossy())
62 .into_owned()
63 .into();
64 }
65
66 if let Some(port) = port {
67 settings.daemon.rpc_port = port;
68 }
69
70 if let Some(log_level) = log_level {
71 settings.daemon.log_level = log_level;
72 }
73
74 Ok(settings)
75 }
76
77 #[inline]
86 pub fn get_config_path() -> Result<PathBuf, std::io::Error> {
87 match crate::get_config_dir() {
88 Ok(config_dir) => {
89 if !config_dir.exists() {
91 std::fs::create_dir_all(&config_dir)?;
92 }
93 let config_file = config_dir.join("Mecomp.toml");
94
95 if !config_file.exists() {
96 std::fs::write(&config_file, DEFAULT_CONFIG)?;
97 }
98
99 Ok(config_file)
100 }
101 Err(e) => {
102 eprintln!("Error: {e}");
103 Err(std::io::Error::new(
104 std::io::ErrorKind::NotFound,
105 "Unable to find the config directory for mecomp.",
106 ))
107 }
108 }
109 }
110}
111
112#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
113pub struct DaemonSettings {
114 #[serde(default = "default_port")]
117 pub rpc_port: u16,
118 #[serde(default = "default_library_paths")]
120 pub library_paths: Box<[PathBuf]>,
121 #[serde(default, deserialize_with = "de_artist_separator")]
135 pub artist_separator: OneOrMany<String>,
136 #[serde(default)]
154 pub protected_artist_names: OneOrMany<String>,
155 #[serde(default)]
156 pub genre_separator: Option<String>,
157 #[serde(default)]
161 pub conflict_resolution: MetadataConflictResolution,
162 #[serde(default = "default_log_level")]
165 #[serde(deserialize_with = "de_log_level")]
166 pub log_level: log::LevelFilter,
167}
168
169fn de_artist_separator<'de, D>(deserializer: D) -> Result<OneOrMany<String>, D::Error>
170where
171 D: serde::Deserializer<'de>,
172{
173 let v = OneOrMany::<String>::deserialize(deserializer)?
174 .into_iter()
175 .filter(|s| !s.is_empty())
176 .collect::<OneOrMany<String>>();
177 if v.is_empty() {
178 Ok(OneOrMany::None)
179 } else {
180 Ok(v)
181 }
182}
183
184fn de_log_level<'de, D>(deserializer: D) -> Result<log::LevelFilter, D::Error>
185where
186 D: serde::Deserializer<'de>,
187{
188 let s = String::deserialize(deserializer)?;
189 Ok(log::LevelFilter::from_str(&s).unwrap_or_else(|_| default_log_level()))
190}
191
192const fn default_port() -> u16 {
193 6600
194}
195
196fn default_library_paths() -> Box<[PathBuf]> {
197 vec![shellexpand::tilde("~/Music/").into_owned().into()].into_boxed_slice()
198}
199
200const fn default_log_level() -> log::LevelFilter {
201 log::LevelFilter::Info
202}
203
204impl Default for DaemonSettings {
205 #[inline]
206 fn default() -> Self {
207 Self {
208 rpc_port: default_port(),
209 library_paths: default_library_paths(),
210 artist_separator: OneOrMany::None,
211 protected_artist_names: OneOrMany::None,
212 genre_separator: None,
213 conflict_resolution: MetadataConflictResolution::Overwrite,
214 log_level: default_log_level(),
215 }
216 }
217}
218
219#[derive(Clone, Copy, Debug, Default, Deserialize, PartialEq, Eq)]
220#[serde(rename_all = "lowercase")]
221pub enum AnalysisKind {
222 #[default]
223 Features,
224 Embedding,
225}
226
227#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
228pub struct AnalysisSettings {
229 #[serde(default)]
237 pub kind: AnalysisKind,
238 #[serde(default)]
246 pub num_threads: Option<NonZeroUsize>,
247 #[serde(default)]
256 pub model_path: Option<PathBuf>,
257}
258
259impl Default for AnalysisSettings {
260 #[inline]
261 fn default() -> Self {
262 Self {
263 kind: AnalysisKind::default(),
264 num_threads: None,
265 model_path: None,
266 }
267 }
268}
269
270#[derive(Clone, Copy, Debug, Deserialize, Default, PartialEq, Eq)]
271#[serde(rename_all = "lowercase")]
272pub enum ClusterAlgorithm {
273 KMeans,
274 #[default]
275 GMM,
276}
277
278#[cfg(feature = "analysis")]
279impl From<ClusterAlgorithm> for mecomp_analysis::clustering::ClusteringMethod {
280 #[inline]
281 fn from(algo: ClusterAlgorithm) -> Self {
282 match algo {
283 ClusterAlgorithm::KMeans => Self::KMeans,
284 ClusterAlgorithm::GMM => Self::GaussianMixtureModel,
285 }
286 }
287}
288
289#[derive(Clone, Copy, Debug, Default, Deserialize, PartialEq, Eq)]
290#[serde(rename_all = "lowercase")]
291pub enum ProjectionMethod {
292 None,
293 #[default]
294 TSne,
295 Pca,
296}
297
298#[cfg(feature = "analysis")]
299impl From<ProjectionMethod> for mecomp_analysis::clustering::ProjectionMethod {
300 #[inline]
301 fn from(proj: ProjectionMethod) -> Self {
302 match proj {
303 ProjectionMethod::None => Self::None,
304 ProjectionMethod::TSne => Self::TSne,
305 ProjectionMethod::Pca => Self::Pca,
306 }
307 }
308}
309
310#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq)]
311pub struct ReclusterSettings {
312 #[serde(default = "default_gap_statistic_reference_datasets")]
318 pub gap_statistic_reference_datasets: u32,
319 #[serde(default = "default_max_clusters")]
324 pub max_clusters: usize,
325 #[serde(default)]
328 pub algorithm: ClusterAlgorithm,
329 #[serde(default)]
334 pub projection_method: ProjectionMethod,
335}
336
337const fn default_gap_statistic_reference_datasets() -> u32 {
338 50
339}
340
341const fn default_max_clusters() -> usize {
342 24
343}
344
345impl Default for ReclusterSettings {
346 #[inline]
347 fn default() -> Self {
348 Self {
349 gap_statistic_reference_datasets: default_gap_statistic_reference_datasets(),
350 max_clusters: default_max_clusters(),
351 algorithm: ClusterAlgorithm::default(),
352 projection_method: ProjectionMethod::default(),
353 }
354 }
355}
356
357#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
358pub struct TuiSettings {
359 #[serde(default = "default_radio_count")]
362 pub radio_count: u32,
363 #[serde(default)]
372 pub colors: TuiColorScheme,
373}
374
375const fn default_radio_count() -> u32 {
376 20
377}
378
379impl Default for TuiSettings {
380 #[inline]
381 fn default() -> Self {
382 Self {
383 radio_count: default_radio_count(),
384 colors: TuiColorScheme::default(),
385 }
386 }
387}
388
389#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Default)]
390pub struct TuiColorScheme {
391 pub app_border: Option<String>,
393 pub app_border_text: Option<String>,
394 pub border_unfocused: Option<String>,
396 pub border_focused: Option<String>,
397 pub popup_border: Option<String>,
399 pub text_normal: Option<String>,
401 pub text_highlight: Option<String>,
402 pub text_highlight_alt: Option<String>,
403 pub gauge_filled: Option<String>,
405 pub gauge_unfilled: Option<String>,
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 use pretty_assertions::assert_eq;
413 use rstest::rstest;
414
415 #[derive(Debug, PartialEq, Eq, Deserialize)]
416 #[allow(dead_code)]
417 #[serde(transparent)]
418 struct ArtistSeparatorTest {
419 #[serde(deserialize_with = "de_artist_separator")]
420 artist_separator: OneOrMany<String>,
421 }
422
423 #[rstest]
424 #[case(Vec::<String>::new())]
425 #[case("")]
426 fn test_de_artist_separator_empty<'de, D>(#[case] input: D)
427 where
428 D: serde::de::IntoDeserializer<'de>,
429 {
430 let deserializer = input.into_deserializer();
431 let result: Result<OneOrMany<String>, _> = de_artist_separator(deserializer);
432 assert!(result.is_ok());
433 assert!(result.unwrap().is_empty());
434 }
435
436 #[rstest]
437 #[case(vec![" & "], String::from(" & ").into())]
438 #[case(" & ", String::from(" & ").into())]
439 #[case(vec![" & ", "; "], vec![String::from(" & "), String::from("; ")].into())]
440 #[case(vec!["", " & ", "", "; "], vec![String::from(" & "), String::from("; ")].into())]
441 fn test_de_artist_separator<'de, D>(#[case] input: D, #[case] expected: OneOrMany<String>)
442 where
443 D: serde::de::IntoDeserializer<'de>,
444 {
445 let deserializer = input.into_deserializer();
446 let result: Result<OneOrMany<String>, _> = de_artist_separator(deserializer);
447 assert!(result.is_ok());
448 assert_eq!(result.unwrap(), expected);
449 }
450
451 #[test]
452 fn test_init_config() {
453 let temp_dir = tempfile::tempdir().unwrap();
454 let config_path = temp_dir.path().join("config.toml");
455 std::fs::write(
456 &config_path,
457 r#"
458[daemon]
459rpc_port = 6600
460library_paths = ["/Music"]
461artist_separator = ["; "]
462genre_separator = ", "
463conflict_resolution = "overwrite"
464log_level = "debug"
465
466[reclustering]
467gap_statistic_reference_datasets = 50
468max_clusters = 24
469algorithm = "gmm"
470
471[tui]
472radio_count = 21
473[tui.colors]
474app_border = "PINK_900"
475app_border_text = "PINK_300"
476border_unfocused = "RED_900"
477border_focused = "RED_200"
478popup_border = "LIGHT_BLUE_500"
479text_normal = "WHITE"
480text_highlight = "RED_600"
481text_highlight_alt = "RED_200"
482gauge_filled = "WHITE"
483gauge_unfilled = "BLACK"
484 "#,
485 )
486 .unwrap();
487
488 let expected = Settings {
489 daemon: DaemonSettings {
490 rpc_port: 6600,
491 library_paths: ["/Music".into()].into(),
492 artist_separator: vec!["; ".into()].into(),
493 protected_artist_names: OneOrMany::None,
494 genre_separator: Some(", ".into()),
495 conflict_resolution: MetadataConflictResolution::Overwrite,
496 log_level: log::LevelFilter::Debug,
497 },
498 analysis: AnalysisSettings::default(),
499 reclustering: ReclusterSettings {
500 gap_statistic_reference_datasets: 50,
501 max_clusters: 24,
502 algorithm: ClusterAlgorithm::GMM,
503 projection_method: ProjectionMethod::TSne,
504 },
505 tui: TuiSettings {
506 radio_count: 21,
507 colors: TuiColorScheme {
508 app_border: Some("PINK_900".into()),
509 app_border_text: Some("PINK_300".into()),
510 border_unfocused: Some("RED_900".into()),
511 border_focused: Some("RED_200".into()),
512 popup_border: Some("LIGHT_BLUE_500".into()),
513 text_normal: Some("WHITE".into()),
514 text_highlight: Some("RED_600".into()),
515 text_highlight_alt: Some("RED_200".into()),
516 gauge_filled: Some("WHITE".into()),
517 gauge_unfilled: Some("BLACK".into()),
518 },
519 },
520 };
521
522 let settings = Settings::init(config_path, None, None).unwrap();
523
524 assert_eq!(settings, expected);
525 }
526
527 #[test]
528 fn test_tui_colors_unset() {
529 let temp_dir = tempfile::tempdir().unwrap();
530 let config_path = temp_dir.path().join("config.toml");
531 std::fs::write(
532 &config_path,
533 r#"
534[daemon]
535rpc_port = 6600
536library_paths = ["/Music"]
537artist_separator = ["; "]
538protected_artist_names = ["Foo & Bar"]
539genre_separator = ", "
540conflict_resolution = "overwrite"
541log_level = "debug"
542
543[reclustering]
544gap_statistic_reference_datasets = 50
545max_clusters = 24
546algorithm = "gmm"
547
548[tui]
549radio_count = 21
550 "#,
551 )
552 .unwrap();
553
554 let expected = Settings {
555 daemon: DaemonSettings {
556 rpc_port: 6600,
557 library_paths: ["/Music".into()].into(),
558 artist_separator: vec!["; ".into()].into(),
559 protected_artist_names: "Foo & Bar".to_string().into(),
560 genre_separator: Some(", ".into()),
561 conflict_resolution: MetadataConflictResolution::Overwrite,
562 log_level: log::LevelFilter::Debug,
563 },
564 analysis: AnalysisSettings::default(),
565 reclustering: ReclusterSettings {
566 gap_statistic_reference_datasets: 50,
567 max_clusters: 24,
568 algorithm: ClusterAlgorithm::GMM,
569 projection_method: ProjectionMethod::TSne,
570 },
571 tui: TuiSettings {
572 radio_count: 21,
573 colors: TuiColorScheme::default(),
574 },
575 };
576
577 let settings = Settings::init(config_path, None, None).unwrap();
578
579 assert_eq!(settings, expected);
580 }
581
582 #[test]
583 fn test_artist_names_to_not_split() {
584 let temp_dir = tempfile::tempdir().unwrap();
585 let config_path = temp_dir.path().join("config.toml");
586 std::fs::write(
587 &config_path,
588 r#"
589[daemon]
590rpc_port = 6600
591library_paths = ["/Music"]
592artist_separator = ["; "]
593protected_artist_names = ["Foo & Bar"]
594genre_separator = ", "
595conflict_resolution = "overwrite"
596log_level = "debug"
597
598[reclustering]
599gap_statistic_reference_datasets = 50
600max_clusters = 24
601algorithm = "gmm"
602
603[tui]
604radio_count = 21
605 "#,
606 )
607 .unwrap();
608
609 let expected = Settings {
610 daemon: DaemonSettings {
611 rpc_port: 6600,
612 library_paths: ["/Music".into()].into(),
613 artist_separator: vec!["; ".into()].into(),
614 protected_artist_names: "Foo & Bar".to_string().into(),
615 genre_separator: Some(", ".into()),
616 conflict_resolution: MetadataConflictResolution::Overwrite,
617 log_level: log::LevelFilter::Debug,
618 },
619 analysis: AnalysisSettings::default(),
620 reclustering: ReclusterSettings {
621 gap_statistic_reference_datasets: 50,
622 max_clusters: 24,
623 algorithm: ClusterAlgorithm::GMM,
624 projection_method: ProjectionMethod::TSne,
625 },
626 tui: TuiSettings {
627 radio_count: 21,
628 colors: TuiColorScheme::default(),
629 },
630 };
631
632 let settings = Settings::init(config_path, None, None).unwrap();
633
634 assert_eq!(settings, expected);
635 }
636
637 #[test]
638 fn test_default_config_works() {
639 let temp_dir = tempfile::tempdir().unwrap();
640 let config_path = temp_dir.path().join("config.toml");
641 std::fs::write(&config_path, DEFAULT_CONFIG).unwrap();
642
643 let settings = Settings::init(config_path, None, None);
644
645 assert!(settings.is_ok(), "Error: {:?}", settings.err());
646 }
647}