use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use reqwest::Client as HttpClient;
use tokio::sync::watch;
use url::Url;
use crate::ClientError;
use crate::auth::AuthProvider;
use crate::state::{ResumeKey, StateStore};
mod builder;
mod helpers;
mod watch_spawn;
pub use builder::AvisoClientBuilder;
pub(crate) use helpers::{
parse_json_response, parse_json_response_optional, validate_path_segment,
};
pub(crate) use watch_spawn::{compute_resume_key, decrement_active_key, increment_active_key};
type ActiveResumeKeys = Arc<Mutex<HashMap<ResumeKey, usize>>>;
const _: fn(&ActiveResumeKeys, &ResumeKey, &str) = increment_active_key;
const _: fn(&Url, &crate::watch::WatchRequest) -> crate::Result<ResumeKey> = compute_resume_key;
pub(crate) struct DropGuard {
sender: watch::Sender<bool>,
}
impl DropGuard {
pub(super) fn new() -> (Arc<Self>, watch::Receiver<bool>) {
let (sender, receiver) = watch::channel(false);
(Arc::new(Self { sender }), receiver)
}
pub(super) fn subscribe(&self) -> watch::Receiver<bool> {
self.sender.subscribe()
}
}
impl Drop for DropGuard {
fn drop(&mut self) {
let _ = self.sender.send(true);
}
}
#[derive(Debug, Default)]
pub(crate) struct RefreshCoordinator {
lock: tokio::sync::Mutex<()>,
generation: AtomicU64,
}
impl RefreshCoordinator {
pub(crate) fn generation(&self) -> u64 {
self.generation.load(Ordering::Relaxed)
}
pub(crate) async fn refresh_once(
&self,
auth: &Arc<dyn AuthProvider>,
observed: u64,
) -> crate::Result<()> {
let _guard = self.lock.lock().await;
if self.generation.load(Ordering::Relaxed) != observed {
return Ok(());
}
auth.refresh().await?;
self.generation.fetch_add(1, Ordering::Relaxed);
Ok(())
}
}
#[derive(Clone)]
#[non_exhaustive]
pub struct AvisoClient {
pub(super) http: HttpClient,
pub(super) base_url: Url,
pub(super) auth: Option<Arc<dyn AuthProvider>>,
pub(super) refresh_coordinator: Arc<RefreshCoordinator>,
pub(super) parent_drop: Arc<DropGuard>,
pub(super) heartbeat_interval: Duration,
pub(super) state_store: Option<Arc<dyn StateStore>>,
pub(super) active_resume_keys: ActiveResumeKeys,
pub(super) danger_accept_invalid_certs: bool,
pub(super) flush_cursor_on_exit: bool,
}
impl std::fmt::Debug for AvisoClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut sanitized = self.base_url.clone();
let _ = sanitized.set_username("");
let _ = sanitized.set_password(None);
f.debug_struct("AvisoClient")
.field("base_url", &sanitized.as_str())
.field("auth", &self.auth)
.finish_non_exhaustive()
}
}
impl AvisoClient {
pub fn builder() -> AvisoClientBuilder {
AvisoClientBuilder::default()
}
#[must_use]
pub fn base_url(&self) -> &Url {
&self.base_url
}
#[must_use]
pub fn auth(&self) -> Option<&Arc<dyn AuthProvider>> {
self.auth.as_ref()
}
#[must_use]
pub fn danger_accept_invalid_certs(&self) -> bool {
self.danger_accept_invalid_certs
}
pub(crate) fn http(&self) -> &HttpClient {
&self.http
}
pub(crate) fn endpoint(&self, relative_path: &str) -> crate::Result<Url> {
self.base_url.join(relative_path).map_err(|e| {
ClientError::Config(format!(
"build endpoint url from base {} and path {relative_path:?}: {e}",
self.base_url
))
})
}
pub(crate) async fn attach_auth(
&self,
builder: reqwest::RequestBuilder,
) -> crate::Result<reqwest::RequestBuilder> {
if let Some(auth) = self.auth() {
let value = auth.authorization_header().await?;
Ok(builder.header(reqwest::header::AUTHORIZATION, value))
} else {
Ok(builder)
}
}
pub(crate) async fn send_with_refresh<F>(
&self,
mut build: F,
) -> crate::Result<reqwest::Response>
where
F: FnMut(&HttpClient) -> reqwest::RequestBuilder,
{
let first = self.attach_auth(build(self.http())).await?;
let observed = self.refresh_coordinator.generation();
let response = first.send().await.map_err(ClientError::from)?;
if response.status() == reqwest::StatusCode::UNAUTHORIZED {
if let Some(auth) = self.auth() {
drop(response);
self.refresh_coordinator
.refresh_once(auth, observed)
.await?;
let retry = self.attach_auth(build(self.http())).await?;
return retry.send().await.map_err(ClientError::from);
}
}
Ok(response)
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
reason = "test code: unwrap and expect on constructor success and on assertion-shaped awaits are the expected diagnostics"
)]
mod tests {
use std::sync::Arc;
use std::time::Duration;
use super::AvisoClient;
use crate::auth::Bearer;
#[test]
fn auth_defaults_to_none() {
let client = AvisoClient::builder()
.base_url("http://localhost:8000")
.build()
.unwrap();
assert!(client.auth().is_none());
}
#[test]
fn auth_is_set_when_provided() {
let provider =
Arc::new(Bearer::new("token").unwrap()) as Arc<dyn crate::auth::AuthProvider>;
let client = AvisoClient::builder()
.base_url("http://localhost:8000")
.auth(provider)
.build()
.unwrap();
assert!(client.auth().is_some());
}
#[test]
fn endpoint_joins_relative_paths() {
let client = AvisoClient::builder()
.base_url("http://localhost:8000")
.build()
.unwrap();
let url = client.endpoint("api/v1/notification").unwrap();
assert_eq!(url.as_str(), "http://localhost:8000/api/v1/notification");
}
#[test]
fn endpoint_join_respects_proxy_path_prefix() {
let client = AvisoClient::builder()
.base_url("https://gw.example.org/aviso")
.build()
.unwrap();
let url = client.endpoint("api/v1/notification").unwrap();
assert_eq!(
url.as_str(),
"https://gw.example.org/aviso/api/v1/notification"
);
}
#[test]
fn client_is_cheap_to_clone() {
let client = AvisoClient::builder()
.base_url("http://localhost:8000")
.build()
.unwrap();
let _copy = client.clone();
}
#[test]
fn debug_does_not_leak_auth_token() {
let provider = Arc::new(Bearer::new("super-secret-jwt-do-not-leak").unwrap())
as Arc<dyn crate::auth::AuthProvider>;
let client = AvisoClient::builder()
.base_url("http://localhost:8000")
.auth(provider)
.build()
.unwrap();
let formatted = format!("{client:?}");
assert!(
!formatted.contains("super-secret-jwt-do-not-leak"),
"AvisoClient Debug must not leak the auth token: {formatted}"
);
}
#[test]
fn debug_strips_userinfo_from_base_url() {
let client = AvisoClient::builder()
.base_url("https://operator:hunter2@aviso.example.org")
.build()
.unwrap();
let formatted = format!("{client:?}");
assert!(
!formatted.contains("hunter2"),
"AvisoClient Debug must strip password from base_url: {formatted}"
);
assert!(
!formatted.contains("operator"),
"AvisoClient Debug must strip username from base_url: {formatted}"
);
}
#[tokio::test]
async fn drop_guard_fires_when_last_clone_drops() {
let client = AvisoClient::builder()
.base_url("http://localhost:8000")
.build()
.unwrap();
let mut receiver = client.parent_drop.subscribe();
assert!(!*receiver.borrow_and_update());
let clone = client.clone();
drop(client);
assert!(
!*receiver.borrow_and_update(),
"guard must NOT fire while another clone is alive"
);
drop(clone);
let observed = tokio::time::timeout(Duration::from_millis(100), receiver.changed())
.await
.expect("guard must fire within 100ms of last clone drop");
assert!(observed.is_ok());
assert!(*receiver.borrow_and_update());
}
#[tokio::test]
async fn drop_guard_broadcasts_to_multiple_subscribers() {
let client = AvisoClient::builder()
.base_url("http://localhost:8000")
.build()
.unwrap();
let mut a = client.parent_drop.subscribe();
let mut b = client.parent_drop.subscribe();
let mut c = client.parent_drop.subscribe();
drop(client);
for rx in [&mut a, &mut b, &mut c] {
let observed = tokio::time::timeout(Duration::from_millis(100), rx.changed()).await;
assert!(
observed.is_ok_and(|r| r.is_ok()),
"every subscriber must observe the drop"
);
assert!(*rx.borrow_and_update());
}
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
reason = "test code: unwrap and expect on constructor success and assertion-shaped awaits are the expected diagnostics"
)]
mod refresh_single_flight {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::time::Duration;
use reqwest::header::HeaderValue;
use serde_json::json;
use wiremock::matchers::{header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
use super::{AvisoClient, RefreshCoordinator};
use crate::ClientError;
use crate::auth::AuthProvider;
#[derive(Debug, Default)]
struct CountingRefresher {
refreshes: AtomicUsize,
delay: Duration,
fail: bool,
}
#[async_trait::async_trait]
impl AuthProvider for CountingRefresher {
async fn authorization_header(&self) -> crate::Result<HeaderValue> {
Ok(HeaderValue::from_static("Bearer test"))
}
async fn refresh(&self) -> crate::Result<()> {
self.refreshes.fetch_add(1, Ordering::SeqCst);
if !self.delay.is_zero() {
tokio::time::sleep(self.delay).await;
}
if self.fail {
return Err(ClientError::Auth("refresh failed".into()));
}
Ok(())
}
}
#[derive(Debug, Default)]
struct RotatingCredential {
refreshes: AtomicUsize,
refreshed: AtomicBool,
}
#[async_trait::async_trait]
impl AuthProvider for RotatingCredential {
async fn authorization_header(&self) -> crate::Result<HeaderValue> {
let token = if self.refreshed.load(Ordering::SeqCst) {
"fresh"
} else {
"stale"
};
Ok(HeaderValue::from_static(token))
}
async fn refresh(&self) -> crate::Result<()> {
self.refreshes.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(50)).await;
self.refreshed.store(true, Ordering::SeqCst);
Ok(())
}
}
#[tokio::test]
async fn concurrent_callers_of_one_epoch_refresh_once() {
let coordinator = Arc::new(RefreshCoordinator::default());
let provider = Arc::new(CountingRefresher {
delay: Duration::from_millis(50),
..CountingRefresher::default()
});
let auth: Arc<dyn AuthProvider> = provider.clone();
let mut handles = Vec::new();
for _ in 0..8 {
let coordinator = coordinator.clone();
let auth = auth.clone();
handles.push(tokio::spawn(async move {
coordinator.refresh_once(&auth, 0).await
}));
}
for handle in handles {
handle.await.unwrap().unwrap();
}
assert_eq!(provider.refreshes.load(Ordering::SeqCst), 1);
assert_eq!(coordinator.generation(), 1);
}
#[tokio::test]
async fn each_new_epoch_refreshes_again() {
let coordinator = RefreshCoordinator::default();
let provider = Arc::new(CountingRefresher::default());
let auth: Arc<dyn AuthProvider> = provider.clone();
coordinator.refresh_once(&auth, 0).await.unwrap();
coordinator.refresh_once(&auth, 1).await.unwrap();
assert_eq!(provider.refreshes.load(Ordering::SeqCst), 2);
assert_eq!(coordinator.generation(), 2);
coordinator.refresh_once(&auth, 0).await.unwrap();
assert_eq!(provider.refreshes.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn failed_refresh_is_not_coalesced_and_keeps_the_epoch() {
let coordinator = RefreshCoordinator::default();
let provider = Arc::new(CountingRefresher {
fail: true,
..CountingRefresher::default()
});
let auth: Arc<dyn AuthProvider> = provider.clone();
let err = coordinator.refresh_once(&auth, 0).await.unwrap_err();
assert!(matches!(err, ClientError::Auth(_)), "got {err:?}");
assert_eq!(coordinator.generation(), 0);
let err = coordinator.refresh_once(&auth, 0).await.unwrap_err();
assert!(matches!(err, ClientError::Auth(_)), "got {err:?}");
assert_eq!(provider.refreshes.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn dropped_refresh_does_not_wedge_the_lock() {
let coordinator = RefreshCoordinator::default();
let slow = Arc::new(CountingRefresher {
delay: Duration::from_secs(10),
..CountingRefresher::default()
});
let slow_auth: Arc<dyn AuthProvider> = slow.clone();
let dropped = tokio::time::timeout(
Duration::from_millis(20),
coordinator.refresh_once(&slow_auth, 0),
)
.await;
assert!(
dropped.is_err(),
"the slow refresh must not finish before the deadline, so the timeout drops it mid-refresh"
);
assert_eq!(
coordinator.generation(),
0,
"a dropped refresh must not advance the epoch"
);
let fast = Arc::new(CountingRefresher::default());
let fast_auth: Arc<dyn AuthProvider> = fast.clone();
tokio::time::timeout(
Duration::from_millis(500),
coordinator.refresh_once(&fast_auth, 0),
)
.await
.expect("lock must be free after the dropped refresh")
.unwrap();
assert_eq!(coordinator.generation(), 1);
assert_eq!(fast.refreshes.load(Ordering::SeqCst), 1);
}
async fn mount_rotating_credential_server() -> MockServer {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/api/v1/notification"))
.and(header("authorization", "stale"))
.respond_with(ResponseTemplate::new(401))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/api/v1/notification"))
.and(header("authorization", "fresh"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"status": "success",
"request_id": "r",
"processed_at": "2026-05-17T12:34:56Z",
})))
.mount(&server)
.await;
server
}
#[tokio::test]
async fn notify_many_burst_of_401s_refreshes_once() {
let server = mount_rotating_credential_server().await;
let provider = Arc::new(RotatingCredential::default());
let client = AvisoClient::builder()
.base_url(server.uri())
.auth(provider.clone() as Arc<dyn AuthProvider>)
.build()
.unwrap();
let requests: Vec<crate::NotificationRequest> = (0..8)
.map(|i| crate::NotificationRequest::new(format!("e{i}")))
.collect();
let results = client.notify_many(&requests, 8).await;
assert!(results.iter().all(Result::is_ok), "{results:?}");
assert_eq!(provider.refreshes.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn cloned_clients_share_the_coordinator() {
let server = mount_rotating_credential_server().await;
let provider = Arc::new(RotatingCredential::default());
let client = AvisoClient::builder()
.base_url(server.uri())
.auth(provider.clone() as Arc<dyn AuthProvider>)
.build()
.unwrap();
let clone = client.clone();
let first = crate::NotificationRequest::new("a");
let second = crate::NotificationRequest::new("b");
let (a, b) = tokio::join!(client.notify(&first), clone.notify(&second));
a.unwrap();
b.unwrap();
assert_eq!(provider.refreshes.load(Ordering::SeqCst), 1);
}
}