use std::fmt;
use std::future::Future;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use serde_json::json;
use tokio::sync::Mutex as AsyncMutex;
use tracing::{error, info, warn};
use crate::db;
use crate::db::mediation_events::MediationEventKind;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuthState {
Authorized,
Unauthorized,
Terminated,
}
#[derive(Clone)]
pub struct AuthRetryHandle {
state: Arc<Mutex<AuthState>>,
}
impl AuthRetryHandle {
pub fn new_authorized() -> Self {
Self::with_state(AuthState::Authorized)
}
fn with_state(state: AuthState) -> Self {
Self {
state: Arc::new(Mutex::new(state)),
}
}
#[cfg(test)]
pub(crate) fn with_state_for_testing(state: AuthState) -> Self {
Self::with_state(state)
}
pub fn current_state(&self) -> AuthState {
*self.state.lock().expect("auth-retry state mutex poisoned")
}
pub fn is_authorized(&self) -> bool {
matches!(self.current_state(), AuthState::Authorized)
}
pub fn signal_auth_lost(&self) {
let mut guard = self.state.lock().expect("auth-retry state mutex poisoned");
if matches!(*guard, AuthState::Authorized) {
*guard = AuthState::Unauthorized;
}
}
fn set_state(&self, new_state: AuthState) {
let mut guard = self.state.lock().expect("auth-retry state mutex poisoned");
*guard = new_state;
}
}
impl fmt::Debug for AuthRetryHandle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AuthRetryHandle")
.field("state", &self.current_state())
.finish()
}
}
#[derive(Debug)]
pub struct AuthCheckError(pub String);
impl fmt::Display for AuthCheckError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
impl std::error::Error for AuthCheckError {}
async fn check_authorization(
_client: &nostr_sdk::Client,
_serbero_keys: &nostr_sdk::Keys,
_mostro_pubkey: &nostr_sdk::PublicKey,
) -> std::result::Result<(), AuthCheckError> {
Ok(())
}
#[derive(Debug, Clone, Copy)]
struct LoopConfig {
initial_delay: Duration,
max_interval: Duration,
max_attempts: u32,
max_total: Duration,
}
impl LoopConfig {
const fn production() -> Self {
Self {
initial_delay: Duration::from_secs(60),
max_interval: Duration::from_secs(3600),
max_attempts: 24,
max_total: Duration::from_secs(86_400),
}
}
}
fn next_delay(current: Duration, cap: Duration) -> Duration {
let doubled = current.saturating_mul(2);
if doubled > cap {
cap
} else {
doubled
}
}
pub async fn ensure_authorized_or_enter_loop(
conn: Arc<AsyncMutex<rusqlite::Connection>>,
client: nostr_sdk::Client,
serbero_keys: nostr_sdk::Keys,
mostro_pubkey: nostr_sdk::PublicKey,
) -> AuthRetryHandle {
let checker = std::sync::Arc::new(move || {
let client = client.clone();
let serbero_keys = serbero_keys.clone();
let mostro_pubkey = mostro_pubkey;
async move { check_authorization(&client, &serbero_keys, &mostro_pubkey).await }
});
ensure_authorized_or_enter_loop_inner(conn, checker, LoopConfig::production()).await
}
async fn ensure_authorized_or_enter_loop_inner<C, Fut>(
conn: Arc<AsyncMutex<rusqlite::Connection>>,
checker: Arc<C>,
config: LoopConfig,
) -> AuthRetryHandle
where
C: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = std::result::Result<(), AuthCheckError>> + Send + 'static,
{
let handle = AuthRetryHandle::with_state(AuthState::Unauthorized);
match checker().await {
Ok(()) => {
handle.set_state(AuthState::Authorized);
return handle;
}
Err(e) => {
let now = current_ts_secs();
let payload = json!({ "attempt": 1, "error": e.to_string() }).to_string();
if let Err(db_err) =
record_auth_event(&conn, MediationEventKind::AuthRetryAttempt, &payload, now).await
{
warn!(error = %db_err, "failed to record auth_retry_attempt (initial)");
}
warn!(attempt = 1, error = %e, "solver authorization check failed; entering retry loop");
}
}
let state = Arc::clone(&handle.state);
tokio::spawn(run_retry_loop(state, conn, checker, config));
handle
}
async fn run_retry_loop<C, Fut>(
state: Arc<Mutex<AuthState>>,
conn: Arc<AsyncMutex<rusqlite::Connection>>,
checker: Arc<C>,
config: LoopConfig,
) where
C: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = std::result::Result<(), AuthCheckError>> + Send + 'static,
{
let mut current_delay = config.initial_delay;
let mut cumulative = Duration::ZERO;
let mut attempt: u32 = 1;
loop {
tokio::time::sleep(current_delay).await;
cumulative = cumulative.saturating_add(current_delay);
attempt += 1;
match checker().await {
Ok(()) => {
{
let mut guard = state.lock().expect("auth-retry state mutex poisoned");
*guard = AuthState::Authorized;
}
let payload = json!({ "attempt": attempt }).to_string();
let now = current_ts_secs();
if let Err(db_err) =
record_auth_event(&conn, MediationEventKind::AuthRetryRecovered, &payload, now)
.await
{
warn!(error = %db_err, "failed to record auth_retry_recovered");
}
info!(attempt = attempt, "solver auth retry recovered");
return;
}
Err(e) => {
let payload = json!({ "attempt": attempt, "error": e.to_string() }).to_string();
let now = current_ts_secs();
if let Err(db_err) =
record_auth_event(&conn, MediationEventKind::AuthRetryAttempt, &payload, now)
.await
{
warn!(error = %db_err, "failed to record auth_retry_attempt");
}
warn!(attempt = attempt, error = %e, "solver auth retry attempt failed");
}
}
if attempt >= config.max_attempts || cumulative >= config.max_total {
{
let mut guard = state.lock().expect("auth-retry state mutex poisoned");
*guard = AuthState::Terminated;
}
let payload = json!({
"attempt": attempt,
"cumulative_secs": cumulative.as_secs(),
})
.to_string();
let now = current_ts_secs();
if let Err(db_err) = record_auth_event(
&conn,
MediationEventKind::AuthRetryTerminated,
&payload,
now,
)
.await
{
warn!(error = %db_err, "failed to record auth_retry_terminated");
}
error!(
attempt = attempt,
cumulative_secs = cumulative.as_secs(),
"solver auth retry loop terminated without recovery"
);
return;
}
current_delay = next_delay(current_delay, config.max_interval);
}
}
async fn record_auth_event(
conn: &Arc<AsyncMutex<rusqlite::Connection>>,
kind: MediationEventKind,
payload_json: &str,
occurred_at: i64,
) -> crate::error::Result<()> {
let guard = conn.lock().await;
db::mediation_events::record_event(
&guard,
kind,
None,
payload_json,
None,
None,
None,
occurred_at,
)?;
Ok(())
}
fn current_ts_secs() -> i64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock is before UNIX_EPOCH")
.as_secs() as i64
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use crate::db::migrations::run_migrations;
use crate::db::open_in_memory;
fn fresh_conn() -> Arc<AsyncMutex<rusqlite::Connection>> {
let mut conn = open_in_memory().unwrap();
run_migrations(&mut conn).unwrap();
Arc::new(AsyncMutex::new(conn))
}
fn tight_config() -> LoopConfig {
LoopConfig {
initial_delay: Duration::from_secs(1),
max_interval: Duration::from_secs(4),
max_attempts: 5,
max_total: Duration::from_secs(3_600),
}
}
async fn wait_until(handle: &AuthRetryHandle, want: AuthState, max_advances: u32) {
for _ in 0..max_advances {
if handle.current_state() == want {
return;
}
tokio::time::advance(Duration::from_secs(5)).await;
tokio::task::yield_now().await;
}
panic!(
"state never reached {want:?} (last observed: {:?})",
handle.current_state()
);
}
async fn count_events(
conn: &Arc<AsyncMutex<rusqlite::Connection>>,
kind: MediationEventKind,
) -> i64 {
let guard = conn.lock().await;
guard
.query_row(
"SELECT COUNT(*) FROM mediation_events WHERE kind = ?1",
[kind.as_str()],
|r| r.get::<_, i64>(0),
)
.unwrap()
}
#[tokio::test]
async fn check_succeeds_immediately_returns_authorized() {
let conn = fresh_conn();
let checker = Arc::new(|| async { Ok::<(), AuthCheckError>(()) });
let handle =
ensure_authorized_or_enter_loop_inner(Arc::clone(&conn), checker, tight_config()).await;
assert_eq!(handle.current_state(), AuthState::Authorized);
assert!(handle.is_authorized());
let attempts = count_events(&conn, MediationEventKind::AuthRetryAttempt).await;
assert_eq!(attempts, 0);
}
#[tokio::test(start_paused = true)]
async fn loop_recovers_after_n_failures() {
let conn = fresh_conn();
let counter = Arc::new(AtomicU32::new(0));
let checker = {
let counter = Arc::clone(&counter);
Arc::new(move || {
let counter = Arc::clone(&counter);
async move {
let n = counter.fetch_add(1, Ordering::SeqCst);
if n < 3 {
Err(AuthCheckError(format!("mock failure #{n}")))
} else {
Ok(())
}
}
})
};
let handle =
ensure_authorized_or_enter_loop_inner(Arc::clone(&conn), checker, tight_config()).await;
assert_eq!(handle.current_state(), AuthState::Unauthorized);
wait_until(&handle, AuthState::Authorized, 20).await;
let attempts = count_events(&conn, MediationEventKind::AuthRetryAttempt).await;
let recovered = count_events(&conn, MediationEventKind::AuthRetryRecovered).await;
let terminated = count_events(&conn, MediationEventKind::AuthRetryTerminated).await;
assert_eq!(recovered, 1, "exactly one recovery event expected");
assert_eq!(terminated, 0, "must not also emit a terminated event");
assert!(attempts >= 3, "expected >=3 attempt events, got {attempts}");
}
#[tokio::test(start_paused = true)]
async fn loop_terminates_after_max_attempts() {
let conn = fresh_conn();
let checker =
Arc::new(|| async { Err::<(), _>(AuthCheckError("mock always fails".into())) });
let cfg = tight_config();
let handle = ensure_authorized_or_enter_loop_inner(Arc::clone(&conn), checker, cfg).await;
assert_eq!(handle.current_state(), AuthState::Unauthorized);
wait_until(&handle, AuthState::Terminated, 40).await;
let terminated = count_events(&conn, MediationEventKind::AuthRetryTerminated).await;
assert_eq!(
terminated, 1,
"exactly one auth_retry_terminated event must be emitted"
);
let recovered = count_events(&conn, MediationEventKind::AuthRetryRecovered).await;
assert_eq!(recovered, 0);
let attempts = count_events(&conn, MediationEventKind::AuthRetryAttempt).await;
assert_eq!(
attempts as u32, cfg.max_attempts,
"expected exactly max_attempts auth_retry_attempt rows"
);
let terminated_attempt: i64 = {
let guard = conn.lock().await;
guard
.query_row(
"SELECT json_extract(payload_json, '$.attempt')
FROM mediation_events WHERE kind = 'auth_retry_terminated'",
[],
|r| r.get(0),
)
.unwrap()
};
assert_eq!(
terminated_attempt as u32, cfg.max_attempts,
"terminated payload must report the final attempt, not max+1"
);
}
#[test]
fn backoff_doubles_up_to_cap() {
let cap = Duration::from_secs(3600);
let mut d = Duration::from_secs(60);
let expected = [120, 240, 480, 960, 1920, 3600, 3600, 3600];
for want in expected {
d = next_delay(d, cap);
assert_eq!(d, Duration::from_secs(want), "unexpected delay step");
}
}
}