Skip to main content

phago_distributed/rpc/
client.rs

1//! tarpc client utilities.
2//!
3//! This module provides client utilities for connecting to remote shards
4//! and the coordinator. It includes connection functions, retry logic,
5//! and a connection pool for efficient client reuse.
6
7use crate::rpc::protocol::{CoordinatorServiceClient, ShardServiceClient};
8use crate::types::ShardId;
9use std::collections::HashMap;
10use std::net::SocketAddr;
11use std::sync::Arc;
12use std::time::Duration;
13use tarpc::client::Config;
14use tokio::sync::RwLock;
15use tokio_serde::formats::Bincode;
16use tracing::{debug, error, info, warn};
17
18/// Default connection timeout in milliseconds.
19const DEFAULT_CONNECT_TIMEOUT_MS: u64 = 5000;
20
21/// Default number of retry attempts for failed connections.
22const DEFAULT_RETRY_ATTEMPTS: u32 = 3;
23
24/// Default delay between retry attempts in milliseconds.
25const DEFAULT_RETRY_DELAY_MS: u64 = 500;
26
27/// Configuration for client connections.
28#[derive(Debug, Clone)]
29pub struct ClientConfig {
30    /// Connection timeout.
31    pub connect_timeout: Duration,
32    /// Number of retry attempts.
33    pub retry_attempts: u32,
34    /// Delay between retries.
35    pub retry_delay: Duration,
36    /// Maximum pending requests per client.
37    pub max_pending_requests: usize,
38}
39
40impl Default for ClientConfig {
41    fn default() -> Self {
42        Self {
43            connect_timeout: Duration::from_millis(DEFAULT_CONNECT_TIMEOUT_MS),
44            retry_attempts: DEFAULT_RETRY_ATTEMPTS,
45            retry_delay: Duration::from_millis(DEFAULT_RETRY_DELAY_MS),
46            max_pending_requests: 100,
47        }
48    }
49}
50
51/// Create a client connection to a shard.
52///
53/// Establishes a TCP connection to the shard at the given address and
54/// returns a tarpc client for making RPC calls.
55///
56/// # Arguments
57///
58/// * `addr` - The socket address of the shard server
59///
60/// # Errors
61///
62/// Returns an error if the connection cannot be established.
63///
64/// # Example
65///
66/// ```rust,ignore
67/// use phago_distributed::rpc::client::connect_to_shard;
68///
69/// let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
70/// let client = connect_to_shard(addr).await?;
71///
72/// let health = client.health_check(tarpc::context::current()).await?;
73/// ```
74pub async fn connect_to_shard(addr: SocketAddr) -> Result<ShardServiceClient, std::io::Error> {
75    debug!("Connecting to shard at {}", addr);
76    let transport = tarpc::serde_transport::tcp::connect(addr, Bincode::default).await?;
77    let client = ShardServiceClient::new(Config::default(), transport).spawn();
78    info!("Connected to shard at {}", addr);
79    Ok(client)
80}
81
82/// Create a client connection to a shard with custom configuration.
83///
84/// # Arguments
85///
86/// * `addr` - The socket address of the shard server
87/// * `config` - Client configuration
88///
89/// # Errors
90///
91/// Returns an error if the connection cannot be established.
92pub async fn connect_to_shard_with_config(
93    addr: SocketAddr,
94    config: &ClientConfig,
95) -> Result<ShardServiceClient, std::io::Error> {
96    debug!("Connecting to shard at {} with custom config", addr);
97
98    let transport = tokio::time::timeout(
99        config.connect_timeout,
100        tarpc::serde_transport::tcp::connect(addr, Bincode::default),
101    )
102    .await
103    .map_err(|_| std::io::Error::new(std::io::ErrorKind::TimedOut, "connection timeout"))??;
104
105    let mut tarpc_config = Config::default();
106    tarpc_config.max_in_flight_requests = config.max_pending_requests;
107
108    let client = ShardServiceClient::new(tarpc_config, transport).spawn();
109    info!("Connected to shard at {}", addr);
110    Ok(client)
111}
112
113/// Create a client connection to the coordinator.
114///
115/// Establishes a TCP connection to the coordinator at the given address and
116/// returns a tarpc client for making RPC calls.
117///
118/// # Arguments
119///
120/// * `addr` - The socket address of the coordinator server
121///
122/// # Errors
123///
124/// Returns an error if the connection cannot be established.
125///
126/// # Example
127///
128/// ```rust,ignore
129/// use phago_distributed::rpc::client::connect_to_coordinator;
130///
131/// let addr: SocketAddr = "127.0.0.1:8000".parse().unwrap();
132/// let client = connect_to_coordinator(addr).await?;
133///
134/// let shards = client.list_shards(tarpc::context::current()).await?;
135/// ```
136pub async fn connect_to_coordinator(
137    addr: SocketAddr,
138) -> Result<CoordinatorServiceClient, std::io::Error> {
139    debug!("Connecting to coordinator at {}", addr);
140    let transport = tarpc::serde_transport::tcp::connect(addr, Bincode::default).await?;
141    let client = CoordinatorServiceClient::new(Config::default(), transport).spawn();
142    info!("Connected to coordinator at {}", addr);
143    Ok(client)
144}
145
146/// Create a client connection to the coordinator with custom configuration.
147///
148/// # Arguments
149///
150/// * `addr` - The socket address of the coordinator server
151/// * `config` - Client configuration
152///
153/// # Errors
154///
155/// Returns an error if the connection cannot be established.
156pub async fn connect_to_coordinator_with_config(
157    addr: SocketAddr,
158    config: &ClientConfig,
159) -> Result<CoordinatorServiceClient, std::io::Error> {
160    debug!("Connecting to coordinator at {} with custom config", addr);
161
162    let transport = tokio::time::timeout(
163        config.connect_timeout,
164        tarpc::serde_transport::tcp::connect(addr, Bincode::default),
165    )
166    .await
167    .map_err(|_| std::io::Error::new(std::io::ErrorKind::TimedOut, "connection timeout"))??;
168
169    let mut tarpc_config = Config::default();
170    tarpc_config.max_in_flight_requests = config.max_pending_requests;
171
172    let client = CoordinatorServiceClient::new(tarpc_config, transport).spawn();
173    info!("Connected to coordinator at {}", addr);
174    Ok(client)
175}
176
177/// Connect to a shard with automatic retry on failure.
178///
179/// Attempts to connect to the shard, retrying on failure according to
180/// the configuration.
181///
182/// # Arguments
183///
184/// * `addr` - The socket address of the shard server
185/// * `config` - Client configuration including retry settings
186///
187/// # Errors
188///
189/// Returns an error if all retry attempts fail.
190pub async fn connect_to_shard_with_retry(
191    addr: SocketAddr,
192    config: &ClientConfig,
193) -> Result<ShardServiceClient, std::io::Error> {
194    let mut last_error = None;
195
196    for attempt in 0..config.retry_attempts {
197        if attempt > 0 {
198            warn!(
199                "Retry attempt {} connecting to shard at {}",
200                attempt + 1,
201                addr
202            );
203            tokio::time::sleep(config.retry_delay).await;
204        }
205
206        match connect_to_shard_with_config(addr, config).await {
207            Ok(client) => {
208                if attempt > 0 {
209                    info!(
210                        "Successfully connected to shard at {} after {} attempts",
211                        addr,
212                        attempt + 1
213                    );
214                }
215                return Ok(client);
216            }
217            Err(e) => {
218                warn!("Failed to connect to shard at {}: {}", addr, e);
219                last_error = Some(e);
220            }
221        }
222    }
223
224    error!(
225        "Failed to connect to shard at {} after {} attempts",
226        addr, config.retry_attempts
227    );
228    Err(last_error.unwrap_or_else(|| {
229        std::io::Error::new(std::io::ErrorKind::NotConnected, "connection failed")
230    }))
231}
232
233/// Connect to the coordinator with automatic retry on failure.
234///
235/// # Arguments
236///
237/// * `addr` - The socket address of the coordinator server
238/// * `config` - Client configuration including retry settings
239///
240/// # Errors
241///
242/// Returns an error if all retry attempts fail.
243pub async fn connect_to_coordinator_with_retry(
244    addr: SocketAddr,
245    config: &ClientConfig,
246) -> Result<CoordinatorServiceClient, std::io::Error> {
247    let mut last_error = None;
248
249    for attempt in 0..config.retry_attempts {
250        if attempt > 0 {
251            warn!(
252                "Retry attempt {} connecting to coordinator at {}",
253                attempt + 1,
254                addr
255            );
256            tokio::time::sleep(config.retry_delay).await;
257        }
258
259        match connect_to_coordinator_with_config(addr, config).await {
260            Ok(client) => {
261                if attempt > 0 {
262                    info!(
263                        "Successfully connected to coordinator at {} after {} attempts",
264                        addr,
265                        attempt + 1
266                    );
267                }
268                return Ok(client);
269            }
270            Err(e) => {
271                warn!("Failed to connect to coordinator at {}: {}", addr, e);
272                last_error = Some(e);
273            }
274        }
275    }
276
277    error!(
278        "Failed to connect to coordinator at {} after {} attempts",
279        addr, config.retry_attempts
280    );
281    Err(last_error.unwrap_or_else(|| {
282        std::io::Error::new(std::io::ErrorKind::NotConnected, "connection failed")
283    }))
284}
285
286/// A pool of shard client connections.
287///
288/// The connection pool maintains a cached set of connections to shards,
289/// creating new connections on demand and reusing existing ones.
290///
291/// # Thread Safety
292///
293/// The pool is thread-safe and can be shared across multiple tasks.
294///
295/// # Example
296///
297/// ```rust,ignore
298/// use phago_distributed::rpc::client::ShardClientPool;
299///
300/// let pool = ShardClientPool::new();
301/// pool.register_shard(ShardId::new(0), "127.0.0.1:8080".parse().unwrap());
302///
303/// let client = pool.get_client(ShardId::new(0)).await?;
304/// let health = client.health_check(tarpc::context::current()).await?;
305/// ```
306pub struct ShardClientPool {
307    /// Mapping from shard ID to address.
308    addresses: Arc<RwLock<HashMap<ShardId, SocketAddr>>>,
309    /// Cached client connections.
310    clients: Arc<RwLock<HashMap<ShardId, ShardServiceClient>>>,
311    /// Client configuration.
312    config: ClientConfig,
313}
314
315impl ShardClientPool {
316    /// Create a new empty connection pool.
317    pub fn new() -> Self {
318        Self {
319            addresses: Arc::new(RwLock::new(HashMap::new())),
320            clients: Arc::new(RwLock::new(HashMap::new())),
321            config: ClientConfig::default(),
322        }
323    }
324
325    /// Create a new connection pool with custom configuration.
326    pub fn with_config(config: ClientConfig) -> Self {
327        Self {
328            addresses: Arc::new(RwLock::new(HashMap::new())),
329            clients: Arc::new(RwLock::new(HashMap::new())),
330            config,
331        }
332    }
333
334    /// Register a shard's address.
335    ///
336    /// This does not establish a connection immediately; connections
337    /// are created lazily when `get_client` is called.
338    pub async fn register_shard(&self, shard_id: ShardId, addr: SocketAddr) {
339        let mut addresses = self.addresses.write().await;
340        addresses.insert(shard_id, addr);
341        debug!("Registered shard {:?} at {}", shard_id, addr);
342    }
343
344    /// Unregister a shard and close any cached connection.
345    pub async fn unregister_shard(&self, shard_id: ShardId) {
346        let mut addresses = self.addresses.write().await;
347        addresses.remove(&shard_id);
348
349        let mut clients = self.clients.write().await;
350        clients.remove(&shard_id);
351        debug!("Unregistered shard {:?}", shard_id);
352    }
353
354    /// Get a client for the specified shard.
355    ///
356    /// Returns a cached client if available, otherwise creates a new connection.
357    ///
358    /// # Errors
359    ///
360    /// Returns an error if the shard is not registered or if the connection fails.
361    pub async fn get_client(
362        &self,
363        shard_id: ShardId,
364    ) -> Result<ShardServiceClient, std::io::Error> {
365        // Check for cached client
366        {
367            let clients = self.clients.read().await;
368            if let Some(client) = clients.get(&shard_id) {
369                return Ok(client.clone());
370            }
371        }
372
373        // Get address
374        let addr = {
375            let addresses = self.addresses.read().await;
376            addresses.get(&shard_id).copied()
377        };
378
379        let addr = addr.ok_or_else(|| {
380            std::io::Error::new(
381                std::io::ErrorKind::NotFound,
382                format!("shard {:?} not registered", shard_id),
383            )
384        })?;
385
386        // Create new connection
387        let client = connect_to_shard_with_retry(addr, &self.config).await?;
388
389        // Cache the client
390        {
391            let mut clients = self.clients.write().await;
392            clients.insert(shard_id, client.clone());
393        }
394
395        Ok(client)
396    }
397
398    /// Get clients for all registered shards.
399    ///
400    /// Attempts to connect to all shards, returning successfully connected clients.
401    /// Failed connections are logged but do not cause the entire operation to fail.
402    pub async fn get_all_clients(&self) -> Vec<(ShardId, ShardServiceClient)> {
403        let addresses: Vec<_> = {
404            let addresses = self.addresses.read().await;
405            addresses.iter().map(|(&id, &addr)| (id, addr)).collect()
406        };
407
408        let mut results = Vec::with_capacity(addresses.len());
409        for (shard_id, _) in addresses {
410            match self.get_client(shard_id).await {
411                Ok(client) => results.push((shard_id, client)),
412                Err(e) => warn!("Failed to get client for shard {:?}: {}", shard_id, e),
413            }
414        }
415
416        results
417    }
418
419    /// Check if a shard is registered.
420    pub async fn has_shard(&self, shard_id: ShardId) -> bool {
421        let addresses = self.addresses.read().await;
422        addresses.contains_key(&shard_id)
423    }
424
425    /// Get the number of registered shards.
426    pub async fn shard_count(&self) -> usize {
427        let addresses = self.addresses.read().await;
428        addresses.len()
429    }
430
431    /// Get the number of cached connections.
432    pub async fn cached_connection_count(&self) -> usize {
433        let clients = self.clients.read().await;
434        clients.len()
435    }
436
437    /// Clear all cached connections.
438    ///
439    /// This forces new connections to be created on the next `get_client` call.
440    pub async fn clear_cache(&self) {
441        let mut clients = self.clients.write().await;
442        clients.clear();
443        debug!("Cleared connection cache");
444    }
445
446    /// Remove a cached connection for a specific shard.
447    ///
448    /// Useful for forcing a reconnection after a failure.
449    pub async fn invalidate_client(&self, shard_id: ShardId) {
450        let mut clients = self.clients.write().await;
451        clients.remove(&shard_id);
452        debug!("Invalidated cached client for shard {:?}", shard_id);
453    }
454}
455
456impl Default for ShardClientPool {
457    fn default() -> Self {
458        Self::new()
459    }
460}
461
462impl Clone for ShardClientPool {
463    fn clone(&self) -> Self {
464        Self {
465            addresses: Arc::clone(&self.addresses),
466            clients: Arc::clone(&self.clients),
467            config: self.config.clone(),
468        }
469    }
470}
471
472/// A coordinator client wrapper with automatic reconnection.
473///
474/// This wrapper maintains a connection to the coordinator and
475/// automatically attempts to reconnect if the connection is lost.
476pub struct CoordinatorClient {
477    /// The coordinator's address.
478    addr: SocketAddr,
479    /// The cached client connection.
480    client: Arc<RwLock<Option<CoordinatorServiceClient>>>,
481    /// Client configuration.
482    config: ClientConfig,
483}
484
485impl CoordinatorClient {
486    /// Create a new coordinator client.
487    ///
488    /// This does not establish a connection immediately; the connection
489    /// is created lazily on the first call to `get`.
490    pub fn new(addr: SocketAddr) -> Self {
491        Self {
492            addr,
493            client: Arc::new(RwLock::new(None)),
494            config: ClientConfig::default(),
495        }
496    }
497
498    /// Create a new coordinator client with custom configuration.
499    pub fn with_config(addr: SocketAddr, config: ClientConfig) -> Self {
500        Self {
501            addr,
502            client: Arc::new(RwLock::new(None)),
503            config,
504        }
505    }
506
507    /// Get the coordinator client, creating a connection if needed.
508    pub async fn get(&self) -> Result<CoordinatorServiceClient, std::io::Error> {
509        // Check for cached client
510        {
511            let client = self.client.read().await;
512            if let Some(ref c) = *client {
513                return Ok(c.clone());
514            }
515        }
516
517        // Create new connection
518        let new_client = connect_to_coordinator_with_retry(self.addr, &self.config).await?;
519
520        // Cache the client
521        {
522            let mut client = self.client.write().await;
523            *client = Some(new_client.clone());
524        }
525
526        Ok(new_client)
527    }
528
529    /// Force a reconnection.
530    ///
531    /// Useful after detecting a connection failure.
532    pub async fn reconnect(&self) -> Result<CoordinatorServiceClient, std::io::Error> {
533        // Clear the cached client
534        {
535            let mut client = self.client.write().await;
536            *client = None;
537        }
538
539        // Get a new connection
540        self.get().await
541    }
542
543    /// Invalidate the cached connection without reconnecting.
544    pub async fn invalidate(&self) {
545        let mut client = self.client.write().await;
546        *client = None;
547        debug!("Invalidated coordinator client");
548    }
549
550    /// Get the coordinator's address.
551    pub fn addr(&self) -> SocketAddr {
552        self.addr
553    }
554}
555
556impl Clone for CoordinatorClient {
557    fn clone(&self) -> Self {
558        Self {
559            addr: self.addr,
560            client: Arc::clone(&self.client),
561            config: self.config.clone(),
562        }
563    }
564}
565
566#[cfg(test)]
567mod tests {
568    use super::*;
569
570    #[test]
571    fn test_client_config_default() {
572        let config = ClientConfig::default();
573        assert_eq!(
574            config.connect_timeout.as_millis(),
575            DEFAULT_CONNECT_TIMEOUT_MS as u128
576        );
577        assert_eq!(config.retry_attempts, DEFAULT_RETRY_ATTEMPTS);
578        assert_eq!(
579            config.retry_delay.as_millis(),
580            DEFAULT_RETRY_DELAY_MS as u128
581        );
582    }
583
584    #[tokio::test]
585    async fn test_shard_client_pool_register() {
586        let pool = ShardClientPool::new();
587
588        let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
589        pool.register_shard(ShardId::new(0), addr).await;
590
591        assert!(pool.has_shard(ShardId::new(0)).await);
592        assert!(!pool.has_shard(ShardId::new(1)).await);
593        assert_eq!(pool.shard_count().await, 1);
594    }
595
596    #[tokio::test]
597    async fn test_shard_client_pool_unregister() {
598        let pool = ShardClientPool::new();
599
600        let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
601        pool.register_shard(ShardId::new(0), addr).await;
602        pool.unregister_shard(ShardId::new(0)).await;
603
604        assert!(!pool.has_shard(ShardId::new(0)).await);
605        assert_eq!(pool.shard_count().await, 0);
606    }
607
608    #[tokio::test]
609    async fn test_shard_client_pool_get_client_not_registered() {
610        let pool = ShardClientPool::new();
611
612        let result = pool.get_client(ShardId::new(0)).await;
613        assert!(result.is_err());
614        assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::NotFound);
615    }
616
617    #[tokio::test]
618    async fn test_shard_client_pool_clear_cache() {
619        let pool = ShardClientPool::new();
620
621        let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
622        pool.register_shard(ShardId::new(0), addr).await;
623
624        // The cache should be empty initially
625        assert_eq!(pool.cached_connection_count().await, 0);
626
627        // Clear should work even when empty
628        pool.clear_cache().await;
629        assert_eq!(pool.cached_connection_count().await, 0);
630    }
631
632    #[tokio::test]
633    async fn test_shard_client_pool_invalidate_client() {
634        let pool = ShardClientPool::new();
635
636        let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
637        pool.register_shard(ShardId::new(0), addr).await;
638
639        // Invalidate should work even without a cached client
640        pool.invalidate_client(ShardId::new(0)).await;
641        assert_eq!(pool.cached_connection_count().await, 0);
642    }
643
644    #[tokio::test]
645    async fn test_shard_client_pool_clone() {
646        let pool = ShardClientPool::new();
647
648        let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
649        pool.register_shard(ShardId::new(0), addr).await;
650
651        let pool_clone = pool.clone();
652        assert!(pool_clone.has_shard(ShardId::new(0)).await);
653
654        // Changes should be visible in both
655        pool_clone
656            .register_shard(ShardId::new(1), "127.0.0.1:8081".parse().unwrap())
657            .await;
658        assert!(pool.has_shard(ShardId::new(1)).await);
659    }
660
661    #[test]
662    fn test_coordinator_client_new() {
663        let addr: SocketAddr = "127.0.0.1:8000".parse().unwrap();
664        let client = CoordinatorClient::new(addr);
665
666        assert_eq!(client.addr(), addr);
667    }
668
669    #[tokio::test]
670    async fn test_coordinator_client_invalidate() {
671        let addr: SocketAddr = "127.0.0.1:8000".parse().unwrap();
672        let client = CoordinatorClient::new(addr);
673
674        // Invalidate should work even without a cached connection
675        client.invalidate().await;
676    }
677
678    #[tokio::test]
679    async fn test_coordinator_client_clone() {
680        let addr: SocketAddr = "127.0.0.1:8000".parse().unwrap();
681        let client = CoordinatorClient::new(addr);
682        let client_clone = client.clone();
683
684        assert_eq!(client_clone.addr(), addr);
685    }
686
687    #[tokio::test]
688    async fn test_shard_client_pool_with_config() {
689        let config = ClientConfig {
690            connect_timeout: Duration::from_secs(10),
691            retry_attempts: 5,
692            retry_delay: Duration::from_millis(200),
693            max_pending_requests: 50,
694        };
695
696        let pool = ShardClientPool::with_config(config);
697        assert_eq!(pool.shard_count().await, 0);
698    }
699
700    #[test]
701    fn test_coordinator_client_with_config() {
702        let addr: SocketAddr = "127.0.0.1:8000".parse().unwrap();
703        let config = ClientConfig {
704            connect_timeout: Duration::from_secs(10),
705            retry_attempts: 5,
706            retry_delay: Duration::from_millis(200),
707            max_pending_requests: 50,
708        };
709
710        let client = CoordinatorClient::with_config(addr, config);
711        assert_eq!(client.addr(), addr);
712    }
713}