use anyhow::{anyhow, Context, Result};
use directories::{ProjectDirs, UserDirs};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
const TRACE_TARGET: &str = "studio_worker::config";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub api_base_url: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub worker_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub auth_token: Option<String>,
pub vram_threshold_gb: f32,
pub auto_start: bool,
#[serde(default = "default_auto_update_enabled")]
pub auto_update_enabled: bool,
#[serde(default = "default_auto_update_interval")]
pub auto_update_interval_secs: u64,
#[serde(default = "default_auto_update_feed")]
pub auto_update_feed: String,
#[serde(default)]
pub auto_update_prerelease: bool,
#[serde(default = "default_models_root_persisted")]
pub models_root: PathBuf,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub ws_reconnect_attempts: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub install_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub registration_request_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub registration_secret: Option<String>,
}
fn default_auto_update_enabled() -> bool {
true
}
fn default_auto_update_interval() -> u64 {
1800
}
fn default_auto_update_feed() -> String {
"https://api.github.com/repos/webbertakken/studio-worker/releases".into()
}
pub fn default_models_root() -> PathBuf {
if let Some(user) = UserDirs::new() {
return user.home_dir().join("models");
}
std::env::temp_dir().join("studio-worker-models")
}
fn default_models_root_persisted() -> PathBuf {
default_models_root()
}
fn expand_home(path: PathBuf) -> PathBuf {
let s = path.to_string_lossy();
if s == "~" {
return UserDirs::new()
.map(|d| d.home_dir().to_path_buf())
.unwrap_or(path);
}
if let Some(rest) = s.strip_prefix("~/") {
if let Some(d) = UserDirs::new() {
return d.home_dir().join(rest);
}
}
path
}
impl Default for Config {
fn default() -> Self {
Self {
api_base_url: "https://studio.minis.gg/".into(),
worker_id: None,
auth_token: None,
vram_threshold_gb: 12.0,
auto_start: true,
auto_update_enabled: default_auto_update_enabled(),
auto_update_interval_secs: default_auto_update_interval(),
auto_update_feed: default_auto_update_feed(),
auto_update_prerelease: false,
models_root: default_models_root(),
ws_reconnect_attempts: None,
install_id: None,
registration_request_id: None,
registration_secret: None,
}
}
}
fn default_config_path() -> Result<PathBuf> {
let dirs = ProjectDirs::from("gg", "minis", "minis-studio-worker")
.ok_or_else(|| anyhow!("cannot resolve config directory"))?;
Ok(dirs.config_dir().join("config.toml"))
}
pub fn resolve_path(override_path: Option<&str>) -> Result<PathBuf> {
if let Some(p) = override_path {
Ok(PathBuf::from(p))
} else {
default_config_path()
}
}
pub fn load(override_path: Option<&str>) -> Result<(Config, PathBuf)> {
let path = resolve_path(override_path)?;
if !path.exists() {
let cfg = Config::default();
save(&cfg, &path)?;
tracing::info!(
target: TRACE_TARGET,
op = "load",
source = "default_created",
config_path = %path.display(),
api_base_url = %cfg.api_base_url,
vram_threshold_gb = cfg.vram_threshold_gb,
auto_start = cfg.auto_start,
models_root = %cfg.models_root.display(),
"config file missing — bootstrapped defaults"
);
return Ok((cfg, path));
}
let text = match std::fs::read_to_string(&path) {
Ok(text) => text,
Err(e) => {
tracing::warn!(
target: TRACE_TARGET,
op = "load",
config_path = %path.display(),
error = %e,
"failed to read config file"
);
return Err(e).with_context(|| format!("reading {}", path.display()));
}
};
let mut cfg: Config = match toml::from_str(&text) {
Ok(cfg) => cfg,
Err(e) => {
tracing::warn!(
target: TRACE_TARGET,
op = "load",
config_path = %path.display(),
"config file is not valid TOML"
);
return Err(e).context("parsing config.toml");
}
};
cfg.models_root = expand_home(std::mem::take(&mut cfg.models_root));
tracing::debug!(
target: TRACE_TARGET,
op = "load",
source = "existing_file",
config_path = %path.display(),
api_base_url = %cfg.api_base_url,
vram_threshold_gb = cfg.vram_threshold_gb,
auto_start = cfg.auto_start,
models_root = %cfg.models_root.display(),
worker_id = cfg.worker_id.as_deref().unwrap_or("(unregistered)"),
has_auth_token = cfg.auth_token.is_some(),
"loaded config from disk"
);
Ok((cfg, path))
}
pub fn save(cfg: &Config, path: &Path) -> Result<()> {
match write_config(cfg, path) {
Ok(bytes) => {
tracing::debug!(
target: TRACE_TARGET,
op = "save",
config_path = %path.display(),
vram_threshold_gb = cfg.vram_threshold_gb,
auto_start = cfg.auto_start,
models_root = %cfg.models_root.display(),
bytes = bytes,
"persisted config to disk"
);
Ok(())
}
Err(e) => {
tracing::warn!(
target: TRACE_TARGET,
op = "save",
config_path = %path.display(),
error = %e,
"failed to persist config to disk"
);
Err(e)
}
}
}
fn write_config(cfg: &Config, path: &Path) -> Result<usize> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("creating {}", parent.display()))?;
}
let text = toml::to_string_pretty(cfg).with_context(|| "serialising config")?;
let bytes = text.len();
write_atomic(path, text.as_bytes())?;
Ok(bytes)
}
fn write_atomic(path: &Path, bytes: &[u8]) -> Result<()> {
use std::io::Write as _;
let dir = match path.parent() {
Some(p) if !p.as_os_str().is_empty() => p,
_ => Path::new("."),
};
let mut tmp = tempfile::NamedTempFile::new_in(dir)
.with_context(|| format!("creating temp file in {}", dir.display()))?;
tmp.write_all(bytes)
.with_context(|| "writing temp config")?;
tmp.as_file()
.sync_all()
.with_context(|| "flushing temp config to disk")?;
tmp.persist(path)
.map_err(|e| anyhow!("atomically replacing {}: {}", path.display(), e.error))?;
Ok(())
}
pub type SharedConfig = std::sync::Arc<Mutex<Config>>;
pub fn shared(cfg: Config) -> SharedConfig {
std::sync::Arc::new(Mutex::new(cfg))
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn default_values_are_sensible() {
let cfg = Config::default();
assert_eq!(cfg.api_base_url, "https://studio.minis.gg/");
assert!(cfg.auto_start);
assert!(cfg.auto_update_enabled);
assert_eq!(cfg.auto_update_interval_secs, 1800);
assert!(!cfg.auto_update_prerelease);
assert!(cfg.auto_update_feed.contains("webbertakken/studio-worker"));
assert_eq!(cfg.vram_threshold_gb, 12.0);
assert!(cfg.worker_id.is_none());
assert!(cfg.auth_token.is_none());
let m = cfg.models_root.to_string_lossy().to_string();
assert!(m.ends_with("models") || m.contains("studio-worker-models"));
}
#[test]
fn resolve_path_uses_override_when_provided() {
let path = resolve_path(Some("/tmp/test-config.toml")).unwrap();
assert_eq!(path, PathBuf::from("/tmp/test-config.toml"));
}
#[test]
fn resolve_path_defaults_when_no_override() {
let path = resolve_path(None).unwrap();
let s = path.to_string_lossy();
assert!(
s.contains("minis-studio-worker") || s.contains("minis.gg.minis-studio-worker"),
"unexpected default path: {s}"
);
assert!(s.ends_with("config.toml"));
}
#[test]
fn load_creates_default_when_file_missing() {
let dir = tempdir().unwrap();
let path = dir.path().join("sub").join("config.toml");
let path_str = path.to_string_lossy().to_string();
let (cfg, returned_path) = load(Some(&path_str)).unwrap();
assert_eq!(returned_path, path);
assert_eq!(cfg.api_base_url, "https://studio.minis.gg/");
assert!(path.exists());
}
#[test]
fn round_trip_via_save_and_load_preserves_fields() {
let dir = tempdir().unwrap();
let path = dir.path().join("config.toml");
let cfg = Config {
worker_id: Some("w-123".into()),
auth_token: Some("tok-xyz".into()),
vram_threshold_gb: 24.0,
auto_update_prerelease: true,
models_root: PathBuf::from("/tmp/test-models"),
..Config::default()
};
save(&cfg, &path).unwrap();
let path_str = path.to_string_lossy().to_string();
let (loaded, _) = load(Some(&path_str)).unwrap();
assert_eq!(loaded.api_base_url, cfg.api_base_url);
assert_eq!(loaded.worker_id, cfg.worker_id);
assert_eq!(loaded.auth_token, cfg.auth_token);
assert_eq!(loaded.vram_threshold_gb, cfg.vram_threshold_gb);
assert_eq!(loaded.auto_update_prerelease, cfg.auto_update_prerelease);
assert_eq!(loaded.models_root, cfg.models_root);
}
#[test]
fn shared_wraps_in_arc_mutex() {
let cfg = Config::default();
let shared = shared(cfg.clone());
let guard = shared.lock();
assert_eq!(guard.api_base_url, cfg.api_base_url);
}
#[test]
fn load_returns_error_on_malformed_toml() {
let dir = tempdir().unwrap();
let path = dir.path().join("config.toml");
std::fs::write(&path, "this :: is = not = toml = :").unwrap();
let path_str = path.to_string_lossy().to_string();
let err = load(Some(&path_str)).unwrap_err();
assert!(err.to_string().contains("parsing config.toml"));
}
#[test]
fn load_strips_legacy_engine_fields_silently() {
let dir = tempdir().unwrap();
let path = dir.path().join("config.toml");
let legacy = r#"
api_base_url = "https://example.invalid"
vram_threshold_gb = 8.0
auto_start = true
engine = "multi"
engines = ["llama", "synthetic"]
auto_enabled = false
label = "alice's rig"
"#;
std::fs::write(&path, legacy).unwrap();
let (cfg, _) = load(Some(&path.to_string_lossy())).unwrap();
assert_eq!(cfg.api_base_url, "https://example.invalid");
assert_eq!(cfg.vram_threshold_gb, 8.0);
}
#[test]
fn load_expands_leading_tilde_in_models_root() {
let dir = tempdir().unwrap();
let path = dir.path().join("config.toml");
let raw = r#"
api_base_url = "https://x.invalid"
vram_threshold_gb = 4.0
auto_start = true
auto_update_enabled = false
auto_update_interval_secs = 1
auto_update_feed = "https://x.invalid"
auto_update_prerelease = false
models_root = "~/models-test"
"#;
std::fs::write(&path, raw).unwrap();
let (cfg, _) = load(Some(&path.to_string_lossy())).unwrap();
assert!(
cfg.models_root.is_absolute(),
"~/ should expand to an absolute path, got {}",
cfg.models_root.display()
);
assert!(cfg.models_root.ends_with("models-test"));
}
#[test]
fn expand_home_leaves_absolute_paths_alone() {
let p = PathBuf::from("/tmp/anywhere");
assert_eq!(expand_home(p.clone()), p);
}
#[test]
fn expand_home_handles_bare_tilde() {
let expanded = expand_home(PathBuf::from("~"));
assert!(
expanded.is_absolute() || expanded == Path::new("~"),
"bare ~ expands to home (or stays put on weird boxes), got {}",
expanded.display()
);
}
#[cfg(unix)]
#[test]
fn save_writes_config_owner_only_because_it_holds_secrets() {
use std::os::unix::fs::PermissionsExt;
let dir = tempdir().unwrap();
let path = dir.path().join("config.toml");
let cfg = Config {
auth_token: Some("super-secret-token".into()),
registration_secret: Some("reg-secret".into()),
..Config::default()
};
save(&cfg, &path).unwrap();
let mode = std::fs::metadata(&path).unwrap().permissions().mode();
assert_eq!(
mode & 0o077,
0,
"secrets-bearing config must not be group/world-accessible; got mode {mode:o}"
);
}
#[test]
fn save_atomically_replaces_existing_config_without_temp_litter() {
let dir = tempdir().unwrap();
let path = dir.path().join("config.toml");
let big = Config {
api_base_url: "https://a-very-long-host-name.example.invalid/studio/".into(),
worker_id: Some("worker-with-a-longish-id-000000".into()),
..Config::default()
};
save(&big, &path).unwrap();
let small = Config {
api_base_url: "https://x/".into(),
..Config::default()
};
save(&small, &path).unwrap();
let (loaded, _) = load(Some(&path.to_string_lossy())).unwrap();
assert_eq!(loaded.api_base_url, "https://x/");
assert!(
loaded.worker_id.is_none(),
"a replacing save must not leave the previous worker_id behind"
);
let names: Vec<String> = std::fs::read_dir(dir.path())
.unwrap()
.map(|e| e.unwrap().file_name().to_string_lossy().to_string())
.collect();
assert_eq!(
names,
vec!["config.toml".to_string()],
"atomic save must leave only the target file, found: {names:?}"
);
}
}