1use config::{Config, ConfigError, Environment, File};
7use one_or_many::OneOrMany;
8use serde::Deserialize;
9
10use std::{path::PathBuf, str::FromStr};
11
12use mecomp_storage::util::MetadataConflictResolution;
13
14pub static DEFAULT_CONFIG: &str = include_str!("../Mecomp.toml");
15
16#[derive(Clone, Debug, Deserialize, Default, PartialEq, Eq)]
17pub struct Settings {
18 #[serde(default)]
20 pub daemon: DaemonSettings,
21 #[serde(default)]
23 pub reclustering: ReclusterSettings,
24 #[serde(default)]
26 pub tui: TuiSettings,
27}
28
29impl Settings {
30 #[inline]
45 pub fn init(
46 config: PathBuf,
47 port: Option<u16>,
48 log_level: Option<log::LevelFilter>,
49 ) -> Result<Self, ConfigError> {
50 let s = Config::builder()
51 .add_source(File::from(config))
52 .add_source(Environment::with_prefix("MECOMP"))
53 .build()?;
54
55 let mut settings: Self = s.try_deserialize()?;
56
57 for path in &mut settings.daemon.library_paths {
58 *path = shellexpand::tilde(&path.to_string_lossy())
59 .into_owned()
60 .into();
61 }
62
63 if let Some(port) = port {
64 settings.daemon.rpc_port = port;
65 }
66
67 if let Some(log_level) = log_level {
68 settings.daemon.log_level = log_level;
69 }
70
71 Ok(settings)
72 }
73
74 #[inline]
83 pub fn get_config_path() -> Result<PathBuf, std::io::Error> {
84 match crate::get_config_dir() {
85 Ok(config_dir) => {
86 if !config_dir.exists() {
88 std::fs::create_dir_all(&config_dir)?;
89 }
90 let config_file = config_dir.join("Mecomp.toml");
91
92 if !config_file.exists() {
93 std::fs::write(&config_file, DEFAULT_CONFIG)?;
94 }
95
96 Ok(config_file)
97 }
98 Err(e) => {
99 eprintln!("Error: {e}");
100 Err(std::io::Error::new(
101 std::io::ErrorKind::NotFound,
102 "Unable to find the config directory for mecomp.",
103 ))
104 }
105 }
106 }
107}
108
109#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
110pub struct DaemonSettings {
111 #[serde(default = "default_port")]
114 pub rpc_port: u16,
115 #[serde(default = "default_library_paths")]
117 pub library_paths: Box<[PathBuf]>,
118 #[serde(default, deserialize_with = "de_artist_separator")]
132 pub artist_separator: OneOrMany<String>,
133 #[serde(default)]
134 pub genre_separator: Option<String>,
135 #[serde(default)]
139 pub conflict_resolution: MetadataConflictResolution,
140 #[serde(default = "default_log_level")]
143 #[serde(deserialize_with = "de_log_level")]
144 pub log_level: log::LevelFilter,
145}
146
147fn de_artist_separator<'de, D>(deserializer: D) -> Result<OneOrMany<String>, D::Error>
148where
149 D: serde::Deserializer<'de>,
150{
151 let v = OneOrMany::<String>::deserialize(deserializer)?
152 .into_iter()
153 .filter(|s| !s.is_empty())
154 .collect::<OneOrMany<String>>();
155 if v.is_empty() {
156 Ok(OneOrMany::None)
157 } else {
158 Ok(v)
159 }
160}
161
162fn de_log_level<'de, D>(deserializer: D) -> Result<log::LevelFilter, D::Error>
163where
164 D: serde::Deserializer<'de>,
165{
166 let s = String::deserialize(deserializer)?;
167 Ok(log::LevelFilter::from_str(&s).unwrap_or_else(|_| default_log_level()))
168}
169
170const fn default_port() -> u16 {
171 6600
172}
173
174fn default_library_paths() -> Box<[PathBuf]> {
175 vec![shellexpand::tilde("~/Music/").into_owned().into()].into_boxed_slice()
176}
177
178const fn default_log_level() -> log::LevelFilter {
179 log::LevelFilter::Info
180}
181
182impl Default for DaemonSettings {
183 #[inline]
184 fn default() -> Self {
185 Self {
186 rpc_port: default_port(),
187 library_paths: default_library_paths(),
188 artist_separator: OneOrMany::None,
189 genre_separator: None,
190 conflict_resolution: MetadataConflictResolution::Overwrite,
191 log_level: default_log_level(),
192 }
193 }
194}
195
196#[derive(Clone, Copy, Debug, Deserialize, Default, PartialEq, Eq)]
197#[serde(rename_all = "lowercase")]
198pub enum ClusterAlgorithm {
199 KMeans,
200 #[default]
201 GMM,
202}
203
204impl From<ClusterAlgorithm> for mecomp_analysis::clustering::ClusteringMethod {
205 #[inline]
206 fn from(algo: ClusterAlgorithm) -> Self {
207 match algo {
208 ClusterAlgorithm::KMeans => Self::KMeans,
209 ClusterAlgorithm::GMM => Self::GaussianMixtureModel,
210 }
211 }
212}
213
214#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq)]
215pub struct ReclusterSettings {
216 #[serde(default = "default_gap_statistic_reference_datasets")]
222 pub gap_statistic_reference_datasets: usize,
223 #[serde(default = "default_max_clusters")]
228 pub max_clusters: usize,
229 #[serde(default)]
232 pub algorithm: ClusterAlgorithm,
233}
234
235const fn default_gap_statistic_reference_datasets() -> usize {
236 50
237}
238
239const fn default_max_clusters() -> usize {
240 #[cfg(debug_assertions)]
241 return 16;
242 #[cfg(not(debug_assertions))]
243 return 24;
244}
245
246impl Default for ReclusterSettings {
247 #[inline]
248 fn default() -> Self {
249 Self {
250 gap_statistic_reference_datasets: default_gap_statistic_reference_datasets(),
251 max_clusters: default_max_clusters(),
252 algorithm: ClusterAlgorithm::default(),
253 }
254 }
255}
256
257#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
258pub struct TuiSettings {
259 #[serde(default = "default_radio_count")]
262 pub radio_count: u32,
263}
264
265const fn default_radio_count() -> u32 {
266 20
267}
268
269impl Default for TuiSettings {
270 #[inline]
271 fn default() -> Self {
272 Self {
273 radio_count: default_radio_count(),
274 }
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 use pretty_assertions::assert_eq;
283 use rstest::rstest;
284
285 #[derive(Debug, PartialEq, Eq, Deserialize)]
286 #[serde(transparent)]
287 struct ArtistSeparatorTest {
288 #[serde(deserialize_with = "de_artist_separator")]
289 artist_separator: OneOrMany<String>,
290 }
291
292 #[rstest]
293 #[case(Vec::<String>::new())]
294 #[case("")]
295 fn test_de_artist_separator_empty<'de, D>(#[case] input: D)
296 where
297 D: serde::de::IntoDeserializer<'de>,
298 {
299 let deserializer = input.into_deserializer();
300 let result: Result<OneOrMany<String>, _> = de_artist_separator(deserializer);
301 assert!(result.is_ok());
302 assert!(result.unwrap().is_empty());
303 }
304
305 #[rstest]
306 #[case(vec![" & "], String::from(" & ").into())]
307 #[case(" & ", String::from(" & ").into())]
308 #[case(vec![" & ", "; "], vec![String::from(" & "), String::from("; ")].into())]
309 #[case(vec!["", " & ", "", "; "], vec![String::from(" & "), String::from("; ")].into())]
310 fn test_de_artist_separator<'de, D>(#[case] input: D, #[case] expected: OneOrMany<String>)
311 where
312 D: serde::de::IntoDeserializer<'de>,
313 {
314 let deserializer = input.into_deserializer();
315 let result: Result<OneOrMany<String>, _> = de_artist_separator(deserializer);
316 assert!(result.is_ok());
317 assert_eq!(result.unwrap(), expected);
318 }
319
320 #[test]
321 fn test_init_config() {
322 let temp_dir = tempfile::tempdir().unwrap();
323 let config_path = temp_dir.path().join("config.toml");
324 std::fs::write(
325 &config_path,
326 r#"
327[daemon]
328rpc_port = 6600
329library_paths = ["/Music"]
330artist_separator = ["; "]
331genre_separator = ", "
332conflict_resolution = "overwrite"
333log_level = "debug"
334
335[reclustering]
336gap_statistic_reference_datasets = 50
337max_clusters = 24
338algorithm = "gmm"
339
340[tui]
341radio_count = 21
342 "#,
343 )
344 .unwrap();
345
346 let expected = Settings {
347 daemon: DaemonSettings {
348 rpc_port: 6600,
349 library_paths: ["/Music".into()].into(),
350 artist_separator: vec!["; ".into()].into(),
351 genre_separator: Some(", ".into()),
352 conflict_resolution: MetadataConflictResolution::Overwrite,
353 log_level: log::LevelFilter::Debug,
354 },
355 reclustering: ReclusterSettings {
356 gap_statistic_reference_datasets: 50,
357 max_clusters: 24,
358 algorithm: ClusterAlgorithm::GMM,
359 },
360 tui: TuiSettings { radio_count: 21 },
361 };
362
363 let settings = Settings::init(config_path, None, None).unwrap();
364
365 assert_eq!(settings, expected);
366 }
367
368 #[test]
369 fn test_default_config_works() {
370 let temp_dir = tempfile::tempdir().unwrap();
371 let config_path = temp_dir.path().join("config.toml");
372 std::fs::write(&config_path, DEFAULT_CONFIG).unwrap();
373
374 let settings = Settings::init(config_path, None, None);
375
376 assert!(settings.is_ok(), "Error: {:?}", settings.err());
377 }
378}