Skip to main content

forge_client/
reconnect.rs

1//! Auto-reconnecting client decorator for transport resilience.
2//!
3//! Wraps an [`McpClient`] and automatically reconnects when a
4//! [`TransportDead`](forge_error::DispatchError::TransportDead) error is detected.
5
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::Arc;
8use std::time::Duration;
9
10use forge_sandbox::{ResourceDispatcher, ToolDispatcher};
11use serde_json::Value;
12use tokio::sync::{Mutex, RwLock};
13
14use crate::{McpClient, TransportConfig};
15
16/// A client wrapper that auto-reconnects on transport death.
17///
18/// Implements both [`ToolDispatcher`] and [`ResourceDispatcher`] by delegating
19/// to an inner [`McpClient`]. When a `TransportDead` error is detected, the
20/// client is transparently replaced with a fresh connection.
21///
22/// Uses a `RwLock` for the inner client so that in-flight read operations
23/// (tool calls, resource reads) can proceed concurrently, while reconnection
24/// briefly takes a write lock.
25pub struct ReconnectingClient {
26    name: String,
27    transport_config: TransportConfig,
28    inner: RwLock<Arc<McpClient>>,
29    reconnecting: AtomicBool,
30    max_backoff: Duration,
31    current_backoff: Mutex<Duration>,
32}
33
34impl ReconnectingClient {
35    /// Create a new reconnecting wrapper around an existing client.
36    pub fn new(
37        name: String,
38        transport_config: TransportConfig,
39        client: Arc<McpClient>,
40        max_backoff: Duration,
41    ) -> Self {
42        Self {
43            name,
44            transport_config,
45            inner: RwLock::new(client),
46            reconnecting: AtomicBool::new(false),
47            max_backoff,
48            current_backoff: Mutex::new(Duration::from_secs(1)),
49        }
50    }
51
52    /// Attempt to reconnect, returning true on success.
53    ///
54    /// Uses CAS on `reconnecting` to ensure only one reconnection attempt
55    /// proceeds at a time. Other callers wait for the reconnection to complete.
56    async fn try_reconnect(&self) -> bool {
57        // CAS: only one task reconnects at a time
58        if self
59            .reconnecting
60            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
61            .is_err()
62        {
63            // Another task is already reconnecting — wait for it
64            tracing::debug!(server = %self.name, "waiting for concurrent reconnection");
65            // Brief sleep to let the other reconnection attempt complete
66            tokio::time::sleep(Duration::from_millis(100)).await;
67            // Check if the reconnection flag was cleared (success)
68            return !self.reconnecting.load(Ordering::SeqCst);
69        }
70
71        // Apply backoff delay
72        let backoff = {
73            let guard = self.current_backoff.lock().await;
74            *guard
75        };
76        tracing::info!(
77            server = %self.name,
78            backoff_ms = backoff.as_millis(),
79            "attempting reconnection after backoff"
80        );
81        tokio::time::sleep(backoff).await;
82
83        // Try to create a new connection
84        match McpClient::connect(self.name.clone(), &self.transport_config).await {
85            Ok(new_client) => {
86                tracing::info!(server = %self.name, "reconnection successful");
87                // Swap the client under write lock
88                {
89                    let mut inner = self.inner.write().await;
90                    *inner = Arc::new(new_client);
91                }
92                // Reset backoff on success
93                {
94                    let mut guard = self.current_backoff.lock().await;
95                    *guard = Duration::from_secs(1);
96                }
97                self.reconnecting.store(false, Ordering::SeqCst);
98                true
99            }
100            Err(e) => {
101                tracing::warn!(
102                    server = %self.name,
103                    error = %e,
104                    "reconnection failed"
105                );
106                // Increase backoff (exponential, capped)
107                {
108                    let mut guard = self.current_backoff.lock().await;
109                    *guard = (*guard * 2).min(self.max_backoff);
110                }
111                self.reconnecting.store(false, Ordering::SeqCst);
112                false
113            }
114        }
115    }
116
117    /// Get a clone of the current inner client.
118    async fn current_client(&self) -> Arc<McpClient> {
119        self.inner.read().await.clone()
120    }
121}
122
123#[async_trait::async_trait]
124impl ToolDispatcher for ReconnectingClient {
125    async fn call_tool(
126        &self,
127        server: &str,
128        tool: &str,
129        args: Value,
130    ) -> Result<Value, forge_error::DispatchError> {
131        let client = self.current_client().await;
132        let result = client.call_tool(server, tool, args.clone()).await;
133
134        match result {
135            Err(forge_error::DispatchError::TransportDead { .. }) => {
136                tracing::warn!(
137                    server = %self.name,
138                    tool = %tool,
139                    "transport dead, attempting reconnection"
140                );
141                if self.try_reconnect().await {
142                    // Retry once with the new client
143                    let new_client = self.current_client().await;
144                    new_client.call_tool(server, tool, args).await
145                } else {
146                    Err(forge_error::DispatchError::TransportDead {
147                        server: self.name.clone(),
148                        reason: "reconnection failed".into(),
149                    })
150                }
151            }
152            other => other,
153        }
154    }
155}
156
157#[async_trait::async_trait]
158impl ResourceDispatcher for ReconnectingClient {
159    async fn read_resource(
160        &self,
161        server: &str,
162        uri: &str,
163    ) -> Result<Value, forge_error::DispatchError> {
164        let client = self.current_client().await;
165        let result = ResourceDispatcher::read_resource(client.as_ref(), server, uri).await;
166
167        match result {
168            Err(forge_error::DispatchError::TransportDead { .. }) => {
169                tracing::warn!(
170                    server = %self.name,
171                    uri = %uri,
172                    "transport dead, attempting reconnection"
173                );
174                if self.try_reconnect().await {
175                    let new_client = self.current_client().await;
176                    ResourceDispatcher::read_resource(new_client.as_ref(), server, uri).await
177                } else {
178                    Err(forge_error::DispatchError::TransportDead {
179                        server: self.name.clone(),
180                        reason: "reconnection failed".into(),
181                    })
182                }
183            }
184            other => other,
185        }
186    }
187}