forge_client/
reconnect.rs1use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::Arc;
8use std::time::Duration;
9
10use forge_sandbox::{ResourceDispatcher, ToolDispatcher};
11use serde_json::Value;
12use tokio::sync::{Mutex, RwLock};
13
14use crate::{McpClient, TransportConfig};
15
16pub struct ReconnectingClient {
26 name: String,
27 transport_config: TransportConfig,
28 inner: RwLock<Arc<McpClient>>,
29 reconnecting: AtomicBool,
30 max_backoff: Duration,
31 current_backoff: Mutex<Duration>,
32}
33
34impl ReconnectingClient {
35 pub fn new(
37 name: String,
38 transport_config: TransportConfig,
39 client: Arc<McpClient>,
40 max_backoff: Duration,
41 ) -> Self {
42 Self {
43 name,
44 transport_config,
45 inner: RwLock::new(client),
46 reconnecting: AtomicBool::new(false),
47 max_backoff,
48 current_backoff: Mutex::new(Duration::from_secs(1)),
49 }
50 }
51
52 async fn try_reconnect(&self) -> bool {
57 if self
59 .reconnecting
60 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
61 .is_err()
62 {
63 tracing::debug!(server = %self.name, "waiting for concurrent reconnection");
65 tokio::time::sleep(Duration::from_millis(100)).await;
67 return !self.reconnecting.load(Ordering::SeqCst);
69 }
70
71 let backoff = {
73 let guard = self.current_backoff.lock().await;
74 *guard
75 };
76 tracing::info!(
77 server = %self.name,
78 backoff_ms = backoff.as_millis(),
79 "attempting reconnection after backoff"
80 );
81 tokio::time::sleep(backoff).await;
82
83 match McpClient::connect(self.name.clone(), &self.transport_config).await {
85 Ok(new_client) => {
86 tracing::info!(server = %self.name, "reconnection successful");
87 {
89 let mut inner = self.inner.write().await;
90 *inner = Arc::new(new_client);
91 }
92 {
94 let mut guard = self.current_backoff.lock().await;
95 *guard = Duration::from_secs(1);
96 }
97 self.reconnecting.store(false, Ordering::SeqCst);
98 true
99 }
100 Err(e) => {
101 tracing::warn!(
102 server = %self.name,
103 error = %e,
104 "reconnection failed"
105 );
106 {
108 let mut guard = self.current_backoff.lock().await;
109 *guard = (*guard * 2).min(self.max_backoff);
110 }
111 self.reconnecting.store(false, Ordering::SeqCst);
112 false
113 }
114 }
115 }
116
117 async fn current_client(&self) -> Arc<McpClient> {
119 self.inner.read().await.clone()
120 }
121}
122
123#[async_trait::async_trait]
124impl ToolDispatcher for ReconnectingClient {
125 async fn call_tool(
126 &self,
127 server: &str,
128 tool: &str,
129 args: Value,
130 ) -> Result<Value, forge_error::DispatchError> {
131 let client = self.current_client().await;
132 let result = client.call_tool(server, tool, args.clone()).await;
133
134 match result {
135 Err(forge_error::DispatchError::TransportDead { .. }) => {
136 tracing::warn!(
137 server = %self.name,
138 tool = %tool,
139 "transport dead, attempting reconnection"
140 );
141 if self.try_reconnect().await {
142 let new_client = self.current_client().await;
144 new_client.call_tool(server, tool, args).await
145 } else {
146 Err(forge_error::DispatchError::TransportDead {
147 server: self.name.clone(),
148 reason: "reconnection failed".into(),
149 })
150 }
151 }
152 other => other,
153 }
154 }
155}
156
157#[async_trait::async_trait]
158impl ResourceDispatcher for ReconnectingClient {
159 async fn read_resource(
160 &self,
161 server: &str,
162 uri: &str,
163 ) -> Result<Value, forge_error::DispatchError> {
164 let client = self.current_client().await;
165 let result = ResourceDispatcher::read_resource(client.as_ref(), server, uri).await;
166
167 match result {
168 Err(forge_error::DispatchError::TransportDead { .. }) => {
169 tracing::warn!(
170 server = %self.name,
171 uri = %uri,
172 "transport dead, attempting reconnection"
173 );
174 if self.try_reconnect().await {
175 let new_client = self.current_client().await;
176 ResourceDispatcher::read_resource(new_client.as_ref(), server, uri).await
177 } else {
178 Err(forge_error::DispatchError::TransportDead {
179 server: self.name.clone(),
180 reason: "reconnection failed".into(),
181 })
182 }
183 }
184 other => other,
185 }
186 }
187}