1use crate::{Client, Error, MessageData, Result};
30use bytes::Bytes;
31use std::collections::HashMap;
32use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
33use std::sync::Arc;
34use std::time::{Duration, Instant};
35use tokio::sync::{Mutex, RwLock, Semaphore};
36use tokio::time::{sleep, timeout};
37use tracing::{debug, info, warn};
38
39#[derive(Debug, Clone)]
45pub struct ResilientClientConfig {
46 pub bootstrap_servers: Vec<String>,
48 pub pool_size: usize,
50 pub retry_max_attempts: u32,
52 pub retry_initial_delay: Duration,
54 pub retry_max_delay: Duration,
56 pub retry_multiplier: f64,
58 pub circuit_breaker_threshold: u32,
60 pub circuit_breaker_timeout: Duration,
62 pub circuit_breaker_success_threshold: u32,
64 pub connection_timeout: Duration,
66 pub request_timeout: Duration,
68 pub health_check_interval: Duration,
70 pub health_check_enabled: bool,
72}
73
74impl Default for ResilientClientConfig {
75 fn default() -> Self {
76 Self {
77 bootstrap_servers: vec!["localhost:9092".to_string()],
78 pool_size: 5,
79 retry_max_attempts: 3,
80 retry_initial_delay: Duration::from_millis(100),
81 retry_max_delay: Duration::from_secs(10),
82 retry_multiplier: 2.0,
83 circuit_breaker_threshold: 5,
84 circuit_breaker_timeout: Duration::from_secs(30),
85 circuit_breaker_success_threshold: 2,
86 connection_timeout: Duration::from_secs(10),
87 request_timeout: Duration::from_secs(30),
88 health_check_interval: Duration::from_secs(30),
89 health_check_enabled: true,
90 }
91 }
92}
93
94impl ResilientClientConfig {
95 pub fn builder() -> ResilientClientConfigBuilder {
97 ResilientClientConfigBuilder::default()
98 }
99}
100
101#[derive(Default)]
103pub struct ResilientClientConfigBuilder {
104 config: ResilientClientConfig,
105}
106
107impl ResilientClientConfigBuilder {
108 pub fn bootstrap_servers(mut self, servers: Vec<String>) -> Self {
110 self.config.bootstrap_servers = servers;
111 self
112 }
113
114 pub fn pool_size(mut self, size: usize) -> Self {
116 self.config.pool_size = size;
117 self
118 }
119
120 pub fn retry_max_attempts(mut self, attempts: u32) -> Self {
122 self.config.retry_max_attempts = attempts;
123 self
124 }
125
126 pub fn retry_initial_delay(mut self, delay: Duration) -> Self {
128 self.config.retry_initial_delay = delay;
129 self
130 }
131
132 pub fn retry_max_delay(mut self, delay: Duration) -> Self {
134 self.config.retry_max_delay = delay;
135 self
136 }
137
138 pub fn retry_multiplier(mut self, multiplier: f64) -> Self {
140 self.config.retry_multiplier = multiplier;
141 self
142 }
143
144 pub fn circuit_breaker_threshold(mut self, threshold: u32) -> Self {
146 self.config.circuit_breaker_threshold = threshold;
147 self
148 }
149
150 pub fn circuit_breaker_timeout(mut self, timeout: Duration) -> Self {
152 self.config.circuit_breaker_timeout = timeout;
153 self
154 }
155
156 pub fn connection_timeout(mut self, timeout: Duration) -> Self {
158 self.config.connection_timeout = timeout;
159 self
160 }
161
162 pub fn request_timeout(mut self, timeout: Duration) -> Self {
164 self.config.request_timeout = timeout;
165 self
166 }
167
168 pub fn health_check_enabled(mut self, enabled: bool) -> Self {
170 self.config.health_check_enabled = enabled;
171 self
172 }
173
174 pub fn health_check_interval(mut self, interval: Duration) -> Self {
176 self.config.health_check_interval = interval;
177 self
178 }
179
180 pub fn build(self) -> ResilientClientConfig {
182 self.config
183 }
184}
185
186#[derive(Debug, Clone, Copy, PartialEq, Eq)]
192pub enum CircuitState {
193 Closed,
194 Open,
195 HalfOpen,
196}
197
198struct CircuitBreaker {
200 state: AtomicU32,
201 failure_count: AtomicU32,
202 success_count: AtomicU32,
203 last_failure: RwLock<Option<Instant>>,
204 config: Arc<ResilientClientConfig>,
205}
206
207impl CircuitBreaker {
208 fn new(config: Arc<ResilientClientConfig>) -> Self {
209 Self {
210 state: AtomicU32::new(0), failure_count: AtomicU32::new(0),
212 success_count: AtomicU32::new(0),
213 last_failure: RwLock::new(None),
214 config,
215 }
216 }
217
218 fn get_state(&self) -> CircuitState {
219 match self.state.load(Ordering::SeqCst) {
220 0 => CircuitState::Closed,
221 1 => CircuitState::Open,
222 _ => CircuitState::HalfOpen,
223 }
224 }
225
226 async fn allow_request(&self) -> bool {
227 match self.get_state() {
228 CircuitState::Closed => true,
229 CircuitState::Open => {
230 let last_failure = self.last_failure.read().await;
231 if let Some(t) = *last_failure {
232 if t.elapsed() > self.config.circuit_breaker_timeout {
233 self.state.store(2, Ordering::SeqCst); self.success_count.store(0, Ordering::SeqCst);
235 return true;
236 }
237 }
238 false
239 }
240 CircuitState::HalfOpen => true,
241 }
242 }
243
244 async fn record_success(&self) {
245 self.failure_count.store(0, Ordering::SeqCst);
246
247 if self.get_state() == CircuitState::HalfOpen {
248 let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
249 if count >= self.config.circuit_breaker_success_threshold {
250 self.state.store(0, Ordering::SeqCst); debug!("Circuit breaker closed after {} successes", count);
252 }
253 }
254 }
255
256 async fn record_failure(&self) {
257 let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
258 *self.last_failure.write().await = Some(Instant::now());
259
260 if count >= self.config.circuit_breaker_threshold {
261 self.state.store(1, Ordering::SeqCst); warn!("Circuit breaker opened after {} failures", count);
263 }
264 }
265}
266
267struct PooledConnection {
277 client: Client,
278 created_at: Instant,
279 last_used: Instant,
280 _permit: tokio::sync::OwnedSemaphorePermit,
282}
283
284struct ConnectionPool {
286 addr: String,
287 connections: Mutex<Vec<PooledConnection>>,
288 semaphore: Arc<Semaphore>,
289 config: Arc<ResilientClientConfig>,
290 circuit_breaker: CircuitBreaker,
291}
292
293impl ConnectionPool {
294 fn new(addr: String, config: Arc<ResilientClientConfig>) -> Self {
295 Self {
296 addr,
297 connections: Mutex::new(Vec::new()),
298 semaphore: Arc::new(Semaphore::new(config.pool_size)),
299 circuit_breaker: CircuitBreaker::new(config.clone()),
300 config,
301 }
302 }
303
304 async fn get(&self) -> Result<PooledConnection> {
305 if !self.circuit_breaker.allow_request().await {
307 return Err(Error::CircuitBreakerOpen(self.addr.clone()));
308 }
309
310 let permit = self
312 .semaphore
313 .clone()
314 .acquire_owned()
315 .await
316 .map_err(|_| Error::ConnectionError("Pool exhausted".to_string()))?;
317
318 {
320 let mut connections = self.connections.lock().await;
321 if let Some(mut conn) = connections.pop() {
322 conn.last_used = Instant::now();
323 conn._permit = permit;
324 return Ok(conn);
325 }
326 }
327
328 let client = timeout(self.config.connection_timeout, Client::connect(&self.addr))
330 .await
331 .map_err(|_| Error::ConnectionError(format!("Connection timeout to {}", self.addr)))?
332 .map_err(|e| {
333 Error::ConnectionError(format!("Failed to connect to {}: {}", self.addr, e))
334 })?;
335
336 Ok(PooledConnection {
337 client,
338 created_at: Instant::now(),
339 last_used: Instant::now(),
340 _permit: permit,
341 })
342 }
343
344 async fn put(&self, conn: PooledConnection) {
345 if conn.created_at.elapsed() < Duration::from_secs(300) {
347 let mut connections = self.connections.lock().await;
348 if connections.len() < self.config.pool_size {
349 connections.push(conn);
350 }
351 }
352 }
353
354 async fn record_success(&self) {
355 self.circuit_breaker.record_success().await;
356 }
357
358 async fn record_failure(&self) {
359 self.circuit_breaker.record_failure().await;
360 }
361
362 fn circuit_state(&self) -> CircuitState {
363 self.circuit_breaker.get_state()
364 }
365}
366
367pub struct ResilientClient {
373 pools: HashMap<String, Arc<ConnectionPool>>,
374 config: Arc<ResilientClientConfig>,
375 current_server: AtomicU64,
376 total_requests: AtomicU64,
377 total_failures: AtomicU64,
378 _health_check_handle: Option<tokio::task::JoinHandle<()>>,
379}
380
381impl ResilientClient {
382 pub async fn new(config: ResilientClientConfig) -> Result<Self> {
384 if config.bootstrap_servers.is_empty() {
385 return Err(Error::ConnectionError(
386 "No bootstrap servers configured".to_string(),
387 ));
388 }
389
390 let config = Arc::new(config);
391 let mut pools = HashMap::new();
392
393 for server in &config.bootstrap_servers {
394 let pool = Arc::new(ConnectionPool::new(server.clone(), config.clone()));
395 pools.insert(server.clone(), pool);
396 }
397
398 info!(
399 "Resilient client initialized with {} servers, pool size {}",
400 config.bootstrap_servers.len(),
401 config.pool_size
402 );
403
404 let mut client = Self {
405 pools,
406 config: config.clone(),
407 current_server: AtomicU64::new(0),
408 total_requests: AtomicU64::new(0),
409 total_failures: AtomicU64::new(0),
410 _health_check_handle: None,
411 };
412
413 if config.health_check_enabled {
415 let pools_clone: HashMap<String, Arc<ConnectionPool>> = client
416 .pools
417 .iter()
418 .map(|(k, v)| (k.clone(), v.clone()))
419 .collect();
420 let interval = config.health_check_interval;
421
422 let handle = tokio::spawn(async move {
423 loop {
424 sleep(interval).await;
425 for (addr, pool) in &pools_clone {
426 if let Ok(mut conn) = pool.get().await {
427 match conn.client.ping().await {
428 Ok(()) => {
429 pool.record_success().await;
430 debug!("Health check passed for {}", addr);
431 }
432 Err(e) => {
433 pool.record_failure().await;
434 warn!("Health check failed for {}: {}", addr, e);
435 }
436 }
437 pool.put(conn).await;
438 }
439 }
440 }
441 });
442
443 client._health_check_handle = Some(handle);
444 }
445
446 Ok(client)
447 }
448
449 async fn execute_with_retry<F, T, Fut>(&self, operation: F) -> Result<T>
451 where
452 F: Fn(PooledConnection) -> Fut + Clone,
453 Fut: std::future::Future<Output = (PooledConnection, Result<T>)>,
454 {
455 self.total_requests.fetch_add(1, Ordering::Relaxed);
456 let servers: Vec<_> = self.config.bootstrap_servers.clone();
457 let num_servers = servers.len();
458
459 for attempt in 0..self.config.retry_max_attempts {
460 let server_idx =
462 (self.current_server.fetch_add(1, Ordering::Relaxed) as usize) % num_servers;
463 let server = &servers[server_idx];
464
465 let pool = match self.pools.get(server) {
466 Some(p) => p,
467 None => continue,
468 };
469
470 if pool.circuit_state() == CircuitState::Open {
472 debug!("Skipping {} (circuit breaker open)", server);
473 continue;
474 }
475
476 let conn = match pool.get().await {
478 Ok(c) => c,
479 Err(e) => {
480 warn!("Failed to get connection from {}: {}", server, e);
481 pool.record_failure().await;
482 continue;
483 }
484 };
485
486 let result = timeout(self.config.request_timeout, (operation.clone())(conn)).await;
488
489 match result {
490 Ok((conn, Ok(value))) => {
491 pool.record_success().await;
492 pool.put(conn).await;
493 return Ok(value);
494 }
495 Ok((conn, Err(e))) => {
496 self.total_failures.fetch_add(1, Ordering::Relaxed);
497 pool.record_failure().await;
498
499 if is_retryable_error(&e) && attempt < self.config.retry_max_attempts - 1 {
501 let delay = calculate_backoff(
502 attempt,
503 self.config.retry_initial_delay,
504 self.config.retry_max_delay,
505 self.config.retry_multiplier,
506 );
507 warn!(
508 "Retryable error on attempt {}: {}. Retrying in {:?}",
509 attempt + 1,
510 e,
511 delay
512 );
513 pool.put(conn).await;
514 sleep(delay).await;
515 continue;
516 }
517
518 return Err(e);
519 }
520 Err(_) => {
521 self.total_failures.fetch_add(1, Ordering::Relaxed);
522 pool.record_failure().await;
523 warn!("Request timeout to {}", server);
524
525 if attempt < self.config.retry_max_attempts - 1 {
526 let delay = calculate_backoff(
527 attempt,
528 self.config.retry_initial_delay,
529 self.config.retry_max_delay,
530 self.config.retry_multiplier,
531 );
532 sleep(delay).await;
533 }
534 }
535 }
536 }
537
538 Err(Error::ConnectionError(format!(
539 "All {} retry attempts exhausted",
540 self.config.retry_max_attempts
541 )))
542 }
543
544 pub async fn publish(&self, topic: impl Into<String>, value: impl Into<Bytes>) -> Result<u64> {
546 let topic = topic.into();
547 let value = value.into();
548
549 self.execute_with_retry(move |mut conn| {
550 let topic = topic.clone();
551 let value = value.clone();
552 async move {
553 let result = conn.client.publish(&topic, value).await;
554 (conn, result)
555 }
556 })
557 .await
558 }
559
560 pub async fn publish_with_key(
562 &self,
563 topic: impl Into<String>,
564 key: Option<impl Into<Bytes>>,
565 value: impl Into<Bytes>,
566 ) -> Result<u64> {
567 let topic = topic.into();
568 let key: Option<Bytes> = key.map(|k| k.into());
569 let value = value.into();
570
571 self.execute_with_retry(move |mut conn| {
572 let topic = topic.clone();
573 let key = key.clone();
574 let value = value.clone();
575 async move {
576 let result = conn.client.publish_with_key(&topic, key, value).await;
577 (conn, result)
578 }
579 })
580 .await
581 }
582
583 pub async fn consume(
588 &self,
589 topic: impl Into<String>,
590 partition: u32,
591 offset: u64,
592 max_messages: usize,
593 ) -> Result<Vec<MessageData>> {
594 self.consume_with_isolation(topic, partition, offset, max_messages, None)
595 .await
596 }
597
598 pub async fn consume_with_isolation(
609 &self,
610 topic: impl Into<String>,
611 partition: u32,
612 offset: u64,
613 max_messages: usize,
614 isolation_level: Option<u8>,
615 ) -> Result<Vec<MessageData>> {
616 let topic = topic.into();
617
618 self.execute_with_retry(move |mut conn| {
619 let topic = topic.clone();
620 async move {
621 let result = conn
622 .client
623 .consume_with_isolation(
624 &topic,
625 partition,
626 offset,
627 max_messages,
628 isolation_level,
629 )
630 .await;
631 (conn, result)
632 }
633 })
634 .await
635 }
636
637 pub async fn consume_read_committed(
641 &self,
642 topic: impl Into<String>,
643 partition: u32,
644 offset: u64,
645 max_messages: usize,
646 ) -> Result<Vec<MessageData>> {
647 self.consume_with_isolation(topic, partition, offset, max_messages, Some(1))
648 .await
649 }
650
651 pub async fn create_topic(
653 &self,
654 name: impl Into<String>,
655 partitions: Option<u32>,
656 ) -> Result<u32> {
657 let name = name.into();
658
659 self.execute_with_retry(move |mut conn| {
660 let name = name.clone();
661 async move {
662 let result = conn.client.create_topic(&name, partitions).await;
663 (conn, result)
664 }
665 })
666 .await
667 }
668
669 pub async fn list_topics(&self) -> Result<Vec<String>> {
671 self.execute_with_retry(|mut conn| async move {
672 let result = conn.client.list_topics().await;
673 (conn, result)
674 })
675 .await
676 }
677
678 pub async fn delete_topic(&self, name: impl Into<String>) -> Result<()> {
680 let name = name.into();
681
682 self.execute_with_retry(move |mut conn| {
683 let name = name.clone();
684 async move {
685 let result = conn.client.delete_topic(&name).await;
686 (conn, result)
687 }
688 })
689 .await
690 }
691
692 pub async fn commit_offset(
694 &self,
695 consumer_group: impl Into<String>,
696 topic: impl Into<String>,
697 partition: u32,
698 offset: u64,
699 ) -> Result<()> {
700 let consumer_group = consumer_group.into();
701 let topic = topic.into();
702
703 self.execute_with_retry(move |mut conn| {
704 let consumer_group = consumer_group.clone();
705 let topic = topic.clone();
706 async move {
707 let result = conn
708 .client
709 .commit_offset(&consumer_group, &topic, partition, offset)
710 .await;
711 (conn, result)
712 }
713 })
714 .await
715 }
716
717 pub async fn get_offset(
719 &self,
720 consumer_group: impl Into<String>,
721 topic: impl Into<String>,
722 partition: u32,
723 ) -> Result<Option<u64>> {
724 let consumer_group = consumer_group.into();
725 let topic = topic.into();
726
727 self.execute_with_retry(move |mut conn| {
728 let consumer_group = consumer_group.clone();
729 let topic = topic.clone();
730 async move {
731 let result = conn
732 .client
733 .get_offset(&consumer_group, &topic, partition)
734 .await;
735 (conn, result)
736 }
737 })
738 .await
739 }
740
741 pub async fn get_offset_bounds(
743 &self,
744 topic: impl Into<String>,
745 partition: u32,
746 ) -> Result<(u64, u64)> {
747 let topic = topic.into();
748
749 self.execute_with_retry(move |mut conn| {
750 let topic = topic.clone();
751 async move {
752 let result = conn.client.get_offset_bounds(&topic, partition).await;
753 (conn, result)
754 }
755 })
756 .await
757 }
758
759 pub async fn get_metadata(&self, topic: impl Into<String>) -> Result<(String, u32)> {
761 let topic = topic.into();
762
763 self.execute_with_retry(move |mut conn| {
764 let topic = topic.clone();
765 async move {
766 let result = conn.client.get_metadata(&topic).await;
767 (conn, result)
768 }
769 })
770 .await
771 }
772
773 pub async fn ping(&self) -> Result<()> {
775 self.execute_with_retry(|mut conn| async move {
776 let result = conn.client.ping().await;
777 (conn, result)
778 })
779 .await
780 }
781
782 pub fn stats(&self) -> ClientStats {
784 let pools: Vec<_> = self
785 .pools
786 .iter()
787 .map(|(addr, pool)| ServerStats {
788 address: addr.clone(),
789 circuit_state: pool.circuit_state(),
790 })
791 .collect();
792
793 ClientStats {
794 total_requests: self.total_requests.load(Ordering::Relaxed),
795 total_failures: self.total_failures.load(Ordering::Relaxed),
796 servers: pools,
797 }
798 }
799}
800
801impl Drop for ResilientClient {
802 fn drop(&mut self) {
803 if let Some(handle) = self._health_check_handle.take() {
804 handle.abort();
805 }
806 }
807}
808
809#[derive(Debug, Clone)]
811pub struct ClientStats {
812 pub total_requests: u64,
813 pub total_failures: u64,
814 pub servers: Vec<ServerStats>,
815}
816
817#[derive(Debug, Clone)]
819pub struct ServerStats {
820 pub address: String,
821 pub circuit_state: CircuitState,
822}
823
824fn is_retryable_error(error: &Error) -> bool {
830 matches!(
831 error,
832 Error::ConnectionError(_) | Error::IoError(_) | Error::CircuitBreakerOpen(_)
833 )
834}
835
836fn calculate_backoff(
838 attempt: u32,
839 initial_delay: Duration,
840 max_delay: Duration,
841 multiplier: f64,
842) -> Duration {
843 let base_delay = initial_delay.as_millis() as f64 * multiplier.powi(attempt as i32);
844 let capped_delay = base_delay.min(max_delay.as_millis() as f64);
845
846 let jitter = (rand_simple() * 0.5 - 0.25) * capped_delay;
848 let final_delay = (capped_delay + jitter).max(0.0);
849
850 Duration::from_millis(final_delay as u64)
851}
852
853fn rand_simple() -> f64 {
855 use std::time::SystemTime;
856 let nanos = SystemTime::now()
857 .duration_since(SystemTime::UNIX_EPOCH)
858 .unwrap()
859 .subsec_nanos();
860 (nanos % 1000) as f64 / 1000.0
861}
862
863#[cfg(test)]
864mod tests {
865 use super::*;
866
867 #[test]
868 fn test_config_builder() {
869 let config = ResilientClientConfig::builder()
870 .bootstrap_servers(vec!["server1:9092".to_string(), "server2:9092".to_string()])
871 .pool_size(10)
872 .retry_max_attempts(5)
873 .circuit_breaker_threshold(10)
874 .connection_timeout(Duration::from_secs(5))
875 .build();
876
877 assert_eq!(config.bootstrap_servers.len(), 2);
878 assert_eq!(config.pool_size, 10);
879 assert_eq!(config.retry_max_attempts, 5);
880 assert_eq!(config.circuit_breaker_threshold, 10);
881 assert_eq!(config.connection_timeout, Duration::from_secs(5));
882 }
883
884 #[test]
885 fn test_calculate_backoff() {
886 let initial = Duration::from_millis(100);
887 let max = Duration::from_secs(10);
888
889 let delay = calculate_backoff(0, initial, max, 2.0);
891 assert!(delay.as_millis() >= 75 && delay.as_millis() <= 125);
892
893 let delay = calculate_backoff(1, initial, max, 2.0);
895 assert!(delay.as_millis() >= 150 && delay.as_millis() <= 250);
896
897 let delay = calculate_backoff(20, initial, max, 2.0);
899 assert!(delay <= max + Duration::from_millis(2500)); }
901
902 #[test]
903 fn test_is_retryable_error() {
904 assert!(is_retryable_error(&Error::ConnectionError("test".into())));
905 assert!(is_retryable_error(&Error::CircuitBreakerOpen(
906 "test".into()
907 )));
908 assert!(!is_retryable_error(&Error::InvalidResponse));
909 assert!(!is_retryable_error(&Error::ServerError("test".into())));
910 }
911
912 #[test]
913 fn test_circuit_state() {
914 let config = Arc::new(ResilientClientConfig::default());
915 let cb = CircuitBreaker::new(config);
916
917 assert_eq!(cb.get_state(), CircuitState::Closed);
918 }
919
920 #[tokio::test]
925 async fn test_circuit_breaker_starts_closed() {
926 let config = Arc::new(ResilientClientConfig::default());
927 let cb = CircuitBreaker::new(config);
928
929 assert_eq!(cb.get_state(), CircuitState::Closed);
930 assert!(cb.allow_request().await);
931 }
932
933 #[tokio::test]
934 async fn test_circuit_breaker_opens_after_threshold_failures() {
935 let config = Arc::new(
936 ResilientClientConfig::builder()
937 .circuit_breaker_threshold(3)
938 .build(),
939 );
940 let cb = CircuitBreaker::new(config);
941
942 assert_eq!(cb.get_state(), CircuitState::Closed);
944
945 cb.record_failure().await;
947 assert_eq!(cb.get_state(), CircuitState::Closed);
948 cb.record_failure().await;
949 assert_eq!(cb.get_state(), CircuitState::Closed);
950
951 cb.record_failure().await;
953 assert_eq!(cb.get_state(), CircuitState::Open);
954 assert!(!cb.allow_request().await);
955 }
956
957 #[tokio::test]
958 async fn test_circuit_breaker_success_resets_failure_count() {
959 let config = Arc::new(
960 ResilientClientConfig::builder()
961 .circuit_breaker_threshold(3)
962 .build(),
963 );
964 let cb = CircuitBreaker::new(config);
965
966 cb.record_failure().await;
968 cb.record_failure().await;
969 assert_eq!(cb.failure_count.load(Ordering::SeqCst), 2);
970
971 cb.record_success().await;
973 assert_eq!(cb.failure_count.load(Ordering::SeqCst), 0);
974 assert_eq!(cb.get_state(), CircuitState::Closed);
975 }
976
977 #[tokio::test]
978 async fn test_circuit_breaker_half_open_after_timeout() {
979 let config = Arc::new(
980 ResilientClientConfig::builder()
981 .circuit_breaker_threshold(1)
982 .circuit_breaker_timeout(Duration::from_millis(50))
983 .build(),
984 );
985 let cb = CircuitBreaker::new(config);
986
987 cb.record_failure().await;
989 assert_eq!(cb.get_state(), CircuitState::Open);
990 assert!(!cb.allow_request().await);
991
992 tokio::time::sleep(Duration::from_millis(100)).await;
994
995 assert!(cb.allow_request().await);
997 assert_eq!(cb.get_state(), CircuitState::HalfOpen);
998 }
999
1000 #[tokio::test]
1001 async fn test_circuit_breaker_closes_after_success_threshold() {
1002 let config = Arc::new(
1003 ResilientClientConfig::builder()
1004 .circuit_breaker_threshold(1)
1005 .circuit_breaker_timeout(Duration::from_millis(10))
1006 .build(),
1007 );
1008 let cb = CircuitBreaker::new(config);
1010
1011 cb.record_failure().await;
1013 assert_eq!(cb.get_state(), CircuitState::Open);
1014
1015 tokio::time::sleep(Duration::from_millis(20)).await;
1017 assert!(cb.allow_request().await);
1018 assert_eq!(cb.get_state(), CircuitState::HalfOpen);
1019
1020 cb.record_success().await;
1022 assert_eq!(cb.get_state(), CircuitState::HalfOpen);
1023
1024 cb.record_success().await;
1026 assert_eq!(cb.get_state(), CircuitState::Closed);
1027 }
1028
1029 #[tokio::test]
1030 async fn test_circuit_breaker_failure_in_half_open_reopens() {
1031 let config = Arc::new(
1032 ResilientClientConfig::builder()
1033 .circuit_breaker_threshold(1)
1034 .circuit_breaker_timeout(Duration::from_millis(10))
1035 .build(),
1036 );
1037 let cb = CircuitBreaker::new(config);
1038
1039 cb.record_failure().await;
1041 assert_eq!(cb.get_state(), CircuitState::Open);
1042
1043 tokio::time::sleep(Duration::from_millis(20)).await;
1045 assert!(cb.allow_request().await);
1046 assert_eq!(cb.get_state(), CircuitState::HalfOpen);
1047
1048 cb.record_failure().await;
1050 assert_eq!(cb.get_state(), CircuitState::Open);
1051 }
1052
1053 #[test]
1058 fn test_pool_config_defaults() {
1059 let config = ResilientClientConfig::default();
1060 assert_eq!(config.pool_size, 5);
1061 assert_eq!(config.retry_max_attempts, 3);
1062 assert_eq!(config.circuit_breaker_threshold, 5);
1063 assert_eq!(config.circuit_breaker_success_threshold, 2);
1064 }
1065
1066 #[tokio::test]
1067 async fn test_pool_semaphore_limits_concurrent_connections() {
1068 let config = Arc::new(ResilientClientConfig::builder().pool_size(2).build());
1069 let pool = ConnectionPool::new("localhost:9999".to_string(), config);
1070
1071 assert_eq!(pool.addr, "localhost:9999");
1074 }
1075
1076 #[test]
1081 fn test_backoff_respects_max_delay() {
1082 let initial = Duration::from_millis(100);
1083 let max = Duration::from_secs(1);
1084
1085 for attempt in 10..20 {
1087 let delay = calculate_backoff(attempt, initial, max, 2.0);
1088 assert!(delay <= max + Duration::from_millis(250));
1090 }
1091 }
1092
1093 #[test]
1094 fn test_backoff_exponential_growth() {
1095 let initial = Duration::from_millis(100);
1096 let max = Duration::from_secs(100);
1097
1098 let delay0 = calculate_backoff(0, initial, max, 2.0);
1100 let delay1 = calculate_backoff(1, initial, max, 2.0);
1101 let delay2 = calculate_backoff(2, initial, max, 2.0);
1102
1103 assert!(delay1 > delay0 / 2); assert!(delay2 > delay1 / 2);
1107 }
1108
1109 #[test]
1114 fn test_client_stats_structure() {
1115 let stats = ClientStats {
1116 total_requests: 100,
1117 total_failures: 5,
1118 servers: vec![
1119 ServerStats {
1120 address: "server1:9092".to_string(),
1121 circuit_state: CircuitState::Closed,
1122 },
1123 ServerStats {
1124 address: "server2:9092".to_string(),
1125 circuit_state: CircuitState::Open,
1126 },
1127 ],
1128 };
1129
1130 assert_eq!(stats.total_requests, 100);
1131 assert_eq!(stats.total_failures, 5);
1132 assert_eq!(stats.servers.len(), 2);
1133 assert_eq!(stats.servers[0].circuit_state, CircuitState::Closed);
1134 assert_eq!(stats.servers[1].circuit_state, CircuitState::Open);
1135 }
1136}