use figment::{
Figment,
providers::{Format, Serialized, Toml},
};
use serde::{Deserialize, Serialize};
use super::CliError;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AppConfig {
pub default_model: String,
pub poll_interval_secs: u64,
pub poll_timeout_secs: u64,
pub output_dir: String,
}
impl Default for AppConfig {
fn default() -> Self {
Self {
default_model: "chirp-fenix".into(),
poll_interval_secs: 5,
poll_timeout_secs: 600,
output_dir: ".".into(),
}
}
}
impl AppConfig {
pub fn load() -> Result<Self, CliError> {
Self::load_from_path(Self::path(), std::env::vars())
}
pub fn load_with_overrides(overrides: &[String]) -> Result<Self, CliError> {
let mut config = Self::load()?;
config.apply_overrides(overrides)?;
Ok(config)
}
pub(crate) fn load_from_path<I>(
path: Option<std::path::PathBuf>,
vars: I,
) -> Result<Self, CliError>
where
I: IntoIterator<Item = (String, String)>,
{
let mut figment = Figment::new().merge(Serialized::defaults(AppConfig::default()));
if let Some(path) = path {
figment = figment.merge(Toml::file(path));
}
let mut config: AppConfig = figment
.extract()
.map_err(|e| CliError::Config(format!("parse config: {e}")))?;
config.apply_env_overrides(vars)?;
Ok(config)
}
pub fn path() -> Option<std::path::PathBuf> {
directories::ProjectDirs::from("com", "sunox", "sunox")
.map(|dirs| dirs.config_dir().join("config.toml"))
}
pub fn set_persisted(key: &str, value: &str) -> Result<Self, CliError> {
let path =
Self::path().ok_or_else(|| CliError::Config("could not resolve config path".into()))?;
let mut stored = StoredConfig::load(&path)?;
stored.set(key, value)?;
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let data = toml::to_string_pretty(&stored)
.map_err(|e| CliError::Config(format!("serialize config: {e}")))?;
std::fs::write(path, data)?;
Self::load()
}
fn apply_env_overrides<I>(&mut self, vars: I) -> Result<(), CliError>
where
I: IntoIterator<Item = (String, String)>,
{
for (key, value) in vars {
match key.as_str() {
"SUNO_DEFAULT_MODEL" => self.default_model = normalize_model_key(&value)?,
"SUNO_POLL_INTERVAL_SECS" => {
self.poll_interval_secs = parse_u64("SUNO_POLL_INTERVAL_SECS", &value)?;
}
"SUNO_POLL_TIMEOUT_SECS" => {
self.poll_timeout_secs = parse_u64("SUNO_POLL_TIMEOUT_SECS", &value)?;
}
"SUNO_OUTPUT_DIR" => self.output_dir = value,
_ => {}
}
}
Ok(())
}
fn apply_overrides(&mut self, overrides: &[String]) -> Result<(), CliError> {
for override_value in overrides {
let (key, value) = override_value.split_once('=').ok_or_else(|| {
CliError::Config(format!(
"config override `{override_value}` must use key=value syntax"
))
})?;
self.set_value(key.trim(), normalize_override_value(value.trim()))?;
}
Ok(())
}
fn set_value(&mut self, key: &str, value: String) -> Result<(), CliError> {
match key {
"default_model" => self.default_model = normalize_model_key(&value)?,
"poll_interval_secs" => self.poll_interval_secs = parse_u64(key, &value)?,
"poll_timeout_secs" => self.poll_timeout_secs = parse_u64(key, &value)?,
"output_dir" => self.output_dir = value,
_ => {
return Err(CliError::Config(format!(
"unknown config key `{key}`; valid keys: default_model, poll_interval_secs, poll_timeout_secs, output_dir"
)));
}
}
Ok(())
}
}
#[derive(Debug, Default, Deserialize, Serialize)]
struct StoredConfig {
#[serde(skip_serializing_if = "Option::is_none")]
default_model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
poll_interval_secs: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
poll_timeout_secs: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
output_dir: Option<String>,
}
impl StoredConfig {
fn load(path: &std::path::Path) -> Result<Self, CliError> {
if !path.exists() {
return Ok(Self::default());
}
let data = std::fs::read_to_string(path)?;
toml::from_str(&data).map_err(|e| CliError::Config(format!("parse config: {e}")))
}
fn set(&mut self, key: &str, value: &str) -> Result<(), CliError> {
match key {
"default_model" => self.default_model = Some(normalize_model_key(value)?),
"poll_interval_secs" => self.poll_interval_secs = Some(parse_u64(key, value)?),
"poll_timeout_secs" => self.poll_timeout_secs = Some(parse_u64(key, value)?),
"output_dir" => self.output_dir = Some(value.to_string()),
_ => {
return Err(CliError::Config(format!(
"unknown config key `{key}`; valid keys: default_model, poll_interval_secs, poll_timeout_secs, output_dir"
)));
}
}
Ok(())
}
}
fn normalize_model_key(value: &str) -> Result<String, CliError> {
let normalized = match value {
"v5.5" | "chirp-fenix" => "chirp-fenix",
"v5" | "chirp-crow" => "chirp-crow",
"v4.5+" | "chirp-bluejay" => "chirp-bluejay",
"v4.5" | "chirp-auk" => "chirp-auk",
"v4" | "chirp-v4" => "chirp-v4",
"v3.5" | "chirp-v3-5" => "chirp-v3-5",
"v3" | "chirp-v3-0" => "chirp-v3-0",
"v2" | "chirp-v2-xxl-alpha" => "chirp-v2-xxl-alpha",
_ => {
return Err(CliError::Config(format!(
"unknown model `{value}`; use a CLI model version such as v5.5 or a Suno API model key such as chirp-fenix"
)));
}
};
Ok(normalized.to_string())
}
fn normalize_override_value(value: &str) -> String {
value
.strip_prefix('"')
.and_then(|value| value.strip_suffix('"'))
.or_else(|| {
value
.strip_prefix('\'')
.and_then(|value| value.strip_suffix('\''))
})
.unwrap_or(value)
.to_string()
}
fn parse_u64(key: &str, value: &str) -> Result<u64, CliError> {
value
.parse::<u64>()
.map_err(|_| CliError::Config(format!("config key `{key}` expects an unsigned integer")))
}
#[cfg(test)]
mod tests {
use super::{AppConfig, StoredConfig};
#[test]
fn stored_config_sets_known_string_key() {
let mut config = StoredConfig::default();
config.set("default_model", "v5.5").expect("set config");
assert_eq!(config.default_model.as_deref(), Some("chirp-fenix"));
}
#[test]
fn stored_config_rejects_unknown_default_model() {
let mut config = StoredConfig::default();
let err = config
.set("default_model", "unknown-model")
.expect_err("unknown model");
assert!(err.to_string().contains("unknown model"));
}
#[test]
fn stored_config_parses_numeric_keys() {
let mut config = StoredConfig::default();
config.set("poll_timeout_secs", "900").expect("set config");
assert_eq!(config.poll_timeout_secs, Some(900));
}
#[test]
fn stored_config_rejects_unknown_keys() {
let mut config = StoredConfig::default();
let err = config.set("missing", "value").expect_err("unknown key");
assert!(err.to_string().contains("unknown config key"));
}
#[test]
fn env_overrides_support_underscored_config_keys() {
let mut config = AppConfig::default();
config
.apply_env_overrides([
("SUNO_DEFAULT_MODEL".to_string(), "v5".to_string()),
("SUNO_POLL_INTERVAL_SECS".to_string(), "9".to_string()),
("SUNO_POLL_TIMEOUT_SECS".to_string(), "777".to_string()),
(
"SUNO_OUTPUT_DIR".to_string(),
"/tmp/suno-output".to_string(),
),
])
.expect("env overrides");
assert_eq!(config.default_model, "chirp-crow");
assert_eq!(config.poll_interval_secs, 9);
assert_eq!(config.poll_timeout_secs, 777);
assert_eq!(config.output_dir, "/tmp/suno-output");
}
#[test]
fn env_override_rejects_unknown_default_model() {
let mut config = AppConfig::default();
let err = config
.apply_env_overrides([(
"SUNO_DEFAULT_MODEL".to_string(),
"unknown-model".to_string(),
)])
.expect_err("unknown model");
assert!(err.to_string().contains("unknown model"));
}
#[test]
fn load_from_path_reports_invalid_toml() {
let path = std::env::temp_dir().join(format!(
"sunox-invalid-config-{}-{}.toml",
std::process::id(),
"core"
));
std::fs::write(&path, "poll_timeout_secs = \"slow\"").expect("write config");
let err = AppConfig::load_from_path(Some(path.clone()), []).expect_err("invalid config");
let _ = std::fs::remove_file(path);
assert!(err.to_string().contains("parse config"));
}
}