use std::time::Duration;
use progenitor_client::Error as RawError;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
use crate::types::ErrorResponse;
use crate::Client;
#[derive(Debug)]
pub enum BrokkrError {
Api(ErrorResponse, reqwest::StatusCode),
Transport(reqwest::Error),
UnexpectedResponse {
status: Option<reqwest::StatusCode>,
detail: String,
},
InvalidRequest(String),
}
impl BrokkrError {
pub fn status(&self) -> Option<reqwest::StatusCode> {
match self {
Self::Api(_, status) => Some(*status),
Self::Transport(e) => e.status(),
Self::UnexpectedResponse { status, .. } => *status,
Self::InvalidRequest(_) => None,
}
}
pub fn code(&self) -> Option<&str> {
match self {
Self::Api(body, _) => Some(&body.code),
_ => None,
}
}
pub fn is_retryable(&self) -> bool {
match self {
Self::Transport(_) => true,
Self::Api(_, status) => is_retryable_status(*status),
Self::UnexpectedResponse {
status: Some(status),
..
} => is_retryable_status(*status),
_ => false,
}
}
}
impl std::fmt::Display for BrokkrError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Api(body, status) => {
write!(f, "{} {}: {}", status.as_u16(), body.code, body.message)
}
Self::Transport(e) => write!(f, "transport error: {e}"),
Self::UnexpectedResponse { status, detail } => match status {
Some(s) => write!(f, "unexpected response ({}): {}", s.as_u16(), detail),
None => write!(f, "unexpected response: {detail}"),
},
Self::InvalidRequest(msg) => write!(f, "invalid request: {msg}"),
}
}
}
impl std::error::Error for BrokkrError {}
impl From<RawError<ErrorResponse>> for BrokkrError {
fn from(err: RawError<ErrorResponse>) -> Self {
match err {
RawError::ErrorResponse(rv) => {
let status = rv.status();
Self::Api(rv.into_inner(), status)
}
RawError::CommunicationError(e)
| RawError::InvalidUpgrade(e)
| RawError::ResponseBodyError(e) => Self::Transport(e),
RawError::InvalidRequest(msg) => Self::InvalidRequest(msg),
RawError::InvalidResponsePayload(bytes, e) => Self::UnexpectedResponse {
status: None,
detail: format!("payload deserialization failed: {e} ({} bytes)", bytes.len()),
},
RawError::UnexpectedResponse(resp) => Self::UnexpectedResponse {
status: Some(resp.status()),
detail: "response not described in OpenAPI spec".to_string(),
},
RawError::Custom(s) => Self::InvalidRequest(s),
}
}
}
fn is_retryable_status(status: reqwest::StatusCode) -> bool {
matches!(status.as_u16(), 408 | 429 | 502 | 503 | 504)
}
#[derive(Debug)]
pub struct BrokkrClientBuilder {
base_url: String,
token: Option<String>,
request_timeout: Duration,
connect_timeout: Duration,
max_retries: u32,
initial_backoff: Duration,
}
impl BrokkrClientBuilder {
fn new(base_url: impl Into<String>) -> Self {
Self {
base_url: base_url.into(),
token: None,
request_timeout: Duration::from_secs(30),
connect_timeout: Duration::from_secs(10),
max_retries: 3,
initial_backoff: Duration::from_millis(200),
}
}
pub fn token(mut self, token: impl Into<String>) -> Self {
self.token = Some(token.into());
self
}
pub fn request_timeout(mut self, timeout: Duration) -> Self {
self.request_timeout = timeout;
self
}
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
pub fn max_retries(mut self, max: u32) -> Self {
self.max_retries = max;
self
}
pub fn initial_backoff(mut self, initial: Duration) -> Self {
self.initial_backoff = initial;
self
}
pub fn build(self) -> Result<BrokkrClient, BrokkrError> {
let mut headers = HeaderMap::new();
if let Some(token) = &self.token {
let value = HeaderValue::from_str(token).map_err(|e| {
BrokkrError::InvalidRequest(format!("invalid token header value: {e}"))
})?;
headers.insert(AUTHORIZATION, value);
}
let reqwest_client = reqwest::Client::builder()
.default_headers(headers)
.connect_timeout(self.connect_timeout)
.timeout(self.request_timeout)
.build()
.map_err(BrokkrError::Transport)?;
let inner = Client::new_with_client(&self.base_url, reqwest_client);
Ok(BrokkrClient {
inner,
max_retries: self.max_retries,
initial_backoff: self.initial_backoff,
})
}
}
#[derive(Debug, Clone)]
pub struct BrokkrClient {
inner: Client,
max_retries: u32,
initial_backoff: Duration,
}
impl BrokkrClient {
pub fn builder(base_url: impl Into<String>) -> BrokkrClientBuilder {
BrokkrClientBuilder::new(base_url)
}
pub fn api(&self) -> &Client {
&self.inner
}
pub async fn retry<F, Fut, T>(&self, mut op: F) -> Result<T, BrokkrError>
where
F: FnMut(&Client) -> Fut,
Fut: std::future::Future<Output = Result<T, BrokkrError>>,
{
let mut attempt: u32 = 0;
loop {
match op(&self.inner).await {
Ok(value) => return Ok(value),
Err(err) if !err.is_retryable() || attempt >= self.max_retries => {
return Err(err);
}
Err(_) => {
let backoff = self
.initial_backoff
.saturating_mul(1u32 << attempt)
.min(Duration::from_secs(10));
tokio::time::sleep(backoff).await;
attempt += 1;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_constructs_without_token() {
use progenitor_client::ClientInfo;
let c = BrokkrClient::builder("http://localhost:3000/api/v1")
.build()
.expect("builder should succeed");
assert_eq!(c.api().baseurl(), "http://localhost:3000/api/v1");
}
#[test]
fn builder_accepts_token_and_timeouts() {
let c = BrokkrClient::builder("http://localhost:3000/api/v1")
.token("bk_admin_test_token")
.request_timeout(Duration::from_secs(5))
.connect_timeout(Duration::from_secs(2))
.max_retries(5)
.initial_backoff(Duration::from_millis(50))
.build()
.expect("builder should succeed");
assert_eq!(c.max_retries, 5);
assert_eq!(c.initial_backoff, Duration::from_millis(50));
}
#[test]
fn invalid_token_header_is_rejected() {
let result = BrokkrClient::builder("http://localhost:3000/api/v1")
.token("invalid\nheader\rvalue")
.build();
assert!(matches!(result, Err(BrokkrError::InvalidRequest(_))));
}
#[test]
fn error_code_extracted_from_api_response() {
let err = BrokkrError::Api(
ErrorResponse {
code: "agent_not_found".to_string(),
message: "agent not found".to_string(),
details: None,
},
reqwest::StatusCode::NOT_FOUND,
);
assert_eq!(err.code(), Some("agent_not_found"));
assert_eq!(err.status(), Some(reqwest::StatusCode::NOT_FOUND));
assert!(!err.is_retryable());
}
#[test]
fn retryable_classification() {
for status in [408u16, 429, 502, 503, 504] {
let err = BrokkrError::Api(
ErrorResponse {
code: "transient".to_string(),
message: "x".to_string(),
details: None,
},
reqwest::StatusCode::from_u16(status).unwrap(),
);
assert!(err.is_retryable(), "{status} should be retryable");
}
for status in [400u16, 401, 403, 404, 409, 422, 500, 501] {
let err = BrokkrError::Api(
ErrorResponse {
code: "non_transient".to_string(),
message: "x".to_string(),
details: None,
},
reqwest::StatusCode::from_u16(status).unwrap(),
);
assert!(!err.is_retryable(), "{status} should NOT be retryable");
}
}
#[tokio::test(start_paused = true)]
async fn retry_stops_after_max_attempts() {
let client = BrokkrClient::builder("http://localhost:3000/api/v1")
.max_retries(2)
.initial_backoff(Duration::from_millis(1))
.build()
.unwrap();
let calls = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
let calls_clone = calls.clone();
let result: Result<(), BrokkrError> = client
.retry(|_| {
let calls = calls_clone.clone();
async move {
calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Err(BrokkrError::Api(
ErrorResponse {
code: "transient".to_string(),
message: "service unavailable".to_string(),
details: None,
},
reqwest::StatusCode::SERVICE_UNAVAILABLE,
))
}
})
.await;
assert!(result.is_err());
assert_eq!(calls.load(std::sync::atomic::Ordering::SeqCst), 3);
}
#[tokio::test(start_paused = true)]
async fn retry_returns_immediately_on_non_retryable() {
let client = BrokkrClient::builder("http://localhost:3000/api/v1")
.max_retries(5)
.build()
.unwrap();
let calls = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
let calls_clone = calls.clone();
let result: Result<(), BrokkrError> = client
.retry(|_| {
let calls = calls_clone.clone();
async move {
calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Err(BrokkrError::Api(
ErrorResponse {
code: "agent_not_found".to_string(),
message: "x".to_string(),
details: None,
},
reqwest::StatusCode::NOT_FOUND,
))
}
})
.await;
assert!(result.is_err());
assert_eq!(calls.load(std::sync::atomic::Ordering::SeqCst), 1);
}
}