assay-lua 0.10.4

General-purpose enhanced Lua runtime. Batteries-included scripting, automation, and web services.
Documentation
use anyhow::{Context, Result, bail};
use serde::Deserialize;
use std::collections::HashMap;
use std::path::Path;
use std::time::Duration;

#[derive(Debug, Clone)]
pub struct Config {
    pub timeout: Duration,
    pub retries: u32,
    pub backoff: Duration,
    pub parallel: bool,
    pub checks: Vec<CheckConfig>,
}

#[derive(Debug, Clone)]
pub struct CheckConfig {
    pub name: String,
    pub check_type: CheckType,
    pub url: Option<String>,
    pub expect: Option<ExpectConfig>,
    pub query: Option<String>,
    pub file: Option<String>,
    pub follow_redirects: bool,
    pub env: HashMap<String, String>,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CheckType {
    Http,
    Prometheus,
    Script,
}

#[derive(Debug, Clone, Default)]
pub struct ExpectConfig {
    pub status: Option<u16>,
    pub json: Option<String>,
    pub body: Option<String>,
    pub min: Option<f64>,
    pub max: Option<f64>,
}

#[derive(Deserialize)]
struct RawConfig {
    #[serde(default = "default_timeout")]
    timeout: String,
    #[serde(default = "default_retries")]
    retries: u32,
    #[serde(default = "default_backoff")]
    backoff: String,
    #[serde(default)]
    parallel: bool,
    checks: Vec<RawCheck>,
}

fn default_timeout() -> String {
    "120s".to_string()
}

fn default_retries() -> u32 {
    3
}

fn default_backoff() -> String {
    "5s".to_string()
}

fn default_follow_redirects() -> bool {
    true
}

#[derive(Deserialize)]
struct RawCheck {
    name: String,
    #[serde(rename = "type")]
    check_type: String,
    url: Option<String>,
    expect: Option<RawExpect>,
    #[serde(default = "default_follow_redirects")]
    follow_redirects: bool,
    query: Option<String>,
    file: Option<String>,
    #[serde(default)]
    env: HashMap<String, String>,
}

#[derive(Deserialize)]
struct RawExpect {
    status: Option<u16>,
    json: Option<String>,
    body: Option<String>,
    min: Option<f64>,
    max: Option<f64>,
}

pub fn parse_duration(s: &str) -> Result<Duration> {
    let s = s.trim();
    if let Some(ms) = s.strip_suffix("ms") {
        let val: u64 = ms.parse().context("invalid milliseconds value")?;
        return Ok(Duration::from_millis(val));
    }
    if let Some(secs) = s.strip_suffix('s') {
        let val: u64 = secs.parse().context("invalid seconds value")?;
        return Ok(Duration::from_secs(val));
    }
    if let Some(mins) = s.strip_suffix('m') {
        let val: u64 = mins.parse().context("invalid minutes value")?;
        return Ok(Duration::from_secs(val * 60));
    }
    bail!("unsupported duration format: {s:?} (use e.g. '120s', '5m', '500ms')")
}

pub fn load(path: &Path) -> Result<Config> {
    let content = std::fs::read_to_string(path)
        .with_context(|| format!("reading config {}", path.display()))?;
    parse(&content)
}

pub fn parse(yaml: &str) -> Result<Config> {
    let raw: RawConfig = serde_yml::from_str(yaml).context("parsing YAML config")?;

    let timeout = parse_duration(&raw.timeout).context("parsing timeout")?;
    let backoff = parse_duration(&raw.backoff).context("parsing backoff")?;

    let checks = raw
        .checks
        .into_iter()
        .map(|c| {
            let check_type = match c.check_type.as_str() {
                "http" => CheckType::Http,
                "prometheus" => CheckType::Prometheus,
                "script" => CheckType::Script,
                other => bail!("unknown check type: {other:?}"),
            };
            Ok(CheckConfig {
                name: c.name,
                check_type,
                url: c.url,
                expect: c.expect.map(|e| ExpectConfig {
                    status: e.status,
                    json: e.json,
                    body: e.body,
                    min: e.min,
                    max: e.max,
                }),
                follow_redirects: c.follow_redirects,
                query: c.query,
                file: c.file,
                env: c.env,
            })
        })
        .collect::<Result<Vec<_>>>()?;

    Ok(Config {
        timeout,
        retries: raw.retries,
        backoff,
        parallel: raw.parallel,
        checks,
    })
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_parse_duration_seconds() {
        assert_eq!(parse_duration("120s").unwrap(), Duration::from_secs(120));
    }

    #[test]
    fn test_parse_duration_minutes() {
        assert_eq!(parse_duration("5m").unwrap(), Duration::from_secs(300));
    }

    #[test]
    fn test_parse_duration_millis() {
        assert_eq!(parse_duration("500ms").unwrap(), Duration::from_millis(500));
    }

    #[test]
    fn test_parse_config() {
        let yaml = r#"
timeout: 30s
retries: 2
backoff: 3s
parallel: false
checks:
  - name: health
    type: http
    url: http://localhost/health
    expect:
      status: 200
  - name: custom
    type: script
    file: /checks/verify.lua
    env:
      FOO: bar
"#;
        let config = parse(yaml).unwrap();
        assert_eq!(config.timeout, Duration::from_secs(30));
        assert_eq!(config.retries, 2);
        assert_eq!(config.checks.len(), 2);
        assert_eq!(config.checks[0].check_type, CheckType::Http);
        assert_eq!(config.checks[1].check_type, CheckType::Script);
    }
}