use std::collections::HashMap;
use anyhow::{Context, Result};
use suno_core::{ClerkAuth, Config, EffectiveSettings, FlagOverrides};
use crate::cli::args::{AuthArgs, AuthCommand, AuthRefreshArgs, GlobalArgs};
use crate::cli::desired::ExitCode;
use crate::cli::run;
use crate::http::ReqwestHttp;
pub async fn run_auth(global: &GlobalArgs, args: &AuthArgs) -> Result<ExitCode> {
match &args.command {
AuthCommand::Refresh(refresh) => refresh_accounts(global, refresh).await,
}
}
async fn refresh_accounts(global: &GlobalArgs, refresh: &AuthRefreshArgs) -> Result<ExitCode> {
let env: HashMap<String, String> = std::env::vars().collect();
let flags = FlagOverrides {
token: global.token.clone(),
..FlagOverrides::default()
};
let config = match run::load_config_reported(global.config.as_deref()) {
Ok(config) => config,
Err(code) => return Ok(code),
};
let resolved = match resolve_targets(config.as_ref(), global, refresh, &env, &flags) {
Ok(resolved) => resolved,
Err(message) => {
eprintln!("error: {message}");
return Ok(ExitCode::Config);
}
};
let http = ReqwestHttp::new().context("failed to build the HTTP client")?;
let mut worst = ExitCode::Ok;
for (label, settings) in resolved {
let Some(token) = settings.token else {
eprintln!("error: no token for account '{label}'; pass --token or set it in config");
worst = worse(worst, ExitCode::Config);
continue;
};
let mut auth = ClerkAuth::new(&token);
match auth.authenticate(&http).await {
Ok(_) => {
if global.verbosity() >= -1 {
eprintln!("Re-authenticated '{label}' as {}", auth.display_name());
}
}
Err(err) => worst = worse(worst, run::report_auth_failure(&label, &err)),
}
}
Ok(worst)
}
fn resolve_targets(
config: Option<&Config>,
global: &GlobalArgs,
refresh: &AuthRefreshArgs,
env: &HashMap<String, String>,
flags: &FlagOverrides,
) -> std::result::Result<Vec<(String, EffectiveSettings)>, String> {
if let Some(account) = &refresh.account {
let settings = resolve_named(config, account, env, flags)?;
return Ok(vec![(account.clone(), settings)]);
}
if global.all {
let cfg = config.ok_or_else(|| "--all requires a config file".to_owned())?;
let mut labels: Vec<String> = cfg.accounts.keys().cloned().collect();
labels.sort();
if labels.is_empty() {
return Err("no accounts are configured".to_owned());
}
return labels
.into_iter()
.map(|label| {
cfg.resolve(&label, None, env, flags)
.map(|settings| (label, settings))
.map_err(|err| err.to_string())
})
.collect();
}
let resolved = run::single_account(config, global, flags, env)?;
Ok(vec![resolved])
}
fn resolve_named(
config: Option<&Config>,
label: &str,
env: &HashMap<String, String>,
flags: &FlagOverrides,
) -> std::result::Result<EffectiveSettings, String> {
let cfg = config.ok_or_else(|| format!("account '{label}' not found: no config file"))?;
if !cfg.accounts.contains_key(label) {
let mut labels: Vec<&str> = cfg.accounts.keys().map(String::as_str).collect();
labels.sort_unstable();
return Err(format!(
"account '{label}' not found in config (configured: {})",
labels.join(", ")
));
}
cfg.resolve(label, None, env, flags)
.map_err(|err| err.to_string())
}
fn worse(a: ExitCode, b: ExitCode) -> ExitCode {
if b.code() >= a.code() { b } else { a }
}
#[cfg(test)]
mod tests {
use super::*;
fn env() -> HashMap<String, String> {
HashMap::new()
}
fn global() -> GlobalArgs {
GlobalArgs::default()
}
#[test]
fn named_account_resolves_from_config() {
let config = Config::from_toml("[accounts.alice]\ntoken = \"t\"\n").unwrap();
let refresh = AuthRefreshArgs {
account: Some("alice".to_owned()),
};
let targets = resolve_targets(
Some(&config),
&global(),
&refresh,
&env(),
&FlagOverrides::default(),
)
.unwrap();
assert_eq!(targets.len(), 1);
assert_eq!(targets[0].0, "alice");
assert_eq!(targets[0].1.token.as_deref(), Some("t"));
}
#[test]
fn unknown_named_account_errors() {
let config = Config::from_toml("[accounts.alice]\ntoken = \"t\"\n").unwrap();
let refresh = AuthRefreshArgs {
account: Some("bob".to_owned()),
};
let err = resolve_targets(
Some(&config),
&global(),
&refresh,
&env(),
&FlagOverrides::default(),
)
.unwrap_err();
assert!(err.contains("not found"));
}
#[test]
fn all_resolves_every_account_sorted() {
let config =
Config::from_toml("[accounts.bob]\ntoken=\"b\"\n[accounts.alice]\ntoken=\"a\"\n")
.unwrap();
let refresh = AuthRefreshArgs { account: None };
let global = GlobalArgs {
all: true,
..Default::default()
};
let targets = resolve_targets(
Some(&config),
&global,
&refresh,
&env(),
&FlagOverrides::default(),
)
.unwrap();
let labels: Vec<&str> = targets.iter().map(|(l, _)| l.as_str()).collect();
assert_eq!(labels, ["alice", "bob"]);
}
#[test]
fn all_without_config_errors() {
let refresh = AuthRefreshArgs { account: None };
let global = GlobalArgs {
all: true,
..Default::default()
};
let err = resolve_targets(None, &global, &refresh, &env(), &FlagOverrides::default())
.unwrap_err();
assert!(err.contains("--all requires"));
}
#[test]
fn worse_picks_higher_code() {
assert_eq!(worse(ExitCode::Ok, ExitCode::Auth), ExitCode::Auth);
assert_eq!(worse(ExitCode::Auth, ExitCode::Ok), ExitCode::Auth);
assert_eq!(worse(ExitCode::Config, ExitCode::Config), ExitCode::Config);
}
}