Skip to main content

rivven_client/
resilient.rs

1//! Production-grade resilient client with connection pooling, retries, and circuit breaker
2//!
3//! # Features
4//!
5//! - **Connection pooling**: Efficiently reuse connections across requests
6//! - **Automatic retries**: Exponential backoff with jitter for transient failures
7//! - **Circuit breaker**: Prevent cascading failures to unhealthy servers
8//! - **Health checking**: Background health monitoring and connection validation
9//! - **Timeouts**: Request and connection timeouts with configurable values
10//!
11//! # Example
12//!
13//! ```rust,ignore
14//! use rivven_client::{ResilientClient, ResilientClientConfig};
15//!
16//! let config = ResilientClientConfig::builder()
17//!     .bootstrap_servers(vec!["localhost:9092".to_string()])
18//!     .pool_size(10)
19//!     .retry_max_attempts(3)
20//!     .circuit_breaker_threshold(5)
21//!     .build();
22//!
23//! let client = ResilientClient::new(config).await?;
24//!
25//! // Auto-retry on transient failures
26//! let offset = client.publish("my-topic", b"hello").await?;
27//! ```
28
29use crate::{Client, Error, MessageData, Result};
30use bytes::Bytes;
31use std::collections::HashMap;
32use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
33use std::sync::Arc;
34use std::time::{Duration, Instant};
35use tokio::sync::{Mutex, RwLock, Semaphore};
36use tokio::time::{sleep, timeout};
37use tracing::{debug, info, warn};
38
39// ============================================================================
40// Configuration
41// ============================================================================
42
43/// Configuration for the resilient client
44#[derive(Debug, Clone)]
45pub struct ResilientClientConfig {
46    /// Bootstrap servers (host:port)
47    pub bootstrap_servers: Vec<String>,
48    /// Connection pool size per server
49    pub pool_size: usize,
50    /// Maximum retry attempts
51    pub retry_max_attempts: u32,
52    /// Initial retry delay
53    pub retry_initial_delay: Duration,
54    /// Maximum retry delay
55    pub retry_max_delay: Duration,
56    /// Retry backoff multiplier
57    pub retry_multiplier: f64,
58    /// Circuit breaker failure threshold
59    pub circuit_breaker_threshold: u32,
60    /// Circuit breaker recovery timeout
61    pub circuit_breaker_timeout: Duration,
62    /// Circuit breaker half-open success threshold
63    pub circuit_breaker_success_threshold: u32,
64    /// Connection timeout
65    pub connection_timeout: Duration,
66    /// Request timeout
67    pub request_timeout: Duration,
68    /// Health check interval
69    pub health_check_interval: Duration,
70    /// Enable automatic health checking
71    pub health_check_enabled: bool,
72}
73
74impl Default for ResilientClientConfig {
75    fn default() -> Self {
76        Self {
77            bootstrap_servers: vec!["localhost:9092".to_string()],
78            pool_size: 5,
79            retry_max_attempts: 3,
80            retry_initial_delay: Duration::from_millis(100),
81            retry_max_delay: Duration::from_secs(10),
82            retry_multiplier: 2.0,
83            circuit_breaker_threshold: 5,
84            circuit_breaker_timeout: Duration::from_secs(30),
85            circuit_breaker_success_threshold: 2,
86            connection_timeout: Duration::from_secs(10),
87            request_timeout: Duration::from_secs(30),
88            health_check_interval: Duration::from_secs(30),
89            health_check_enabled: true,
90        }
91    }
92}
93
94impl ResilientClientConfig {
95    /// Create a new builder
96    pub fn builder() -> ResilientClientConfigBuilder {
97        ResilientClientConfigBuilder::default()
98    }
99}
100
101/// Builder for ResilientClientConfig
102#[derive(Default)]
103pub struct ResilientClientConfigBuilder {
104    config: ResilientClientConfig,
105}
106
107impl ResilientClientConfigBuilder {
108    /// Set bootstrap servers
109    pub fn bootstrap_servers(mut self, servers: Vec<String>) -> Self {
110        self.config.bootstrap_servers = servers;
111        self
112    }
113
114    /// Set pool size per server
115    pub fn pool_size(mut self, size: usize) -> Self {
116        self.config.pool_size = size;
117        self
118    }
119
120    /// Set maximum retry attempts
121    pub fn retry_max_attempts(mut self, attempts: u32) -> Self {
122        self.config.retry_max_attempts = attempts;
123        self
124    }
125
126    /// Set initial retry delay
127    pub fn retry_initial_delay(mut self, delay: Duration) -> Self {
128        self.config.retry_initial_delay = delay;
129        self
130    }
131
132    /// Set maximum retry delay
133    pub fn retry_max_delay(mut self, delay: Duration) -> Self {
134        self.config.retry_max_delay = delay;
135        self
136    }
137
138    /// Set retry backoff multiplier
139    pub fn retry_multiplier(mut self, multiplier: f64) -> Self {
140        self.config.retry_multiplier = multiplier;
141        self
142    }
143
144    /// Set circuit breaker failure threshold
145    pub fn circuit_breaker_threshold(mut self, threshold: u32) -> Self {
146        self.config.circuit_breaker_threshold = threshold;
147        self
148    }
149
150    /// Set circuit breaker recovery timeout
151    pub fn circuit_breaker_timeout(mut self, timeout: Duration) -> Self {
152        self.config.circuit_breaker_timeout = timeout;
153        self
154    }
155
156    /// Set connection timeout
157    pub fn connection_timeout(mut self, timeout: Duration) -> Self {
158        self.config.connection_timeout = timeout;
159        self
160    }
161
162    /// Set request timeout
163    pub fn request_timeout(mut self, timeout: Duration) -> Self {
164        self.config.request_timeout = timeout;
165        self
166    }
167
168    /// Enable or disable health checking
169    pub fn health_check_enabled(mut self, enabled: bool) -> Self {
170        self.config.health_check_enabled = enabled;
171        self
172    }
173
174    /// Set health check interval
175    pub fn health_check_interval(mut self, interval: Duration) -> Self {
176        self.config.health_check_interval = interval;
177        self
178    }
179
180    /// Build the configuration
181    pub fn build(self) -> ResilientClientConfig {
182        self.config
183    }
184}
185
186// ============================================================================
187// Circuit Breaker
188// ============================================================================
189
190/// Circuit breaker states
191#[derive(Debug, Clone, Copy, PartialEq, Eq)]
192pub enum CircuitState {
193    Closed,
194    Open,
195    HalfOpen,
196}
197
198/// Circuit breaker for a single server
199struct CircuitBreaker {
200    state: AtomicU32,
201    failure_count: AtomicU32,
202    success_count: AtomicU32,
203    last_failure: RwLock<Option<Instant>>,
204    config: Arc<ResilientClientConfig>,
205}
206
207impl CircuitBreaker {
208    fn new(config: Arc<ResilientClientConfig>) -> Self {
209        Self {
210            state: AtomicU32::new(0), // Closed
211            failure_count: AtomicU32::new(0),
212            success_count: AtomicU32::new(0),
213            last_failure: RwLock::new(None),
214            config,
215        }
216    }
217
218    fn get_state(&self) -> CircuitState {
219        match self.state.load(Ordering::SeqCst) {
220            0 => CircuitState::Closed,
221            1 => CircuitState::Open,
222            _ => CircuitState::HalfOpen,
223        }
224    }
225
226    async fn allow_request(&self) -> bool {
227        match self.get_state() {
228            CircuitState::Closed => true,
229            CircuitState::Open => {
230                let last_failure = self.last_failure.read().await;
231                if let Some(t) = *last_failure {
232                    if t.elapsed() > self.config.circuit_breaker_timeout {
233                        self.state.store(2, Ordering::SeqCst); // HalfOpen
234                        self.success_count.store(0, Ordering::SeqCst);
235                        return true;
236                    }
237                }
238                false
239            }
240            CircuitState::HalfOpen => true,
241        }
242    }
243
244    async fn record_success(&self) {
245        self.failure_count.store(0, Ordering::SeqCst);
246
247        if self.get_state() == CircuitState::HalfOpen {
248            let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
249            if count >= self.config.circuit_breaker_success_threshold {
250                self.state.store(0, Ordering::SeqCst); // Closed
251                debug!("Circuit breaker closed after {} successes", count);
252            }
253        }
254    }
255
256    async fn record_failure(&self) {
257        let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
258        *self.last_failure.write().await = Some(Instant::now());
259
260        if count >= self.config.circuit_breaker_threshold {
261            self.state.store(1, Ordering::SeqCst); // Open
262            warn!("Circuit breaker opened after {} failures", count);
263        }
264    }
265}
266
267// ============================================================================
268// Connection Pool
269// ============================================================================
270
271/// Pooled connection wrapper.
272///
273/// Holds a semaphore permit from the owning `ConnectionPool`. When this struct
274/// is dropped (e.g. on request timeout), the permit is automatically released,
275/// preventing permanent pool slot exhaustion.
276struct PooledConnection {
277    client: Client,
278    created_at: Instant,
279    last_used: Instant,
280    /// Semaphore permit — released on drop to return the pool slot.
281    _permit: tokio::sync::OwnedSemaphorePermit,
282}
283
284/// Connection pool for a single server
285struct ConnectionPool {
286    addr: String,
287    connections: Mutex<Vec<PooledConnection>>,
288    semaphore: Arc<Semaphore>,
289    config: Arc<ResilientClientConfig>,
290    circuit_breaker: CircuitBreaker,
291}
292
293impl ConnectionPool {
294    fn new(addr: String, config: Arc<ResilientClientConfig>) -> Self {
295        Self {
296            addr,
297            connections: Mutex::new(Vec::new()),
298            semaphore: Arc::new(Semaphore::new(config.pool_size)),
299            circuit_breaker: CircuitBreaker::new(config.clone()),
300            config,
301        }
302    }
303
304    async fn get(&self) -> Result<PooledConnection> {
305        // Check circuit breaker
306        if !self.circuit_breaker.allow_request().await {
307            return Err(Error::CircuitBreakerOpen(self.addr.clone()));
308        }
309
310        // Acquire semaphore permit — owned so it can be stored in PooledConnection
311        let permit = self
312            .semaphore
313            .clone()
314            .acquire_owned()
315            .await
316            .map_err(|_| Error::ConnectionError("Pool exhausted".to_string()))?;
317
318        // Try to get existing connection
319        {
320            let mut connections = self.connections.lock().await;
321            if let Some(mut conn) = connections.pop() {
322                conn.last_used = Instant::now();
323                conn._permit = permit;
324                return Ok(conn);
325            }
326        }
327
328        // Create new connection with timeout
329        let client = timeout(self.config.connection_timeout, Client::connect(&self.addr))
330            .await
331            .map_err(|_| Error::ConnectionError(format!("Connection timeout to {}", self.addr)))?
332            .map_err(|e| {
333                Error::ConnectionError(format!("Failed to connect to {}: {}", self.addr, e))
334            })?;
335
336        Ok(PooledConnection {
337            client,
338            created_at: Instant::now(),
339            last_used: Instant::now(),
340            _permit: permit,
341        })
342    }
343
344    async fn put(&self, conn: PooledConnection) {
345        // Only return if connection is healthy
346        if conn.created_at.elapsed() < Duration::from_secs(300) {
347            let mut connections = self.connections.lock().await;
348            if connections.len() < self.config.pool_size {
349                connections.push(conn);
350            }
351        }
352    }
353
354    async fn record_success(&self) {
355        self.circuit_breaker.record_success().await;
356    }
357
358    async fn record_failure(&self) {
359        self.circuit_breaker.record_failure().await;
360    }
361
362    fn circuit_state(&self) -> CircuitState {
363        self.circuit_breaker.get_state()
364    }
365}
366
367// ============================================================================
368// Resilient Client
369// ============================================================================
370
371/// Production-grade resilient client with connection pooling and fault tolerance
372pub struct ResilientClient {
373    pools: HashMap<String, Arc<ConnectionPool>>,
374    config: Arc<ResilientClientConfig>,
375    current_server: AtomicU64,
376    total_requests: AtomicU64,
377    total_failures: AtomicU64,
378    _health_check_handle: Option<tokio::task::JoinHandle<()>>,
379}
380
381impl ResilientClient {
382    /// Create a new resilient client
383    pub async fn new(config: ResilientClientConfig) -> Result<Self> {
384        if config.bootstrap_servers.is_empty() {
385            return Err(Error::ConnectionError(
386                "No bootstrap servers configured".to_string(),
387            ));
388        }
389
390        let config = Arc::new(config);
391        let mut pools = HashMap::new();
392
393        for server in &config.bootstrap_servers {
394            let pool = Arc::new(ConnectionPool::new(server.clone(), config.clone()));
395            pools.insert(server.clone(), pool);
396        }
397
398        info!(
399            "Resilient client initialized with {} servers, pool size {}",
400            config.bootstrap_servers.len(),
401            config.pool_size
402        );
403
404        let mut client = Self {
405            pools,
406            config: config.clone(),
407            current_server: AtomicU64::new(0),
408            total_requests: AtomicU64::new(0),
409            total_failures: AtomicU64::new(0),
410            _health_check_handle: None,
411        };
412
413        // Start health check background task
414        if config.health_check_enabled {
415            let pools_clone: HashMap<String, Arc<ConnectionPool>> = client
416                .pools
417                .iter()
418                .map(|(k, v)| (k.clone(), v.clone()))
419                .collect();
420            let interval = config.health_check_interval;
421
422            let handle = tokio::spawn(async move {
423                loop {
424                    sleep(interval).await;
425                    for (addr, pool) in &pools_clone {
426                        if let Ok(mut conn) = pool.get().await {
427                            match conn.client.ping().await {
428                                Ok(()) => {
429                                    pool.record_success().await;
430                                    debug!("Health check passed for {}", addr);
431                                }
432                                Err(e) => {
433                                    pool.record_failure().await;
434                                    warn!("Health check failed for {}: {}", addr, e);
435                                }
436                            }
437                            pool.put(conn).await;
438                        }
439                    }
440                }
441            });
442
443            client._health_check_handle = Some(handle);
444        }
445
446        Ok(client)
447    }
448
449    /// Execute an operation with automatic retries and server failover
450    async fn execute_with_retry<F, T, Fut>(&self, operation: F) -> Result<T>
451    where
452        F: Fn(PooledConnection) -> Fut + Clone,
453        Fut: std::future::Future<Output = (PooledConnection, Result<T>)>,
454    {
455        self.total_requests.fetch_add(1, Ordering::Relaxed);
456        let servers: Vec<_> = self.config.bootstrap_servers.clone();
457        let num_servers = servers.len();
458
459        for attempt in 0..self.config.retry_max_attempts {
460            // Round-robin server selection with failover
461            let server_idx =
462                (self.current_server.fetch_add(1, Ordering::Relaxed) as usize) % num_servers;
463            let server = &servers[server_idx];
464
465            let pool = match self.pools.get(server) {
466                Some(p) => p,
467                None => continue,
468            };
469
470            // Skip servers with open circuit breaker
471            if pool.circuit_state() == CircuitState::Open {
472                debug!("Skipping {} (circuit breaker open)", server);
473                continue;
474            }
475
476            // Get connection from pool
477            let conn = match pool.get().await {
478                Ok(c) => c,
479                Err(e) => {
480                    warn!("Failed to get connection from {}: {}", server, e);
481                    pool.record_failure().await;
482                    continue;
483                }
484            };
485
486            // Execute operation with timeout
487            let result = timeout(self.config.request_timeout, (operation.clone())(conn)).await;
488
489            match result {
490                Ok((conn, Ok(value))) => {
491                    pool.record_success().await;
492                    pool.put(conn).await;
493                    return Ok(value);
494                }
495                Ok((conn, Err(e))) => {
496                    self.total_failures.fetch_add(1, Ordering::Relaxed);
497                    pool.record_failure().await;
498
499                    // Determine if error is retryable
500                    if is_retryable_error(&e) && attempt < self.config.retry_max_attempts - 1 {
501                        let delay = calculate_backoff(
502                            attempt,
503                            self.config.retry_initial_delay,
504                            self.config.retry_max_delay,
505                            self.config.retry_multiplier,
506                        );
507                        warn!(
508                            "Retryable error on attempt {}: {}. Retrying in {:?}",
509                            attempt + 1,
510                            e,
511                            delay
512                        );
513                        pool.put(conn).await;
514                        sleep(delay).await;
515                        continue;
516                    }
517
518                    return Err(e);
519                }
520                Err(_) => {
521                    self.total_failures.fetch_add(1, Ordering::Relaxed);
522                    pool.record_failure().await;
523                    warn!("Request timeout to {}", server);
524
525                    if attempt < self.config.retry_max_attempts - 1 {
526                        let delay = calculate_backoff(
527                            attempt,
528                            self.config.retry_initial_delay,
529                            self.config.retry_max_delay,
530                            self.config.retry_multiplier,
531                        );
532                        sleep(delay).await;
533                    }
534                }
535            }
536        }
537
538        Err(Error::ConnectionError(format!(
539            "All {} retry attempts exhausted",
540            self.config.retry_max_attempts
541        )))
542    }
543
544    /// Publish a message to a topic with automatic retries
545    pub async fn publish(&self, topic: impl Into<String>, value: impl Into<Bytes>) -> Result<u64> {
546        let topic = topic.into();
547        let value = value.into();
548
549        self.execute_with_retry(move |mut conn| {
550            let topic = topic.clone();
551            let value = value.clone();
552            async move {
553                let result = conn.client.publish(&topic, value).await;
554                (conn, result)
555            }
556        })
557        .await
558    }
559
560    /// Publish a message with a key
561    pub async fn publish_with_key(
562        &self,
563        topic: impl Into<String>,
564        key: Option<impl Into<Bytes>>,
565        value: impl Into<Bytes>,
566    ) -> Result<u64> {
567        let topic = topic.into();
568        let key: Option<Bytes> = key.map(|k| k.into());
569        let value = value.into();
570
571        self.execute_with_retry(move |mut conn| {
572            let topic = topic.clone();
573            let key = key.clone();
574            let value = value.clone();
575            async move {
576                let result = conn.client.publish_with_key(&topic, key, value).await;
577                (conn, result)
578            }
579        })
580        .await
581    }
582
583    /// Consume messages with automatic retries
584    ///
585    /// Uses read_uncommitted isolation level (default).
586    /// For transactional consumers, use [`Self::consume_with_isolation`] or [`Self::consume_read_committed`].
587    pub async fn consume(
588        &self,
589        topic: impl Into<String>,
590        partition: u32,
591        offset: u64,
592        max_messages: usize,
593    ) -> Result<Vec<MessageData>> {
594        self.consume_with_isolation(topic, partition, offset, max_messages, None)
595            .await
596    }
597
598    /// Consume messages with specified isolation level and automatic retries
599    ///
600    /// # Arguments
601    /// * `topic` - Topic name
602    /// * `partition` - Partition number
603    /// * `offset` - Starting offset
604    /// * `max_messages` - Maximum messages to return
605    /// * `isolation_level` - Transaction isolation level:
606    ///   - `None` or `Some(0)` = read_uncommitted (default)
607    ///   - `Some(1)` = read_committed (filters aborted transactions)
608    pub async fn consume_with_isolation(
609        &self,
610        topic: impl Into<String>,
611        partition: u32,
612        offset: u64,
613        max_messages: usize,
614        isolation_level: Option<u8>,
615    ) -> Result<Vec<MessageData>> {
616        let topic = topic.into();
617
618        self.execute_with_retry(move |mut conn| {
619            let topic = topic.clone();
620            async move {
621                let result = conn
622                    .client
623                    .consume_with_isolation(
624                        &topic,
625                        partition,
626                        offset,
627                        max_messages,
628                        isolation_level,
629                    )
630                    .await;
631                (conn, result)
632            }
633        })
634        .await
635    }
636
637    /// Consume messages with read_committed isolation level and automatic retries
638    ///
639    /// Only returns committed transactional messages; aborted transactions are filtered out.
640    pub async fn consume_read_committed(
641        &self,
642        topic: impl Into<String>,
643        partition: u32,
644        offset: u64,
645        max_messages: usize,
646    ) -> Result<Vec<MessageData>> {
647        self.consume_with_isolation(topic, partition, offset, max_messages, Some(1))
648            .await
649    }
650
651    /// Create a topic with automatic retries
652    pub async fn create_topic(
653        &self,
654        name: impl Into<String>,
655        partitions: Option<u32>,
656    ) -> Result<u32> {
657        let name = name.into();
658
659        self.execute_with_retry(move |mut conn| {
660            let name = name.clone();
661            async move {
662                let result = conn.client.create_topic(&name, partitions).await;
663                (conn, result)
664            }
665        })
666        .await
667    }
668
669    /// List all topics
670    pub async fn list_topics(&self) -> Result<Vec<String>> {
671        self.execute_with_retry(|mut conn| async move {
672            let result = conn.client.list_topics().await;
673            (conn, result)
674        })
675        .await
676    }
677
678    /// Delete a topic
679    pub async fn delete_topic(&self, name: impl Into<String>) -> Result<()> {
680        let name = name.into();
681
682        self.execute_with_retry(move |mut conn| {
683            let name = name.clone();
684            async move {
685                let result = conn.client.delete_topic(&name).await;
686                (conn, result)
687            }
688        })
689        .await
690    }
691
692    /// Commit consumer offset
693    pub async fn commit_offset(
694        &self,
695        consumer_group: impl Into<String>,
696        topic: impl Into<String>,
697        partition: u32,
698        offset: u64,
699    ) -> Result<()> {
700        let consumer_group = consumer_group.into();
701        let topic = topic.into();
702
703        self.execute_with_retry(move |mut conn| {
704            let consumer_group = consumer_group.clone();
705            let topic = topic.clone();
706            async move {
707                let result = conn
708                    .client
709                    .commit_offset(&consumer_group, &topic, partition, offset)
710                    .await;
711                (conn, result)
712            }
713        })
714        .await
715    }
716
717    /// Get consumer offset
718    pub async fn get_offset(
719        &self,
720        consumer_group: impl Into<String>,
721        topic: impl Into<String>,
722        partition: u32,
723    ) -> Result<Option<u64>> {
724        let consumer_group = consumer_group.into();
725        let topic = topic.into();
726
727        self.execute_with_retry(move |mut conn| {
728            let consumer_group = consumer_group.clone();
729            let topic = topic.clone();
730            async move {
731                let result = conn
732                    .client
733                    .get_offset(&consumer_group, &topic, partition)
734                    .await;
735                (conn, result)
736            }
737        })
738        .await
739    }
740
741    /// Get offset bounds (earliest, latest)
742    pub async fn get_offset_bounds(
743        &self,
744        topic: impl Into<String>,
745        partition: u32,
746    ) -> Result<(u64, u64)> {
747        let topic = topic.into();
748
749        self.execute_with_retry(move |mut conn| {
750            let topic = topic.clone();
751            async move {
752                let result = conn.client.get_offset_bounds(&topic, partition).await;
753                (conn, result)
754            }
755        })
756        .await
757    }
758
759    /// Get topic metadata
760    pub async fn get_metadata(&self, topic: impl Into<String>) -> Result<(String, u32)> {
761        let topic = topic.into();
762
763        self.execute_with_retry(move |mut conn| {
764            let topic = topic.clone();
765            async move {
766                let result = conn.client.get_metadata(&topic).await;
767                (conn, result)
768            }
769        })
770        .await
771    }
772
773    /// Ping (health check)
774    pub async fn ping(&self) -> Result<()> {
775        self.execute_with_retry(|mut conn| async move {
776            let result = conn.client.ping().await;
777            (conn, result)
778        })
779        .await
780    }
781
782    /// Get client statistics
783    pub fn stats(&self) -> ClientStats {
784        let pools: Vec<_> = self
785            .pools
786            .iter()
787            .map(|(addr, pool)| ServerStats {
788                address: addr.clone(),
789                circuit_state: pool.circuit_state(),
790            })
791            .collect();
792
793        ClientStats {
794            total_requests: self.total_requests.load(Ordering::Relaxed),
795            total_failures: self.total_failures.load(Ordering::Relaxed),
796            servers: pools,
797        }
798    }
799}
800
801impl Drop for ResilientClient {
802    fn drop(&mut self) {
803        if let Some(handle) = self._health_check_handle.take() {
804            handle.abort();
805        }
806    }
807}
808
809/// Client statistics
810#[derive(Debug, Clone)]
811pub struct ClientStats {
812    pub total_requests: u64,
813    pub total_failures: u64,
814    pub servers: Vec<ServerStats>,
815}
816
817/// Per-server statistics
818#[derive(Debug, Clone)]
819pub struct ServerStats {
820    pub address: String,
821    pub circuit_state: CircuitState,
822}
823
824// ============================================================================
825// Helper Functions
826// ============================================================================
827
828/// Determine if an error is retryable
829fn is_retryable_error(error: &Error) -> bool {
830    matches!(
831        error,
832        Error::ConnectionError(_) | Error::IoError(_) | Error::CircuitBreakerOpen(_)
833    )
834}
835
836/// Calculate exponential backoff with jitter
837fn calculate_backoff(
838    attempt: u32,
839    initial_delay: Duration,
840    max_delay: Duration,
841    multiplier: f64,
842) -> Duration {
843    let base_delay = initial_delay.as_millis() as f64 * multiplier.powi(attempt as i32);
844    let capped_delay = base_delay.min(max_delay.as_millis() as f64);
845
846    // Add jitter (±25%)
847    let jitter = (rand_simple() * 0.5 - 0.25) * capped_delay;
848    let final_delay = (capped_delay + jitter).max(0.0);
849
850    Duration::from_millis(final_delay as u64)
851}
852
853/// Simple random number generator (0.0 - 1.0)
854fn rand_simple() -> f64 {
855    use std::time::SystemTime;
856    let nanos = SystemTime::now()
857        .duration_since(SystemTime::UNIX_EPOCH)
858        .unwrap()
859        .subsec_nanos();
860    (nanos % 1000) as f64 / 1000.0
861}
862
863#[cfg(test)]
864mod tests {
865    use super::*;
866
867    #[test]
868    fn test_config_builder() {
869        let config = ResilientClientConfig::builder()
870            .bootstrap_servers(vec!["server1:9092".to_string(), "server2:9092".to_string()])
871            .pool_size(10)
872            .retry_max_attempts(5)
873            .circuit_breaker_threshold(10)
874            .connection_timeout(Duration::from_secs(5))
875            .build();
876
877        assert_eq!(config.bootstrap_servers.len(), 2);
878        assert_eq!(config.pool_size, 10);
879        assert_eq!(config.retry_max_attempts, 5);
880        assert_eq!(config.circuit_breaker_threshold, 10);
881        assert_eq!(config.connection_timeout, Duration::from_secs(5));
882    }
883
884    #[test]
885    fn test_calculate_backoff() {
886        let initial = Duration::from_millis(100);
887        let max = Duration::from_secs(10);
888
889        // First attempt
890        let delay = calculate_backoff(0, initial, max, 2.0);
891        assert!(delay.as_millis() >= 75 && delay.as_millis() <= 125);
892
893        // Second attempt (should be ~200ms)
894        let delay = calculate_backoff(1, initial, max, 2.0);
895        assert!(delay.as_millis() >= 150 && delay.as_millis() <= 250);
896
897        // Many attempts (should cap at max)
898        let delay = calculate_backoff(20, initial, max, 2.0);
899        assert!(delay <= max + Duration::from_millis(2500)); // max + jitter
900    }
901
902    #[test]
903    fn test_is_retryable_error() {
904        assert!(is_retryable_error(&Error::ConnectionError("test".into())));
905        assert!(is_retryable_error(&Error::CircuitBreakerOpen(
906            "test".into()
907        )));
908        assert!(!is_retryable_error(&Error::InvalidResponse));
909        assert!(!is_retryable_error(&Error::ServerError("test".into())));
910    }
911
912    #[test]
913    fn test_circuit_state() {
914        let config = Arc::new(ResilientClientConfig::default());
915        let cb = CircuitBreaker::new(config);
916
917        assert_eq!(cb.get_state(), CircuitState::Closed);
918    }
919
920    // ========================================================================
921    // Circuit Breaker State Machine Tests
922    // ========================================================================
923
924    #[tokio::test]
925    async fn test_circuit_breaker_starts_closed() {
926        let config = Arc::new(ResilientClientConfig::default());
927        let cb = CircuitBreaker::new(config);
928
929        assert_eq!(cb.get_state(), CircuitState::Closed);
930        assert!(cb.allow_request().await);
931    }
932
933    #[tokio::test]
934    async fn test_circuit_breaker_opens_after_threshold_failures() {
935        let config = Arc::new(
936            ResilientClientConfig::builder()
937                .circuit_breaker_threshold(3)
938                .build(),
939        );
940        let cb = CircuitBreaker::new(config);
941
942        // Should be closed initially
943        assert_eq!(cb.get_state(), CircuitState::Closed);
944
945        // Record failures up to threshold - 1
946        cb.record_failure().await;
947        assert_eq!(cb.get_state(), CircuitState::Closed);
948        cb.record_failure().await;
949        assert_eq!(cb.get_state(), CircuitState::Closed);
950
951        // Threshold reached - should open
952        cb.record_failure().await;
953        assert_eq!(cb.get_state(), CircuitState::Open);
954        assert!(!cb.allow_request().await);
955    }
956
957    #[tokio::test]
958    async fn test_circuit_breaker_success_resets_failure_count() {
959        let config = Arc::new(
960            ResilientClientConfig::builder()
961                .circuit_breaker_threshold(3)
962                .build(),
963        );
964        let cb = CircuitBreaker::new(config);
965
966        // Record some failures
967        cb.record_failure().await;
968        cb.record_failure().await;
969        assert_eq!(cb.failure_count.load(Ordering::SeqCst), 2);
970
971        // Success should reset
972        cb.record_success().await;
973        assert_eq!(cb.failure_count.load(Ordering::SeqCst), 0);
974        assert_eq!(cb.get_state(), CircuitState::Closed);
975    }
976
977    #[tokio::test]
978    async fn test_circuit_breaker_half_open_after_timeout() {
979        let config = Arc::new(
980            ResilientClientConfig::builder()
981                .circuit_breaker_threshold(1)
982                .circuit_breaker_timeout(Duration::from_millis(50))
983                .build(),
984        );
985        let cb = CircuitBreaker::new(config);
986
987        // Open the circuit
988        cb.record_failure().await;
989        assert_eq!(cb.get_state(), CircuitState::Open);
990        assert!(!cb.allow_request().await);
991
992        // Wait for timeout
993        tokio::time::sleep(Duration::from_millis(100)).await;
994
995        // Should transition to half-open and allow request
996        assert!(cb.allow_request().await);
997        assert_eq!(cb.get_state(), CircuitState::HalfOpen);
998    }
999
1000    #[tokio::test]
1001    async fn test_circuit_breaker_closes_after_success_threshold() {
1002        let config = Arc::new(
1003            ResilientClientConfig::builder()
1004                .circuit_breaker_threshold(1)
1005                .circuit_breaker_timeout(Duration::from_millis(10))
1006                .build(),
1007        );
1008        // Note: default success threshold is 2
1009        let cb = CircuitBreaker::new(config);
1010
1011        // Open the circuit
1012        cb.record_failure().await;
1013        assert_eq!(cb.get_state(), CircuitState::Open);
1014
1015        // Wait for timeout and transition to half-open
1016        tokio::time::sleep(Duration::from_millis(20)).await;
1017        assert!(cb.allow_request().await);
1018        assert_eq!(cb.get_state(), CircuitState::HalfOpen);
1019
1020        // First success - still half-open
1021        cb.record_success().await;
1022        assert_eq!(cb.get_state(), CircuitState::HalfOpen);
1023
1024        // Second success - should close
1025        cb.record_success().await;
1026        assert_eq!(cb.get_state(), CircuitState::Closed);
1027    }
1028
1029    #[tokio::test]
1030    async fn test_circuit_breaker_failure_in_half_open_reopens() {
1031        let config = Arc::new(
1032            ResilientClientConfig::builder()
1033                .circuit_breaker_threshold(1)
1034                .circuit_breaker_timeout(Duration::from_millis(10))
1035                .build(),
1036        );
1037        let cb = CircuitBreaker::new(config);
1038
1039        // Open the circuit
1040        cb.record_failure().await;
1041        assert_eq!(cb.get_state(), CircuitState::Open);
1042
1043        // Wait for timeout and transition to half-open
1044        tokio::time::sleep(Duration::from_millis(20)).await;
1045        assert!(cb.allow_request().await);
1046        assert_eq!(cb.get_state(), CircuitState::HalfOpen);
1047
1048        // Failure in half-open should reopen
1049        cb.record_failure().await;
1050        assert_eq!(cb.get_state(), CircuitState::Open);
1051    }
1052
1053    // ========================================================================
1054    // Connection Pool Tests
1055    // ========================================================================
1056
1057    #[test]
1058    fn test_pool_config_defaults() {
1059        let config = ResilientClientConfig::default();
1060        assert_eq!(config.pool_size, 5);
1061        assert_eq!(config.retry_max_attempts, 3);
1062        assert_eq!(config.circuit_breaker_threshold, 5);
1063        assert_eq!(config.circuit_breaker_success_threshold, 2);
1064    }
1065
1066    #[tokio::test]
1067    async fn test_pool_semaphore_limits_concurrent_connections() {
1068        let config = Arc::new(ResilientClientConfig::builder().pool_size(2).build());
1069        let pool = ConnectionPool::new("localhost:9999".to_string(), config);
1070
1071        // Verify semaphore has correct permits
1072        // Note: can't directly test without a server, but verify pool was created
1073        assert_eq!(pool.addr, "localhost:9999");
1074    }
1075
1076    // ========================================================================
1077    // Retry Logic Tests
1078    // ========================================================================
1079
1080    #[test]
1081    fn test_backoff_respects_max_delay() {
1082        let initial = Duration::from_millis(100);
1083        let max = Duration::from_secs(1);
1084
1085        // Even with high attempt count, should not exceed max + jitter
1086        for attempt in 10..20 {
1087            let delay = calculate_backoff(attempt, initial, max, 2.0);
1088            // Max jitter is 25% of max = 250ms
1089            assert!(delay <= max + Duration::from_millis(250));
1090        }
1091    }
1092
1093    #[test]
1094    fn test_backoff_exponential_growth() {
1095        let initial = Duration::from_millis(100);
1096        let max = Duration::from_secs(100);
1097
1098        // Get base delays (center of jitter range)
1099        let delay0 = calculate_backoff(0, initial, max, 2.0);
1100        let delay1 = calculate_backoff(1, initial, max, 2.0);
1101        let delay2 = calculate_backoff(2, initial, max, 2.0);
1102
1103        // Each should be roughly 2x the previous (accounting for jitter)
1104        // delay0 ≈ 100ms, delay1 ≈ 200ms, delay2 ≈ 400ms
1105        assert!(delay1 > delay0 / 2); // Very loose check due to jitter
1106        assert!(delay2 > delay1 / 2);
1107    }
1108
1109    // ========================================================================
1110    // Client Statistics Tests
1111    // ========================================================================
1112
1113    #[test]
1114    fn test_client_stats_structure() {
1115        let stats = ClientStats {
1116            total_requests: 100,
1117            total_failures: 5,
1118            servers: vec![
1119                ServerStats {
1120                    address: "server1:9092".to_string(),
1121                    circuit_state: CircuitState::Closed,
1122                },
1123                ServerStats {
1124                    address: "server2:9092".to_string(),
1125                    circuit_state: CircuitState::Open,
1126                },
1127            ],
1128        };
1129
1130        assert_eq!(stats.total_requests, 100);
1131        assert_eq!(stats.total_failures, 5);
1132        assert_eq!(stats.servers.len(), 2);
1133        assert_eq!(stats.servers[0].circuit_state, CircuitState::Closed);
1134        assert_eq!(stats.servers[1].circuit_state, CircuitState::Open);
1135    }
1136}