gflow 0.4.15

A lightweight, single-node job scheduler written in Rust.
Documentation
use crate::core::gpu_allocation::GpuAllocationStrategy;
use crate::paths::get_config_dir;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;

#[derive(Deserialize, Serialize, Debug, Default, Clone)]
pub struct Config {
    #[serde(default)]
    pub daemon: DaemonConfig,
    /// Timezone for displaying and parsing times (e.g., "Asia/Shanghai", "America/Los_Angeles", "UTC")
    /// If not set, uses local timezone
    #[serde(default)]
    #[serde(skip_serializing_if = "Option::is_none")]
    pub timezone: Option<String>,
    /// Webhook/notification settings for gflowd
    #[serde(default)]
    #[serde(skip_serializing_if = "NotificationsConfig::is_default")]
    pub notifications: NotificationsConfig,
    /// Project tracking settings
    #[serde(default)]
    #[serde(skip_serializing_if = "ProjectsConfig::is_default")]
    pub projects: ProjectsConfig,
}

#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct DaemonConfig {
    #[serde(default = "default_host")]
    pub host: String,
    #[serde(default = "default_port")]
    pub port: u16,
    /// Limit which GPUs the scheduler can use (None = all GPUs)
    #[serde(default)]
    #[serde(skip_serializing_if = "Option::is_none")]
    pub gpus: Option<Vec<u32>>,
    /// GPU assignment strategy when selecting from available GPUs.
    #[serde(default)]
    pub gpu_allocation_strategy: GpuAllocationStrategy,
    /// How often to poll NVML for GPU occupancy updates.
    #[serde(default = "default_gpu_poll_interval_secs")]
    #[serde(skip_serializing_if = "is_default_gpu_poll_interval_secs")]
    pub gpu_poll_interval_secs: u64,
}

#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct NotificationsConfig {
    /// Enable notification system (default: false)
    #[serde(default)]
    pub enabled: bool,
    /// List of webhook endpoints
    #[serde(default)]
    #[serde(skip_serializing_if = "Vec::is_empty")]
    pub webhooks: Vec<WebhookConfig>,
    /// List of email endpoints
    #[serde(default)]
    #[serde(skip_serializing_if = "Vec::is_empty")]
    pub emails: Vec<EmailConfig>,
    /// Limit concurrent notification deliveries across all endpoints
    #[serde(default = "default_max_concurrent_deliveries")]
    #[serde(skip_serializing_if = "is_default_max_concurrent_deliveries")]
    pub max_concurrent_deliveries: usize,
}

impl Default for NotificationsConfig {
    fn default() -> Self {
        Self {
            enabled: false,
            webhooks: vec![],
            emails: vec![],
            max_concurrent_deliveries: default_max_concurrent_deliveries(),
        }
    }
}

impl NotificationsConfig {
    fn is_default(value: &Self) -> bool {
        !value.enabled
            && value.webhooks.is_empty()
            && value.emails.is_empty()
            && value.max_concurrent_deliveries == default_max_concurrent_deliveries()
    }
}

fn default_max_concurrent_deliveries() -> usize {
    16
}

fn is_default_max_concurrent_deliveries(v: &usize) -> bool {
    *v == default_max_concurrent_deliveries()
}

#[derive(Deserialize, Serialize, Debug, Clone, Default)]
pub struct ProjectsConfig {
    /// List of known/allowed project codes
    #[serde(default)]
    #[serde(skip_serializing_if = "Vec::is_empty")]
    pub known_projects: Vec<String>,
    /// Require project to be specified for all jobs
    #[serde(default)]
    pub require_project: bool,
}

impl ProjectsConfig {
    fn is_default(value: &Self) -> bool {
        value.known_projects.is_empty() && !value.require_project
    }
}

#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct WebhookConfig {
    pub url: String,
    /// Events to subscribe to. Supports `"*"` (all).
    ///
    /// Examples: `["job_completed", "job_failed"]`, `["*"]`
    #[serde(default = "default_webhook_events")]
    pub events: Vec<String>,
    /// Optional: only notify for specific users (job submitter / reservation owner)
    #[serde(default)]
    #[serde(skip_serializing_if = "Option::is_none")]
    pub filter_users: Option<Vec<String>>,
    /// Optional: custom HTTP headers (e.g., Authorization)
    #[serde(default)]
    #[serde(skip_serializing_if = "HashMap::is_empty")]
    pub headers: HashMap<String, String>,
    /// Optional: per-delivery timeout in seconds (default: 10)
    #[serde(default = "default_webhook_timeout_secs")]
    pub timeout_secs: u64,
    /// Optional: number of retries after the initial attempt (default: 3)
    #[serde(default = "default_webhook_max_retries")]
    pub max_retries: u32,
}

#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct EmailConfig {
    /// SMTP connection URL (e.g. "smtps://user:pass@smtp.example.com:465")
    pub smtp_url: String,
    /// From mailbox (supports display name syntax like "gflow <noreply@example.com>")
    pub from: String,
    /// Recipient mailboxes
    #[serde(default)]
    pub to: Vec<String>,
    /// Events to subscribe to. Supports `"*"` (all).
    #[serde(default = "default_email_events")]
    pub events: Vec<String>,
    /// Optional: only notify for specific users (job submitter / reservation owner)
    #[serde(default)]
    #[serde(skip_serializing_if = "Option::is_none")]
    pub filter_users: Option<Vec<String>>,
    /// Optional subject prefix, e.g. "[gflow-prod]"
    #[serde(default)]
    #[serde(skip_serializing_if = "Option::is_none")]
    pub subject_prefix: Option<String>,
    /// Optional: per-delivery timeout in seconds (default: 10)
    #[serde(default = "default_email_timeout_secs")]
    pub timeout_secs: u64,
    /// Optional: number of retries after the initial attempt (default: 3)
    #[serde(default = "default_email_max_retries")]
    pub max_retries: u32,
}

fn default_webhook_events() -> Vec<String> {
    vec!["*".to_string()]
}

fn default_webhook_timeout_secs() -> u64 {
    10
}

fn default_webhook_max_retries() -> u32 {
    3
}

fn default_email_events() -> Vec<String> {
    vec!["*".to_string()]
}

fn default_email_timeout_secs() -> u64 {
    10
}

fn default_email_max_retries() -> u32 {
    3
}

fn default_host() -> String {
    "localhost".to_string()
}

fn default_port() -> u16 {
    59000
}

fn default_gpu_poll_interval_secs() -> u64 {
    10
}

fn is_default_gpu_poll_interval_secs(v: &u64) -> bool {
    *v == default_gpu_poll_interval_secs()
}

impl Default for DaemonConfig {
    fn default() -> Self {
        Self {
            host: default_host(),
            port: default_port(),
            gpus: None,
            gpu_allocation_strategy: GpuAllocationStrategy::default(),
            gpu_poll_interval_secs: default_gpu_poll_interval_secs(),
        }
    }
}

pub fn load_config(config_path: Option<&PathBuf>) -> Result<Config, config::ConfigError> {
    let mut config_vec = vec![];

    // Default config file
    if let Ok(default_config_path) = get_config_dir().map(|d| d.join("gflow.toml")) {
        if default_config_path.exists() {
            config_vec.push(default_config_path);
        }
    }

    // User-provided config file (should override defaults)
    if let Some(config_path) = config_path {
        if config_path.exists() {
            config_vec.push(config_path.clone());
        } else {
            eprintln!("Warning: Config file {config_path:?} not found.");
        }
    }

    let settings = config::Config::builder();
    let settings = config_vec.iter().fold(settings, |s, path| {
        s.add_source(config::File::from(path.as_path()))
    });

    settings
        .add_source(environment_source(None))
        .build()?
        .try_deserialize()
}

fn environment_source(source: Option<config::Map<String, String>>) -> config::Environment {
    config::Environment::with_prefix("GFLOW")
        .prefix_separator("_")
        .separator("__")
        .source(source)
        .try_parsing(true)
        .list_separator(",")
        .with_list_parse_key("daemon.gpus")
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn environment_source_applies_gpu_allocation_strategy() {
        let mut env = config::Map::new();
        env.insert(
            "GFLOW_DAEMON__GPU_ALLOCATION_STRATEGY".to_string(),
            "random".to_string(),
        );

        let config = config::Config::builder()
            .add_source(environment_source(Some(env)))
            .build()
            .unwrap()
            .try_deserialize::<Config>()
            .unwrap();

        assert_eq!(
            config.daemon.gpu_allocation_strategy,
            GpuAllocationStrategy::Random
        );
    }

    #[test]
    fn environment_source_applies_gpu_poll_interval() {
        let mut env = config::Map::new();
        env.insert(
            "GFLOW_DAEMON__GPU_POLL_INTERVAL_SECS".to_string(),
            "3".to_string(),
        );

        let config = config::Config::builder()
            .add_source(environment_source(Some(env)))
            .build()
            .unwrap()
            .try_deserialize::<Config>()
            .unwrap();

        assert_eq!(config.daemon.gpu_poll_interval_secs, 3);
    }

    #[test]
    fn environment_source_rejects_invalid_gpu_poll_interval() {
        let mut env = config::Map::new();
        env.insert(
            "GFLOW_DAEMON__GPU_POLL_INTERVAL_SECS".to_string(),
            "abc".to_string(),
        );

        let error = config::Config::builder()
            .add_source(environment_source(Some(env)))
            .build()
            .unwrap()
            .try_deserialize::<Config>()
            .unwrap_err();

        assert!(error.to_string().contains("invalid type"));
    }

    #[test]
    fn environment_source_does_not_treat_single_underscore_as_nested_separator() {
        let mut env = config::Map::new();
        env.insert(
            "GFLOW_DAEMON_GPU_POLL_INTERVAL_SECS".to_string(),
            "3".to_string(),
        );

        let config = config::Config::builder()
            .add_source(environment_source(Some(env)))
            .build()
            .unwrap()
            .try_deserialize::<Config>()
            .unwrap();

        assert_eq!(config.daemon.gpu_poll_interval_secs, 10);
    }
}