1use crate::{Client, Error, MessageData, Result};
30use bytes::Bytes;
31use std::collections::HashMap;
32use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
33use std::sync::Arc;
34use std::time::{Duration, Instant};
35use tokio::sync::{Mutex, RwLock, Semaphore};
36use tokio::time::{sleep, timeout};
37use tracing::{debug, info, warn};
38
39#[derive(Debug, Clone)]
45pub struct ResilientClientConfig {
46 pub bootstrap_servers: Vec<String>,
48 pub pool_size: usize,
50 pub retry_max_attempts: u32,
52 pub retry_initial_delay: Duration,
54 pub retry_max_delay: Duration,
56 pub retry_multiplier: f64,
58 pub circuit_breaker_threshold: u32,
60 pub circuit_breaker_timeout: Duration,
62 pub circuit_breaker_success_threshold: u32,
64 pub connection_timeout: Duration,
66 pub request_timeout: Duration,
68 pub health_check_interval: Duration,
70 pub health_check_enabled: bool,
72}
73
74impl Default for ResilientClientConfig {
75 fn default() -> Self {
76 Self {
77 bootstrap_servers: vec!["localhost:9092".to_string()],
78 pool_size: 5,
79 retry_max_attempts: 3,
80 retry_initial_delay: Duration::from_millis(100),
81 retry_max_delay: Duration::from_secs(10),
82 retry_multiplier: 2.0,
83 circuit_breaker_threshold: 5,
84 circuit_breaker_timeout: Duration::from_secs(30),
85 circuit_breaker_success_threshold: 2,
86 connection_timeout: Duration::from_secs(10),
87 request_timeout: Duration::from_secs(30),
88 health_check_interval: Duration::from_secs(30),
89 health_check_enabled: true,
90 }
91 }
92}
93
94impl ResilientClientConfig {
95 pub fn builder() -> ResilientClientConfigBuilder {
97 ResilientClientConfigBuilder::default()
98 }
99}
100
101#[derive(Default)]
103pub struct ResilientClientConfigBuilder {
104 config: ResilientClientConfig,
105}
106
107impl ResilientClientConfigBuilder {
108 pub fn bootstrap_servers(mut self, servers: Vec<String>) -> Self {
110 self.config.bootstrap_servers = servers;
111 self
112 }
113
114 pub fn pool_size(mut self, size: usize) -> Self {
116 self.config.pool_size = size;
117 self
118 }
119
120 pub fn retry_max_attempts(mut self, attempts: u32) -> Self {
122 self.config.retry_max_attempts = attempts;
123 self
124 }
125
126 pub fn retry_initial_delay(mut self, delay: Duration) -> Self {
128 self.config.retry_initial_delay = delay;
129 self
130 }
131
132 pub fn retry_max_delay(mut self, delay: Duration) -> Self {
134 self.config.retry_max_delay = delay;
135 self
136 }
137
138 pub fn retry_multiplier(mut self, multiplier: f64) -> Self {
140 self.config.retry_multiplier = multiplier;
141 self
142 }
143
144 pub fn circuit_breaker_threshold(mut self, threshold: u32) -> Self {
146 self.config.circuit_breaker_threshold = threshold;
147 self
148 }
149
150 pub fn circuit_breaker_timeout(mut self, timeout: Duration) -> Self {
152 self.config.circuit_breaker_timeout = timeout;
153 self
154 }
155
156 pub fn connection_timeout(mut self, timeout: Duration) -> Self {
158 self.config.connection_timeout = timeout;
159 self
160 }
161
162 pub fn request_timeout(mut self, timeout: Duration) -> Self {
164 self.config.request_timeout = timeout;
165 self
166 }
167
168 pub fn health_check_enabled(mut self, enabled: bool) -> Self {
170 self.config.health_check_enabled = enabled;
171 self
172 }
173
174 pub fn health_check_interval(mut self, interval: Duration) -> Self {
176 self.config.health_check_interval = interval;
177 self
178 }
179
180 pub fn build(self) -> ResilientClientConfig {
182 self.config
183 }
184}
185
186#[derive(Debug, Clone, Copy, PartialEq, Eq)]
192pub enum CircuitState {
193 Closed,
194 Open,
195 HalfOpen,
196}
197
198struct CircuitBreaker {
200 state: AtomicU32,
201 failure_count: AtomicU32,
202 success_count: AtomicU32,
203 last_failure: RwLock<Option<Instant>>,
204 config: Arc<ResilientClientConfig>,
205}
206
207impl CircuitBreaker {
208 fn new(config: Arc<ResilientClientConfig>) -> Self {
209 Self {
210 state: AtomicU32::new(0), failure_count: AtomicU32::new(0),
212 success_count: AtomicU32::new(0),
213 last_failure: RwLock::new(None),
214 config,
215 }
216 }
217
218 fn get_state(&self) -> CircuitState {
219 match self.state.load(Ordering::SeqCst) {
220 0 => CircuitState::Closed,
221 1 => CircuitState::Open,
222 _ => CircuitState::HalfOpen,
223 }
224 }
225
226 async fn allow_request(&self) -> bool {
227 match self.get_state() {
228 CircuitState::Closed => true,
229 CircuitState::Open => {
230 let last_failure = self.last_failure.read().await;
231 if let Some(t) = *last_failure {
232 if t.elapsed() > self.config.circuit_breaker_timeout {
233 self.state.store(2, Ordering::SeqCst); self.success_count.store(0, Ordering::SeqCst);
235 return true;
236 }
237 }
238 false
239 }
240 CircuitState::HalfOpen => true,
241 }
242 }
243
244 async fn record_success(&self) {
245 self.failure_count.store(0, Ordering::SeqCst);
246
247 if self.get_state() == CircuitState::HalfOpen {
248 let count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
249 if count >= self.config.circuit_breaker_success_threshold {
250 self.state.store(0, Ordering::SeqCst); debug!("Circuit breaker closed after {} successes", count);
252 }
253 }
254 }
255
256 async fn record_failure(&self) {
257 let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
258 *self.last_failure.write().await = Some(Instant::now());
259
260 if count >= self.config.circuit_breaker_threshold {
261 self.state.store(1, Ordering::SeqCst); warn!("Circuit breaker opened after {} failures", count);
263 }
264 }
265}
266
267struct PooledConnection {
273 client: Client,
274 created_at: Instant,
275 last_used: Instant,
276}
277
278struct ConnectionPool {
280 addr: String,
281 connections: Mutex<Vec<PooledConnection>>,
282 semaphore: Semaphore,
283 config: Arc<ResilientClientConfig>,
284 circuit_breaker: CircuitBreaker,
285}
286
287impl ConnectionPool {
288 fn new(addr: String, config: Arc<ResilientClientConfig>) -> Self {
289 Self {
290 addr,
291 connections: Mutex::new(Vec::new()),
292 semaphore: Semaphore::new(config.pool_size),
293 circuit_breaker: CircuitBreaker::new(config.clone()),
294 config,
295 }
296 }
297
298 async fn get(&self) -> Result<PooledConnection> {
299 if !self.circuit_breaker.allow_request().await {
301 return Err(Error::CircuitBreakerOpen(self.addr.clone()));
302 }
303
304 let _permit = self
306 .semaphore
307 .acquire()
308 .await
309 .map_err(|_| Error::ConnectionError("Pool exhausted".to_string()))?;
310
311 {
313 let mut connections = self.connections.lock().await;
314 if let Some(mut conn) = connections.pop() {
315 conn.last_used = Instant::now();
316 return Ok(conn);
317 }
318 }
319
320 let client = timeout(self.config.connection_timeout, Client::connect(&self.addr))
322 .await
323 .map_err(|_| Error::ConnectionError(format!("Connection timeout to {}", self.addr)))?
324 .map_err(|e| {
325 Error::ConnectionError(format!("Failed to connect to {}: {}", self.addr, e))
326 })?;
327
328 Ok(PooledConnection {
329 client,
330 created_at: Instant::now(),
331 last_used: Instant::now(),
332 })
333 }
334
335 async fn put(&self, conn: PooledConnection) {
336 if conn.created_at.elapsed() < Duration::from_secs(300) {
338 let mut connections = self.connections.lock().await;
339 if connections.len() < self.config.pool_size {
340 connections.push(conn);
341 }
342 }
343 }
344
345 async fn record_success(&self) {
346 self.circuit_breaker.record_success().await;
347 }
348
349 async fn record_failure(&self) {
350 self.circuit_breaker.record_failure().await;
351 }
352
353 fn circuit_state(&self) -> CircuitState {
354 self.circuit_breaker.get_state()
355 }
356}
357
358pub struct ResilientClient {
364 pools: HashMap<String, Arc<ConnectionPool>>,
365 config: Arc<ResilientClientConfig>,
366 current_server: AtomicU64,
367 total_requests: AtomicU64,
368 total_failures: AtomicU64,
369 _health_check_handle: Option<tokio::task::JoinHandle<()>>,
370}
371
372impl ResilientClient {
373 pub async fn new(config: ResilientClientConfig) -> Result<Self> {
375 if config.bootstrap_servers.is_empty() {
376 return Err(Error::ConnectionError(
377 "No bootstrap servers configured".to_string(),
378 ));
379 }
380
381 let config = Arc::new(config);
382 let mut pools = HashMap::new();
383
384 for server in &config.bootstrap_servers {
385 let pool = Arc::new(ConnectionPool::new(server.clone(), config.clone()));
386 pools.insert(server.clone(), pool);
387 }
388
389 info!(
390 "Resilient client initialized with {} servers, pool size {}",
391 config.bootstrap_servers.len(),
392 config.pool_size
393 );
394
395 let mut client = Self {
396 pools,
397 config: config.clone(),
398 current_server: AtomicU64::new(0),
399 total_requests: AtomicU64::new(0),
400 total_failures: AtomicU64::new(0),
401 _health_check_handle: None,
402 };
403
404 if config.health_check_enabled {
406 let pools_clone: HashMap<String, Arc<ConnectionPool>> = client
407 .pools
408 .iter()
409 .map(|(k, v)| (k.clone(), v.clone()))
410 .collect();
411 let interval = config.health_check_interval;
412
413 let handle = tokio::spawn(async move {
414 loop {
415 sleep(interval).await;
416 for (addr, pool) in &pools_clone {
417 if let Ok(mut conn) = pool.get().await {
418 match conn.client.ping().await {
419 Ok(()) => {
420 pool.record_success().await;
421 debug!("Health check passed for {}", addr);
422 }
423 Err(e) => {
424 pool.record_failure().await;
425 warn!("Health check failed for {}: {}", addr, e);
426 }
427 }
428 pool.put(conn).await;
429 }
430 }
431 }
432 });
433
434 client._health_check_handle = Some(handle);
435 }
436
437 Ok(client)
438 }
439
440 async fn execute_with_retry<F, T, Fut>(&self, operation: F) -> Result<T>
442 where
443 F: Fn(PooledConnection) -> Fut + Clone,
444 Fut: std::future::Future<Output = (PooledConnection, Result<T>)>,
445 {
446 self.total_requests.fetch_add(1, Ordering::Relaxed);
447 let servers: Vec<_> = self.config.bootstrap_servers.clone();
448 let num_servers = servers.len();
449
450 for attempt in 0..self.config.retry_max_attempts {
451 let server_idx =
453 (self.current_server.fetch_add(1, Ordering::Relaxed) as usize) % num_servers;
454 let server = &servers[server_idx];
455
456 let pool = match self.pools.get(server) {
457 Some(p) => p,
458 None => continue,
459 };
460
461 if pool.circuit_state() == CircuitState::Open {
463 debug!("Skipping {} (circuit breaker open)", server);
464 continue;
465 }
466
467 let conn = match pool.get().await {
469 Ok(c) => c,
470 Err(e) => {
471 warn!("Failed to get connection from {}: {}", server, e);
472 pool.record_failure().await;
473 continue;
474 }
475 };
476
477 let result = timeout(self.config.request_timeout, (operation.clone())(conn)).await;
479
480 match result {
481 Ok((conn, Ok(value))) => {
482 pool.record_success().await;
483 pool.put(conn).await;
484 return Ok(value);
485 }
486 Ok((conn, Err(e))) => {
487 self.total_failures.fetch_add(1, Ordering::Relaxed);
488 pool.record_failure().await;
489
490 if is_retryable_error(&e) && attempt < self.config.retry_max_attempts - 1 {
492 let delay = calculate_backoff(
493 attempt,
494 self.config.retry_initial_delay,
495 self.config.retry_max_delay,
496 self.config.retry_multiplier,
497 );
498 warn!(
499 "Retryable error on attempt {}: {}. Retrying in {:?}",
500 attempt + 1,
501 e,
502 delay
503 );
504 pool.put(conn).await;
505 sleep(delay).await;
506 continue;
507 }
508
509 return Err(e);
510 }
511 Err(_) => {
512 self.total_failures.fetch_add(1, Ordering::Relaxed);
513 pool.record_failure().await;
514 warn!("Request timeout to {}", server);
515
516 if attempt < self.config.retry_max_attempts - 1 {
517 let delay = calculate_backoff(
518 attempt,
519 self.config.retry_initial_delay,
520 self.config.retry_max_delay,
521 self.config.retry_multiplier,
522 );
523 sleep(delay).await;
524 }
525 }
526 }
527 }
528
529 Err(Error::ConnectionError(format!(
530 "All {} retry attempts exhausted",
531 self.config.retry_max_attempts
532 )))
533 }
534
535 pub async fn publish(&self, topic: impl Into<String>, value: impl Into<Bytes>) -> Result<u64> {
537 let topic = topic.into();
538 let value = value.into();
539
540 self.execute_with_retry(move |mut conn| {
541 let topic = topic.clone();
542 let value = value.clone();
543 async move {
544 let result = conn.client.publish(&topic, value).await;
545 (conn, result)
546 }
547 })
548 .await
549 }
550
551 pub async fn publish_with_key(
553 &self,
554 topic: impl Into<String>,
555 key: Option<impl Into<Bytes>>,
556 value: impl Into<Bytes>,
557 ) -> Result<u64> {
558 let topic = topic.into();
559 let key: Option<Bytes> = key.map(|k| k.into());
560 let value = value.into();
561
562 self.execute_with_retry(move |mut conn| {
563 let topic = topic.clone();
564 let key = key.clone();
565 let value = value.clone();
566 async move {
567 let result = conn.client.publish_with_key(&topic, key, value).await;
568 (conn, result)
569 }
570 })
571 .await
572 }
573
574 pub async fn consume(
576 &self,
577 topic: impl Into<String>,
578 partition: u32,
579 offset: u64,
580 max_messages: usize,
581 ) -> Result<Vec<MessageData>> {
582 let topic = topic.into();
583
584 self.execute_with_retry(move |mut conn| {
585 let topic = topic.clone();
586 async move {
587 let result = conn
588 .client
589 .consume(&topic, partition, offset, max_messages)
590 .await;
591 (conn, result)
592 }
593 })
594 .await
595 }
596
597 pub async fn create_topic(
599 &self,
600 name: impl Into<String>,
601 partitions: Option<u32>,
602 ) -> Result<u32> {
603 let name = name.into();
604
605 self.execute_with_retry(move |mut conn| {
606 let name = name.clone();
607 async move {
608 let result = conn.client.create_topic(&name, partitions).await;
609 (conn, result)
610 }
611 })
612 .await
613 }
614
615 pub async fn list_topics(&self) -> Result<Vec<String>> {
617 self.execute_with_retry(|mut conn| async move {
618 let result = conn.client.list_topics().await;
619 (conn, result)
620 })
621 .await
622 }
623
624 pub async fn delete_topic(&self, name: impl Into<String>) -> Result<()> {
626 let name = name.into();
627
628 self.execute_with_retry(move |mut conn| {
629 let name = name.clone();
630 async move {
631 let result = conn.client.delete_topic(&name).await;
632 (conn, result)
633 }
634 })
635 .await
636 }
637
638 pub async fn commit_offset(
640 &self,
641 consumer_group: impl Into<String>,
642 topic: impl Into<String>,
643 partition: u32,
644 offset: u64,
645 ) -> Result<()> {
646 let consumer_group = consumer_group.into();
647 let topic = topic.into();
648
649 self.execute_with_retry(move |mut conn| {
650 let consumer_group = consumer_group.clone();
651 let topic = topic.clone();
652 async move {
653 let result = conn
654 .client
655 .commit_offset(&consumer_group, &topic, partition, offset)
656 .await;
657 (conn, result)
658 }
659 })
660 .await
661 }
662
663 pub async fn get_offset(
665 &self,
666 consumer_group: impl Into<String>,
667 topic: impl Into<String>,
668 partition: u32,
669 ) -> Result<Option<u64>> {
670 let consumer_group = consumer_group.into();
671 let topic = topic.into();
672
673 self.execute_with_retry(move |mut conn| {
674 let consumer_group = consumer_group.clone();
675 let topic = topic.clone();
676 async move {
677 let result = conn
678 .client
679 .get_offset(&consumer_group, &topic, partition)
680 .await;
681 (conn, result)
682 }
683 })
684 .await
685 }
686
687 pub async fn get_offset_bounds(
689 &self,
690 topic: impl Into<String>,
691 partition: u32,
692 ) -> Result<(u64, u64)> {
693 let topic = topic.into();
694
695 self.execute_with_retry(move |mut conn| {
696 let topic = topic.clone();
697 async move {
698 let result = conn.client.get_offset_bounds(&topic, partition).await;
699 (conn, result)
700 }
701 })
702 .await
703 }
704
705 pub async fn get_metadata(&self, topic: impl Into<String>) -> Result<(String, u32)> {
707 let topic = topic.into();
708
709 self.execute_with_retry(move |mut conn| {
710 let topic = topic.clone();
711 async move {
712 let result = conn.client.get_metadata(&topic).await;
713 (conn, result)
714 }
715 })
716 .await
717 }
718
719 pub async fn ping(&self) -> Result<()> {
721 self.execute_with_retry(|mut conn| async move {
722 let result = conn.client.ping().await;
723 (conn, result)
724 })
725 .await
726 }
727
728 pub fn stats(&self) -> ClientStats {
730 let pools: Vec<_> = self
731 .pools
732 .iter()
733 .map(|(addr, pool)| ServerStats {
734 address: addr.clone(),
735 circuit_state: pool.circuit_state(),
736 })
737 .collect();
738
739 ClientStats {
740 total_requests: self.total_requests.load(Ordering::Relaxed),
741 total_failures: self.total_failures.load(Ordering::Relaxed),
742 servers: pools,
743 }
744 }
745}
746
747#[derive(Debug, Clone)]
749pub struct ClientStats {
750 pub total_requests: u64,
751 pub total_failures: u64,
752 pub servers: Vec<ServerStats>,
753}
754
755#[derive(Debug, Clone)]
757pub struct ServerStats {
758 pub address: String,
759 pub circuit_state: CircuitState,
760}
761
762fn is_retryable_error(error: &Error) -> bool {
768 matches!(
769 error,
770 Error::ConnectionError(_) | Error::IoError(_) | Error::CircuitBreakerOpen(_)
771 )
772}
773
774fn calculate_backoff(
776 attempt: u32,
777 initial_delay: Duration,
778 max_delay: Duration,
779 multiplier: f64,
780) -> Duration {
781 let base_delay = initial_delay.as_millis() as f64 * multiplier.powi(attempt as i32);
782 let capped_delay = base_delay.min(max_delay.as_millis() as f64);
783
784 let jitter = (rand_simple() * 0.5 - 0.25) * capped_delay;
786 let final_delay = (capped_delay + jitter).max(0.0);
787
788 Duration::from_millis(final_delay as u64)
789}
790
791fn rand_simple() -> f64 {
793 use std::time::SystemTime;
794 let nanos = SystemTime::now()
795 .duration_since(SystemTime::UNIX_EPOCH)
796 .unwrap()
797 .subsec_nanos();
798 (nanos % 1000) as f64 / 1000.0
799}
800
801#[cfg(test)]
802mod tests {
803 use super::*;
804
805 #[test]
806 fn test_config_builder() {
807 let config = ResilientClientConfig::builder()
808 .bootstrap_servers(vec!["server1:9092".to_string(), "server2:9092".to_string()])
809 .pool_size(10)
810 .retry_max_attempts(5)
811 .circuit_breaker_threshold(10)
812 .connection_timeout(Duration::from_secs(5))
813 .build();
814
815 assert_eq!(config.bootstrap_servers.len(), 2);
816 assert_eq!(config.pool_size, 10);
817 assert_eq!(config.retry_max_attempts, 5);
818 assert_eq!(config.circuit_breaker_threshold, 10);
819 assert_eq!(config.connection_timeout, Duration::from_secs(5));
820 }
821
822 #[test]
823 fn test_calculate_backoff() {
824 let initial = Duration::from_millis(100);
825 let max = Duration::from_secs(10);
826
827 let delay = calculate_backoff(0, initial, max, 2.0);
829 assert!(delay.as_millis() >= 75 && delay.as_millis() <= 125);
830
831 let delay = calculate_backoff(1, initial, max, 2.0);
833 assert!(delay.as_millis() >= 150 && delay.as_millis() <= 250);
834
835 let delay = calculate_backoff(20, initial, max, 2.0);
837 assert!(delay <= max + Duration::from_millis(2500)); }
839
840 #[test]
841 fn test_is_retryable_error() {
842 assert!(is_retryable_error(&Error::ConnectionError("test".into())));
843 assert!(is_retryable_error(&Error::CircuitBreakerOpen(
844 "test".into()
845 )));
846 assert!(!is_retryable_error(&Error::InvalidResponse));
847 assert!(!is_retryable_error(&Error::ServerError("test".into())));
848 }
849
850 #[test]
851 fn test_circuit_state() {
852 let config = Arc::new(ResilientClientConfig::default());
853 let cb = CircuitBreaker::new(config);
854
855 assert_eq!(cb.get_state(), CircuitState::Closed);
856 }
857}