sunox 0.0.2

Generate AI music from your terminal via direct Suno web workflows
use figment::{
    Figment,
    providers::{Format, Serialized, Toml},
};
use serde::{Deserialize, Serialize};

use super::CliError;

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AppConfig {
    pub default_model: String,
    pub poll_interval_secs: u64,
    pub poll_timeout_secs: u64,
    pub output_dir: String,
}

impl Default for AppConfig {
    fn default() -> Self {
        Self {
            default_model: "chirp-fenix".into(),
            poll_interval_secs: 5,
            poll_timeout_secs: 600,
            output_dir: ".".into(),
        }
    }
}

impl AppConfig {
    pub fn load() -> Result<Self, CliError> {
        Self::load_from_path(Self::path(), std::env::vars())
    }

    pub fn load_with_overrides(overrides: &[String]) -> Result<Self, CliError> {
        let mut config = Self::load()?;
        config.apply_overrides(overrides)?;
        Ok(config)
    }

    pub(crate) fn load_from_path<I>(
        path: Option<std::path::PathBuf>,
        vars: I,
    ) -> Result<Self, CliError>
    where
        I: IntoIterator<Item = (String, String)>,
    {
        let mut figment = Figment::new().merge(Serialized::defaults(AppConfig::default()));
        if let Some(path) = path {
            figment = figment.merge(Toml::file(path));
        }
        let mut config: AppConfig = figment
            .extract()
            .map_err(|e| CliError::Config(format!("parse config: {e}")))?;
        config.apply_env_overrides(vars)?;
        Ok(config)
    }

    pub fn path() -> Option<std::path::PathBuf> {
        directories::ProjectDirs::from("com", "sunox", "sunox")
            .map(|dirs| dirs.config_dir().join("config.toml"))
    }

    pub fn set_persisted(key: &str, value: &str) -> Result<Self, CliError> {
        let path =
            Self::path().ok_or_else(|| CliError::Config("could not resolve config path".into()))?;
        let mut stored = StoredConfig::load(&path)?;
        stored.set(key, value)?;
        if let Some(parent) = path.parent() {
            std::fs::create_dir_all(parent)?;
        }
        let data = toml::to_string_pretty(&stored)
            .map_err(|e| CliError::Config(format!("serialize config: {e}")))?;
        std::fs::write(path, data)?;
        Self::load()
    }

    fn apply_env_overrides<I>(&mut self, vars: I) -> Result<(), CliError>
    where
        I: IntoIterator<Item = (String, String)>,
    {
        for (key, value) in vars {
            match key.as_str() {
                "SUNO_DEFAULT_MODEL" => self.default_model = normalize_model_key(&value)?,
                "SUNO_POLL_INTERVAL_SECS" => {
                    self.poll_interval_secs = parse_u64("SUNO_POLL_INTERVAL_SECS", &value)?;
                }
                "SUNO_POLL_TIMEOUT_SECS" => {
                    self.poll_timeout_secs = parse_u64("SUNO_POLL_TIMEOUT_SECS", &value)?;
                }
                "SUNO_OUTPUT_DIR" => self.output_dir = value,
                _ => {}
            }
        }
        Ok(())
    }

    fn apply_overrides(&mut self, overrides: &[String]) -> Result<(), CliError> {
        for override_value in overrides {
            let (key, value) = override_value.split_once('=').ok_or_else(|| {
                CliError::Config(format!(
                    "config override `{override_value}` must use key=value syntax"
                ))
            })?;
            self.set_value(key.trim(), normalize_override_value(value.trim()))?;
        }
        Ok(())
    }

    fn set_value(&mut self, key: &str, value: String) -> Result<(), CliError> {
        match key {
            "default_model" => self.default_model = normalize_model_key(&value)?,
            "poll_interval_secs" => self.poll_interval_secs = parse_u64(key, &value)?,
            "poll_timeout_secs" => self.poll_timeout_secs = parse_u64(key, &value)?,
            "output_dir" => self.output_dir = value,
            _ => {
                return Err(CliError::Config(format!(
                    "unknown config key `{key}`; valid keys: default_model, poll_interval_secs, poll_timeout_secs, output_dir"
                )));
            }
        }
        Ok(())
    }
}

#[derive(Debug, Default, Deserialize, Serialize)]
struct StoredConfig {
    #[serde(skip_serializing_if = "Option::is_none")]
    default_model: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    poll_interval_secs: Option<u64>,
    #[serde(skip_serializing_if = "Option::is_none")]
    poll_timeout_secs: Option<u64>,
    #[serde(skip_serializing_if = "Option::is_none")]
    output_dir: Option<String>,
}

impl StoredConfig {
    fn load(path: &std::path::Path) -> Result<Self, CliError> {
        if !path.exists() {
            return Ok(Self::default());
        }
        let data = std::fs::read_to_string(path)?;
        toml::from_str(&data).map_err(|e| CliError::Config(format!("parse config: {e}")))
    }

    fn set(&mut self, key: &str, value: &str) -> Result<(), CliError> {
        match key {
            "default_model" => self.default_model = Some(normalize_model_key(value)?),
            "poll_interval_secs" => self.poll_interval_secs = Some(parse_u64(key, value)?),
            "poll_timeout_secs" => self.poll_timeout_secs = Some(parse_u64(key, value)?),
            "output_dir" => self.output_dir = Some(value.to_string()),
            _ => {
                return Err(CliError::Config(format!(
                    "unknown config key `{key}`; valid keys: default_model, poll_interval_secs, poll_timeout_secs, output_dir"
                )));
            }
        }
        Ok(())
    }
}

fn normalize_model_key(value: &str) -> Result<String, CliError> {
    let normalized = match value {
        "v5.5" | "chirp-fenix" => "chirp-fenix",
        "v5" | "chirp-crow" => "chirp-crow",
        "v4.5+" | "chirp-bluejay" => "chirp-bluejay",
        "v4.5" | "chirp-auk" => "chirp-auk",
        "v4" | "chirp-v4" => "chirp-v4",
        "v3.5" | "chirp-v3-5" => "chirp-v3-5",
        "v3" | "chirp-v3-0" => "chirp-v3-0",
        "v2" | "chirp-v2-xxl-alpha" => "chirp-v2-xxl-alpha",
        _ => {
            return Err(CliError::Config(format!(
                "unknown model `{value}`; use a CLI model version such as v5.5 or a Suno API model key such as chirp-fenix"
            )));
        }
    };
    Ok(normalized.to_string())
}

fn normalize_override_value(value: &str) -> String {
    value
        .strip_prefix('"')
        .and_then(|value| value.strip_suffix('"'))
        .or_else(|| {
            value
                .strip_prefix('\'')
                .and_then(|value| value.strip_suffix('\''))
        })
        .unwrap_or(value)
        .to_string()
}

fn parse_u64(key: &str, value: &str) -> Result<u64, CliError> {
    value
        .parse::<u64>()
        .map_err(|_| CliError::Config(format!("config key `{key}` expects an unsigned integer")))
}

#[cfg(test)]
mod tests {
    use super::{AppConfig, StoredConfig};

    #[test]
    fn stored_config_sets_known_string_key() {
        let mut config = StoredConfig::default();

        config.set("default_model", "v5.5").expect("set config");

        assert_eq!(config.default_model.as_deref(), Some("chirp-fenix"));
    }

    #[test]
    fn stored_config_rejects_unknown_default_model() {
        let mut config = StoredConfig::default();

        let err = config
            .set("default_model", "unknown-model")
            .expect_err("unknown model");

        assert!(err.to_string().contains("unknown model"));
    }

    #[test]
    fn stored_config_parses_numeric_keys() {
        let mut config = StoredConfig::default();

        config.set("poll_timeout_secs", "900").expect("set config");

        assert_eq!(config.poll_timeout_secs, Some(900));
    }

    #[test]
    fn stored_config_rejects_unknown_keys() {
        let mut config = StoredConfig::default();

        let err = config.set("missing", "value").expect_err("unknown key");

        assert!(err.to_string().contains("unknown config key"));
    }

    #[test]
    fn env_overrides_support_underscored_config_keys() {
        let mut config = AppConfig::default();

        config
            .apply_env_overrides([
                ("SUNO_DEFAULT_MODEL".to_string(), "v5".to_string()),
                ("SUNO_POLL_INTERVAL_SECS".to_string(), "9".to_string()),
                ("SUNO_POLL_TIMEOUT_SECS".to_string(), "777".to_string()),
                (
                    "SUNO_OUTPUT_DIR".to_string(),
                    "/tmp/suno-output".to_string(),
                ),
            ])
            .expect("env overrides");

        assert_eq!(config.default_model, "chirp-crow");
        assert_eq!(config.poll_interval_secs, 9);
        assert_eq!(config.poll_timeout_secs, 777);
        assert_eq!(config.output_dir, "/tmp/suno-output");
    }

    #[test]
    fn env_override_rejects_unknown_default_model() {
        let mut config = AppConfig::default();

        let err = config
            .apply_env_overrides([(
                "SUNO_DEFAULT_MODEL".to_string(),
                "unknown-model".to_string(),
            )])
            .expect_err("unknown model");

        assert!(err.to_string().contains("unknown model"));
    }

    #[test]
    fn load_from_path_reports_invalid_toml() {
        let path = std::env::temp_dir().join(format!(
            "sunox-invalid-config-{}-{}.toml",
            std::process::id(),
            "core"
        ));
        std::fs::write(&path, "poll_timeout_secs = \"slow\"").expect("write config");

        let err = AppConfig::load_from_path(Some(path.clone()), []).expect_err("invalid config");

        let _ = std::fs::remove_file(path);
        assert!(err.to_string().contains("parse config"));
    }
}