use super::SunoClient;
use crate::auth::{self, AuthState};
use crate::core::CliError;
pub(super) async fn refresh_state_if_needed(
client: &reqwest::Client,
auth: &mut AuthState,
) -> Result<(), CliError> {
auth::refresh_state_if_needed(client, auth).await
}
impl SunoClient {
pub(crate) async fn refresh_jwt_after_auth_failure(
&self,
failed_jwt: Option<String>,
) -> Result<(), CliError> {
let mut auth = {
let auth = self.auth.lock().expect("auth mutex poisoned");
if !should_refresh_after_auth_failure(auth.jwt.as_deref(), failed_jwt.as_deref()) {
return Ok(());
}
auth.clone()
};
auth::refresh_state_for_retry(&self.client, &mut auth).await?;
{
let mut current_auth = self.auth.lock().expect("auth mutex poisoned");
if should_replace_auth_after_refresh(
current_auth.jwt.as_deref(),
failed_jwt.as_deref(),
auth.jwt.as_deref(),
) {
*current_auth = auth;
}
}
Ok(())
}
pub(crate) async fn try_refresh_jwt_for_challenge_recheck(&self) -> Result<bool, CliError> {
let (request_jwt, has_refresh_material) = {
let auth = self.auth.lock().expect("auth mutex poisoned");
(auth.jwt.clone(), auth.clerk_client_cookie.is_some())
};
if !has_refresh_material {
return Ok(false);
}
match self.refresh_jwt_after_auth_failure(request_jwt).await {
Ok(()) => Ok(true),
Err(CliError::AuthExpired) => Ok(false),
Err(error) => Err(error),
}
}
fn current_jwt(&self) -> Option<String> {
self.auth.lock().expect("auth mutex poisoned").jwt.clone()
}
pub(crate) async fn with_auth_retry<F, Fut, T>(&self, mut f: F) -> Result<T, CliError>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, CliError>>,
{
let request_jwt = self.current_jwt();
match f().await {
Err(CliError::AuthExpired) => {
self.refresh_jwt_after_auth_failure(request_jwt).await?;
f().await
}
other => other,
}
}
}
fn should_refresh_after_auth_failure(current_jwt: Option<&str>, failed_jwt: Option<&str>) -> bool {
current_jwt == failed_jwt
}
fn should_replace_auth_after_refresh(
current_jwt: Option<&str>,
failed_jwt: Option<&str>,
refreshed_jwt: Option<&str>,
) -> bool {
current_jwt == failed_jwt || current_jwt == refreshed_jwt
}
#[cfg(test)]
mod tests {
use super::{should_refresh_after_auth_failure, should_replace_auth_after_refresh};
#[test]
fn retry_refresh_runs_when_auth_still_matches_failed_jwt() {
assert!(should_refresh_after_auth_failure(
Some("old-jwt"),
Some("old-jwt")
));
assert!(should_refresh_after_auth_failure(None, None));
}
#[test]
fn retry_refresh_skips_when_another_task_already_updated_jwt() {
assert!(!should_refresh_after_auth_failure(
Some("new-jwt"),
Some("old-jwt")
));
assert!(!should_refresh_after_auth_failure(Some("new-jwt"), None));
}
#[test]
fn retry_refresh_does_not_overwrite_newer_auth_state() {
assert!(should_replace_auth_after_refresh(
Some("old-jwt"),
Some("old-jwt"),
Some("new-jwt")
));
assert!(should_replace_auth_after_refresh(
Some("new-jwt"),
Some("old-jwt"),
Some("new-jwt")
));
assert!(!should_replace_auth_after_refresh(
Some("newer-jwt"),
Some("old-jwt"),
Some("new-jwt")
));
}
}