use crate::core::gpu_allocation::GpuAllocationStrategy;
use crate::paths::get_config_dir;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
#[derive(Deserialize, Serialize, Debug, Default, Clone)]
pub struct Config {
#[serde(default)]
pub daemon: DaemonConfig,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub timezone: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "NotificationsConfig::is_default")]
pub notifications: NotificationsConfig,
#[serde(default)]
#[serde(skip_serializing_if = "ProjectsConfig::is_default")]
pub projects: ProjectsConfig,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct DaemonConfig {
#[serde(default = "default_host")]
pub host: String,
#[serde(default = "default_port")]
pub port: u16,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub gpus: Option<Vec<u32>>,
#[serde(default)]
pub gpu_allocation_strategy: GpuAllocationStrategy,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct NotificationsConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
#[serde(skip_serializing_if = "Vec::is_empty")]
pub webhooks: Vec<WebhookConfig>,
#[serde(default = "default_max_concurrent_deliveries")]
#[serde(skip_serializing_if = "is_default_max_concurrent_deliveries")]
pub max_concurrent_deliveries: usize,
}
impl Default for NotificationsConfig {
fn default() -> Self {
Self {
enabled: false,
webhooks: vec![],
max_concurrent_deliveries: default_max_concurrent_deliveries(),
}
}
}
impl NotificationsConfig {
fn is_default(value: &Self) -> bool {
!value.enabled
&& value.webhooks.is_empty()
&& value.max_concurrent_deliveries == default_max_concurrent_deliveries()
}
}
fn default_max_concurrent_deliveries() -> usize {
16
}
fn is_default_max_concurrent_deliveries(v: &usize) -> bool {
*v == default_max_concurrent_deliveries()
}
#[derive(Deserialize, Serialize, Debug, Clone, Default)]
pub struct ProjectsConfig {
#[serde(default)]
#[serde(skip_serializing_if = "Vec::is_empty")]
pub known_projects: Vec<String>,
#[serde(default)]
pub require_project: bool,
}
impl ProjectsConfig {
fn is_default(value: &Self) -> bool {
value.known_projects.is_empty() && !value.require_project
}
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct WebhookConfig {
pub url: String,
#[serde(default = "default_webhook_events")]
pub events: Vec<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub filter_users: Option<Vec<String>>,
#[serde(default)]
#[serde(skip_serializing_if = "HashMap::is_empty")]
pub headers: HashMap<String, String>,
#[serde(default = "default_webhook_timeout_secs")]
pub timeout_secs: u64,
#[serde(default = "default_webhook_max_retries")]
pub max_retries: u32,
}
fn default_webhook_events() -> Vec<String> {
vec!["*".to_string()]
}
fn default_webhook_timeout_secs() -> u64 {
10
}
fn default_webhook_max_retries() -> u32 {
3
}
fn default_host() -> String {
"localhost".to_string()
}
fn default_port() -> u16 {
59000
}
impl Default for DaemonConfig {
fn default() -> Self {
Self {
host: default_host(),
port: default_port(),
gpus: None,
gpu_allocation_strategy: GpuAllocationStrategy::default(),
}
}
}
pub fn load_config(config_path: Option<&PathBuf>) -> Result<Config, config::ConfigError> {
let mut config_vec = vec![];
if let Ok(default_config_path) = get_config_dir().map(|d| d.join("gflow.toml")) {
if default_config_path.exists() {
config_vec.push(default_config_path);
}
}
if let Some(config_path) = config_path {
if config_path.exists() {
config_vec.push(config_path.clone());
} else {
eprintln!("Warning: Config file {config_path:?} not found.");
}
}
let settings = config::Config::builder();
let settings = config_vec.iter().fold(settings, |s, path| {
s.add_source(config::File::from(path.as_path()))
});
settings
.add_source(
config::Environment::with_prefix("GFLOW")
.separator("_")
.try_parsing(true)
.list_separator(",")
.with_list_parse_key("daemon.gpus"),
)
.build()?
.try_deserialize()
}