use crate::client::backend::{BackendError, BackendType};
use crate::client::config::{ClientConfig, ConnectionConfig, HealthConfig, PoolConfig};
use std::time::Duration;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum BuilderError {
#[error("connection URL is required")]
MissingUrl,
}
pub struct AthenaClientBuilder {
backend: BackendType,
url: Option<String>,
key: Option<String>,
client_name: Option<String>,
ssl: Option<bool>,
port: Option<u16>,
database: Option<String>,
pool: PoolConfig,
health: HealthConfig,
}
impl Default for AthenaClientBuilder {
fn default() -> Self {
Self::new()
}
}
impl AthenaClientBuilder {
pub fn new() -> Self {
Self {
backend: BackendType::Native,
url: None,
key: None,
client_name: None,
ssl: None,
port: None,
database: None,
pool: PoolConfig::default(),
health: HealthConfig::default(),
}
}
pub fn backend(mut self, backend: BackendType) -> Self {
self.backend = backend;
self
}
pub fn url(mut self, url: impl Into<String>) -> Self {
self.url = Some(url.into());
self
}
pub fn key(mut self, key: impl Into<String>) -> Self {
self.key = Some(key.into());
self
}
pub fn client(mut self, client_name: impl Into<String>) -> Self {
self.client_name = Some(client_name.into());
self
}
pub fn ssl(mut self, enabled: bool) -> Self {
self.ssl = Some(enabled);
self
}
pub fn port(mut self, port: u16) -> Self {
self.port = Some(port);
self
}
pub fn database(mut self, database: impl Into<String>) -> Self {
self.database = Some(database.into());
self
}
pub fn max_connections(mut self, max: u32) -> Self {
self.pool.max_connections = max;
self
}
pub fn min_connections(mut self, min: u32) -> Self {
self.pool.min_connections = min;
self
}
pub fn connection_timeout(mut self, timeout: Duration) -> Self {
self.pool.connection_timeout = timeout;
self
}
pub fn idle_timeout(mut self, timeout: Duration) -> Self {
self.pool.idle_timeout = timeout;
self
}
pub fn health_tracking(mut self, enabled: bool) -> Self {
self.health.enabled = enabled;
self
}
pub fn circuit_breaker_threshold(mut self, threshold: u32) -> Self {
self.health.circuit_breaker_threshold = threshold;
self
}
pub fn circuit_breaker_timeout(mut self, timeout: Duration) -> Self {
self.health.circuit_breaker_timeout = timeout;
self
}
pub fn health_check_interval(mut self, interval: Duration) -> Self {
self.health.check_interval = interval;
self
}
pub fn build_config(self) -> Result<ClientConfig, BuilderError> {
let url: String = self.url.ok_or(BuilderError::MissingUrl)?;
let mut connection: ConnectionConfig = ConnectionConfig::new(url);
if let Some(key) = self.key {
connection.key = Some(key);
}
if let Some(ssl) = self.ssl {
connection.ssl = ssl;
}
connection.port = self.port;
connection.database = self.database;
Ok(ClientConfig::new(
self.backend,
self.client_name,
connection,
self.pool,
self.health,
))
}
}
impl From<BuilderError> for BackendError {
fn from(value: BuilderError) -> Self {
BackendError::Generic(value.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_defaults_backend_to_native() {
let config = AthenaClientBuilder::new()
.url("http://localhost:4052")
.build_config()
.expect("builder should succeed with URL");
assert_eq!(config.backend_type, BackendType::Native);
}
#[test]
fn builder_persists_client_name_when_set() {
let config = AthenaClientBuilder::new()
.url("http://localhost:4052")
.client("reporting")
.build_config()
.expect("builder should succeed with URL");
assert_eq!(config.client_name.as_deref(), Some("reporting"));
}
#[test]
fn builder_requires_url() {
let err = AthenaClientBuilder::new()
.build_config()
.expect_err("missing URL should fail");
assert!(matches!(err, BuilderError::MissingUrl));
}
}