use std::time::Duration;
use tokio::sync::oneshot;
use crate::client::Client;
use crate::error::{Error, Result};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct AuthStatus {
pub authenticated: bool,
pub connected: bool,
pub competing: bool,
pub message: Option<String>,
}
impl From<bezant_api::BrokerageSessionStatus> for AuthStatus {
fn from(s: bezant_api::BrokerageSessionStatus) -> Self {
Self {
authenticated: s.authenticated.unwrap_or(false),
connected: s.connected.unwrap_or(false),
competing: s.competing.unwrap_or(false),
message: s.message,
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct TickleResponse {
pub session: Option<String>,
pub raw: bezant_api::TickleResponse,
}
impl Client {
#[tracing::instrument(skip(self), level = "debug")]
pub async fn auth_status(&self) -> Result<AuthStatus> {
let mut url = self.base_url().clone();
url.path_segments_mut()
.map_err(|()| Error::UrlNotABase {
url: self.base_url().to_string(),
})?
.push("iserver")
.push("auth")
.push("status");
url.set_query(None);
let gateway_origin = self
.gateway_root_url()
.as_str()
.trim_end_matches('/')
.to_owned();
let resp = self
.http()
.post(url.clone())
.header(reqwest::header::CONTENT_LENGTH, "0")
.header(reqwest::header::ORIGIN, &gateway_origin)
.header(reqwest::header::REFERER, format!("{gateway_origin}/"))
.send()
.await
.map_err(Error::Http)?;
let status = resp.status();
if status == reqwest::StatusCode::UNAUTHORIZED {
return Err(Error::NotAuthenticated);
}
if status.is_redirection() {
let location = resp
.headers()
.get(reqwest::header::LOCATION)
.and_then(|v| v.to_str().ok())
.unwrap_or_default()
.to_ascii_lowercase();
if location.contains("/sso/login") || location.contains("/sso/dispatcher") {
return Err(Error::NotAuthenticated);
}
return Err(Error::UpstreamStatus {
endpoint: "iserver/auth/status",
status: status.as_u16(),
body_preview: Some(format!("redirect to: {location}")),
});
}
if !status.is_success() {
return Err(Error::UpstreamStatus {
endpoint: "iserver/auth/status",
status: status.as_u16(),
body_preview: None,
});
}
let parsed: bezant_api::BrokerageSessionStatus = resp.json().await.map_err(|e| {
Error::Decode {
endpoint: format!("POST {}/iserver/auth/status", self.base_url()),
status: status.as_u16(),
message: e.to_string(),
}
})?;
Ok(AuthStatus::from(parsed))
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn tickle(&self) -> Result<TickleResponse> {
let resp = self
.api()
.get_session_token(bezant_api::GetSessionTokenRequest::default())
.await?;
match resp {
bezant_api::GetSessionTokenResponse::Ok(payload) => {
let session = match &payload {
bezant_api::TickleResponse::Successful(s) => s.session.clone(),
bezant_api::TickleResponse::Failed(_) => None,
};
Ok(TickleResponse {
session,
raw: payload,
})
}
bezant_api::GetSessionTokenResponse::Unauthorized => Err(Error::NotAuthenticated),
bezant_api::GetSessionTokenResponse::Unknown => Err(Error::Unknown {
endpoint: "iserver/auth/tickle",
}),
}
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn health(&self) -> Result<AuthStatus> {
let status = self.auth_status().await?;
if !status.connected {
return Err(Error::NoSession);
}
if !status.authenticated {
return Err(Error::NotAuthenticated);
}
Ok(status)
}
#[must_use]
pub fn spawn_keepalive(&self, interval: Duration) -> KeepaliveHandle {
use tracing::Instrument;
let client = self.clone();
let (tx, mut rx) = oneshot::channel::<()>();
let span = tracing::info_span!("bezant_keepalive", interval_secs = interval.as_secs());
let join = tokio::spawn(
async move {
let mut ticker = tokio::time::interval(interval);
ticker.tick().await;
loop {
tokio::select! {
_ = &mut rx => break,
_ = ticker.tick() => {
if let Err(e) = client.tickle().await {
tracing::warn!(error = %e, "bezant keepalive tickle failed");
}
}
}
}
}
.instrument(span),
);
KeepaliveHandle {
shutdown: Some(tx),
join: Some(join),
}
}
}
#[derive(Debug)]
pub struct KeepaliveHandle {
shutdown: Option<oneshot::Sender<()>>,
join: Option<tokio::task::JoinHandle<()>>,
}
impl KeepaliveHandle {
pub async fn stop(mut self) -> Result<()> {
if let Some(tx) = self.shutdown.take() {
let _ = tx.send(());
}
if let Some(join) = self.join.take() {
join.await
.map_err(|e| Error::Other(format!("keepalive task panicked: {e}")))?;
}
Ok(())
}
}
impl Drop for KeepaliveHandle {
fn drop(&mut self) {
if let Some(tx) = self.shutdown.take() {
let _ = tx.send(());
}
}
}
#[cfg(test)]
mod keepalive_tests {
use super::*;
use std::time::Duration;
fn dummy_client() -> Client {
Client::builder("http://127.0.0.1:1/v1/api")
.build()
.expect("client")
}
#[tokio::test]
async fn stop_sends_shutdown_and_joins_cleanly() {
let client = dummy_client();
let handle = client.spawn_keepalive(Duration::from_secs(60));
tokio::time::sleep(Duration::from_millis(50)).await;
handle.stop().await.expect("clean stop");
}
#[tokio::test]
async fn drop_sends_shutdown_signal() {
let client = dummy_client();
{
let _handle = client.spawn_keepalive(Duration::from_secs(60));
tokio::time::sleep(Duration::from_millis(50)).await;
}
let _h2 = client.spawn_keepalive(Duration::from_secs(60));
tokio::time::sleep(Duration::from_millis(50)).await;
}
}