1use std::collections::{HashMap, HashSet};
7use std::convert::Infallible;
8use std::path::PathBuf;
9use std::sync::mpsc as std_mpsc;
10use std::time::Duration;
11
12use notify_debouncer_mini::{DebouncedEventKind, new_debouncer};
13use tokio::process::Command;
14use tower::util::BoxCloneService;
15use tower_mcp::proxy::{BackendService, McpProxy};
16use tower_mcp::{RouterRequest, RouterResponse};
17
18use crate::config::{BackendConfig, ProxyConfig, TransportType};
19
20pub fn spawn_config_watcher(
22 config_path: PathBuf,
23 proxy: McpProxy,
24 #[cfg(feature = "discovery")] discovery_index: Option<(
25 crate::discovery::SharedDiscoveryIndex,
26 String,
27 )>,
28) {
29 std::thread::spawn(move || {
30 let rt = tokio::runtime::Builder::new_current_thread()
31 .enable_all()
32 .build()
33 .expect("hot reload runtime");
34 rt.block_on(watch_loop(
35 config_path,
36 proxy,
37 #[cfg(feature = "discovery")]
38 discovery_index,
39 ));
40 });
41}
42
43async fn watch_loop(
44 config_path: PathBuf,
45 proxy: McpProxy,
46 #[cfg(feature = "discovery")] discovery_index: Option<(
47 crate::discovery::SharedDiscoveryIndex,
48 String,
49 )>,
50) {
51 let (tx, rx) = std_mpsc::channel();
52
53 let mut debouncer = match new_debouncer(Duration::from_secs(2), tx) {
54 Ok(d) => d,
55 Err(e) => {
56 tracing::error!(error = %e, "Failed to create file watcher, hot reload disabled");
57 return;
58 }
59 };
60
61 if let Err(e) = debouncer
62 .watcher()
63 .watch(&config_path, notify::RecursiveMode::NonRecursive)
64 {
65 tracing::error!(
66 path = %config_path.display(),
67 error = %e,
68 "Failed to watch config file, hot reload disabled"
69 );
70 return;
71 }
72
73 tracing::info!(path = %config_path.display(), "Hot reload watching config file");
74
75 let mut backend_fingerprints: HashMap<String, String> = {
77 if let Ok(config) = ProxyConfig::load(&config_path) {
78 config
79 .backends
80 .iter()
81 .map(|b| (b.name.clone(), config_fingerprint(b)))
82 .collect()
83 } else {
84 HashMap::new()
85 }
86 };
87
88 loop {
89 let events = match rx.recv() {
91 Ok(Ok(events)) => events,
92 Ok(Err(e)) => {
93 tracing::warn!(error = %e, "File watcher error");
94 continue;
95 }
96 Err(_) => {
97 tracing::info!("File watcher channel closed, stopping hot reload");
98 break;
99 }
100 };
101
102 let has_write = events
104 .iter()
105 .any(|e| matches!(e.kind, DebouncedEventKind::Any));
106 if !has_write {
107 continue;
108 }
109
110 tracing::info!("Config file changed, reloading backends");
111
112 let mut new_config = match ProxyConfig::load(&config_path) {
113 Ok(c) => c,
114 Err(e) => {
115 tracing::warn!(error = %e, "Failed to parse updated config, skipping reload");
116 continue;
117 }
118 };
119 new_config.resolve_env_vars();
120
121 let new_fingerprints: HashMap<String, String> = new_config
122 .backends
123 .iter()
124 .map(|b| (b.name.clone(), config_fingerprint(b)))
125 .collect();
126
127 let old_names: HashSet<&String> = backend_fingerprints.keys().collect();
128 let new_names: HashSet<&String> = new_fingerprints.keys().collect();
129
130 for removed in old_names.difference(&new_names) {
132 tracing::info!(backend = %removed, "Removing backend via hot reload");
133 if proxy.remove_backend(removed).await {
134 tracing::info!(backend = %removed, "Backend removed");
135 } else {
136 tracing::warn!(backend = %removed, "Backend not found for removal");
137 }
138 }
139
140 for backend in &new_config.backends {
142 if backend_fingerprints.contains_key(&backend.name) {
143 let old_fp = &backend_fingerprints[&backend.name];
145 let new_fp = &new_fingerprints[&backend.name];
146
147 if old_fp != new_fp {
148 tracing::info!(
149 backend = %backend.name,
150 "Backend config changed, replacing via hot reload"
151 );
152
153 proxy.remove_backend(&backend.name).await;
155 if let Err(e) = add_backend(&proxy, backend).await {
156 tracing::error!(
157 backend = %backend.name,
158 error = %e,
159 "Failed to replace backend via hot reload"
160 );
161 } else {
162 tracing::info!(backend = %backend.name, "Backend replaced");
163 }
164 }
165 continue;
166 }
167
168 tracing::info!(
169 name = %backend.name,
170 transport = ?backend.transport,
171 "Adding new backend via hot reload"
172 );
173
174 if let Err(e) = add_backend(&proxy, backend).await {
175 tracing::error!(
176 backend = %backend.name,
177 error = %e,
178 "Failed to add backend via hot reload"
179 );
180 } else {
181 tracing::info!(backend = %backend.name, "Backend added via hot reload");
182 }
183 }
184
185 backend_fingerprints = new_fingerprints;
187
188 #[cfg(feature = "discovery")]
190 if let Some((ref index, ref separator)) = discovery_index {
191 let mut proxy_clone = proxy.clone();
192 crate::discovery::reindex(index, &mut proxy_clone, separator).await;
193 }
194 }
195}
196
197fn config_fingerprint(backend: &BackendConfig) -> String {
200 toml::to_string(backend).unwrap_or_default()
201}
202
203async fn add_backend(proxy: &McpProxy, backend: &BackendConfig) -> anyhow::Result<()> {
205 let has_middleware = backend.timeout.is_some()
206 || backend.circuit_breaker.is_some()
207 || backend.rate_limit.is_some()
208 || backend.concurrency.is_some()
209 || backend.retry.is_some()
210 || backend.hedging.is_some()
211 || backend.outlier_detection.is_some();
212
213 match backend.transport {
214 TransportType::Stdio => {
215 let command = backend
216 .command
217 .as_deref()
218 .ok_or_else(|| anyhow::anyhow!("stdio backend requires 'command'"))?;
219 let args: Vec<&str> = backend.args.iter().map(|s| s.as_str()).collect();
220
221 let mut cmd = Command::new(command);
222 cmd.args(&args);
223 for (key, value) in &backend.env {
224 cmd.env(key, value);
225 }
226
227 let transport =
228 tower_mcp::client::StdioClientTransport::spawn_command(&mut cmd).await?;
229
230 if has_middleware {
231 let layer = build_backend_layer(backend);
232 proxy
233 .add_backend_with_layer(&backend.name, transport, layer)
234 .await
235 .map_err(|e| anyhow::anyhow!("{}", e))?;
236 } else {
237 proxy
238 .add_backend(&backend.name, transport)
239 .await
240 .map_err(|e| anyhow::anyhow!("{}", e))?;
241 }
242 }
243 TransportType::Http => {
244 let url = backend
245 .url
246 .as_deref()
247 .ok_or_else(|| anyhow::anyhow!("http backend requires 'url'"))?;
248 let mut transport = tower_mcp::client::HttpClientTransport::new(url);
249 if let Some(token) = &backend.bearer_token {
250 transport = transport.bearer_token(token);
251 }
252
253 if has_middleware {
254 let layer = build_backend_layer(backend);
255 proxy
256 .add_backend_with_layer(&backend.name, transport, layer)
257 .await
258 .map_err(|e| anyhow::anyhow!("{}", e))?;
259 } else {
260 proxy
261 .add_backend(&backend.name, transport)
262 .await
263 .map_err(|e| anyhow::anyhow!("{}", e))?;
264 }
265 }
266 #[cfg(feature = "websocket")]
267 TransportType::Websocket => {
268 let url = backend
269 .url
270 .as_deref()
271 .ok_or_else(|| anyhow::anyhow!("websocket backend requires 'url'"))?;
272 let transport = if let Some(token) = &backend.bearer_token {
273 crate::ws_transport::WebSocketClientTransport::connect_with_bearer_token(url, token)
274 .await?
275 } else {
276 crate::ws_transport::WebSocketClientTransport::connect(url).await?
277 };
278
279 if has_middleware {
280 let layer = build_backend_layer(backend);
281 proxy
282 .add_backend_with_layer(&backend.name, transport, layer)
283 .await
284 .map_err(|e| anyhow::anyhow!("{}", e))?;
285 } else {
286 proxy
287 .add_backend(&backend.name, transport)
288 .await
289 .map_err(|e| anyhow::anyhow!("{}", e))?;
290 }
291 }
292 #[cfg(not(feature = "websocket"))]
293 TransportType::Websocket => {
294 anyhow::bail!(
295 "WebSocket transport requires the 'websocket' feature. \
296 Rebuild with: cargo install mcp-proxy --features websocket"
297 );
298 }
299 }
300
301 if has_middleware {
302 tracing::info!(
303 backend = %backend.name,
304 timeout = backend.timeout.is_some(),
305 circuit_breaker = backend.circuit_breaker.is_some(),
306 rate_limit = backend.rate_limit.is_some(),
307 concurrency = backend.concurrency.is_some(),
308 "Per-backend middleware applied to hot-reloaded backend"
309 );
310 }
311
312 Ok(())
313}
314
315struct BackendMiddlewareLayer {
320 build_fn: Box<
321 dyn Fn(BackendService) -> BoxCloneService<RouterRequest, RouterResponse, Infallible> + Send,
322 >,
323}
324
325impl tower::Layer<BackendService> for BackendMiddlewareLayer {
326 type Service = BoxCloneService<RouterRequest, RouterResponse, Infallible>;
327
328 fn layer(&self, inner: BackendService) -> Self::Service {
329 (self.build_fn)(inner)
330 }
331}
332
333fn build_backend_layer(backend: &BackendConfig) -> BackendMiddlewareLayer {
338 let retry_config = backend.retry.clone();
339 let concurrency = backend.concurrency.as_ref().map(|cc| cc.max_concurrent);
340 let rate_limit = backend
341 .rate_limit
342 .as_ref()
343 .map(|rl| (rl.requests, rl.period_seconds));
344 let timeout_secs = backend.timeout.as_ref().map(|t| t.seconds);
345 let circuit_breaker = backend.circuit_breaker.as_ref().map(|cb| {
346 (
347 cb.failure_rate_threshold,
348 cb.minimum_calls,
349 cb.wait_duration_seconds,
350 cb.permitted_calls_in_half_open,
351 )
352 });
353 let hedging = backend.hedging.clone();
354 let outlier = backend.outlier_detection.clone();
355 let name = backend.name.clone();
356
357 BackendMiddlewareLayer {
358 build_fn: Box::new(move |inner: BackendService| {
359 let mut svc: BoxCloneService<RouterRequest, RouterResponse, Infallible> =
360 BoxCloneService::new(inner);
361
362 if let Some(ref retry_cfg) = retry_config {
364 let layer = crate::retry::build_retry_layer(retry_cfg, &name);
365 let retried = tower::Layer::layer(&layer, svc);
366 svc = BoxCloneService::new(retried);
367 }
368
369 if let Some(ref hedge_cfg) = hedging {
371 let delay = Duration::from_millis(hedge_cfg.delay_ms);
372 let max_attempts = hedge_cfg.max_hedges + 1;
373 let layer = if delay.is_zero() {
374 tower_resilience::hedge::HedgeLayer::builder()
375 .no_delay()
376 .max_hedged_attempts(max_attempts)
377 .name(format!("{}-hedge", name))
378 .build()
379 } else {
380 tower_resilience::hedge::HedgeLayer::builder()
381 .delay(delay)
382 .max_hedged_attempts(max_attempts)
383 .name(format!("{}-hedge", name))
384 .build()
385 };
386 let hedged = tower::Layer::layer(&layer, svc);
387 svc = BoxCloneService::new(tower_mcp::CatchError::new(hedged));
388 }
389
390 if let Some(max) = concurrency {
392 let limited =
393 tower::Layer::layer(&tower::limit::ConcurrencyLimitLayer::new(max), svc);
394 svc = BoxCloneService::new(tower_mcp::CatchError::new(limited));
395 }
396
397 if let Some((requests, period_seconds)) = rate_limit {
399 let layer = tower_resilience::ratelimiter::RateLimiterLayer::builder()
400 .limit_for_period(requests)
401 .refresh_period(Duration::from_secs(period_seconds))
402 .name(format!("{}-ratelimit", name))
403 .build();
404 let limited = tower::Layer::layer(&layer, svc);
405 svc = BoxCloneService::new(tower_mcp::CatchError::new(limited));
406 }
407
408 if let Some(seconds) = timeout_secs {
410 let limited = tower::Layer::layer(
411 &tower::timeout::TimeoutLayer::new(Duration::from_secs(seconds)),
412 svc,
413 );
414 svc = BoxCloneService::new(tower_mcp::CatchError::new(limited));
415 }
416
417 if let Some((failure_rate, min_calls, wait_secs, half_open)) = circuit_breaker {
419 let layer = tower_resilience::circuitbreaker::CircuitBreakerLayer::builder()
420 .failure_rate_threshold(failure_rate)
421 .minimum_number_of_calls(min_calls)
422 .wait_duration_in_open(Duration::from_secs(wait_secs))
423 .permitted_calls_in_half_open(half_open)
424 .name(format!("{}-cb", name))
425 .build();
426 let limited = tower::Layer::layer(&layer, svc);
427 svc = BoxCloneService::new(tower_mcp::CatchError::new(limited));
428 }
429
430 if let Some(ref od_config) = outlier {
432 let detector = crate::outlier::OutlierDetector::new(od_config.max_ejection_percent);
435 let layer = crate::outlier::OutlierDetectionLayer::new(
436 name.clone(),
437 od_config.clone(),
438 detector,
439 );
440 let od_svc = tower::Layer::layer(&layer, svc);
441 svc = BoxCloneService::new(od_svc);
442 }
443
444 svc
445 }),
446 }
447}