use crate::balancer::BalancingStrategy;
use crate::circuit_breaker::CircuitBreaker;
use crate::error::{NetError, NetResult};
use crate::pool::{ConnectionPool, ConnectionPoolBuilder, PoolConfig, PoolStats};
use crate::proto::aql::aql_service_client::AqlServiceClient;
use crate::proto::aql::{
BatchRequest, BatchResponse, HealthCheckRequest, HealthCheckResponse, QueryRequest,
QueryResponse,
};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tonic::codec::CompressionEncoding;
use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CompressionAlgorithm {
#[default]
Identity,
Gzip,
}
#[derive(Debug, Clone)]
pub struct CompressionConfig {
pub enabled: bool,
pub algorithm: CompressionAlgorithm,
}
impl Default for CompressionConfig {
fn default() -> Self {
Self {
enabled: false,
algorithm: CompressionAlgorithm::Identity,
}
}
}
impl CompressionConfig {
pub fn gzip() -> Self {
Self {
enabled: true,
algorithm: CompressionAlgorithm::Gzip,
}
}
fn to_tonic_encoding(&self) -> Option<CompressionEncoding> {
if !self.enabled {
return None;
}
match self.algorithm {
CompressionAlgorithm::Identity => None,
CompressionAlgorithm::Gzip => Some(CompressionEncoding::Gzip),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct TlsClientConfig {
pub ca_cert_path: Option<PathBuf>,
pub client_cert_path: Option<PathBuf>,
pub client_key_path: Option<PathBuf>,
pub domain_name: Option<String>,
pub skip_verification: bool,
}
impl TlsClientConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
self.ca_cert_path = Some(path.into());
self
}
pub fn with_client_identity(
mut self,
cert_path: impl Into<PathBuf>,
key_path: impl Into<PathBuf>,
) -> Self {
self.client_cert_path = Some(cert_path.into());
self.client_key_path = Some(key_path.into());
self
}
pub fn with_domain_name(mut self, domain: impl Into<String>) -> Self {
self.domain_name = Some(domain.into());
self
}
pub fn with_skip_verification(mut self, skip: bool) -> Self {
self.skip_verification = skip;
self
}
pub fn validate(&self) -> NetResult<()> {
match (&self.client_cert_path, &self.client_key_path) {
(Some(_), None) => {
return Err(NetError::TlsError(
"Client certificate specified without client key".to_string(),
));
}
(None, Some(_)) => {
return Err(NetError::TlsError(
"Client key specified without client certificate".to_string(),
));
}
_ => {}
}
Ok(())
}
pub fn build_tonic_tls_config(&self) -> NetResult<ClientTlsConfig> {
self.validate()?;
let mut tls_config = ClientTlsConfig::new();
if let Some(ref ca_path) = self.ca_cert_path {
let ca_pem = std::fs::read(ca_path).map_err(|e| {
NetError::TlsError(format!(
"Failed to read CA certificate file '{}': {}",
ca_path.display(),
e
))
})?;
let ca_cert = Certificate::from_pem(ca_pem);
tls_config = tls_config.ca_certificate(ca_cert);
}
if let Some(ref cert_path) = self.client_cert_path {
let key_path = self.client_key_path.as_ref().ok_or_else(|| {
NetError::TlsError("Client key path missing for mTLS".to_string())
})?;
let cert_pem = std::fs::read(cert_path).map_err(|e| {
NetError::TlsError(format!(
"Failed to read client certificate file '{}': {}",
cert_path.display(),
e
))
})?;
let key_pem = std::fs::read(key_path).map_err(|e| {
NetError::TlsError(format!(
"Failed to read client key file '{}': {}",
key_path.display(),
e
))
})?;
let identity = Identity::from_pem(cert_pem, key_pem);
tls_config = tls_config.identity(identity);
}
if let Some(ref domain) = self.domain_name {
tls_config = tls_config.domain_name(domain.clone());
}
Ok(tls_config)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum RetryPolicy {
Never,
OnError,
#[default]
OnTransient,
}
impl RetryPolicy {
pub fn should_retry(&self, error: &NetError) -> bool {
match self {
RetryPolicy::Never => false,
RetryPolicy::OnError => true,
RetryPolicy::OnTransient => error.is_retryable(),
}
}
}
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: usize,
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub backoff_multiplier: f64,
pub policy: RetryPolicy,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(10),
backoff_multiplier: 2.0,
policy: RetryPolicy::OnTransient,
}
}
}
impl RetryConfig {
pub fn backoff_duration(&self, attempt: usize) -> Duration {
let base_ms =
self.initial_backoff.as_millis() as f64 * self.backoff_multiplier.powi(attempt as i32);
let capped_ms = base_ms.min(self.max_backoff.as_millis() as f64);
let jitter_factor = ((attempt as f64 * 0.618033988) % 1.0) * 0.25;
let jittered_ms = capped_ms * (1.0 + jitter_factor);
let final_ms = jittered_ms.min(self.max_backoff.as_millis() as f64);
Duration::from_millis(final_ms as u64)
}
pub fn no_retry() -> Self {
Self {
max_retries: 0,
policy: RetryPolicy::Never,
..Default::default()
}
}
}
#[derive(Debug, Clone)]
pub struct ClientConfig {
pub connect_timeout: Duration,
pub request_timeout: Duration,
pub keep_alive: bool,
pub keep_alive_interval: Duration,
pub pool: PoolConfig,
pub retry: RetryConfig,
pub compression: CompressionConfig,
pub tls: Option<TlsClientConfig>,
}
impl Default for ClientConfig {
fn default() -> Self {
Self {
connect_timeout: Duration::from_secs(10),
request_timeout: Duration::from_secs(30),
keep_alive: true,
keep_alive_interval: Duration::from_secs(60),
pool: PoolConfig::default(),
retry: RetryConfig::default(),
compression: CompressionConfig::default(),
tls: None,
}
}
}
pub struct AqlClient {
pool: Arc<ConnectionPool>,
config: ClientConfig,
circuit_breaker: Option<CircuitBreaker>,
}
impl AqlClient {
pub fn new() -> Self {
Self::with_config(ClientConfig::default())
.expect("Default ClientConfig should always be valid")
}
pub fn with_config(config: ClientConfig) -> NetResult<Self> {
let cb = if config.pool.enable_circuit_breaker {
Some(CircuitBreaker::new())
} else {
None
};
let pool = if let Some(ref tls_cfg) = config.tls {
let tonic_tls = tls_cfg.build_tonic_tls_config()?;
ConnectionPool::with_tls(config.pool.clone(), tonic_tls)
} else {
ConnectionPool::new(config.pool.clone())
};
Ok(Self {
pool: Arc::new(pool),
config,
circuit_breaker: cb,
})
}
pub fn builder() -> AqlClientBuilder {
AqlClientBuilder::new()
}
pub fn add_endpoint(&self, id: String, address: String) {
self.pool.add_endpoint(id, address);
}
pub fn add_endpoint_with_weight(&self, id: String, address: String, weight: u32) {
self.pool.add_endpoint_with_weight(id, address, weight);
}
pub fn remove_endpoint(&self, endpoint_id: &str) -> bool {
self.pool.remove_endpoint(endpoint_id)
}
pub async fn get_service_client(&self) -> NetResult<AqlServiceClient<Channel>> {
let conn = self.pool.get_connection().await?;
let mut client = AqlServiceClient::new(conn.channel().clone());
#[cfg(feature = "compression")]
if let Some(encoding) = self.config.compression.to_tonic_encoding() {
client = client.send_compressed(encoding);
client = client.accept_compressed(encoding);
}
Ok(client)
}
pub async fn execute_query(&self, request: QueryRequest) -> NetResult<QueryResponse> {
self.execute_with_retry(|mut client| {
let req = request.clone();
Box::pin(async move {
client
.execute_query(req)
.await
.map(|resp| resp.into_inner())
.map_err(NetError::from)
})
})
.await
}
pub async fn execute_batch(&self, request: BatchRequest) -> NetResult<BatchResponse> {
self.execute_with_retry(|mut client| {
let req = request.clone();
Box::pin(async move {
client
.execute_batch(req)
.await
.map(|resp| resp.into_inner())
.map_err(NetError::from)
})
})
.await
}
pub async fn health_check(&self, service: Option<String>) -> NetResult<HealthCheckResponse> {
self.execute_with_retry(|mut client| {
let svc = service.clone();
Box::pin(async move {
let request = HealthCheckRequest { service: svc };
client
.health_check(request)
.await
.map(|resp| resp.into_inner())
.map_err(NetError::from)
})
})
.await
}
pub async fn execute_with_retry<F, T>(&self, operation: F) -> NetResult<T>
where
F: Fn(
AqlServiceClient<Channel>,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = NetResult<T>> + Send>>,
T: Send + 'static,
{
let retry_config = &self.config.retry;
let mut last_error: Option<NetError> = None;
for attempt in 0..=retry_config.max_retries {
if let Some(ref cb) = self.circuit_breaker {
cb.is_request_allowed()?;
}
let client = match self.get_service_client().await {
Ok(c) => c,
Err(e) => {
if let Some(ref cb) = self.circuit_breaker {
cb.record_failure();
}
if attempt < retry_config.max_retries && retry_config.policy.should_retry(&e) {
last_error = Some(e);
let backoff = retry_config.backoff_duration(attempt);
tokio::time::sleep(backoff).await;
continue;
}
return Err(e);
}
};
match operation(client).await {
Ok(result) => {
if let Some(ref cb) = self.circuit_breaker {
cb.record_success();
}
return Ok(result);
}
Err(e) => {
if let Some(ref cb) = self.circuit_breaker {
cb.record_failure();
}
if attempt < retry_config.max_retries && retry_config.policy.should_retry(&e) {
last_error = Some(e);
let backoff = retry_config.backoff_duration(attempt);
tokio::time::sleep(backoff).await;
continue;
}
return Err(e);
}
}
}
Err(last_error.unwrap_or_else(|| {
NetError::Unknown("Retry loop exhausted without producing a result".to_string())
}))
}
pub fn pool_stats(&self) -> PoolStats {
self.pool.stats()
}
pub fn circuit_breaker_stats(&self) -> Option<crate::circuit_breaker::CircuitBreakerStats> {
self.pool.circuit_breaker_stats()
}
pub fn retry_config(&self) -> &RetryConfig {
&self.config.retry
}
pub fn compression_config(&self) -> &CompressionConfig {
&self.config.compression
}
pub fn tls_config(&self) -> Option<&TlsClientConfig> {
self.config.tls.as_ref()
}
pub async fn drain(&self) -> NetResult<()> {
self.pool.drain().await
}
pub async fn shutdown(self) -> NetResult<()> {
Arc::try_unwrap(self.pool)
.map_err(|_| {
NetError::ServerInternal("Cannot shutdown: pool still has references".to_string())
})?
.shutdown()
.await
}
}
impl Default for AqlClient {
fn default() -> Self {
Self::new()
}
}
pub struct AqlClientBuilder {
config: ClientConfig,
pool_builder: ConnectionPoolBuilder,
circuit_breaker: Option<CircuitBreaker>,
tls_client_config: Option<TlsClientConfig>,
}
impl AqlClientBuilder {
pub fn new() -> Self {
Self {
config: ClientConfig::default(),
pool_builder: ConnectionPoolBuilder::new(),
circuit_breaker: None,
tls_client_config: None,
}
}
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.config.connect_timeout = timeout;
self.pool_builder = self.pool_builder.connect_timeout(timeout);
self
}
pub fn request_timeout(mut self, timeout: Duration) -> Self {
self.config.request_timeout = timeout;
self
}
pub fn keep_alive(mut self, enabled: bool) -> Self {
self.config.keep_alive = enabled;
self
}
pub fn keep_alive_interval(mut self, interval: Duration) -> Self {
self.config.keep_alive_interval = interval;
self
}
pub fn min_pool_size(mut self, size: usize) -> Self {
self.config.pool.min_size = size;
self.pool_builder = self.pool_builder.min_size(size);
self
}
pub fn max_pool_size(mut self, size: usize) -> Self {
self.config.pool.max_size = size;
self.pool_builder = self.pool_builder.max_size(size);
self
}
pub fn idle_timeout(mut self, timeout: Duration) -> Self {
self.config.pool.idle_timeout = timeout;
self.pool_builder = self.pool_builder.idle_timeout(timeout);
self
}
pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
self.config.pool.max_lifetime = lifetime;
self.pool_builder = self.pool_builder.max_lifetime(lifetime);
self
}
pub fn health_check_interval(mut self, interval: Duration) -> Self {
self.config.pool.health_check_interval = interval;
self.pool_builder = self.pool_builder.health_check_interval(interval);
self
}
pub fn balancing_strategy(mut self, strategy: BalancingStrategy) -> Self {
self.config.pool.balancing_strategy = strategy;
self.pool_builder = self.pool_builder.balancing_strategy(strategy);
self
}
pub fn circuit_breaker(mut self, enabled: bool) -> Self {
self.config.pool.enable_circuit_breaker = enabled;
self.pool_builder = self.pool_builder.circuit_breaker(enabled);
if enabled {
self.circuit_breaker = Some(CircuitBreaker::new());
} else {
self.circuit_breaker = None;
}
self
}
pub fn with_retry(mut self, retry_config: RetryConfig) -> Self {
self.config.retry = retry_config;
self
}
pub fn with_compression(mut self, compression_config: CompressionConfig) -> Self {
self.config.compression = compression_config;
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.config.request_timeout = timeout;
self
}
pub fn with_tls_config(mut self, tls_config: TlsClientConfig) -> Self {
self.tls_client_config = Some(tls_config);
self
}
pub fn with_ca_cert(mut self, ca_cert_path: impl Into<PathBuf>) -> Self {
let config = self
.tls_client_config
.take()
.unwrap_or_default()
.with_ca_cert(ca_cert_path);
self.tls_client_config = Some(config);
self
}
pub fn with_mtls(
mut self,
cert_path: impl Into<PathBuf>,
key_path: impl Into<PathBuf>,
) -> Self {
let config = self
.tls_client_config
.take()
.unwrap_or_default()
.with_client_identity(cert_path, key_path);
self.tls_client_config = Some(config);
self
}
pub fn with_tls_domain(mut self, domain: impl Into<String>) -> Self {
let config = self
.tls_client_config
.take()
.unwrap_or_default()
.with_domain_name(domain);
self.tls_client_config = Some(config);
self
}
pub fn with_tls_skip_verification(mut self) -> Self {
let config = self
.tls_client_config
.take()
.unwrap_or_default()
.with_skip_verification(true);
self.tls_client_config = Some(config);
self
}
pub fn add_endpoint(mut self, id: String, address: String) -> Self {
self.pool_builder = self.pool_builder.add_endpoint(id, address);
self
}
pub fn add_endpoint_with_weight(mut self, id: String, address: String, weight: u32) -> Self {
self.pool_builder = self
.pool_builder
.add_endpoint_with_weight(id, address, weight);
self
}
pub fn build(self) -> NetResult<AqlClient> {
let pool_builder = if let Some(ref tls_cfg) = self.tls_client_config {
let tonic_tls = tls_cfg.build_tonic_tls_config()?;
self.pool_builder.tls_config(tonic_tls)
} else {
self.pool_builder
};
let pool = pool_builder.build();
let cb = if self.config.pool.enable_circuit_breaker {
self.circuit_breaker.or_else(|| Some(CircuitBreaker::new()))
} else {
None
};
let mut config = self.config;
config.tls = self.tls_client_config;
Ok(AqlClient {
pool: Arc::new(pool),
config,
circuit_breaker: cb,
})
}
}
impl Default for AqlClientBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_config_default() {
let config = ClientConfig::default();
assert_eq!(config.connect_timeout, Duration::from_secs(10));
assert_eq!(config.request_timeout, Duration::from_secs(30));
assert!(config.keep_alive);
assert_eq!(config.retry.max_retries, 3);
assert_eq!(config.retry.initial_backoff, Duration::from_millis(100));
assert_eq!(config.retry.max_backoff, Duration::from_secs(10));
assert!((config.retry.backoff_multiplier - 2.0).abs() < f64::EPSILON);
assert!(!config.compression.enabled);
assert!(config.tls.is_none());
}
#[tokio::test]
async fn test_client_creation() {
let config = ClientConfig::default();
let _client = AqlClient::with_config(config).expect("default config should be valid");
}
#[tokio::test]
async fn test_client_builder() {
let client = AqlClient::builder()
.connect_timeout(Duration::from_secs(5))
.request_timeout(Duration::from_secs(15))
.min_pool_size(3)
.max_pool_size(15)
.balancing_strategy(BalancingStrategy::RoundRobin)
.add_endpoint("ep1".to_string(), "localhost:50051".to_string())
.add_endpoint("ep2".to_string(), "localhost:50052".to_string())
.build()
.expect("builder should succeed without TLS");
let stats = client.pool_stats();
assert_eq!(stats.active_connections, 0);
}
#[tokio::test]
async fn test_client_add_remove_endpoint() {
let client = AqlClient::new();
client.add_endpoint("ep1".to_string(), "localhost:50051".to_string());
client.add_endpoint("ep2".to_string(), "localhost:50052".to_string());
assert!(client.remove_endpoint("ep1"));
assert!(!client.remove_endpoint("ep3"));
}
#[tokio::test]
async fn test_client_pool_stats() {
let client = AqlClient::builder()
.add_endpoint("ep1".to_string(), "localhost:50051".to_string())
.build()
.expect("builder should succeed");
let stats = client.pool_stats();
assert_eq!(stats.total_connections, 0);
}
#[tokio::test]
async fn test_client_drain() {
let client = AqlClient::builder()
.add_endpoint("ep1".to_string(), "localhost:50051".to_string())
.build()
.expect("builder should succeed");
let result = client.drain().await;
assert!(result.is_ok());
}
#[test]
fn test_retry_config_defaults() {
let config = RetryConfig::default();
assert_eq!(config.max_retries, 3);
assert_eq!(config.initial_backoff, Duration::from_millis(100));
assert_eq!(config.max_backoff, Duration::from_secs(10));
assert!((config.backoff_multiplier - 2.0).abs() < f64::EPSILON);
assert_eq!(config.policy, RetryPolicy::OnTransient);
}
#[test]
fn test_retry_config_no_retry() {
let config = RetryConfig::no_retry();
assert_eq!(config.max_retries, 0);
assert_eq!(config.policy, RetryPolicy::Never);
}
#[test]
fn test_retry_config_custom() {
let config = RetryConfig {
max_retries: 5,
initial_backoff: Duration::from_millis(200),
max_backoff: Duration::from_secs(30),
backoff_multiplier: 3.0,
policy: RetryPolicy::OnError,
};
assert_eq!(config.max_retries, 5);
assert_eq!(config.initial_backoff, Duration::from_millis(200));
assert_eq!(config.max_backoff, Duration::from_secs(30));
assert!((config.backoff_multiplier - 3.0).abs() < f64::EPSILON);
assert_eq!(config.policy, RetryPolicy::OnError);
}
#[test]
fn test_backoff_duration_exponential_growth() {
let config = RetryConfig {
initial_backoff: Duration::from_millis(100),
backoff_multiplier: 2.0,
max_backoff: Duration::from_secs(60),
..Default::default()
};
let d0 = config.backoff_duration(0);
let d1 = config.backoff_duration(1);
let d2 = config.backoff_duration(2);
assert!(d0.as_millis() >= 100, "d0 should be >= 100ms, got {d0:?}");
assert!(
d0.as_millis() <= 130,
"d0 should be <= 130ms (jitter), got {d0:?}"
);
assert!(d1.as_millis() >= 200, "d1 should be >= 200ms, got {d1:?}");
assert!(
d1.as_millis() <= 260,
"d1 should be <= 260ms (jitter), got {d1:?}"
);
assert!(d2.as_millis() >= 400, "d2 should be >= 400ms, got {d2:?}");
assert!(
d2.as_millis() <= 520,
"d2 should be <= 520ms (jitter), got {d2:?}"
);
}
#[test]
fn test_backoff_duration_capped_at_max() {
let config = RetryConfig {
initial_backoff: Duration::from_secs(1),
backoff_multiplier: 10.0,
max_backoff: Duration::from_secs(5),
..Default::default()
};
let d0 = config.backoff_duration(0);
assert!(d0.as_millis() >= 1000);
assert!(d0.as_millis() <= 1300);
let d2 = config.backoff_duration(2);
assert!(
d2.as_millis() <= 5000,
"Should be capped at max_backoff, got {d2:?}"
);
}
#[test]
fn test_backoff_duration_with_multiplier_one() {
let config = RetryConfig {
initial_backoff: Duration::from_millis(500),
backoff_multiplier: 1.0,
max_backoff: Duration::from_secs(60),
..Default::default()
};
let d0 = config.backoff_duration(0);
let d1 = config.backoff_duration(1);
let d2 = config.backoff_duration(2);
assert!(d0.as_millis() >= 500 && d0.as_millis() <= 650);
assert!(d1.as_millis() >= 500 && d1.as_millis() <= 650);
assert!(d2.as_millis() >= 500 && d2.as_millis() <= 650);
}
#[test]
fn test_retry_policy_never() {
let policy = RetryPolicy::Never;
assert!(!policy.should_retry(&NetError::Timeout("test".to_string())));
assert!(!policy.should_retry(&NetError::ServerUnavailable("test".to_string())));
assert!(!policy.should_retry(&NetError::InvalidRequest("test".to_string())));
}
#[test]
fn test_retry_policy_on_error() {
let policy = RetryPolicy::OnError;
assert!(policy.should_retry(&NetError::Timeout("test".to_string())));
assert!(policy.should_retry(&NetError::ServerUnavailable("test".to_string())));
assert!(policy.should_retry(&NetError::InvalidRequest("test".to_string())));
assert!(policy.should_retry(&NetError::AuthFailed("test".to_string())));
}
#[test]
fn test_retry_policy_on_transient() {
let policy = RetryPolicy::OnTransient;
assert!(policy.should_retry(&NetError::Timeout("test".to_string())));
assert!(policy.should_retry(&NetError::ConnectionRefused("test".to_string())));
assert!(policy.should_retry(&NetError::ConnectionReset("test".to_string())));
assert!(policy.should_retry(&NetError::ServerUnavailable("test".to_string())));
assert!(policy.should_retry(&NetError::ServerOverloaded("test".to_string())));
assert!(!policy.should_retry(&NetError::InvalidRequest("test".to_string())));
assert!(!policy.should_retry(&NetError::AuthFailed("test".to_string())));
assert!(!policy.should_retry(&NetError::InsufficientPermissions("test".to_string())));
assert!(!policy.should_retry(&NetError::MalformedMessage("test".to_string())));
assert!(!policy.should_retry(&NetError::ServerInternal("test".to_string())));
}
#[test]
fn test_retry_policy_default_is_on_transient() {
let policy = RetryPolicy::default();
assert_eq!(policy, RetryPolicy::OnTransient);
}
#[test]
fn test_compression_config_default() {
let config = CompressionConfig::default();
assert!(!config.enabled);
assert_eq!(config.algorithm, CompressionAlgorithm::Identity);
assert!(config.to_tonic_encoding().is_none());
}
#[test]
fn test_compression_config_gzip() {
let config = CompressionConfig::gzip();
assert!(config.enabled);
assert_eq!(config.algorithm, CompressionAlgorithm::Gzip);
assert!(config.to_tonic_encoding().is_some());
}
#[test]
fn test_compression_identity_returns_none() {
let config = CompressionConfig {
enabled: true,
algorithm: CompressionAlgorithm::Identity,
};
assert!(config.to_tonic_encoding().is_none());
}
#[test]
fn test_compression_disabled_returns_none() {
let config = CompressionConfig {
enabled: false,
algorithm: CompressionAlgorithm::Gzip,
};
assert!(config.to_tonic_encoding().is_none());
}
#[tokio::test]
async fn test_builder_with_retry() {
let retry = RetryConfig {
max_retries: 5,
initial_backoff: Duration::from_millis(50),
max_backoff: Duration::from_secs(5),
backoff_multiplier: 1.5,
policy: RetryPolicy::OnError,
};
let client = AqlClient::builder()
.with_retry(retry)
.add_endpoint("ep1".to_string(), "localhost:50051".to_string())
.build()
.expect("builder should succeed");
assert_eq!(client.retry_config().max_retries, 5);
assert_eq!(
client.retry_config().initial_backoff,
Duration::from_millis(50)
);
assert_eq!(client.retry_config().policy, RetryPolicy::OnError);
}
#[tokio::test]
async fn test_builder_with_compression() {
let client = AqlClient::builder()
.with_compression(CompressionConfig::gzip())
.add_endpoint("ep1".to_string(), "localhost:50051".to_string())
.build()
.expect("builder should succeed");
assert!(client.compression_config().enabled);
assert_eq!(
client.compression_config().algorithm,
CompressionAlgorithm::Gzip
);
}
#[tokio::test]
async fn test_builder_with_timeout() {
let client = AqlClient::builder()
.with_timeout(Duration::from_secs(60))
.add_endpoint("ep1".to_string(), "localhost:50051".to_string())
.build()
.expect("builder should succeed");
assert_eq!(client.config.request_timeout, Duration::from_secs(60));
}
#[tokio::test]
async fn test_builder_full_chain() {
let client = AqlClient::builder()
.connect_timeout(Duration::from_secs(5))
.request_timeout(Duration::from_secs(15))
.keep_alive(true)
.keep_alive_interval(Duration::from_secs(30))
.min_pool_size(2)
.max_pool_size(20)
.idle_timeout(Duration::from_secs(120))
.max_lifetime(Duration::from_secs(600))
.health_check_interval(Duration::from_secs(10))
.balancing_strategy(BalancingStrategy::RoundRobin)
.circuit_breaker(true)
.with_retry(RetryConfig {
max_retries: 5,
policy: RetryPolicy::OnTransient,
..Default::default()
})
.with_compression(CompressionConfig::gzip())
.with_timeout(Duration::from_secs(20))
.add_endpoint("ep1".to_string(), "localhost:50051".to_string())
.add_endpoint_with_weight("ep2".to_string(), "localhost:50052".to_string(), 3)
.build()
.expect("builder should succeed");
assert_eq!(client.config.connect_timeout, Duration::from_secs(5));
assert_eq!(client.config.request_timeout, Duration::from_secs(20));
assert!(client.config.keep_alive);
assert_eq!(client.retry_config().max_retries, 5);
assert!(client.compression_config().enabled);
}
#[tokio::test]
async fn test_circuit_breaker_blocks_retries() {
use crate::circuit_breaker::CircuitBreakerConfig;
let cb_config = CircuitBreakerConfig {
failure_threshold: 2,
..Default::default()
};
let cb = CircuitBreaker::with_config(cb_config);
cb.is_request_allowed().ok();
cb.record_failure();
cb.is_request_allowed().ok();
cb.record_failure();
assert_eq!(cb.state(), crate::circuit_breaker::CircuitState::Open);
assert!(cb.is_request_allowed().is_err());
}
#[tokio::test]
async fn test_circuit_breaker_enabled_in_builder() {
let client = AqlClient::builder()
.circuit_breaker(true)
.add_endpoint("ep1".to_string(), "localhost:50051".to_string())
.build()
.expect("builder should succeed");
assert!(client.circuit_breaker.is_some());
}
#[tokio::test]
async fn test_circuit_breaker_disabled_in_builder() {
let client = AqlClient::builder()
.circuit_breaker(false)
.add_endpoint("ep1".to_string(), "localhost:50051".to_string())
.build()
.expect("builder should succeed");
assert!(client.circuit_breaker.is_none());
}
#[tokio::test]
async fn test_default_client_has_circuit_breaker() {
let client = AqlClient::new();
assert!(client.circuit_breaker.is_some());
}
#[test]
fn test_compression_algorithm_default() {
let algo = CompressionAlgorithm::default();
assert_eq!(algo, CompressionAlgorithm::Identity);
}
#[test]
fn test_compression_algorithm_variants() {
assert_ne!(CompressionAlgorithm::Gzip, CompressionAlgorithm::Identity);
}
#[tokio::test]
async fn test_builder_default() {
let builder = AqlClientBuilder::default();
let client = builder.build().expect("default builder should succeed");
assert_eq!(client.config.connect_timeout, Duration::from_secs(10));
assert_eq!(client.config.request_timeout, Duration::from_secs(30));
}
#[tokio::test]
async fn test_client_default() {
let client = AqlClient::default();
assert_eq!(client.config.connect_timeout, Duration::from_secs(10));
}
#[test]
fn test_tls_config_default() {
let config = TlsClientConfig::default();
assert!(config.ca_cert_path.is_none());
assert!(config.client_cert_path.is_none());
assert!(config.client_key_path.is_none());
assert!(config.domain_name.is_none());
assert!(!config.skip_verification);
}
#[test]
fn test_tls_config_with_ca_cert() {
let config = TlsClientConfig::new().with_ca_cert("/tmp/ca.pem");
assert_eq!(config.ca_cert_path, Some(PathBuf::from("/tmp/ca.pem")));
assert!(config.client_cert_path.is_none());
assert!(config.client_key_path.is_none());
}
#[test]
fn test_tls_config_with_mtls() {
let config =
TlsClientConfig::new().with_client_identity("/tmp/client.pem", "/tmp/client.key");
assert!(config.ca_cert_path.is_none());
assert_eq!(
config.client_cert_path,
Some(PathBuf::from("/tmp/client.pem"))
);
assert_eq!(
config.client_key_path,
Some(PathBuf::from("/tmp/client.key"))
);
}
#[test]
fn test_tls_config_missing_key() {
let config = TlsClientConfig {
client_cert_path: Some(PathBuf::from("/tmp/client.pem")),
client_key_path: None,
..Default::default()
};
let result = config.validate();
assert!(result.is_err());
let err = result.expect_err("should be TlsError");
assert!(
err.to_string().contains("without client key"),
"Error should mention missing key: {}",
err
);
}
#[test]
fn test_tls_config_missing_cert() {
let config = TlsClientConfig {
client_cert_path: None,
client_key_path: Some(PathBuf::from("/tmp/client.key")),
..Default::default()
};
let result = config.validate();
assert!(result.is_err());
let err = result.expect_err("should be TlsError");
assert!(
err.to_string().contains("without client certificate"),
"Error should mention missing cert: {}",
err
);
}
#[tokio::test]
async fn test_builder_with_ca_cert() {
let builder = AqlClient::builder()
.with_ca_cert("/tmp/test-ca.pem")
.add_endpoint("ep1".to_string(), "localhost:50051".to_string());
let tls_cfg = builder
.tls_client_config
.as_ref()
.expect("TLS config should be set");
assert_eq!(
tls_cfg.ca_cert_path,
Some(PathBuf::from("/tmp/test-ca.pem"))
);
}
#[tokio::test]
async fn test_builder_with_mtls() {
let builder = AqlClient::builder()
.with_mtls("/tmp/client.pem", "/tmp/client.key")
.add_endpoint("ep1".to_string(), "localhost:50051".to_string());
let tls_cfg = builder
.tls_client_config
.as_ref()
.expect("TLS config should be set");
assert_eq!(
tls_cfg.client_cert_path,
Some(PathBuf::from("/tmp/client.pem"))
);
assert_eq!(
tls_cfg.client_key_path,
Some(PathBuf::from("/tmp/client.key"))
);
}
#[tokio::test]
async fn test_builder_with_full_config() {
let tls_config = TlsClientConfig::new()
.with_ca_cert("/tmp/ca.pem")
.with_client_identity("/tmp/client.pem", "/tmp/client.key")
.with_domain_name("example.com");
let builder = AqlClient::builder()
.with_tls_config(tls_config)
.add_endpoint("ep1".to_string(), "localhost:50051".to_string());
let tls_cfg = builder
.tls_client_config
.as_ref()
.expect("TLS config should be set");
assert_eq!(tls_cfg.ca_cert_path, Some(PathBuf::from("/tmp/ca.pem")));
assert_eq!(
tls_cfg.client_cert_path,
Some(PathBuf::from("/tmp/client.pem"))
);
assert_eq!(
tls_cfg.client_key_path,
Some(PathBuf::from("/tmp/client.key"))
);
assert_eq!(tls_cfg.domain_name, Some("example.com".to_string()));
}
#[test]
fn test_tls_config_domain_name() {
let config = TlsClientConfig::new().with_domain_name("my.server.com");
assert_eq!(config.domain_name, Some("my.server.com".to_string()));
}
#[test]
fn test_tls_config_skip_verification() {
let config = TlsClientConfig::new().with_skip_verification(true);
assert!(config.skip_verification);
let config2 = TlsClientConfig::new().with_skip_verification(false);
assert!(!config2.skip_verification);
}
#[test]
fn test_tls_integration_invalid_cert() {
let tmp_dir = std::env::temp_dir();
let fake_cert_path = tmp_dir.join("nonexistent_test_cert_amaters.pem");
let config = TlsClientConfig::new().with_ca_cert(&fake_cert_path);
assert!(config.validate().is_ok());
let result = config.build_tonic_tls_config();
assert!(result.is_err());
let err = result.expect_err("should fail for missing file");
assert!(
err.to_string().contains("Failed to read CA certificate"),
"Error should mention CA cert read failure: {}",
err
);
}
#[test]
fn test_tls_config_validate_valid_configs() {
assert!(TlsClientConfig::default().validate().is_ok());
assert!(
TlsClientConfig::new()
.with_ca_cert("/tmp/ca.pem")
.validate()
.is_ok()
);
assert!(
TlsClientConfig::new()
.with_ca_cert("/tmp/ca.pem")
.with_client_identity("/tmp/client.pem", "/tmp/client.key")
.validate()
.is_ok()
);
}
#[tokio::test]
async fn test_builder_with_tls_domain() {
let builder = AqlClient::builder()
.with_tls_domain("custom.domain.io")
.add_endpoint("ep1".to_string(), "localhost:50051".to_string());
let tls_cfg = builder
.tls_client_config
.as_ref()
.expect("TLS config should be set");
assert_eq!(tls_cfg.domain_name, Some("custom.domain.io".to_string()));
}
#[tokio::test]
async fn test_builder_with_tls_skip_verification() {
let builder = AqlClient::builder()
.with_tls_skip_verification()
.add_endpoint("ep1".to_string(), "localhost:50051".to_string());
let tls_cfg = builder
.tls_client_config
.as_ref()
.expect("TLS config should be set");
assert!(tls_cfg.skip_verification);
}
#[test]
fn test_tls_config_chaining() {
let config = TlsClientConfig::new()
.with_ca_cert("/tmp/ca.pem")
.with_client_identity("/tmp/cert.pem", "/tmp/key.pem")
.with_domain_name("example.com")
.with_skip_verification(false);
assert_eq!(config.ca_cert_path, Some(PathBuf::from("/tmp/ca.pem")));
assert_eq!(
config.client_cert_path,
Some(PathBuf::from("/tmp/cert.pem"))
);
assert_eq!(config.client_key_path, Some(PathBuf::from("/tmp/key.pem")));
assert_eq!(config.domain_name, Some("example.com".to_string()));
assert!(!config.skip_verification);
assert!(config.validate().is_ok());
}
}