1use config::{Config, ConfigError, Environment, File};
7use one_or_many::OneOrMany;
8use serde::Deserialize;
9
10use std::{num::NonZeroUsize, path::PathBuf, str::FromStr};
11
12use mecomp_storage::util::MetadataConflictResolution;
13
14pub static DEFAULT_CONFIG: &str = include_str!("../Mecomp.toml");
15
16#[derive(Clone, Debug, Deserialize, Default, PartialEq, Eq)]
17pub struct Settings {
18 #[serde(default)]
20 pub daemon: DaemonSettings,
21 #[serde(default)]
23 pub analysis: AnalysisSettings,
24 #[serde(default)]
26 pub reclustering: ReclusterSettings,
27 #[serde(default)]
29 pub tui: TuiSettings,
30}
31
32impl Settings {
33 #[inline]
48 pub fn init(
49 config: PathBuf,
50 port: Option<u16>,
51 log_level: Option<log::LevelFilter>,
52 ) -> Result<Self, ConfigError> {
53 let s = Config::builder()
54 .add_source(File::from(config))
55 .add_source(Environment::with_prefix("MECOMP"))
56 .build()?;
57
58 let mut settings: Self = s.try_deserialize()?;
59
60 for path in &mut settings.daemon.library_paths {
61 *path = shellexpand::tilde(&path.to_string_lossy())
62 .into_owned()
63 .into();
64 }
65
66 if let Some(port) = port {
67 settings.daemon.rpc_port = port;
68 }
69
70 if let Some(log_level) = log_level {
71 settings.daemon.log_level = log_level;
72 }
73
74 Ok(settings)
75 }
76
77 #[inline]
86 pub fn get_config_path() -> Result<PathBuf, std::io::Error> {
87 match crate::get_config_dir() {
88 Ok(config_dir) => {
89 if !config_dir.exists() {
91 std::fs::create_dir_all(&config_dir)?;
92 }
93 let config_file = config_dir.join("Mecomp.toml");
94
95 if !config_file.exists() {
96 std::fs::write(&config_file, DEFAULT_CONFIG)?;
97 }
98
99 Ok(config_file)
100 }
101 Err(e) => {
102 eprintln!("Error: {e}");
103 Err(std::io::Error::new(
104 std::io::ErrorKind::NotFound,
105 "Unable to find the config directory for mecomp.",
106 ))
107 }
108 }
109 }
110}
111
112#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
113pub struct DaemonSettings {
114 #[serde(default = "default_port")]
117 pub rpc_port: u16,
118 #[serde(default = "default_library_paths")]
120 pub library_paths: Box<[PathBuf]>,
121 #[serde(default, deserialize_with = "de_artist_separator")]
135 pub artist_separator: OneOrMany<String>,
136 #[serde(default)]
154 pub protected_artist_names: OneOrMany<String>,
155 #[serde(default)]
156 pub genre_separator: Option<String>,
157 #[serde(default)]
161 pub conflict_resolution: MetadataConflictResolution,
162 #[serde(default = "default_log_level")]
165 #[serde(deserialize_with = "de_log_level")]
166 pub log_level: log::LevelFilter,
167}
168
169fn de_artist_separator<'de, D>(deserializer: D) -> Result<OneOrMany<String>, D::Error>
170where
171 D: serde::Deserializer<'de>,
172{
173 let v = OneOrMany::<String>::deserialize(deserializer)?
174 .into_iter()
175 .filter(|s| !s.is_empty())
176 .collect::<OneOrMany<String>>();
177 if v.is_empty() {
178 Ok(OneOrMany::None)
179 } else {
180 Ok(v)
181 }
182}
183
184fn de_log_level<'de, D>(deserializer: D) -> Result<log::LevelFilter, D::Error>
185where
186 D: serde::Deserializer<'de>,
187{
188 let s = String::deserialize(deserializer)?;
189 Ok(log::LevelFilter::from_str(&s).unwrap_or_else(|_| default_log_level()))
190}
191
192const fn default_port() -> u16 {
193 6600
194}
195
196fn default_library_paths() -> Box<[PathBuf]> {
197 vec![shellexpand::tilde("~/Music/").into_owned().into()].into_boxed_slice()
198}
199
200const fn default_log_level() -> log::LevelFilter {
201 log::LevelFilter::Info
202}
203
204impl Default for DaemonSettings {
205 #[inline]
206 fn default() -> Self {
207 Self {
208 rpc_port: default_port(),
209 library_paths: default_library_paths(),
210 artist_separator: OneOrMany::None,
211 protected_artist_names: OneOrMany::None,
212 genre_separator: None,
213 conflict_resolution: MetadataConflictResolution::Overwrite,
214 log_level: default_log_level(),
215 }
216 }
217}
218
219#[derive(Clone, Copy, Debug, Default, Deserialize, PartialEq, Eq)]
220#[serde(rename_all = "lowercase")]
221pub enum AnalysisKind {
222 #[default]
223 Features,
224 Embedding,
225}
226
227#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
228pub struct AnalysisSettings {
229 #[serde(default)]
237 pub kind: AnalysisKind,
238 #[serde(default)]
246 pub num_threads: Option<NonZeroUsize>,
247 #[serde(default)]
256 pub model_path: Option<PathBuf>,
257}
258
259impl Default for AnalysisSettings {
260 #[inline]
261 fn default() -> Self {
262 Self {
263 kind: AnalysisKind::default(),
264 num_threads: None,
265 model_path: None,
266 }
267 }
268}
269
270#[derive(Clone, Copy, Debug, Deserialize, Default, PartialEq, Eq)]
271#[serde(rename_all = "lowercase")]
272pub enum ClusterAlgorithm {
273 KMeans,
274 #[default]
275 GMM,
276}
277
278#[cfg(feature = "analysis")]
279impl From<ClusterAlgorithm> for mecomp_analysis::clustering::ClusteringMethod {
280 #[inline]
281 fn from(algo: ClusterAlgorithm) -> Self {
282 match algo {
283 ClusterAlgorithm::KMeans => Self::KMeans,
284 ClusterAlgorithm::GMM => Self::GaussianMixtureModel,
285 }
286 }
287}
288
289#[derive(Clone, Copy, Debug, Default, Deserialize, PartialEq, Eq)]
290#[serde(rename_all = "lowercase")]
291pub enum ProjectionMethod {
292 #[default]
293 None,
294 TSne,
295 Pca,
296}
297
298#[cfg(feature = "analysis")]
299impl From<ProjectionMethod> for mecomp_analysis::clustering::ProjectionMethod {
300 #[inline]
301 fn from(proj: ProjectionMethod) -> Self {
302 match proj {
303 ProjectionMethod::None => Self::None,
304 ProjectionMethod::TSne => Self::TSne,
305 ProjectionMethod::Pca => Self::Pca,
306 }
307 }
308}
309
310#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq)]
311pub struct ReclusterSettings {
312 #[serde(default = "default_gap_statistic_reference_datasets")]
318 pub gap_statistic_reference_datasets: u32,
319 #[serde(default = "default_max_clusters")]
324 pub max_clusters: usize,
325 #[serde(default)]
328 pub algorithm: ClusterAlgorithm,
329 #[serde(default)]
333 pub projection_method: ProjectionMethod,
334}
335
336const fn default_gap_statistic_reference_datasets() -> u32 {
337 50
338}
339
340const fn default_max_clusters() -> usize {
341 #[cfg(debug_assertions)]
342 return 16;
343 #[cfg(not(debug_assertions))]
344 return 24;
345}
346
347impl Default for ReclusterSettings {
348 #[inline]
349 fn default() -> Self {
350 Self {
351 gap_statistic_reference_datasets: default_gap_statistic_reference_datasets(),
352 max_clusters: default_max_clusters(),
353 algorithm: ClusterAlgorithm::default(),
354 projection_method: ProjectionMethod::default(),
355 }
356 }
357}
358
359#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
360pub struct TuiSettings {
361 #[serde(default = "default_radio_count")]
364 pub radio_count: u32,
365 #[serde(default)]
374 pub colors: TuiColorScheme,
375}
376
377const fn default_radio_count() -> u32 {
378 20
379}
380
381impl Default for TuiSettings {
382 #[inline]
383 fn default() -> Self {
384 Self {
385 radio_count: default_radio_count(),
386 colors: TuiColorScheme::default(),
387 }
388 }
389}
390
391#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Default)]
392pub struct TuiColorScheme {
393 pub app_border: Option<String>,
395 pub app_border_text: Option<String>,
396 pub border_unfocused: Option<String>,
398 pub border_focused: Option<String>,
399 pub popup_border: Option<String>,
401 pub text_normal: Option<String>,
403 pub text_highlight: Option<String>,
404 pub text_highlight_alt: Option<String>,
405 pub gauge_filled: Option<String>,
407 pub gauge_unfilled: Option<String>,
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413
414 use pretty_assertions::assert_eq;
415 use rstest::rstest;
416
417 #[derive(Debug, PartialEq, Eq, Deserialize)]
418 #[allow(dead_code)]
419 #[serde(transparent)]
420 struct ArtistSeparatorTest {
421 #[serde(deserialize_with = "de_artist_separator")]
422 artist_separator: OneOrMany<String>,
423 }
424
425 #[rstest]
426 #[case(Vec::<String>::new())]
427 #[case("")]
428 fn test_de_artist_separator_empty<'de, D>(#[case] input: D)
429 where
430 D: serde::de::IntoDeserializer<'de>,
431 {
432 let deserializer = input.into_deserializer();
433 let result: Result<OneOrMany<String>, _> = de_artist_separator(deserializer);
434 assert!(result.is_ok());
435 assert!(result.unwrap().is_empty());
436 }
437
438 #[rstest]
439 #[case(vec![" & "], String::from(" & ").into())]
440 #[case(" & ", String::from(" & ").into())]
441 #[case(vec![" & ", "; "], vec![String::from(" & "), String::from("; ")].into())]
442 #[case(vec!["", " & ", "", "; "], vec![String::from(" & "), String::from("; ")].into())]
443 fn test_de_artist_separator<'de, D>(#[case] input: D, #[case] expected: OneOrMany<String>)
444 where
445 D: serde::de::IntoDeserializer<'de>,
446 {
447 let deserializer = input.into_deserializer();
448 let result: Result<OneOrMany<String>, _> = de_artist_separator(deserializer);
449 assert!(result.is_ok());
450 assert_eq!(result.unwrap(), expected);
451 }
452
453 #[test]
454 fn test_init_config() {
455 let temp_dir = tempfile::tempdir().unwrap();
456 let config_path = temp_dir.path().join("config.toml");
457 std::fs::write(
458 &config_path,
459 r#"
460[daemon]
461rpc_port = 6600
462library_paths = ["/Music"]
463artist_separator = ["; "]
464genre_separator = ", "
465conflict_resolution = "overwrite"
466log_level = "debug"
467
468[reclustering]
469gap_statistic_reference_datasets = 50
470max_clusters = 24
471algorithm = "gmm"
472
473[tui]
474radio_count = 21
475[tui.colors]
476app_border = "PINK_900"
477app_border_text = "PINK_300"
478border_unfocused = "RED_900"
479border_focused = "RED_200"
480popup_border = "LIGHT_BLUE_500"
481text_normal = "WHITE"
482text_highlight = "RED_600"
483text_highlight_alt = "RED_200"
484gauge_filled = "WHITE"
485gauge_unfilled = "BLACK"
486 "#,
487 )
488 .unwrap();
489
490 let expected = Settings {
491 daemon: DaemonSettings {
492 rpc_port: 6600,
493 library_paths: ["/Music".into()].into(),
494 artist_separator: vec!["; ".into()].into(),
495 protected_artist_names: OneOrMany::None,
496 genre_separator: Some(", ".into()),
497 conflict_resolution: MetadataConflictResolution::Overwrite,
498 log_level: log::LevelFilter::Debug,
499 },
500 analysis: AnalysisSettings::default(),
501 reclustering: ReclusterSettings {
502 gap_statistic_reference_datasets: 50,
503 max_clusters: 24,
504 algorithm: ClusterAlgorithm::GMM,
505 projection_method: ProjectionMethod::None,
506 },
507 tui: TuiSettings {
508 radio_count: 21,
509 colors: TuiColorScheme {
510 app_border: Some("PINK_900".into()),
511 app_border_text: Some("PINK_300".into()),
512 border_unfocused: Some("RED_900".into()),
513 border_focused: Some("RED_200".into()),
514 popup_border: Some("LIGHT_BLUE_500".into()),
515 text_normal: Some("WHITE".into()),
516 text_highlight: Some("RED_600".into()),
517 text_highlight_alt: Some("RED_200".into()),
518 gauge_filled: Some("WHITE".into()),
519 gauge_unfilled: Some("BLACK".into()),
520 },
521 },
522 };
523
524 let settings = Settings::init(config_path, None, None).unwrap();
525
526 assert_eq!(settings, expected);
527 }
528
529 #[test]
530 fn test_tui_colors_unset() {
531 let temp_dir = tempfile::tempdir().unwrap();
532 let config_path = temp_dir.path().join("config.toml");
533 std::fs::write(
534 &config_path,
535 r#"
536[daemon]
537rpc_port = 6600
538library_paths = ["/Music"]
539artist_separator = ["; "]
540protected_artist_names = ["Foo & Bar"]
541genre_separator = ", "
542conflict_resolution = "overwrite"
543log_level = "debug"
544
545[reclustering]
546gap_statistic_reference_datasets = 50
547max_clusters = 24
548algorithm = "gmm"
549
550[tui]
551radio_count = 21
552 "#,
553 )
554 .unwrap();
555
556 let expected = Settings {
557 daemon: DaemonSettings {
558 rpc_port: 6600,
559 library_paths: ["/Music".into()].into(),
560 artist_separator: vec!["; ".into()].into(),
561 protected_artist_names: "Foo & Bar".to_string().into(),
562 genre_separator: Some(", ".into()),
563 conflict_resolution: MetadataConflictResolution::Overwrite,
564 log_level: log::LevelFilter::Debug,
565 },
566 analysis: AnalysisSettings::default(),
567 reclustering: ReclusterSettings {
568 gap_statistic_reference_datasets: 50,
569 max_clusters: 24,
570 algorithm: ClusterAlgorithm::GMM,
571 projection_method: ProjectionMethod::None,
572 },
573 tui: TuiSettings {
574 radio_count: 21,
575 colors: TuiColorScheme::default(),
576 },
577 };
578
579 let settings = Settings::init(config_path, None, None).unwrap();
580
581 assert_eq!(settings, expected);
582 }
583
584 #[test]
585 fn test_artist_names_to_not_split() {
586 let temp_dir = tempfile::tempdir().unwrap();
587 let config_path = temp_dir.path().join("config.toml");
588 std::fs::write(
589 &config_path,
590 r#"
591[daemon]
592rpc_port = 6600
593library_paths = ["/Music"]
594artist_separator = ["; "]
595protected_artist_names = ["Foo & Bar"]
596genre_separator = ", "
597conflict_resolution = "overwrite"
598log_level = "debug"
599
600[reclustering]
601gap_statistic_reference_datasets = 50
602max_clusters = 24
603algorithm = "gmm"
604
605[tui]
606radio_count = 21
607 "#,
608 )
609 .unwrap();
610
611 let expected = Settings {
612 daemon: DaemonSettings {
613 rpc_port: 6600,
614 library_paths: ["/Music".into()].into(),
615 artist_separator: vec!["; ".into()].into(),
616 protected_artist_names: "Foo & Bar".to_string().into(),
617 genre_separator: Some(", ".into()),
618 conflict_resolution: MetadataConflictResolution::Overwrite,
619 log_level: log::LevelFilter::Debug,
620 },
621 analysis: AnalysisSettings::default(),
622 reclustering: ReclusterSettings {
623 gap_statistic_reference_datasets: 50,
624 max_clusters: 24,
625 algorithm: ClusterAlgorithm::GMM,
626 projection_method: ProjectionMethod::None,
627 },
628 tui: TuiSettings {
629 radio_count: 21,
630 colors: TuiColorScheme::default(),
631 },
632 };
633
634 let settings = Settings::init(config_path, None, None).unwrap();
635
636 assert_eq!(settings, expected);
637 }
638
639 #[test]
640 fn test_default_config_works() {
641 let temp_dir = tempfile::tempdir().unwrap();
642 let config_path = temp_dir.path().join("config.toml");
643 std::fs::write(&config_path, DEFAULT_CONFIG).unwrap();
644
645 let settings = Settings::init(config_path, None, None);
646
647 assert!(settings.is_ok(), "Error: {:?}", settings.err());
648 }
649}