use crate::core::gpu_allocation::GpuAllocationStrategy;
use crate::paths::get_config_dir;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
#[derive(Deserialize, Serialize, Debug, Default, Clone)]
pub struct Config {
#[serde(default)]
pub daemon: DaemonConfig,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub timezone: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "NotificationsConfig::is_default")]
pub notifications: NotificationsConfig,
#[serde(default)]
#[serde(skip_serializing_if = "ProjectsConfig::is_default")]
pub projects: ProjectsConfig,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct DaemonConfig {
#[serde(default = "default_host")]
pub host: String,
#[serde(default = "default_port")]
pub port: u16,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub gpus: Option<Vec<u32>>,
#[serde(default)]
pub gpu_allocation_strategy: GpuAllocationStrategy,
#[serde(default = "default_gpu_poll_interval_secs")]
#[serde(skip_serializing_if = "is_default_gpu_poll_interval_secs")]
pub gpu_poll_interval_secs: u64,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct NotificationsConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
#[serde(skip_serializing_if = "Vec::is_empty")]
pub webhooks: Vec<WebhookConfig>,
#[serde(default)]
#[serde(skip_serializing_if = "Vec::is_empty")]
pub emails: Vec<EmailConfig>,
#[serde(default = "default_max_concurrent_deliveries")]
#[serde(skip_serializing_if = "is_default_max_concurrent_deliveries")]
pub max_concurrent_deliveries: usize,
}
impl Default for NotificationsConfig {
fn default() -> Self {
Self {
enabled: false,
webhooks: vec![],
emails: vec![],
max_concurrent_deliveries: default_max_concurrent_deliveries(),
}
}
}
impl NotificationsConfig {
fn is_default(value: &Self) -> bool {
!value.enabled
&& value.webhooks.is_empty()
&& value.emails.is_empty()
&& value.max_concurrent_deliveries == default_max_concurrent_deliveries()
}
}
fn default_max_concurrent_deliveries() -> usize {
16
}
fn is_default_max_concurrent_deliveries(v: &usize) -> bool {
*v == default_max_concurrent_deliveries()
}
#[derive(Deserialize, Serialize, Debug, Clone, Default)]
pub struct ProjectsConfig {
#[serde(default)]
#[serde(skip_serializing_if = "Vec::is_empty")]
pub known_projects: Vec<String>,
#[serde(default)]
pub require_project: bool,
}
impl ProjectsConfig {
fn is_default(value: &Self) -> bool {
value.known_projects.is_empty() && !value.require_project
}
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct WebhookConfig {
pub url: String,
#[serde(default = "default_webhook_events")]
pub events: Vec<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub filter_users: Option<Vec<String>>,
#[serde(default)]
#[serde(skip_serializing_if = "HashMap::is_empty")]
pub headers: HashMap<String, String>,
#[serde(default = "default_webhook_timeout_secs")]
pub timeout_secs: u64,
#[serde(default = "default_webhook_max_retries")]
pub max_retries: u32,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct EmailConfig {
pub smtp_url: String,
pub from: String,
#[serde(default)]
pub to: Vec<String>,
#[serde(default = "default_email_events")]
pub events: Vec<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub filter_users: Option<Vec<String>>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub subject_prefix: Option<String>,
#[serde(default = "default_email_timeout_secs")]
pub timeout_secs: u64,
#[serde(default = "default_email_max_retries")]
pub max_retries: u32,
}
fn default_webhook_events() -> Vec<String> {
vec!["*".to_string()]
}
fn default_webhook_timeout_secs() -> u64 {
10
}
fn default_webhook_max_retries() -> u32 {
3
}
fn default_email_events() -> Vec<String> {
vec!["*".to_string()]
}
fn default_email_timeout_secs() -> u64 {
10
}
fn default_email_max_retries() -> u32 {
3
}
fn default_host() -> String {
"localhost".to_string()
}
fn default_port() -> u16 {
59000
}
fn default_gpu_poll_interval_secs() -> u64 {
10
}
fn is_default_gpu_poll_interval_secs(v: &u64) -> bool {
*v == default_gpu_poll_interval_secs()
}
impl Default for DaemonConfig {
fn default() -> Self {
Self {
host: default_host(),
port: default_port(),
gpus: None,
gpu_allocation_strategy: GpuAllocationStrategy::default(),
gpu_poll_interval_secs: default_gpu_poll_interval_secs(),
}
}
}
pub fn load_config(config_path: Option<&PathBuf>) -> Result<Config, config::ConfigError> {
let mut config_vec = vec![];
if let Ok(default_config_path) = get_config_dir().map(|d| d.join("gflow.toml")) {
if default_config_path.exists() {
config_vec.push(default_config_path);
}
}
if let Some(config_path) = config_path {
if config_path.exists() {
config_vec.push(config_path.clone());
} else {
eprintln!("Warning: Config file {config_path:?} not found.");
}
}
let settings = config::Config::builder();
let settings = config_vec.iter().fold(settings, |s, path| {
s.add_source(config::File::from(path.as_path()))
});
settings
.add_source(environment_source(None))
.build()?
.try_deserialize()
}
fn environment_source(source: Option<config::Map<String, String>>) -> config::Environment {
config::Environment::with_prefix("GFLOW")
.prefix_separator("_")
.separator("__")
.source(source)
.try_parsing(true)
.list_separator(",")
.with_list_parse_key("daemon.gpus")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn environment_source_applies_gpu_allocation_strategy() {
let mut env = config::Map::new();
env.insert(
"GFLOW_DAEMON__GPU_ALLOCATION_STRATEGY".to_string(),
"random".to_string(),
);
let config = config::Config::builder()
.add_source(environment_source(Some(env)))
.build()
.unwrap()
.try_deserialize::<Config>()
.unwrap();
assert_eq!(
config.daemon.gpu_allocation_strategy,
GpuAllocationStrategy::Random
);
}
#[test]
fn environment_source_applies_gpu_poll_interval() {
let mut env = config::Map::new();
env.insert(
"GFLOW_DAEMON__GPU_POLL_INTERVAL_SECS".to_string(),
"3".to_string(),
);
let config = config::Config::builder()
.add_source(environment_source(Some(env)))
.build()
.unwrap()
.try_deserialize::<Config>()
.unwrap();
assert_eq!(config.daemon.gpu_poll_interval_secs, 3);
}
#[test]
fn environment_source_rejects_invalid_gpu_poll_interval() {
let mut env = config::Map::new();
env.insert(
"GFLOW_DAEMON__GPU_POLL_INTERVAL_SECS".to_string(),
"abc".to_string(),
);
let error = config::Config::builder()
.add_source(environment_source(Some(env)))
.build()
.unwrap()
.try_deserialize::<Config>()
.unwrap_err();
assert!(error.to_string().contains("invalid type"));
}
#[test]
fn environment_source_does_not_treat_single_underscore_as_nested_separator() {
let mut env = config::Map::new();
env.insert(
"GFLOW_DAEMON_GPU_POLL_INTERVAL_SECS".to_string(),
"3".to_string(),
);
let config = config::Config::builder()
.add_source(environment_source(Some(env)))
.build()
.unwrap()
.try_deserialize::<Config>()
.unwrap();
assert_eq!(config.daemon.gpu_poll_interval_secs, 10);
}
}