1use std::collections::{BTreeMap, HashMap};
2use std::path::PathBuf;
3use std::str::FromStr;
4use std::sync::Arc;
5use std::time::Duration;
6
7use axum::http::{HeaderName, HeaderValue};
8
9use crate::Error;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum BackendTransport {
14 Stdio,
16 StreamableHttp,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum BackendAuthMode {
23 Auto,
26 ExplicitHeaders,
28 OAuth,
30}
31
32impl Default for BackendAuthMode {
33 fn default() -> Self {
34 Self::Auto
35 }
36}
37
38pub type HeaderProvider = Arc<dyn Fn() -> Result<BTreeMap<String, String>, Error> + Send + Sync>;
39
40#[derive(Clone)]
42pub struct BackendServerConfig {
43 pub name: String,
44 pub command: String,
45 pub args: Vec<String>,
46 pub env: HashMap<String, String>,
47 pub cwd: Option<PathBuf>,
48 pub timeout: Option<Duration>,
49 pub transport: BackendTransport,
50 pub headers: HashMap<String, String>,
51 pub header_provider: Option<HeaderProvider>,
52 pub auth_mode: BackendAuthMode,
53 pub oauth_app_name: Option<String>,
54}
55
56impl BackendServerConfig {
57 pub fn new(
58 name: impl Into<String>,
59 command: impl Into<String>,
60 args: impl IntoIterator<Item = impl Into<String>>,
61 ) -> Self {
62 let command = command.into();
63 let transport = if is_http_url(&command) {
64 BackendTransport::StreamableHttp
65 } else {
66 BackendTransport::Stdio
67 };
68 let raw_args = args.into_iter().map(Into::into).collect::<Vec<_>>();
69 let parsed_args = parse_backend_args(raw_args, transport);
70 Self {
71 name: name.into(),
72 command,
73 args: parsed_args.args,
74 env: parsed_args.env,
75 cwd: parsed_args.cwd,
76 timeout: parsed_args.timeout,
77 transport,
78 headers: parsed_args.headers,
79 header_provider: None,
80 auth_mode: parsed_args.auth_mode,
81 oauth_app_name: None,
82 }
83 }
84
85 pub fn with_env(
86 mut self,
87 env: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
88 ) -> Self {
89 self.env = env.into_iter().map(|(k, v)| (k.into(), v.into())).collect();
90 self
91 }
92
93 pub fn with_cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
94 self.cwd = Some(cwd.into());
95 self
96 }
97
98 pub fn with_timeout(mut self, timeout: Duration) -> Self {
99 self.timeout = Some(timeout);
100 self
101 }
102
103 pub fn with_headers(
104 mut self,
105 headers: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
106 ) -> Self {
107 self.headers = headers
108 .into_iter()
109 .map(|(name, value)| (name.into(), value.into()))
110 .collect();
111 self
112 }
113
114 pub fn with_auth_mode(mut self, auth_mode: BackendAuthMode) -> Self {
115 self.auth_mode = auth_mode;
116 self
117 }
118
119 pub fn with_header_provider(mut self, provider: HeaderProvider) -> Self {
120 self.header_provider = Some(provider);
121 self
122 }
123
124 pub fn with_oauth_app_name(mut self, app_name: impl Into<String>) -> Self {
125 self.oauth_app_name = Some(app_name.into());
126 self
127 }
128
129 pub fn has_dynamic_headers(&self) -> bool {
130 self.header_provider.is_some()
131 }
132
133 pub fn has_authorization_header(&self) -> bool {
134 self.headers
135 .keys()
136 .any(|name| name.eq_ignore_ascii_case("authorization"))
137 }
138
139 pub fn should_use_oauth(&self) -> bool {
140 self.transport == BackendTransport::StreamableHttp
141 && match self.auth_mode {
142 BackendAuthMode::Auto => {
143 !self.has_authorization_header() && !self.has_dynamic_headers()
144 }
145 BackendAuthMode::ExplicitHeaders => false,
146 BackendAuthMode::OAuth => true,
147 }
148 }
149}
150
151pub fn backend_http_headers(
152 backend: &BackendServerConfig,
153) -> Result<HashMap<HeaderName, HeaderValue>, Error> {
154 backend
155 .headers
156 .iter()
157 .map(|(name, value)| {
158 let name = HeaderName::from_str(name).map_err(|error| {
159 Error::Config(format!("invalid HTTP header name {name:?}: {error}"))
160 })?;
161 let value = HeaderValue::from_str(value).map_err(|error| {
162 Error::Config(format!("invalid HTTP header value for {name:?}: {error}"))
163 })?;
164 Ok((name, value))
165 })
166 .collect()
167}
168
169#[derive(Debug, Default)]
170struct ParsedBackendArgs {
171 args: Vec<String>,
172 env: HashMap<String, String>,
173 cwd: Option<PathBuf>,
174 timeout: Option<Duration>,
175 headers: HashMap<String, String>,
176 auth_mode: BackendAuthMode,
177}
178
179fn parse_backend_args(args: Vec<String>, transport: BackendTransport) -> ParsedBackendArgs {
180 let mut parsed = ParsedBackendArgs {
181 auth_mode: BackendAuthMode::Auto,
182 ..Default::default()
183 };
184 let mut index = 0;
185 while index < args.len() {
186 let arg = &args[index];
187 if arg == "-H" || arg == "--header" {
188 if let Some(header) = args.get(index + 1) {
189 if transport == BackendTransport::StreamableHttp {
190 if let Some((name, value)) = parse_header_arg(header) {
191 parsed.headers.insert(name, value);
192 } else {
193 parsed.args.push(arg.clone());
194 parsed.args.push(header.clone());
195 }
196 } else {
197 parsed.args.push(arg.clone());
198 parsed.args.push(header.clone());
199 }
200 index += 2;
201 } else {
202 parsed.args.push(arg.clone());
203 index += 1;
204 }
205 } else if let Some(header) = arg
206 .strip_prefix("-H=")
207 .or_else(|| arg.strip_prefix("--header="))
208 {
209 if transport == BackendTransport::StreamableHttp {
210 if let Some((name, value)) = parse_header_arg(header) {
211 parsed.headers.insert(name, value);
212 } else {
213 parsed.args.push(arg.clone());
214 }
215 } else {
216 parsed.args.push(arg.clone());
217 }
218 index += 1;
219 } else if let Some(cwd) = arg.strip_prefix("--cwd=") {
220 parsed.cwd = Some(PathBuf::from(cwd));
221 index += 1;
222 } else if arg == "--cwd" {
223 if let Some(cwd) = args.get(index + 1) {
224 parsed.cwd = Some(PathBuf::from(cwd));
225 index += 2;
226 } else {
227 parsed.args.push(arg.clone());
228 index += 1;
229 }
230 } else if arg == "-e" || arg == "--env" {
231 if let Some(env) = args.get(index + 1) {
232 if let Some((key, value)) = parse_key_value_arg(env) {
233 parsed.env.insert(key, interpolate_env(&value));
234 } else {
235 parsed.args.push(arg.clone());
236 parsed.args.push(env.clone());
237 }
238 index += 2;
239 } else {
240 parsed.args.push(arg.clone());
241 index += 1;
242 }
243 } else if let Some(env) = arg
244 .strip_prefix("-e=")
245 .or_else(|| arg.strip_prefix("--env="))
246 {
247 if let Some((key, value)) = parse_key_value_arg(env) {
248 parsed.env.insert(key, interpolate_env(&value));
249 } else {
250 parsed.args.push(arg.clone());
251 }
252 index += 1;
253 } else if arg == "-t" || arg == "--timeout" {
254 if let Some(timeout) = args.get(index + 1) {
255 if let Ok(seconds) = timeout.parse::<f64>() {
256 if seconds.is_finite() && seconds > 0.0 {
257 parsed.timeout = Some(Duration::from_secs_f64(seconds));
258 } else {
259 parsed.args.push(arg.clone());
260 parsed.args.push(timeout.clone());
261 }
262 } else {
263 parsed.args.push(arg.clone());
264 parsed.args.push(timeout.clone());
265 }
266 index += 2;
267 } else {
268 parsed.args.push(arg.clone());
269 index += 1;
270 }
271 } else if let Some(timeout) = arg
272 .strip_prefix("-t=")
273 .or_else(|| arg.strip_prefix("--timeout="))
274 {
275 if let Ok(seconds) = timeout.parse::<f64>() {
276 if seconds.is_finite() && seconds > 0.0 {
277 parsed.timeout = Some(Duration::from_secs_f64(seconds));
278 } else {
279 parsed.args.push(arg.clone());
280 }
281 } else {
282 parsed.args.push(arg.clone());
283 }
284 index += 1;
285 } else if let Some(mode) = arg.strip_prefix("--auth=") {
286 match mode {
287 "explicit-headers" | "headers" | "none" => {
288 parsed.auth_mode = BackendAuthMode::ExplicitHeaders;
289 }
290 "oauth" => {
291 parsed.auth_mode = BackendAuthMode::OAuth;
292 }
293 _ => parsed.args.push(arg.clone()),
294 }
295 index += 1;
296 } else if arg == "--auth" {
297 if let Some(mode) = args.get(index + 1) {
298 match mode.as_str() {
299 "explicit-headers" | "headers" | "none" => {
300 parsed.auth_mode = BackendAuthMode::ExplicitHeaders;
301 }
302 "oauth" => {
303 parsed.auth_mode = BackendAuthMode::OAuth;
304 }
305 _ => {
306 parsed.args.push(arg.clone());
307 parsed.args.push(mode.clone());
308 }
309 }
310 index += 2;
311 } else {
312 parsed.args.push(arg.clone());
313 index += 1;
314 }
315 } else {
316 parsed.args.push(arg.clone());
317 index += 1;
318 }
319 }
320 parsed
321}
322
323fn parse_header_arg(header: &str) -> Option<(String, String)> {
324 let (name, value) = header.split_once('=').or_else(|| header.split_once(':'))?;
325 let name = name.trim();
326 let value = value.trim();
327 if name.is_empty() || value.is_empty() {
328 return None;
329 }
330 Some((name.to_string(), interpolate_env(value)))
331}
332
333fn parse_key_value_arg(value: &str) -> Option<(String, String)> {
334 let (key, value) = value.split_once('=')?;
335 let key = key.trim();
336 if key.is_empty() {
337 return None;
338 }
339 Some((key.to_string(), value.to_string()))
340}
341
342fn interpolate_env(value: &str) -> String {
343 let mut output = String::new();
344 let chars = value.chars().collect::<Vec<_>>();
345 let mut index = 0;
346 while index < chars.len() {
347 if chars[index] == '$' && chars.get(index + 1) == Some(&'{') {
348 if let Some(end) = chars[index + 2..].iter().position(|ch| *ch == '}') {
349 let name = chars[index + 2..index + 2 + end].iter().collect::<String>();
350 output.push_str(&std::env::var(&name).unwrap_or_else(|_| format!("${{{name}}}")));
351 index += end + 3;
352 continue;
353 }
354 }
355 output.push(chars[index]);
356 index += 1;
357 }
358 output
359}
360
361fn is_http_url(value: &str) -> bool {
362 value.starts_with("http://") || value.starts_with("https://")
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[test]
370 fn http_backend_url_parses_curl_style_headers_after_separator() {
371 let backend = BackendServerConfig::new(
372 "remote",
373 "https://example.test/mcp",
374 ["-H", "Authorization=Basic token", "--header", "X-Test=yes"],
375 );
376
377 assert_eq!(backend.transport, BackendTransport::StreamableHttp);
378 assert!(backend.args.is_empty());
379 assert_eq!(backend.headers["Authorization"], "Basic token");
380 assert_eq!(backend.headers["X-Test"], "yes");
381 }
382
383 #[test]
384 fn http_backend_url_parses_equals_header_forms() {
385 let backend = BackendServerConfig::new(
386 "remote",
387 "https://example.test/mcp",
388 ["-H=Authorization=Bearer token", "--header=X-Test=yes"],
389 );
390
391 assert!(backend.args.is_empty());
392 assert_eq!(backend.headers["Authorization"], "Bearer token");
393 assert_eq!(backend.headers["X-Test"], "yes");
394 }
395
396 #[test]
397 fn http_backend_header_values_preserve_missing_environment_variables() {
398 let backend = BackendServerConfig::new(
399 "remote",
400 "https://example.test/mcp",
401 [
402 "-H",
403 "Authorization=Bearer ${MCP_COMPRESSOR_MISSING_TEST_TOKEN}",
404 ],
405 );
406
407 assert_eq!(
408 backend.headers["Authorization"],
409 "Bearer ${MCP_COMPRESSOR_MISSING_TEST_TOKEN}"
410 );
411 }
412
413 #[test]
414 fn remote_http_auto_auth_uses_oauth_without_authorization_header() {
415 let backend =
416 BackendServerConfig::new("remote", "https://example.test/mcp", [] as [&str; 0]);
417
418 assert!(backend.should_use_oauth());
419 }
420
421 #[test]
422 fn remote_http_auto_auth_skips_oauth_with_authorization_header() {
423 let backend = BackendServerConfig::new(
424 "remote",
425 "https://example.test/mcp",
426 ["-H", "Authorization=Basic token"],
427 );
428
429 assert!(backend.has_authorization_header());
430 assert!(!backend.should_use_oauth());
431 }
432
433 #[test]
434 fn http_backend_url_parses_auth_mode_args() {
435 let explicit = BackendServerConfig::new(
436 "remote",
437 "https://example.test/mcp",
438 ["--auth", "explicit-headers"],
439 );
440 let oauth =
441 BackendServerConfig::new("remote", "https://example.test/mcp", ["--auth=oauth"]);
442
443 assert_eq!(explicit.auth_mode, BackendAuthMode::ExplicitHeaders);
444 assert!(explicit.args.is_empty());
445 assert_eq!(oauth.auth_mode, BackendAuthMode::OAuth);
446 assert!(oauth.args.is_empty());
447 }
448
449 #[test]
450 fn explicit_headers_auth_mode_skips_oauth_without_authorization_header() {
451 let backend =
452 BackendServerConfig::new("remote", "https://example.test/mcp", [] as [&str; 0])
453 .with_auth_mode(BackendAuthMode::ExplicitHeaders);
454
455 assert!(!backend.should_use_oauth());
456 }
457
458 #[test]
459 fn forced_oauth_auth_mode_uses_oauth_even_with_authorization_header() {
460 let backend = BackendServerConfig::new(
461 "remote",
462 "https://example.test/mcp",
463 ["-H", "Authorization=Basic token"],
464 )
465 .with_auth_mode(BackendAuthMode::OAuth);
466
467 assert!(backend.should_use_oauth());
468 }
469
470 #[test]
471 fn stdio_backend_never_uses_oauth() {
472 let backend = BackendServerConfig::new("local", "python", ["server.py"]);
473
474 assert!(!backend.should_use_oauth());
475 }
476
477 #[test]
478 fn backend_args_parse_cwd_env_and_timeout_after_separator() {
479 let backend = BackendServerConfig::new(
480 "local",
481 "python",
482 [
483 "server.py",
484 "--cwd",
485 "/tmp/example",
486 "-e",
487 "FOO=bar",
488 "--env=EMPTY=",
489 "-t",
490 "2.5",
491 ],
492 );
493
494 assert_eq!(backend.args, ["server.py"]);
495 assert_eq!(
496 backend.cwd.as_deref(),
497 Some(std::path::Path::new("/tmp/example"))
498 );
499 assert_eq!(backend.env["FOO"], "bar");
500 assert_eq!(backend.env["EMPTY"], "");
501 assert_eq!(backend.timeout, Some(Duration::from_secs_f64(2.5)));
502 }
503
504 #[test]
505 fn backend_args_preserve_invalid_timeout_for_backend_validation() {
506 let backend = BackendServerConfig::new("local", "python", ["server.py", "--timeout", "0"]);
507
508 assert_eq!(backend.args, ["server.py", "--timeout", "0"]);
509 assert_eq!(backend.timeout, None);
510 }
511
512 #[test]
513 fn http_backend_url_preserves_unrecognized_args_for_validation() {
514 let backend = BackendServerConfig::new(
515 "remote",
516 "https://example.test/mcp",
517 ["--unknown", "value", "-H"],
518 );
519
520 assert_eq!(backend.args, ["--unknown", "value", "-H"]);
521 assert!(backend.headers.is_empty());
522 }
523}