use config::{Config, ConfigError, Environment, File};
use one_or_many::OneOrMany;
use serde::Deserialize;
use std::{num::NonZeroUsize, path::PathBuf, str::FromStr};
use mecomp_storage::util::MetadataConflictResolution;
pub static DEFAULT_CONFIG: &str = include_str!("../Mecomp.toml");
#[derive(Clone, Debug, Deserialize, Default, PartialEq, Eq)]
pub struct Settings {
#[serde(default)]
pub daemon: DaemonSettings,
#[serde(default)]
pub analysis: AnalysisSettings,
#[serde(default)]
pub reclustering: ReclusterSettings,
#[serde(default)]
pub tui: TuiSettings,
}
impl Settings {
#[inline]
pub fn init(
config: PathBuf,
port: Option<u16>,
log_level: Option<log::LevelFilter>,
) -> Result<Self, ConfigError> {
let s = Config::builder()
.add_source(File::from(config))
.add_source(Environment::with_prefix("MECOMP"))
.build()?;
let mut settings: Self = s.try_deserialize()?;
for path in &mut settings.daemon.library_paths {
*path = shellexpand::tilde(&path.to_string_lossy())
.into_owned()
.into();
}
if let Some(port) = port {
settings.daemon.rpc_port = port;
}
if let Some(log_level) = log_level {
settings.daemon.log_level = log_level;
}
Ok(settings)
}
#[inline]
pub fn get_config_path() -> Result<PathBuf, std::io::Error> {
match crate::get_config_dir() {
Ok(config_dir) => {
if !config_dir.exists() {
std::fs::create_dir_all(&config_dir)?;
}
let config_file = config_dir.join("Mecomp.toml");
if !config_file.exists() {
std::fs::write(&config_file, DEFAULT_CONFIG)?;
}
Ok(config_file)
}
Err(e) => {
eprintln!("Error: {e}");
Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
"Unable to find the config directory for mecomp.",
))
}
}
}
}
#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
pub struct DaemonSettings {
#[serde(default = "default_port")]
pub rpc_port: u16,
#[serde(default = "default_library_paths")]
pub library_paths: Box<[PathBuf]>,
#[serde(default, deserialize_with = "de_artist_separator")]
pub artist_separator: OneOrMany<String>,
#[serde(default)]
pub protected_artist_names: OneOrMany<String>,
#[serde(default)]
pub genre_separator: Option<String>,
#[serde(default)]
pub conflict_resolution: MetadataConflictResolution,
#[serde(default = "default_log_level")]
#[serde(deserialize_with = "de_log_level")]
pub log_level: log::LevelFilter,
}
fn de_artist_separator<'de, D>(deserializer: D) -> Result<OneOrMany<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
let v = OneOrMany::<String>::deserialize(deserializer)?
.into_iter()
.filter(|s| !s.is_empty())
.collect::<OneOrMany<String>>();
if v.is_empty() {
Ok(OneOrMany::None)
} else {
Ok(v)
}
}
fn de_log_level<'de, D>(deserializer: D) -> Result<log::LevelFilter, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Ok(log::LevelFilter::from_str(&s).unwrap_or_else(|_| default_log_level()))
}
const fn default_port() -> u16 {
6600
}
fn default_library_paths() -> Box<[PathBuf]> {
vec![shellexpand::tilde("~/Music/").into_owned().into()].into_boxed_slice()
}
const fn default_log_level() -> log::LevelFilter {
log::LevelFilter::Info
}
impl Default for DaemonSettings {
#[inline]
fn default() -> Self {
Self {
rpc_port: default_port(),
library_paths: default_library_paths(),
artist_separator: OneOrMany::None,
protected_artist_names: OneOrMany::None,
genre_separator: None,
conflict_resolution: MetadataConflictResolution::Overwrite,
log_level: default_log_level(),
}
}
}
#[derive(Clone, Copy, Debug, Default, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum AnalysisKind {
#[default]
Features,
Embedding,
}
#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
pub struct AnalysisSettings {
#[serde(default)]
pub kind: AnalysisKind,
#[serde(default)]
pub num_threads: Option<NonZeroUsize>,
#[serde(default)]
pub model_path: Option<PathBuf>,
}
impl Default for AnalysisSettings {
#[inline]
fn default() -> Self {
Self {
kind: AnalysisKind::default(),
num_threads: None,
model_path: None,
}
}
}
#[derive(Clone, Copy, Debug, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ClusterAlgorithm {
KMeans,
#[default]
GMM,
}
#[cfg(feature = "analysis")]
impl From<ClusterAlgorithm> for mecomp_analysis::clustering::ClusteringMethod {
#[inline]
fn from(algo: ClusterAlgorithm) -> Self {
match algo {
ClusterAlgorithm::KMeans => Self::KMeans,
ClusterAlgorithm::GMM => Self::GaussianMixtureModel,
}
}
}
#[derive(Clone, Copy, Debug, Default, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ProjectionMethod {
None,
#[default]
TSne,
Pca,
}
#[cfg(feature = "analysis")]
impl From<ProjectionMethod> for mecomp_analysis::clustering::ProjectionMethod {
#[inline]
fn from(proj: ProjectionMethod) -> Self {
match proj {
ProjectionMethod::None => Self::None,
ProjectionMethod::TSne => Self::TSne,
ProjectionMethod::Pca => Self::Pca,
}
}
}
#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq)]
pub struct ReclusterSettings {
#[serde(default = "default_gap_statistic_reference_datasets")]
pub gap_statistic_reference_datasets: u32,
#[serde(default = "default_max_clusters")]
pub max_clusters: usize,
#[serde(default)]
pub algorithm: ClusterAlgorithm,
#[serde(default)]
pub projection_method: ProjectionMethod,
}
const fn default_gap_statistic_reference_datasets() -> u32 {
50
}
const fn default_max_clusters() -> usize {
24
}
impl Default for ReclusterSettings {
#[inline]
fn default() -> Self {
Self {
gap_statistic_reference_datasets: default_gap_statistic_reference_datasets(),
max_clusters: default_max_clusters(),
algorithm: ClusterAlgorithm::default(),
projection_method: ProjectionMethod::default(),
}
}
}
#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
pub struct TuiSettings {
#[serde(default = "default_radio_count")]
pub radio_count: u32,
#[serde(default)]
pub colors: TuiColorScheme,
}
const fn default_radio_count() -> u32 {
20
}
impl Default for TuiSettings {
#[inline]
fn default() -> Self {
Self {
radio_count: default_radio_count(),
colors: TuiColorScheme::default(),
}
}
}
#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Default)]
pub struct TuiColorScheme {
pub app_border: Option<String>,
pub app_border_text: Option<String>,
pub border_unfocused: Option<String>,
pub border_focused: Option<String>,
pub popup_border: Option<String>,
pub text_normal: Option<String>,
pub text_highlight: Option<String>,
pub text_highlight_alt: Option<String>,
pub gauge_filled: Option<String>,
pub gauge_unfilled: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use rstest::rstest;
#[derive(Debug, PartialEq, Eq, Deserialize)]
#[allow(dead_code)]
#[serde(transparent)]
struct ArtistSeparatorTest {
#[serde(deserialize_with = "de_artist_separator")]
artist_separator: OneOrMany<String>,
}
#[rstest]
#[case(Vec::<String>::new())]
#[case("")]
fn test_de_artist_separator_empty<'de, D>(#[case] input: D)
where
D: serde::de::IntoDeserializer<'de>,
{
let deserializer = input.into_deserializer();
let result: Result<OneOrMany<String>, _> = de_artist_separator(deserializer);
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[rstest]
#[case(vec![" & "], String::from(" & ").into())]
#[case(" & ", String::from(" & ").into())]
#[case(vec![" & ", "; "], vec![String::from(" & "), String::from("; ")].into())]
#[case(vec!["", " & ", "", "; "], vec![String::from(" & "), String::from("; ")].into())]
fn test_de_artist_separator<'de, D>(#[case] input: D, #[case] expected: OneOrMany<String>)
where
D: serde::de::IntoDeserializer<'de>,
{
let deserializer = input.into_deserializer();
let result: Result<OneOrMany<String>, _> = de_artist_separator(deserializer);
assert!(result.is_ok());
assert_eq!(result.unwrap(), expected);
}
#[test]
fn test_init_config() {
let temp_dir = tempfile::tempdir().unwrap();
let config_path = temp_dir.path().join("config.toml");
std::fs::write(
&config_path,
r#"
[daemon]
rpc_port = 6600
library_paths = ["/Music"]
artist_separator = ["; "]
genre_separator = ", "
conflict_resolution = "overwrite"
log_level = "debug"
[reclustering]
gap_statistic_reference_datasets = 50
max_clusters = 24
algorithm = "gmm"
[tui]
radio_count = 21
[tui.colors]
app_border = "PINK_900"
app_border_text = "PINK_300"
border_unfocused = "RED_900"
border_focused = "RED_200"
popup_border = "LIGHT_BLUE_500"
text_normal = "WHITE"
text_highlight = "RED_600"
text_highlight_alt = "RED_200"
gauge_filled = "WHITE"
gauge_unfilled = "BLACK"
"#,
)
.unwrap();
let expected = Settings {
daemon: DaemonSettings {
rpc_port: 6600,
library_paths: ["/Music".into()].into(),
artist_separator: vec!["; ".into()].into(),
protected_artist_names: OneOrMany::None,
genre_separator: Some(", ".into()),
conflict_resolution: MetadataConflictResolution::Overwrite,
log_level: log::LevelFilter::Debug,
},
analysis: AnalysisSettings::default(),
reclustering: ReclusterSettings {
gap_statistic_reference_datasets: 50,
max_clusters: 24,
algorithm: ClusterAlgorithm::GMM,
projection_method: ProjectionMethod::TSne,
},
tui: TuiSettings {
radio_count: 21,
colors: TuiColorScheme {
app_border: Some("PINK_900".into()),
app_border_text: Some("PINK_300".into()),
border_unfocused: Some("RED_900".into()),
border_focused: Some("RED_200".into()),
popup_border: Some("LIGHT_BLUE_500".into()),
text_normal: Some("WHITE".into()),
text_highlight: Some("RED_600".into()),
text_highlight_alt: Some("RED_200".into()),
gauge_filled: Some("WHITE".into()),
gauge_unfilled: Some("BLACK".into()),
},
},
};
let settings = Settings::init(config_path, None, None).unwrap();
assert_eq!(settings, expected);
}
#[test]
fn test_tui_colors_unset() {
let temp_dir = tempfile::tempdir().unwrap();
let config_path = temp_dir.path().join("config.toml");
std::fs::write(
&config_path,
r#"
[daemon]
rpc_port = 6600
library_paths = ["/Music"]
artist_separator = ["; "]
protected_artist_names = ["Foo & Bar"]
genre_separator = ", "
conflict_resolution = "overwrite"
log_level = "debug"
[reclustering]
gap_statistic_reference_datasets = 50
max_clusters = 24
algorithm = "gmm"
[tui]
radio_count = 21
"#,
)
.unwrap();
let expected = Settings {
daemon: DaemonSettings {
rpc_port: 6600,
library_paths: ["/Music".into()].into(),
artist_separator: vec!["; ".into()].into(),
protected_artist_names: "Foo & Bar".to_string().into(),
genre_separator: Some(", ".into()),
conflict_resolution: MetadataConflictResolution::Overwrite,
log_level: log::LevelFilter::Debug,
},
analysis: AnalysisSettings::default(),
reclustering: ReclusterSettings {
gap_statistic_reference_datasets: 50,
max_clusters: 24,
algorithm: ClusterAlgorithm::GMM,
projection_method: ProjectionMethod::TSne,
},
tui: TuiSettings {
radio_count: 21,
colors: TuiColorScheme::default(),
},
};
let settings = Settings::init(config_path, None, None).unwrap();
assert_eq!(settings, expected);
}
#[test]
fn test_artist_names_to_not_split() {
let temp_dir = tempfile::tempdir().unwrap();
let config_path = temp_dir.path().join("config.toml");
std::fs::write(
&config_path,
r#"
[daemon]
rpc_port = 6600
library_paths = ["/Music"]
artist_separator = ["; "]
protected_artist_names = ["Foo & Bar"]
genre_separator = ", "
conflict_resolution = "overwrite"
log_level = "debug"
[reclustering]
gap_statistic_reference_datasets = 50
max_clusters = 24
algorithm = "gmm"
[tui]
radio_count = 21
"#,
)
.unwrap();
let expected = Settings {
daemon: DaemonSettings {
rpc_port: 6600,
library_paths: ["/Music".into()].into(),
artist_separator: vec!["; ".into()].into(),
protected_artist_names: "Foo & Bar".to_string().into(),
genre_separator: Some(", ".into()),
conflict_resolution: MetadataConflictResolution::Overwrite,
log_level: log::LevelFilter::Debug,
},
analysis: AnalysisSettings::default(),
reclustering: ReclusterSettings {
gap_statistic_reference_datasets: 50,
max_clusters: 24,
algorithm: ClusterAlgorithm::GMM,
projection_method: ProjectionMethod::TSne,
},
tui: TuiSettings {
radio_count: 21,
colors: TuiColorScheme::default(),
},
};
let settings = Settings::init(config_path, None, None).unwrap();
assert_eq!(settings, expected);
}
#[test]
fn test_default_config_works() {
let temp_dir = tempfile::tempdir().unwrap();
let config_path = temp_dir.path().join("config.toml");
std::fs::write(&config_path, DEFAULT_CONFIG).unwrap();
let settings = Settings::init(config_path, None, None);
assert!(settings.is_ok(), "Error: {:?}", settings.err());
}
}