Skip to main content

mcp_proxy/
reload.rs

1//! Hot reload: watch the config file for changes and manage backends dynamically.
2//!
3//! Supports adding, removing, and replacing backends at runtime when the config
4//! file changes. Uses content hashing to detect modifications.
5
6use std::collections::{HashMap, HashSet};
7use std::convert::Infallible;
8use std::path::PathBuf;
9use std::sync::mpsc as std_mpsc;
10use std::time::Duration;
11
12use notify_debouncer_mini::{DebouncedEventKind, new_debouncer};
13use tokio::process::Command;
14use tower::util::BoxCloneService;
15use tower_mcp::proxy::{BackendService, McpProxy};
16use tower_mcp::{RouterRequest, RouterResponse};
17
18use crate::config::{BackendConfig, ProxyConfig, TransportType};
19
20/// Spawn a background task that watches the config file and adds new backends.
21pub fn spawn_config_watcher(
22    config_path: PathBuf,
23    proxy: McpProxy,
24    #[cfg(feature = "discovery")] discovery_index: Option<(
25        crate::discovery::SharedDiscoveryIndex,
26        String,
27    )>,
28) {
29    std::thread::spawn(move || {
30        let rt = tokio::runtime::Builder::new_current_thread()
31            .enable_all()
32            .build()
33            .expect("hot reload runtime");
34        rt.block_on(watch_loop(
35            config_path,
36            proxy,
37            #[cfg(feature = "discovery")]
38            discovery_index,
39        ));
40    });
41}
42
43async fn watch_loop(
44    config_path: PathBuf,
45    proxy: McpProxy,
46    #[cfg(feature = "discovery")] discovery_index: Option<(
47        crate::discovery::SharedDiscoveryIndex,
48        String,
49    )>,
50) {
51    let (tx, rx) = std_mpsc::channel();
52
53    let mut debouncer = match new_debouncer(Duration::from_secs(2), tx) {
54        Ok(d) => d,
55        Err(e) => {
56            tracing::error!(error = %e, "Failed to create file watcher, hot reload disabled");
57            return;
58        }
59    };
60
61    if let Err(e) = debouncer
62        .watcher()
63        .watch(&config_path, notify::RecursiveMode::NonRecursive)
64    {
65        tracing::error!(
66            path = %config_path.display(),
67            error = %e,
68            "Failed to watch config file, hot reload disabled"
69        );
70        return;
71    }
72
73    tracing::info!(path = %config_path.display(), "Hot reload watching config file");
74
75    // Track known backends and their config fingerprints for change detection
76    let mut backend_fingerprints: HashMap<String, String> = {
77        if let Ok(config) = ProxyConfig::load(&config_path) {
78            config
79                .backends
80                .iter()
81                .map(|b| (b.name.clone(), config_fingerprint(b)))
82                .collect()
83        } else {
84            HashMap::new()
85        }
86    };
87
88    loop {
89        // Block until a file event arrives (this is a std::sync channel)
90        let events = match rx.recv() {
91            Ok(Ok(events)) => events,
92            Ok(Err(e)) => {
93                tracing::warn!(error = %e, "File watcher error");
94                continue;
95            }
96            Err(_) => {
97                tracing::info!("File watcher channel closed, stopping hot reload");
98                break;
99            }
100        };
101
102        // Only process write events
103        let has_write = events
104            .iter()
105            .any(|e| matches!(e.kind, DebouncedEventKind::Any));
106        if !has_write {
107            continue;
108        }
109
110        tracing::info!("Config file changed, reloading backends");
111
112        let mut new_config = match ProxyConfig::load(&config_path) {
113            Ok(c) => c,
114            Err(e) => {
115                tracing::warn!(error = %e, "Failed to parse updated config, skipping reload");
116                continue;
117            }
118        };
119        new_config.resolve_env_vars();
120
121        let new_fingerprints: HashMap<String, String> = new_config
122            .backends
123            .iter()
124            .map(|b| (b.name.clone(), config_fingerprint(b)))
125            .collect();
126
127        let old_names: HashSet<&String> = backend_fingerprints.keys().collect();
128        let new_names: HashSet<&String> = new_fingerprints.keys().collect();
129
130        // Remove backends that are no longer in config
131        for removed in old_names.difference(&new_names) {
132            tracing::info!(backend = %removed, "Removing backend via hot reload");
133            if proxy.remove_backend(removed).await {
134                tracing::info!(backend = %removed, "Backend removed");
135            } else {
136                tracing::warn!(backend = %removed, "Backend not found for removal");
137            }
138        }
139
140        // Add new backends
141        for backend in &new_config.backends {
142            if backend_fingerprints.contains_key(&backend.name) {
143                // Existing backend -- check for modification
144                let old_fp = &backend_fingerprints[&backend.name];
145                let new_fp = &new_fingerprints[&backend.name];
146
147                if old_fp != new_fp {
148                    tracing::info!(
149                        backend = %backend.name,
150                        "Backend config changed, replacing via hot reload"
151                    );
152
153                    // Remove old, add new
154                    proxy.remove_backend(&backend.name).await;
155                    if let Err(e) = add_backend(&proxy, backend).await {
156                        tracing::error!(
157                            backend = %backend.name,
158                            error = %e,
159                            "Failed to replace backend via hot reload"
160                        );
161                    } else {
162                        tracing::info!(backend = %backend.name, "Backend replaced");
163                    }
164                }
165                continue;
166            }
167
168            tracing::info!(
169                name = %backend.name,
170                transport = ?backend.transport,
171                "Adding new backend via hot reload"
172            );
173
174            if let Err(e) = add_backend(&proxy, backend).await {
175                tracing::error!(
176                    backend = %backend.name,
177                    error = %e,
178                    "Failed to add backend via hot reload"
179                );
180            } else {
181                tracing::info!(backend = %backend.name, "Backend added via hot reload");
182            }
183        }
184
185        // Update fingerprints to reflect current state
186        backend_fingerprints = new_fingerprints;
187
188        // Re-index discovery if enabled
189        #[cfg(feature = "discovery")]
190        if let Some((ref index, ref separator)) = discovery_index {
191            let mut proxy_clone = proxy.clone();
192            crate::discovery::reindex(index, &mut proxy_clone, separator).await;
193        }
194    }
195}
196
197/// Generate a fingerprint for a backend config to detect changes.
198/// Uses TOML serialization for a stable, content-based comparison.
199fn config_fingerprint(backend: &BackendConfig) -> String {
200    toml::to_string(backend).unwrap_or_default()
201}
202
203/// Connect and add a single backend to the proxy, including per-backend middleware.
204async fn add_backend(proxy: &McpProxy, backend: &BackendConfig) -> anyhow::Result<()> {
205    let has_middleware = backend.timeout.is_some()
206        || backend.circuit_breaker.is_some()
207        || backend.rate_limit.is_some()
208        || backend.concurrency.is_some()
209        || backend.retry.is_some()
210        || backend.hedging.is_some()
211        || backend.outlier_detection.is_some();
212
213    match backend.transport {
214        TransportType::Stdio => {
215            let command = backend
216                .command
217                .as_deref()
218                .ok_or_else(|| anyhow::anyhow!("stdio backend requires 'command'"))?;
219            let args: Vec<&str> = backend.args.iter().map(|s| s.as_str()).collect();
220
221            let mut cmd = Command::new(command);
222            cmd.args(&args);
223            for (key, value) in &backend.env {
224                cmd.env(key, value);
225            }
226
227            let transport =
228                tower_mcp::client::StdioClientTransport::spawn_command(&mut cmd).await?;
229
230            if has_middleware {
231                let layer = build_backend_layer(backend);
232                proxy
233                    .add_backend_with_layer(&backend.name, transport, layer)
234                    .await
235                    .map_err(|e| anyhow::anyhow!("{}", e))?;
236            } else {
237                proxy
238                    .add_backend(&backend.name, transport)
239                    .await
240                    .map_err(|e| anyhow::anyhow!("{}", e))?;
241            }
242        }
243        TransportType::Http => {
244            let url = backend
245                .url
246                .as_deref()
247                .ok_or_else(|| anyhow::anyhow!("http backend requires 'url'"))?;
248            let mut transport = tower_mcp::client::HttpClientTransport::new(url);
249            if let Some(token) = &backend.bearer_token {
250                transport = transport.bearer_token(token);
251            }
252
253            if has_middleware {
254                let layer = build_backend_layer(backend);
255                proxy
256                    .add_backend_with_layer(&backend.name, transport, layer)
257                    .await
258                    .map_err(|e| anyhow::anyhow!("{}", e))?;
259            } else {
260                proxy
261                    .add_backend(&backend.name, transport)
262                    .await
263                    .map_err(|e| anyhow::anyhow!("{}", e))?;
264            }
265        }
266        #[cfg(feature = "websocket")]
267        TransportType::Websocket => {
268            let url = backend
269                .url
270                .as_deref()
271                .ok_or_else(|| anyhow::anyhow!("websocket backend requires 'url'"))?;
272            let transport = if let Some(token) = &backend.bearer_token {
273                crate::ws_transport::WebSocketClientTransport::connect_with_bearer_token(url, token)
274                    .await?
275            } else {
276                crate::ws_transport::WebSocketClientTransport::connect(url).await?
277            };
278
279            if has_middleware {
280                let layer = build_backend_layer(backend);
281                proxy
282                    .add_backend_with_layer(&backend.name, transport, layer)
283                    .await
284                    .map_err(|e| anyhow::anyhow!("{}", e))?;
285            } else {
286                proxy
287                    .add_backend(&backend.name, transport)
288                    .await
289                    .map_err(|e| anyhow::anyhow!("{}", e))?;
290            }
291        }
292        #[cfg(not(feature = "websocket"))]
293        TransportType::Websocket => {
294            anyhow::bail!(
295                "WebSocket transport requires the 'websocket' feature. \
296                 Rebuild with: cargo install mcp-proxy --features websocket"
297            );
298        }
299    }
300
301    if has_middleware {
302        tracing::info!(
303            backend = %backend.name,
304            timeout = backend.timeout.is_some(),
305            circuit_breaker = backend.circuit_breaker.is_some(),
306            rate_limit = backend.rate_limit.is_some(),
307            concurrency = backend.concurrency.is_some(),
308            "Per-backend middleware applied to hot-reloaded backend"
309        );
310    }
311
312    Ok(())
313}
314
315/// A type-erasing layer that builds the full per-backend middleware stack.
316///
317/// Uses `BoxCloneService` to erase the composed middleware types, allowing
318/// arbitrary combinations of optional layers.
319struct BackendMiddlewareLayer {
320    build_fn: Box<
321        dyn Fn(BackendService) -> BoxCloneService<RouterRequest, RouterResponse, Infallible> + Send,
322    >,
323}
324
325impl tower::Layer<BackendService> for BackendMiddlewareLayer {
326    type Service = BoxCloneService<RouterRequest, RouterResponse, Infallible>;
327
328    fn layer(&self, inner: BackendService) -> Self::Service {
329        (self.build_fn)(inner)
330    }
331}
332
333/// Build a type-erased layer for per-backend middleware from config.
334///
335/// Layers are applied inner to outer:
336/// retry -> concurrency -> rate limit -> timeout -> circuit breaker -> outlier detection.
337fn build_backend_layer(backend: &BackendConfig) -> BackendMiddlewareLayer {
338    let retry_config = backend.retry.clone();
339    let concurrency = backend.concurrency.as_ref().map(|cc| cc.max_concurrent);
340    let rate_limit = backend
341        .rate_limit
342        .as_ref()
343        .map(|rl| (rl.requests, rl.period_seconds));
344    let timeout_secs = backend.timeout.as_ref().map(|t| t.seconds);
345    let circuit_breaker = backend.circuit_breaker.as_ref().map(|cb| {
346        (
347            cb.failure_rate_threshold,
348            cb.minimum_calls,
349            cb.wait_duration_seconds,
350            cb.permitted_calls_in_half_open,
351        )
352    });
353    let hedging = backend.hedging.clone();
354    let outlier = backend.outlier_detection.clone();
355    let name = backend.name.clone();
356
357    BackendMiddlewareLayer {
358        build_fn: Box::new(move |inner: BackendService| {
359            let mut svc: BoxCloneService<RouterRequest, RouterResponse, Infallible> =
360                BoxCloneService::new(inner);
361
362            // Retry (innermost)
363            if let Some(ref retry_cfg) = retry_config {
364                let layer = crate::retry::build_retry_layer(retry_cfg, &name);
365                let retried = tower::Layer::layer(&layer, svc);
366                svc = BoxCloneService::new(retried);
367            }
368
369            // Hedging
370            if let Some(ref hedge_cfg) = hedging {
371                let delay = Duration::from_millis(hedge_cfg.delay_ms);
372                let max_attempts = hedge_cfg.max_hedges + 1;
373                let layer = if delay.is_zero() {
374                    tower_resilience::hedge::HedgeLayer::builder()
375                        .no_delay()
376                        .max_hedged_attempts(max_attempts)
377                        .name(format!("{}-hedge", name))
378                        .build()
379                } else {
380                    tower_resilience::hedge::HedgeLayer::builder()
381                        .delay(delay)
382                        .max_hedged_attempts(max_attempts)
383                        .name(format!("{}-hedge", name))
384                        .build()
385                };
386                let hedged = tower::Layer::layer(&layer, svc);
387                svc = BoxCloneService::new(tower_mcp::CatchError::new(hedged));
388            }
389
390            // Concurrency limit
391            if let Some(max) = concurrency {
392                let limited =
393                    tower::Layer::layer(&tower::limit::ConcurrencyLimitLayer::new(max), svc);
394                svc = BoxCloneService::new(tower_mcp::CatchError::new(limited));
395            }
396
397            // Rate limit
398            if let Some((requests, period_seconds)) = rate_limit {
399                let layer = tower_resilience::ratelimiter::RateLimiterLayer::builder()
400                    .limit_for_period(requests)
401                    .refresh_period(Duration::from_secs(period_seconds))
402                    .name(format!("{}-ratelimit", name))
403                    .build();
404                let limited = tower::Layer::layer(&layer, svc);
405                svc = BoxCloneService::new(tower_mcp::CatchError::new(limited));
406            }
407
408            // Timeout
409            if let Some(seconds) = timeout_secs {
410                let limited = tower::Layer::layer(
411                    &tower::timeout::TimeoutLayer::new(Duration::from_secs(seconds)),
412                    svc,
413                );
414                svc = BoxCloneService::new(tower_mcp::CatchError::new(limited));
415            }
416
417            // Circuit breaker
418            if let Some((failure_rate, min_calls, wait_secs, half_open)) = circuit_breaker {
419                let layer = tower_resilience::circuitbreaker::CircuitBreakerLayer::builder()
420                    .failure_rate_threshold(failure_rate)
421                    .minimum_number_of_calls(min_calls)
422                    .wait_duration_in_open(Duration::from_secs(wait_secs))
423                    .permitted_calls_in_half_open(half_open)
424                    .name(format!("{}-cb", name))
425                    .build();
426                let limited = tower::Layer::layer(&layer, svc);
427                svc = BoxCloneService::new(tower_mcp::CatchError::new(limited));
428            }
429
430            // Outlier detection (outermost)
431            if let Some(ref od_config) = outlier {
432                // Hot-reloaded backends get their own detector (single-backend scope).
433                // The main proxy build path uses a shared detector across all backends.
434                let detector = crate::outlier::OutlierDetector::new(od_config.max_ejection_percent);
435                let layer = crate::outlier::OutlierDetectionLayer::new(
436                    name.clone(),
437                    od_config.clone(),
438                    detector,
439                );
440                let od_svc = tower::Layer::layer(&layer, svc);
441                svc = BoxCloneService::new(od_svc);
442            }
443
444            svc
445        }),
446    }
447}