1use config::{Config, ConfigError, Environment, File};
7use one_or_many::OneOrMany;
8use serde::Deserialize;
9
10use std::{path::PathBuf, str::FromStr};
11
12use mecomp_storage::util::MetadataConflictResolution;
13
14pub static DEFAULT_CONFIG: &str = include_str!("../Mecomp.toml");
15
16#[derive(Clone, Debug, Deserialize, Default, PartialEq, Eq)]
17pub struct Settings {
18 #[serde(default)]
20 pub daemon: DaemonSettings,
21 #[serde(default)]
23 pub reclustering: ReclusterSettings,
24 #[serde(default)]
26 pub tui: TuiSettings,
27}
28
29impl Settings {
30 #[inline]
45 pub fn init(
46 config: PathBuf,
47 port: Option<u16>,
48 log_level: Option<log::LevelFilter>,
49 ) -> Result<Self, ConfigError> {
50 let s = Config::builder()
51 .add_source(File::from(config))
52 .add_source(Environment::with_prefix("MECOMP"))
53 .build()?;
54
55 let mut settings: Self = s.try_deserialize()?;
56
57 for path in &mut settings.daemon.library_paths {
58 *path = shellexpand::tilde(&path.to_string_lossy())
59 .into_owned()
60 .into();
61 }
62
63 if let Some(port) = port {
64 settings.daemon.rpc_port = port;
65 }
66
67 if let Some(log_level) = log_level {
68 settings.daemon.log_level = log_level;
69 }
70
71 Ok(settings)
72 }
73
74 #[inline]
83 pub fn get_config_path() -> Result<PathBuf, std::io::Error> {
84 match crate::get_config_dir() {
85 Ok(config_dir) => {
86 if !config_dir.exists() {
88 std::fs::create_dir_all(&config_dir)?;
89 }
90 let config_file = config_dir.join("Mecomp.toml");
91
92 if !config_file.exists() {
93 std::fs::write(&config_file, DEFAULT_CONFIG)?;
94 }
95
96 Ok(config_file)
97 }
98 Err(e) => {
99 eprintln!("Error: {e}");
100 Err(std::io::Error::new(
101 std::io::ErrorKind::NotFound,
102 "Unable to find the config directory for mecomp.",
103 ))
104 }
105 }
106 }
107}
108
109#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
110pub struct DaemonSettings {
111 #[serde(default = "default_port")]
114 pub rpc_port: u16,
115 #[serde(default = "default_library_paths")]
117 pub library_paths: Box<[PathBuf]>,
118 #[serde(default, deserialize_with = "de_artist_separator")]
132 pub artist_separator: OneOrMany<String>,
133 #[serde(default)]
151 pub protected_artist_names: OneOrMany<String>,
152 #[serde(default)]
153 pub genre_separator: Option<String>,
154 #[serde(default)]
158 pub conflict_resolution: MetadataConflictResolution,
159 #[serde(default = "default_log_level")]
162 #[serde(deserialize_with = "de_log_level")]
163 pub log_level: log::LevelFilter,
164}
165
166fn de_artist_separator<'de, D>(deserializer: D) -> Result<OneOrMany<String>, D::Error>
167where
168 D: serde::Deserializer<'de>,
169{
170 let v = OneOrMany::<String>::deserialize(deserializer)?
171 .into_iter()
172 .filter(|s| !s.is_empty())
173 .collect::<OneOrMany<String>>();
174 if v.is_empty() {
175 Ok(OneOrMany::None)
176 } else {
177 Ok(v)
178 }
179}
180
181fn de_log_level<'de, D>(deserializer: D) -> Result<log::LevelFilter, D::Error>
182where
183 D: serde::Deserializer<'de>,
184{
185 let s = String::deserialize(deserializer)?;
186 Ok(log::LevelFilter::from_str(&s).unwrap_or_else(|_| default_log_level()))
187}
188
189const fn default_port() -> u16 {
190 6600
191}
192
193fn default_library_paths() -> Box<[PathBuf]> {
194 vec![shellexpand::tilde("~/Music/").into_owned().into()].into_boxed_slice()
195}
196
197const fn default_log_level() -> log::LevelFilter {
198 log::LevelFilter::Info
199}
200
201impl Default for DaemonSettings {
202 #[inline]
203 fn default() -> Self {
204 Self {
205 rpc_port: default_port(),
206 library_paths: default_library_paths(),
207 artist_separator: OneOrMany::None,
208 protected_artist_names: OneOrMany::None,
209 genre_separator: None,
210 conflict_resolution: MetadataConflictResolution::Overwrite,
211 log_level: default_log_level(),
212 }
213 }
214}
215
216#[derive(Clone, Copy, Debug, Deserialize, Default, PartialEq, Eq)]
217#[serde(rename_all = "lowercase")]
218pub enum ClusterAlgorithm {
219 KMeans,
220 #[default]
221 GMM,
222}
223
224impl From<ClusterAlgorithm> for mecomp_analysis::clustering::ClusteringMethod {
225 #[inline]
226 fn from(algo: ClusterAlgorithm) -> Self {
227 match algo {
228 ClusterAlgorithm::KMeans => Self::KMeans,
229 ClusterAlgorithm::GMM => Self::GaussianMixtureModel,
230 }
231 }
232}
233
234#[derive(Clone, Copy, Debug, Default, Deserialize, PartialEq, Eq)]
235#[serde(rename_all = "lowercase")]
236pub enum ProjectionMethod {
237 #[default]
238 None,
239 TSne,
240 Pca,
241}
242
243impl From<ProjectionMethod> for mecomp_analysis::clustering::ProjectionMethod {
244 #[inline]
245 fn from(proj: ProjectionMethod) -> Self {
246 match proj {
247 ProjectionMethod::None => Self::None,
248 ProjectionMethod::TSne => Self::TSne,
249 ProjectionMethod::Pca => Self::Pca,
250 }
251 }
252}
253
254#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq)]
255pub struct ReclusterSettings {
256 #[serde(default = "default_gap_statistic_reference_datasets")]
262 pub gap_statistic_reference_datasets: usize,
263 #[serde(default = "default_max_clusters")]
268 pub max_clusters: usize,
269 #[serde(default)]
272 pub algorithm: ClusterAlgorithm,
273 #[serde(default)]
277 pub projection_method: ProjectionMethod,
278}
279
280const fn default_gap_statistic_reference_datasets() -> usize {
281 50
282}
283
284const fn default_max_clusters() -> usize {
285 #[cfg(debug_assertions)]
286 return 16;
287 #[cfg(not(debug_assertions))]
288 return 24;
289}
290
291impl Default for ReclusterSettings {
292 #[inline]
293 fn default() -> Self {
294 Self {
295 gap_statistic_reference_datasets: default_gap_statistic_reference_datasets(),
296 max_clusters: default_max_clusters(),
297 algorithm: ClusterAlgorithm::default(),
298 projection_method: ProjectionMethod::default(),
299 }
300 }
301}
302
303#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
304pub struct TuiSettings {
305 #[serde(default = "default_radio_count")]
308 pub radio_count: u32,
309}
310
311const fn default_radio_count() -> u32 {
312 20
313}
314
315impl Default for TuiSettings {
316 #[inline]
317 fn default() -> Self {
318 Self {
319 radio_count: default_radio_count(),
320 }
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 use pretty_assertions::assert_eq;
329 use rstest::rstest;
330
331 #[derive(Debug, PartialEq, Eq, Deserialize)]
332 #[serde(transparent)]
333 struct ArtistSeparatorTest {
334 #[serde(deserialize_with = "de_artist_separator")]
335 artist_separator: OneOrMany<String>,
336 }
337
338 #[rstest]
339 #[case(Vec::<String>::new())]
340 #[case("")]
341 fn test_de_artist_separator_empty<'de, D>(#[case] input: D)
342 where
343 D: serde::de::IntoDeserializer<'de>,
344 {
345 let deserializer = input.into_deserializer();
346 let result: Result<OneOrMany<String>, _> = de_artist_separator(deserializer);
347 assert!(result.is_ok());
348 assert!(result.unwrap().is_empty());
349 }
350
351 #[rstest]
352 #[case(vec![" & "], String::from(" & ").into())]
353 #[case(" & ", String::from(" & ").into())]
354 #[case(vec![" & ", "; "], vec![String::from(" & "), String::from("; ")].into())]
355 #[case(vec!["", " & ", "", "; "], vec![String::from(" & "), String::from("; ")].into())]
356 fn test_de_artist_separator<'de, D>(#[case] input: D, #[case] expected: OneOrMany<String>)
357 where
358 D: serde::de::IntoDeserializer<'de>,
359 {
360 let deserializer = input.into_deserializer();
361 let result: Result<OneOrMany<String>, _> = de_artist_separator(deserializer);
362 assert!(result.is_ok());
363 assert_eq!(result.unwrap(), expected);
364 }
365
366 #[test]
367 fn test_init_config() {
368 let temp_dir = tempfile::tempdir().unwrap();
369 let config_path = temp_dir.path().join("config.toml");
370 std::fs::write(
371 &config_path,
372 r#"
373[daemon]
374rpc_port = 6600
375library_paths = ["/Music"]
376artist_separator = ["; "]
377genre_separator = ", "
378conflict_resolution = "overwrite"
379log_level = "debug"
380
381[reclustering]
382gap_statistic_reference_datasets = 50
383max_clusters = 24
384algorithm = "gmm"
385
386[tui]
387radio_count = 21
388 "#,
389 )
390 .unwrap();
391
392 let expected = Settings {
393 daemon: DaemonSettings {
394 rpc_port: 6600,
395 library_paths: ["/Music".into()].into(),
396 artist_separator: vec!["; ".into()].into(),
397 protected_artist_names: OneOrMany::None,
398 genre_separator: Some(", ".into()),
399 conflict_resolution: MetadataConflictResolution::Overwrite,
400 log_level: log::LevelFilter::Debug,
401 },
402 reclustering: ReclusterSettings {
403 gap_statistic_reference_datasets: 50,
404 max_clusters: 24,
405 algorithm: ClusterAlgorithm::GMM,
406 projection_method: ProjectionMethod::None,
407 },
408 tui: TuiSettings { radio_count: 21 },
409 };
410
411 let settings = Settings::init(config_path, None, None).unwrap();
412
413 assert_eq!(settings, expected);
414 }
415
416 #[test]
417 fn test_artist_names_to_not_split() {
418 let temp_dir = tempfile::tempdir().unwrap();
419 let config_path = temp_dir.path().join("config.toml");
420 std::fs::write(
421 &config_path,
422 r#"
423[daemon]
424rpc_port = 6600
425library_paths = ["/Music"]
426artist_separator = ["; "]
427protected_artist_names = ["Foo & Bar"]
428genre_separator = ", "
429conflict_resolution = "overwrite"
430log_level = "debug"
431
432[reclustering]
433gap_statistic_reference_datasets = 50
434max_clusters = 24
435algorithm = "gmm"
436
437[tui]
438radio_count = 21
439 "#,
440 )
441 .unwrap();
442
443 let expected = Settings {
444 daemon: DaemonSettings {
445 rpc_port: 6600,
446 library_paths: ["/Music".into()].into(),
447 artist_separator: vec!["; ".into()].into(),
448 protected_artist_names: OneOrMany::One("Foo & Bar".into()),
449 genre_separator: Some(", ".into()),
450 conflict_resolution: MetadataConflictResolution::Overwrite,
451 log_level: log::LevelFilter::Debug,
452 },
453 reclustering: ReclusterSettings {
454 gap_statistic_reference_datasets: 50,
455 max_clusters: 24,
456 algorithm: ClusterAlgorithm::GMM,
457 projection_method: ProjectionMethod::None,
458 },
459 tui: TuiSettings { radio_count: 21 },
460 };
461
462 let settings = Settings::init(config_path, None, None).unwrap();
463
464 assert_eq!(settings, expected);
465 }
466
467 #[test]
468 fn test_default_config_works() {
469 let temp_dir = tempfile::tempdir().unwrap();
470 let config_path = temp_dir.path().join("config.toml");
471 std::fs::write(&config_path, DEFAULT_CONFIG).unwrap();
472
473 let settings = Settings::init(config_path, None, None);
474
475 assert!(settings.is_ok(), "Error: {:?}", settings.err());
476 }
477}