1use config::{Config, ConfigError, Environment, File};
7use one_or_many::OneOrMany;
8use serde::Deserialize;
9
10use std::{path::PathBuf, str::FromStr};
11
12use mecomp_storage::util::MetadataConflictResolution;
13
14pub static DEFAULT_CONFIG: &str = include_str!("../Mecomp.toml");
15
16#[derive(Clone, Debug, Deserialize, Default, PartialEq, Eq)]
17pub struct Settings {
18 #[serde(default)]
20 pub daemon: DaemonSettings,
21 #[serde(default)]
23 pub reclustering: ReclusterSettings,
24 #[serde(default)]
26 pub tui: TuiSettings,
27}
28
29impl Settings {
30 #[inline]
45 pub fn init(
46 config: PathBuf,
47 port: Option<u16>,
48 log_level: Option<log::LevelFilter>,
49 ) -> Result<Self, ConfigError> {
50 let s = Config::builder()
51 .add_source(File::from(config))
52 .add_source(Environment::with_prefix("MECOMP"))
53 .build()?;
54
55 let mut settings: Self = s.try_deserialize()?;
56
57 for path in &mut settings.daemon.library_paths {
58 *path = shellexpand::tilde(&path.to_string_lossy())
59 .into_owned()
60 .into();
61 }
62
63 if let Some(port) = port {
64 settings.daemon.rpc_port = port;
65 }
66
67 if let Some(log_level) = log_level {
68 settings.daemon.log_level = log_level;
69 }
70
71 Ok(settings)
72 }
73
74 #[inline]
83 pub fn get_config_path() -> Result<PathBuf, std::io::Error> {
84 match crate::get_config_dir() {
85 Ok(config_dir) => {
86 if !config_dir.exists() {
88 std::fs::create_dir_all(&config_dir)?;
89 }
90 let config_file = config_dir.join("Mecomp.toml");
91
92 if !config_file.exists() {
93 std::fs::write(&config_file, DEFAULT_CONFIG)?;
94 }
95
96 Ok(config_file)
97 }
98 Err(e) => {
99 eprintln!("Error: {e}");
100 Err(std::io::Error::new(
101 std::io::ErrorKind::NotFound,
102 "Unable to find the config directory for mecomp.",
103 ))
104 }
105 }
106 }
107}
108
109#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
110pub struct DaemonSettings {
111 #[serde(default = "default_port")]
114 pub rpc_port: u16,
115 #[serde(default = "default_library_paths")]
117 pub library_paths: Box<[PathBuf]>,
118 #[serde(default, deserialize_with = "de_artist_separator")]
132 pub artist_separator: OneOrMany<String>,
133 #[serde(default)]
151 pub protected_artist_names: OneOrMany<String>,
152 #[serde(default)]
153 pub genre_separator: Option<String>,
154 #[serde(default)]
158 pub conflict_resolution: MetadataConflictResolution,
159 #[serde(default = "default_log_level")]
162 #[serde(deserialize_with = "de_log_level")]
163 pub log_level: log::LevelFilter,
164}
165
166fn de_artist_separator<'de, D>(deserializer: D) -> Result<OneOrMany<String>, D::Error>
167where
168 D: serde::Deserializer<'de>,
169{
170 let v = OneOrMany::<String>::deserialize(deserializer)?
171 .into_iter()
172 .filter(|s| !s.is_empty())
173 .collect::<OneOrMany<String>>();
174 if v.is_empty() {
175 Ok(OneOrMany::None)
176 } else {
177 Ok(v)
178 }
179}
180
181fn de_log_level<'de, D>(deserializer: D) -> Result<log::LevelFilter, D::Error>
182where
183 D: serde::Deserializer<'de>,
184{
185 let s = String::deserialize(deserializer)?;
186 Ok(log::LevelFilter::from_str(&s).unwrap_or_else(|_| default_log_level()))
187}
188
189const fn default_port() -> u16 {
190 6600
191}
192
193fn default_library_paths() -> Box<[PathBuf]> {
194 vec![shellexpand::tilde("~/Music/").into_owned().into()].into_boxed_slice()
195}
196
197const fn default_log_level() -> log::LevelFilter {
198 log::LevelFilter::Info
199}
200
201impl Default for DaemonSettings {
202 #[inline]
203 fn default() -> Self {
204 Self {
205 rpc_port: default_port(),
206 library_paths: default_library_paths(),
207 artist_separator: OneOrMany::None,
208 protected_artist_names: OneOrMany::None,
209 genre_separator: None,
210 conflict_resolution: MetadataConflictResolution::Overwrite,
211 log_level: default_log_level(),
212 }
213 }
214}
215
216#[derive(Clone, Copy, Debug, Deserialize, Default, PartialEq, Eq)]
217#[serde(rename_all = "lowercase")]
218pub enum ClusterAlgorithm {
219 KMeans,
220 #[default]
221 GMM,
222}
223
224#[cfg(feature = "analysis")]
225impl From<ClusterAlgorithm> for mecomp_analysis::clustering::ClusteringMethod {
226 #[inline]
227 fn from(algo: ClusterAlgorithm) -> Self {
228 match algo {
229 ClusterAlgorithm::KMeans => Self::KMeans,
230 ClusterAlgorithm::GMM => Self::GaussianMixtureModel,
231 }
232 }
233}
234
235#[derive(Clone, Copy, Debug, Default, Deserialize, PartialEq, Eq)]
236#[serde(rename_all = "lowercase")]
237pub enum ProjectionMethod {
238 #[default]
239 None,
240 TSne,
241 Pca,
242}
243
244#[cfg(feature = "analysis")]
245impl From<ProjectionMethod> for mecomp_analysis::clustering::ProjectionMethod {
246 #[inline]
247 fn from(proj: ProjectionMethod) -> Self {
248 match proj {
249 ProjectionMethod::None => Self::None,
250 ProjectionMethod::TSne => Self::TSne,
251 ProjectionMethod::Pca => Self::Pca,
252 }
253 }
254}
255
256#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq)]
257pub struct ReclusterSettings {
258 #[serde(default = "default_gap_statistic_reference_datasets")]
264 pub gap_statistic_reference_datasets: usize,
265 #[serde(default = "default_max_clusters")]
270 pub max_clusters: usize,
271 #[serde(default)]
274 pub algorithm: ClusterAlgorithm,
275 #[serde(default)]
279 pub projection_method: ProjectionMethod,
280}
281
282const fn default_gap_statistic_reference_datasets() -> usize {
283 50
284}
285
286const fn default_max_clusters() -> usize {
287 #[cfg(debug_assertions)]
288 return 16;
289 #[cfg(not(debug_assertions))]
290 return 24;
291}
292
293impl Default for ReclusterSettings {
294 #[inline]
295 fn default() -> Self {
296 Self {
297 gap_statistic_reference_datasets: default_gap_statistic_reference_datasets(),
298 max_clusters: default_max_clusters(),
299 algorithm: ClusterAlgorithm::default(),
300 projection_method: ProjectionMethod::default(),
301 }
302 }
303}
304
305#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
306pub struct TuiSettings {
307 #[serde(default = "default_radio_count")]
310 pub radio_count: u32,
311}
312
313const fn default_radio_count() -> u32 {
314 20
315}
316
317impl Default for TuiSettings {
318 #[inline]
319 fn default() -> Self {
320 Self {
321 radio_count: default_radio_count(),
322 }
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 use pretty_assertions::assert_eq;
331 use rstest::rstest;
332
333 #[derive(Debug, PartialEq, Eq, Deserialize)]
334 #[serde(transparent)]
335 struct ArtistSeparatorTest {
336 #[serde(deserialize_with = "de_artist_separator")]
337 artist_separator: OneOrMany<String>,
338 }
339
340 #[rstest]
341 #[case(Vec::<String>::new())]
342 #[case("")]
343 fn test_de_artist_separator_empty<'de, D>(#[case] input: D)
344 where
345 D: serde::de::IntoDeserializer<'de>,
346 {
347 let deserializer = input.into_deserializer();
348 let result: Result<OneOrMany<String>, _> = de_artist_separator(deserializer);
349 assert!(result.is_ok());
350 assert!(result.unwrap().is_empty());
351 }
352
353 #[rstest]
354 #[case(vec![" & "], String::from(" & ").into())]
355 #[case(" & ", String::from(" & ").into())]
356 #[case(vec![" & ", "; "], vec![String::from(" & "), String::from("; ")].into())]
357 #[case(vec!["", " & ", "", "; "], vec![String::from(" & "), String::from("; ")].into())]
358 fn test_de_artist_separator<'de, D>(#[case] input: D, #[case] expected: OneOrMany<String>)
359 where
360 D: serde::de::IntoDeserializer<'de>,
361 {
362 let deserializer = input.into_deserializer();
363 let result: Result<OneOrMany<String>, _> = de_artist_separator(deserializer);
364 assert!(result.is_ok());
365 assert_eq!(result.unwrap(), expected);
366 }
367
368 #[test]
369 fn test_init_config() {
370 let temp_dir = tempfile::tempdir().unwrap();
371 let config_path = temp_dir.path().join("config.toml");
372 std::fs::write(
373 &config_path,
374 r#"
375[daemon]
376rpc_port = 6600
377library_paths = ["/Music"]
378artist_separator = ["; "]
379genre_separator = ", "
380conflict_resolution = "overwrite"
381log_level = "debug"
382
383[reclustering]
384gap_statistic_reference_datasets = 50
385max_clusters = 24
386algorithm = "gmm"
387
388[tui]
389radio_count = 21
390 "#,
391 )
392 .unwrap();
393
394 let expected = Settings {
395 daemon: DaemonSettings {
396 rpc_port: 6600,
397 library_paths: ["/Music".into()].into(),
398 artist_separator: vec!["; ".into()].into(),
399 protected_artist_names: OneOrMany::None,
400 genre_separator: Some(", ".into()),
401 conflict_resolution: MetadataConflictResolution::Overwrite,
402 log_level: log::LevelFilter::Debug,
403 },
404 reclustering: ReclusterSettings {
405 gap_statistic_reference_datasets: 50,
406 max_clusters: 24,
407 algorithm: ClusterAlgorithm::GMM,
408 projection_method: ProjectionMethod::None,
409 },
410 tui: TuiSettings { radio_count: 21 },
411 };
412
413 let settings = Settings::init(config_path, None, None).unwrap();
414
415 assert_eq!(settings, expected);
416 }
417
418 #[test]
419 fn test_artist_names_to_not_split() {
420 let temp_dir = tempfile::tempdir().unwrap();
421 let config_path = temp_dir.path().join("config.toml");
422 std::fs::write(
423 &config_path,
424 r#"
425[daemon]
426rpc_port = 6600
427library_paths = ["/Music"]
428artist_separator = ["; "]
429protected_artist_names = ["Foo & Bar"]
430genre_separator = ", "
431conflict_resolution = "overwrite"
432log_level = "debug"
433
434[reclustering]
435gap_statistic_reference_datasets = 50
436max_clusters = 24
437algorithm = "gmm"
438
439[tui]
440radio_count = 21
441 "#,
442 )
443 .unwrap();
444
445 let expected = Settings {
446 daemon: DaemonSettings {
447 rpc_port: 6600,
448 library_paths: ["/Music".into()].into(),
449 artist_separator: vec!["; ".into()].into(),
450 protected_artist_names: OneOrMany::One("Foo & Bar".into()),
451 genre_separator: Some(", ".into()),
452 conflict_resolution: MetadataConflictResolution::Overwrite,
453 log_level: log::LevelFilter::Debug,
454 },
455 reclustering: ReclusterSettings {
456 gap_statistic_reference_datasets: 50,
457 max_clusters: 24,
458 algorithm: ClusterAlgorithm::GMM,
459 projection_method: ProjectionMethod::None,
460 },
461 tui: TuiSettings { radio_count: 21 },
462 };
463
464 let settings = Settings::init(config_path, None, None).unwrap();
465
466 assert_eq!(settings, expected);
467 }
468
469 #[test]
470 fn test_default_config_works() {
471 let temp_dir = tempfile::tempdir().unwrap();
472 let config_path = temp_dir.path().join("config.toml");
473 std::fs::write(&config_path, DEFAULT_CONFIG).unwrap();
474
475 let settings = Settings::init(config_path, None, None);
476
477 assert!(settings.is_ok(), "Error: {:?}", settings.err());
478 }
479}