1use crate::rpc::protocol::{CoordinatorServiceClient, ShardServiceClient};
8use crate::types::ShardId;
9use std::collections::HashMap;
10use std::net::SocketAddr;
11use std::sync::Arc;
12use std::time::Duration;
13use tarpc::client::Config;
14use tokio::sync::RwLock;
15use tokio_serde::formats::Bincode;
16use tracing::{debug, error, info, warn};
17
18const DEFAULT_CONNECT_TIMEOUT_MS: u64 = 5000;
20
21const DEFAULT_RETRY_ATTEMPTS: u32 = 3;
23
24const DEFAULT_RETRY_DELAY_MS: u64 = 500;
26
27#[derive(Debug, Clone)]
29pub struct ClientConfig {
30 pub connect_timeout: Duration,
32 pub retry_attempts: u32,
34 pub retry_delay: Duration,
36 pub max_pending_requests: usize,
38}
39
40impl Default for ClientConfig {
41 fn default() -> Self {
42 Self {
43 connect_timeout: Duration::from_millis(DEFAULT_CONNECT_TIMEOUT_MS),
44 retry_attempts: DEFAULT_RETRY_ATTEMPTS,
45 retry_delay: Duration::from_millis(DEFAULT_RETRY_DELAY_MS),
46 max_pending_requests: 100,
47 }
48 }
49}
50
51pub async fn connect_to_shard(addr: SocketAddr) -> Result<ShardServiceClient, std::io::Error> {
75 debug!("Connecting to shard at {}", addr);
76 let transport = tarpc::serde_transport::tcp::connect(addr, Bincode::default).await?;
77 let client = ShardServiceClient::new(Config::default(), transport).spawn();
78 info!("Connected to shard at {}", addr);
79 Ok(client)
80}
81
82pub async fn connect_to_shard_with_config(
93 addr: SocketAddr,
94 config: &ClientConfig,
95) -> Result<ShardServiceClient, std::io::Error> {
96 debug!("Connecting to shard at {} with custom config", addr);
97
98 let transport = tokio::time::timeout(
99 config.connect_timeout,
100 tarpc::serde_transport::tcp::connect(addr, Bincode::default),
101 )
102 .await
103 .map_err(|_| std::io::Error::new(std::io::ErrorKind::TimedOut, "connection timeout"))??;
104
105 let mut tarpc_config = Config::default();
106 tarpc_config.max_in_flight_requests = config.max_pending_requests;
107
108 let client = ShardServiceClient::new(tarpc_config, transport).spawn();
109 info!("Connected to shard at {}", addr);
110 Ok(client)
111}
112
113pub async fn connect_to_coordinator(
137 addr: SocketAddr,
138) -> Result<CoordinatorServiceClient, std::io::Error> {
139 debug!("Connecting to coordinator at {}", addr);
140 let transport = tarpc::serde_transport::tcp::connect(addr, Bincode::default).await?;
141 let client = CoordinatorServiceClient::new(Config::default(), transport).spawn();
142 info!("Connected to coordinator at {}", addr);
143 Ok(client)
144}
145
146pub async fn connect_to_coordinator_with_config(
157 addr: SocketAddr,
158 config: &ClientConfig,
159) -> Result<CoordinatorServiceClient, std::io::Error> {
160 debug!("Connecting to coordinator at {} with custom config", addr);
161
162 let transport = tokio::time::timeout(
163 config.connect_timeout,
164 tarpc::serde_transport::tcp::connect(addr, Bincode::default),
165 )
166 .await
167 .map_err(|_| std::io::Error::new(std::io::ErrorKind::TimedOut, "connection timeout"))??;
168
169 let mut tarpc_config = Config::default();
170 tarpc_config.max_in_flight_requests = config.max_pending_requests;
171
172 let client = CoordinatorServiceClient::new(tarpc_config, transport).spawn();
173 info!("Connected to coordinator at {}", addr);
174 Ok(client)
175}
176
177pub async fn connect_to_shard_with_retry(
191 addr: SocketAddr,
192 config: &ClientConfig,
193) -> Result<ShardServiceClient, std::io::Error> {
194 let mut last_error = None;
195
196 for attempt in 0..config.retry_attempts {
197 if attempt > 0 {
198 warn!(
199 "Retry attempt {} connecting to shard at {}",
200 attempt + 1,
201 addr
202 );
203 tokio::time::sleep(config.retry_delay).await;
204 }
205
206 match connect_to_shard_with_config(addr, config).await {
207 Ok(client) => {
208 if attempt > 0 {
209 info!(
210 "Successfully connected to shard at {} after {} attempts",
211 addr,
212 attempt + 1
213 );
214 }
215 return Ok(client);
216 }
217 Err(e) => {
218 warn!("Failed to connect to shard at {}: {}", addr, e);
219 last_error = Some(e);
220 }
221 }
222 }
223
224 error!(
225 "Failed to connect to shard at {} after {} attempts",
226 addr, config.retry_attempts
227 );
228 Err(last_error.unwrap_or_else(|| {
229 std::io::Error::new(std::io::ErrorKind::NotConnected, "connection failed")
230 }))
231}
232
233pub async fn connect_to_coordinator_with_retry(
244 addr: SocketAddr,
245 config: &ClientConfig,
246) -> Result<CoordinatorServiceClient, std::io::Error> {
247 let mut last_error = None;
248
249 for attempt in 0..config.retry_attempts {
250 if attempt > 0 {
251 warn!(
252 "Retry attempt {} connecting to coordinator at {}",
253 attempt + 1,
254 addr
255 );
256 tokio::time::sleep(config.retry_delay).await;
257 }
258
259 match connect_to_coordinator_with_config(addr, config).await {
260 Ok(client) => {
261 if attempt > 0 {
262 info!(
263 "Successfully connected to coordinator at {} after {} attempts",
264 addr,
265 attempt + 1
266 );
267 }
268 return Ok(client);
269 }
270 Err(e) => {
271 warn!("Failed to connect to coordinator at {}: {}", addr, e);
272 last_error = Some(e);
273 }
274 }
275 }
276
277 error!(
278 "Failed to connect to coordinator at {} after {} attempts",
279 addr, config.retry_attempts
280 );
281 Err(last_error.unwrap_or_else(|| {
282 std::io::Error::new(std::io::ErrorKind::NotConnected, "connection failed")
283 }))
284}
285
286pub struct ShardClientPool {
307 addresses: Arc<RwLock<HashMap<ShardId, SocketAddr>>>,
309 clients: Arc<RwLock<HashMap<ShardId, ShardServiceClient>>>,
311 config: ClientConfig,
313}
314
315impl ShardClientPool {
316 pub fn new() -> Self {
318 Self {
319 addresses: Arc::new(RwLock::new(HashMap::new())),
320 clients: Arc::new(RwLock::new(HashMap::new())),
321 config: ClientConfig::default(),
322 }
323 }
324
325 pub fn with_config(config: ClientConfig) -> Self {
327 Self {
328 addresses: Arc::new(RwLock::new(HashMap::new())),
329 clients: Arc::new(RwLock::new(HashMap::new())),
330 config,
331 }
332 }
333
334 pub async fn register_shard(&self, shard_id: ShardId, addr: SocketAddr) {
339 let mut addresses = self.addresses.write().await;
340 addresses.insert(shard_id, addr);
341 debug!("Registered shard {:?} at {}", shard_id, addr);
342 }
343
344 pub async fn unregister_shard(&self, shard_id: ShardId) {
346 let mut addresses = self.addresses.write().await;
347 addresses.remove(&shard_id);
348
349 let mut clients = self.clients.write().await;
350 clients.remove(&shard_id);
351 debug!("Unregistered shard {:?}", shard_id);
352 }
353
354 pub async fn get_client(
362 &self,
363 shard_id: ShardId,
364 ) -> Result<ShardServiceClient, std::io::Error> {
365 {
367 let clients = self.clients.read().await;
368 if let Some(client) = clients.get(&shard_id) {
369 return Ok(client.clone());
370 }
371 }
372
373 let addr = {
375 let addresses = self.addresses.read().await;
376 addresses.get(&shard_id).copied()
377 };
378
379 let addr = addr.ok_or_else(|| {
380 std::io::Error::new(
381 std::io::ErrorKind::NotFound,
382 format!("shard {:?} not registered", shard_id),
383 )
384 })?;
385
386 let client = connect_to_shard_with_retry(addr, &self.config).await?;
388
389 {
391 let mut clients = self.clients.write().await;
392 clients.insert(shard_id, client.clone());
393 }
394
395 Ok(client)
396 }
397
398 pub async fn get_all_clients(&self) -> Vec<(ShardId, ShardServiceClient)> {
403 let addresses: Vec<_> = {
404 let addresses = self.addresses.read().await;
405 addresses.iter().map(|(&id, &addr)| (id, addr)).collect()
406 };
407
408 let mut results = Vec::with_capacity(addresses.len());
409 for (shard_id, _) in addresses {
410 match self.get_client(shard_id).await {
411 Ok(client) => results.push((shard_id, client)),
412 Err(e) => warn!("Failed to get client for shard {:?}: {}", shard_id, e),
413 }
414 }
415
416 results
417 }
418
419 pub async fn has_shard(&self, shard_id: ShardId) -> bool {
421 let addresses = self.addresses.read().await;
422 addresses.contains_key(&shard_id)
423 }
424
425 pub async fn shard_count(&self) -> usize {
427 let addresses = self.addresses.read().await;
428 addresses.len()
429 }
430
431 pub async fn cached_connection_count(&self) -> usize {
433 let clients = self.clients.read().await;
434 clients.len()
435 }
436
437 pub async fn clear_cache(&self) {
441 let mut clients = self.clients.write().await;
442 clients.clear();
443 debug!("Cleared connection cache");
444 }
445
446 pub async fn invalidate_client(&self, shard_id: ShardId) {
450 let mut clients = self.clients.write().await;
451 clients.remove(&shard_id);
452 debug!("Invalidated cached client for shard {:?}", shard_id);
453 }
454}
455
456impl Default for ShardClientPool {
457 fn default() -> Self {
458 Self::new()
459 }
460}
461
462impl Clone for ShardClientPool {
463 fn clone(&self) -> Self {
464 Self {
465 addresses: Arc::clone(&self.addresses),
466 clients: Arc::clone(&self.clients),
467 config: self.config.clone(),
468 }
469 }
470}
471
472pub struct CoordinatorClient {
477 addr: SocketAddr,
479 client: Arc<RwLock<Option<CoordinatorServiceClient>>>,
481 config: ClientConfig,
483}
484
485impl CoordinatorClient {
486 pub fn new(addr: SocketAddr) -> Self {
491 Self {
492 addr,
493 client: Arc::new(RwLock::new(None)),
494 config: ClientConfig::default(),
495 }
496 }
497
498 pub fn with_config(addr: SocketAddr, config: ClientConfig) -> Self {
500 Self {
501 addr,
502 client: Arc::new(RwLock::new(None)),
503 config,
504 }
505 }
506
507 pub async fn get(&self) -> Result<CoordinatorServiceClient, std::io::Error> {
509 {
511 let client = self.client.read().await;
512 if let Some(ref c) = *client {
513 return Ok(c.clone());
514 }
515 }
516
517 let new_client = connect_to_coordinator_with_retry(self.addr, &self.config).await?;
519
520 {
522 let mut client = self.client.write().await;
523 *client = Some(new_client.clone());
524 }
525
526 Ok(new_client)
527 }
528
529 pub async fn reconnect(&self) -> Result<CoordinatorServiceClient, std::io::Error> {
533 {
535 let mut client = self.client.write().await;
536 *client = None;
537 }
538
539 self.get().await
541 }
542
543 pub async fn invalidate(&self) {
545 let mut client = self.client.write().await;
546 *client = None;
547 debug!("Invalidated coordinator client");
548 }
549
550 pub fn addr(&self) -> SocketAddr {
552 self.addr
553 }
554}
555
556impl Clone for CoordinatorClient {
557 fn clone(&self) -> Self {
558 Self {
559 addr: self.addr,
560 client: Arc::clone(&self.client),
561 config: self.config.clone(),
562 }
563 }
564}
565
566#[cfg(test)]
567mod tests {
568 use super::*;
569
570 #[test]
571 fn test_client_config_default() {
572 let config = ClientConfig::default();
573 assert_eq!(
574 config.connect_timeout.as_millis(),
575 DEFAULT_CONNECT_TIMEOUT_MS as u128
576 );
577 assert_eq!(config.retry_attempts, DEFAULT_RETRY_ATTEMPTS);
578 assert_eq!(
579 config.retry_delay.as_millis(),
580 DEFAULT_RETRY_DELAY_MS as u128
581 );
582 }
583
584 #[tokio::test]
585 async fn test_shard_client_pool_register() {
586 let pool = ShardClientPool::new();
587
588 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
589 pool.register_shard(ShardId::new(0), addr).await;
590
591 assert!(pool.has_shard(ShardId::new(0)).await);
592 assert!(!pool.has_shard(ShardId::new(1)).await);
593 assert_eq!(pool.shard_count().await, 1);
594 }
595
596 #[tokio::test]
597 async fn test_shard_client_pool_unregister() {
598 let pool = ShardClientPool::new();
599
600 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
601 pool.register_shard(ShardId::new(0), addr).await;
602 pool.unregister_shard(ShardId::new(0)).await;
603
604 assert!(!pool.has_shard(ShardId::new(0)).await);
605 assert_eq!(pool.shard_count().await, 0);
606 }
607
608 #[tokio::test]
609 async fn test_shard_client_pool_get_client_not_registered() {
610 let pool = ShardClientPool::new();
611
612 let result = pool.get_client(ShardId::new(0)).await;
613 assert!(result.is_err());
614 assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::NotFound);
615 }
616
617 #[tokio::test]
618 async fn test_shard_client_pool_clear_cache() {
619 let pool = ShardClientPool::new();
620
621 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
622 pool.register_shard(ShardId::new(0), addr).await;
623
624 assert_eq!(pool.cached_connection_count().await, 0);
626
627 pool.clear_cache().await;
629 assert_eq!(pool.cached_connection_count().await, 0);
630 }
631
632 #[tokio::test]
633 async fn test_shard_client_pool_invalidate_client() {
634 let pool = ShardClientPool::new();
635
636 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
637 pool.register_shard(ShardId::new(0), addr).await;
638
639 pool.invalidate_client(ShardId::new(0)).await;
641 assert_eq!(pool.cached_connection_count().await, 0);
642 }
643
644 #[tokio::test]
645 async fn test_shard_client_pool_clone() {
646 let pool = ShardClientPool::new();
647
648 let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
649 pool.register_shard(ShardId::new(0), addr).await;
650
651 let pool_clone = pool.clone();
652 assert!(pool_clone.has_shard(ShardId::new(0)).await);
653
654 pool_clone
656 .register_shard(ShardId::new(1), "127.0.0.1:8081".parse().unwrap())
657 .await;
658 assert!(pool.has_shard(ShardId::new(1)).await);
659 }
660
661 #[test]
662 fn test_coordinator_client_new() {
663 let addr: SocketAddr = "127.0.0.1:8000".parse().unwrap();
664 let client = CoordinatorClient::new(addr);
665
666 assert_eq!(client.addr(), addr);
667 }
668
669 #[tokio::test]
670 async fn test_coordinator_client_invalidate() {
671 let addr: SocketAddr = "127.0.0.1:8000".parse().unwrap();
672 let client = CoordinatorClient::new(addr);
673
674 client.invalidate().await;
676 }
677
678 #[tokio::test]
679 async fn test_coordinator_client_clone() {
680 let addr: SocketAddr = "127.0.0.1:8000".parse().unwrap();
681 let client = CoordinatorClient::new(addr);
682 let client_clone = client.clone();
683
684 assert_eq!(client_clone.addr(), addr);
685 }
686
687 #[tokio::test]
688 async fn test_shard_client_pool_with_config() {
689 let config = ClientConfig {
690 connect_timeout: Duration::from_secs(10),
691 retry_attempts: 5,
692 retry_delay: Duration::from_millis(200),
693 max_pending_requests: 50,
694 };
695
696 let pool = ShardClientPool::with_config(config);
697 assert_eq!(pool.shard_count().await, 0);
698 }
699
700 #[test]
701 fn test_coordinator_client_with_config() {
702 let addr: SocketAddr = "127.0.0.1:8000".parse().unwrap();
703 let config = ClientConfig {
704 connect_timeout: Duration::from_secs(10),
705 retry_attempts: 5,
706 retry_delay: Duration::from_millis(200),
707 max_pending_requests: 50,
708 };
709
710 let client = CoordinatorClient::with_config(addr, config);
711 assert_eq!(client.addr(), addr);
712 }
713}