Skip to main content

mcp_compressor_core/server/
backend.rs

1use std::collections::{BTreeMap, HashMap};
2use std::path::PathBuf;
3use std::str::FromStr;
4use std::sync::Arc;
5use std::time::Duration;
6
7use axum::http::{HeaderName, HeaderValue};
8
9use crate::Error;
10
11/// Transport type used to reach an upstream MCP server.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum BackendTransport {
14    /// Spawn a local command and speak MCP over stdio.
15    Stdio,
16    /// Connect to a remote streamable HTTP MCP endpoint.
17    StreamableHttp,
18}
19
20/// Authentication strategy for a remote upstream MCP server.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum BackendAuthMode {
23    /// Match Python parity: explicit `Authorization` headers are used as-is;
24    /// otherwise native OAuth should be attempted for remote HTTP backends.
25    Auto,
26    /// Use explicit backend headers only; never start OAuth.
27    ExplicitHeaders,
28    /// Force native OAuth.
29    OAuth,
30}
31
32impl Default for BackendAuthMode {
33    fn default() -> Self {
34        Self::Auto
35    }
36}
37
38pub type HeaderProvider = Arc<dyn Fn() -> Result<BTreeMap<String, String>, Error> + Send + Sync>;
39
40/// Configuration for one upstream MCP server.
41#[derive(Clone)]
42pub struct BackendServerConfig {
43    pub name: String,
44    pub command: String,
45    pub args: Vec<String>,
46    pub env: HashMap<String, String>,
47    pub cwd: Option<PathBuf>,
48    pub timeout: Option<Duration>,
49    pub transport: BackendTransport,
50    pub headers: HashMap<String, String>,
51    pub header_provider: Option<HeaderProvider>,
52    pub auth_mode: BackendAuthMode,
53    pub oauth_app_name: Option<String>,
54}
55
56impl BackendServerConfig {
57    pub fn new(
58        name: impl Into<String>,
59        command: impl Into<String>,
60        args: impl IntoIterator<Item = impl Into<String>>,
61    ) -> Self {
62        let command = command.into();
63        let transport = if is_http_url(&command) {
64            BackendTransport::StreamableHttp
65        } else {
66            BackendTransport::Stdio
67        };
68        let raw_args = args.into_iter().map(Into::into).collect::<Vec<_>>();
69        let parsed_args = parse_backend_args(raw_args, transport);
70        Self {
71            name: name.into(),
72            command,
73            args: parsed_args.args,
74            env: parsed_args.env,
75            cwd: parsed_args.cwd,
76            timeout: parsed_args.timeout,
77            transport,
78            headers: parsed_args.headers,
79            header_provider: None,
80            auth_mode: parsed_args.auth_mode,
81            oauth_app_name: None,
82        }
83    }
84
85    pub fn with_env(
86        mut self,
87        env: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
88    ) -> Self {
89        self.env = env.into_iter().map(|(k, v)| (k.into(), v.into())).collect();
90        self
91    }
92
93    pub fn with_cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
94        self.cwd = Some(cwd.into());
95        self
96    }
97
98    pub fn with_timeout(mut self, timeout: Duration) -> Self {
99        self.timeout = Some(timeout);
100        self
101    }
102
103    pub fn with_headers(
104        mut self,
105        headers: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
106    ) -> Self {
107        self.headers = headers
108            .into_iter()
109            .map(|(name, value)| (name.into(), value.into()))
110            .collect();
111        self
112    }
113
114    pub fn with_auth_mode(mut self, auth_mode: BackendAuthMode) -> Self {
115        self.auth_mode = auth_mode;
116        self
117    }
118
119    pub fn with_header_provider(mut self, provider: HeaderProvider) -> Self {
120        self.header_provider = Some(provider);
121        self
122    }
123
124    pub fn with_oauth_app_name(mut self, app_name: impl Into<String>) -> Self {
125        self.oauth_app_name = Some(app_name.into());
126        self
127    }
128
129    pub fn has_dynamic_headers(&self) -> bool {
130        self.header_provider.is_some()
131    }
132
133    pub fn has_authorization_header(&self) -> bool {
134        self.headers
135            .keys()
136            .any(|name| name.eq_ignore_ascii_case("authorization"))
137    }
138
139    pub fn should_use_oauth(&self) -> bool {
140        self.transport == BackendTransport::StreamableHttp
141            && match self.auth_mode {
142                BackendAuthMode::Auto => {
143                    !self.has_authorization_header() && !self.has_dynamic_headers()
144                }
145                BackendAuthMode::ExplicitHeaders => false,
146                BackendAuthMode::OAuth => true,
147            }
148    }
149}
150
151pub fn backend_http_headers(
152    backend: &BackendServerConfig,
153) -> Result<HashMap<HeaderName, HeaderValue>, Error> {
154    backend
155        .headers
156        .iter()
157        .map(|(name, value)| {
158            let name = HeaderName::from_str(name).map_err(|error| {
159                Error::Config(format!("invalid HTTP header name {name:?}: {error}"))
160            })?;
161            let value = HeaderValue::from_str(value).map_err(|error| {
162                Error::Config(format!("invalid HTTP header value for {name:?}: {error}"))
163            })?;
164            Ok((name, value))
165        })
166        .collect()
167}
168
169#[derive(Debug, Default)]
170struct ParsedBackendArgs {
171    args: Vec<String>,
172    env: HashMap<String, String>,
173    cwd: Option<PathBuf>,
174    timeout: Option<Duration>,
175    headers: HashMap<String, String>,
176    auth_mode: BackendAuthMode,
177}
178
179fn parse_backend_args(args: Vec<String>, transport: BackendTransport) -> ParsedBackendArgs {
180    let mut parsed = ParsedBackendArgs {
181        auth_mode: BackendAuthMode::Auto,
182        ..Default::default()
183    };
184    let mut index = 0;
185    while index < args.len() {
186        let arg = &args[index];
187        if arg == "-H" || arg == "--header" {
188            if let Some(header) = args.get(index + 1) {
189                if transport == BackendTransport::StreamableHttp {
190                    if let Some((name, value)) = parse_header_arg(header) {
191                        parsed.headers.insert(name, value);
192                    } else {
193                        parsed.args.push(arg.clone());
194                        parsed.args.push(header.clone());
195                    }
196                } else {
197                    parsed.args.push(arg.clone());
198                    parsed.args.push(header.clone());
199                }
200                index += 2;
201            } else {
202                parsed.args.push(arg.clone());
203                index += 1;
204            }
205        } else if let Some(header) = arg
206            .strip_prefix("-H=")
207            .or_else(|| arg.strip_prefix("--header="))
208        {
209            if transport == BackendTransport::StreamableHttp {
210                if let Some((name, value)) = parse_header_arg(header) {
211                    parsed.headers.insert(name, value);
212                } else {
213                    parsed.args.push(arg.clone());
214                }
215            } else {
216                parsed.args.push(arg.clone());
217            }
218            index += 1;
219        } else if let Some(cwd) = arg.strip_prefix("--cwd=") {
220            parsed.cwd = Some(PathBuf::from(cwd));
221            index += 1;
222        } else if arg == "--cwd" {
223            if let Some(cwd) = args.get(index + 1) {
224                parsed.cwd = Some(PathBuf::from(cwd));
225                index += 2;
226            } else {
227                parsed.args.push(arg.clone());
228                index += 1;
229            }
230        } else if arg == "-e" || arg == "--env" {
231            if let Some(env) = args.get(index + 1) {
232                if let Some((key, value)) = parse_key_value_arg(env) {
233                    parsed.env.insert(key, interpolate_env(&value));
234                } else {
235                    parsed.args.push(arg.clone());
236                    parsed.args.push(env.clone());
237                }
238                index += 2;
239            } else {
240                parsed.args.push(arg.clone());
241                index += 1;
242            }
243        } else if let Some(env) = arg
244            .strip_prefix("-e=")
245            .or_else(|| arg.strip_prefix("--env="))
246        {
247            if let Some((key, value)) = parse_key_value_arg(env) {
248                parsed.env.insert(key, interpolate_env(&value));
249            } else {
250                parsed.args.push(arg.clone());
251            }
252            index += 1;
253        } else if arg == "-t" || arg == "--timeout" {
254            if let Some(timeout) = args.get(index + 1) {
255                if let Ok(seconds) = timeout.parse::<f64>() {
256                    if seconds.is_finite() && seconds > 0.0 {
257                        parsed.timeout = Some(Duration::from_secs_f64(seconds));
258                    } else {
259                        parsed.args.push(arg.clone());
260                        parsed.args.push(timeout.clone());
261                    }
262                } else {
263                    parsed.args.push(arg.clone());
264                    parsed.args.push(timeout.clone());
265                }
266                index += 2;
267            } else {
268                parsed.args.push(arg.clone());
269                index += 1;
270            }
271        } else if let Some(timeout) = arg
272            .strip_prefix("-t=")
273            .or_else(|| arg.strip_prefix("--timeout="))
274        {
275            if let Ok(seconds) = timeout.parse::<f64>() {
276                if seconds.is_finite() && seconds > 0.0 {
277                    parsed.timeout = Some(Duration::from_secs_f64(seconds));
278                } else {
279                    parsed.args.push(arg.clone());
280                }
281            } else {
282                parsed.args.push(arg.clone());
283            }
284            index += 1;
285        } else if let Some(mode) = arg.strip_prefix("--auth=") {
286            match mode {
287                "explicit-headers" | "headers" | "none" => {
288                    parsed.auth_mode = BackendAuthMode::ExplicitHeaders;
289                }
290                "oauth" => {
291                    parsed.auth_mode = BackendAuthMode::OAuth;
292                }
293                _ => parsed.args.push(arg.clone()),
294            }
295            index += 1;
296        } else if arg == "--auth" {
297            if let Some(mode) = args.get(index + 1) {
298                match mode.as_str() {
299                    "explicit-headers" | "headers" | "none" => {
300                        parsed.auth_mode = BackendAuthMode::ExplicitHeaders;
301                    }
302                    "oauth" => {
303                        parsed.auth_mode = BackendAuthMode::OAuth;
304                    }
305                    _ => {
306                        parsed.args.push(arg.clone());
307                        parsed.args.push(mode.clone());
308                    }
309                }
310                index += 2;
311            } else {
312                parsed.args.push(arg.clone());
313                index += 1;
314            }
315        } else {
316            parsed.args.push(arg.clone());
317            index += 1;
318        }
319    }
320    parsed
321}
322
323fn parse_header_arg(header: &str) -> Option<(String, String)> {
324    let (name, value) = header.split_once('=').or_else(|| header.split_once(':'))?;
325    let name = name.trim();
326    let value = value.trim();
327    if name.is_empty() || value.is_empty() {
328        return None;
329    }
330    Some((name.to_string(), interpolate_env(value)))
331}
332
333fn parse_key_value_arg(value: &str) -> Option<(String, String)> {
334    let (key, value) = value.split_once('=')?;
335    let key = key.trim();
336    if key.is_empty() {
337        return None;
338    }
339    Some((key.to_string(), value.to_string()))
340}
341
342fn interpolate_env(value: &str) -> String {
343    let mut output = String::new();
344    let chars = value.chars().collect::<Vec<_>>();
345    let mut index = 0;
346    while index < chars.len() {
347        if chars[index] == '$' && chars.get(index + 1) == Some(&'{') {
348            if let Some(end) = chars[index + 2..].iter().position(|ch| *ch == '}') {
349                let name = chars[index + 2..index + 2 + end].iter().collect::<String>();
350                output.push_str(&std::env::var(&name).unwrap_or_else(|_| format!("${{{name}}}")));
351                index += end + 3;
352                continue;
353            }
354        }
355        output.push(chars[index]);
356        index += 1;
357    }
358    output
359}
360
361fn is_http_url(value: &str) -> bool {
362    value.starts_with("http://") || value.starts_with("https://")
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[test]
370    fn http_backend_url_parses_curl_style_headers_after_separator() {
371        let backend = BackendServerConfig::new(
372            "remote",
373            "https://example.test/mcp",
374            ["-H", "Authorization=Basic token", "--header", "X-Test=yes"],
375        );
376
377        assert_eq!(backend.transport, BackendTransport::StreamableHttp);
378        assert!(backend.args.is_empty());
379        assert_eq!(backend.headers["Authorization"], "Basic token");
380        assert_eq!(backend.headers["X-Test"], "yes");
381    }
382
383    #[test]
384    fn http_backend_url_parses_equals_header_forms() {
385        let backend = BackendServerConfig::new(
386            "remote",
387            "https://example.test/mcp",
388            ["-H=Authorization=Bearer token", "--header=X-Test=yes"],
389        );
390
391        assert!(backend.args.is_empty());
392        assert_eq!(backend.headers["Authorization"], "Bearer token");
393        assert_eq!(backend.headers["X-Test"], "yes");
394    }
395
396    #[test]
397    fn http_backend_header_values_preserve_missing_environment_variables() {
398        let backend = BackendServerConfig::new(
399            "remote",
400            "https://example.test/mcp",
401            [
402                "-H",
403                "Authorization=Bearer ${MCP_COMPRESSOR_MISSING_TEST_TOKEN}",
404            ],
405        );
406
407        assert_eq!(
408            backend.headers["Authorization"],
409            "Bearer ${MCP_COMPRESSOR_MISSING_TEST_TOKEN}"
410        );
411    }
412
413    #[test]
414    fn remote_http_auto_auth_uses_oauth_without_authorization_header() {
415        let backend =
416            BackendServerConfig::new("remote", "https://example.test/mcp", [] as [&str; 0]);
417
418        assert!(backend.should_use_oauth());
419    }
420
421    #[test]
422    fn remote_http_auto_auth_skips_oauth_with_authorization_header() {
423        let backend = BackendServerConfig::new(
424            "remote",
425            "https://example.test/mcp",
426            ["-H", "Authorization=Basic token"],
427        );
428
429        assert!(backend.has_authorization_header());
430        assert!(!backend.should_use_oauth());
431    }
432
433    #[test]
434    fn http_backend_url_parses_auth_mode_args() {
435        let explicit = BackendServerConfig::new(
436            "remote",
437            "https://example.test/mcp",
438            ["--auth", "explicit-headers"],
439        );
440        let oauth =
441            BackendServerConfig::new("remote", "https://example.test/mcp", ["--auth=oauth"]);
442
443        assert_eq!(explicit.auth_mode, BackendAuthMode::ExplicitHeaders);
444        assert!(explicit.args.is_empty());
445        assert_eq!(oauth.auth_mode, BackendAuthMode::OAuth);
446        assert!(oauth.args.is_empty());
447    }
448
449    #[test]
450    fn explicit_headers_auth_mode_skips_oauth_without_authorization_header() {
451        let backend =
452            BackendServerConfig::new("remote", "https://example.test/mcp", [] as [&str; 0])
453                .with_auth_mode(BackendAuthMode::ExplicitHeaders);
454
455        assert!(!backend.should_use_oauth());
456    }
457
458    #[test]
459    fn forced_oauth_auth_mode_uses_oauth_even_with_authorization_header() {
460        let backend = BackendServerConfig::new(
461            "remote",
462            "https://example.test/mcp",
463            ["-H", "Authorization=Basic token"],
464        )
465        .with_auth_mode(BackendAuthMode::OAuth);
466
467        assert!(backend.should_use_oauth());
468    }
469
470    #[test]
471    fn stdio_backend_never_uses_oauth() {
472        let backend = BackendServerConfig::new("local", "python", ["server.py"]);
473
474        assert!(!backend.should_use_oauth());
475    }
476
477    #[test]
478    fn backend_args_parse_cwd_env_and_timeout_after_separator() {
479        let backend = BackendServerConfig::new(
480            "local",
481            "python",
482            [
483                "server.py",
484                "--cwd",
485                "/tmp/example",
486                "-e",
487                "FOO=bar",
488                "--env=EMPTY=",
489                "-t",
490                "2.5",
491            ],
492        );
493
494        assert_eq!(backend.args, ["server.py"]);
495        assert_eq!(
496            backend.cwd.as_deref(),
497            Some(std::path::Path::new("/tmp/example"))
498        );
499        assert_eq!(backend.env["FOO"], "bar");
500        assert_eq!(backend.env["EMPTY"], "");
501        assert_eq!(backend.timeout, Some(Duration::from_secs_f64(2.5)));
502    }
503
504    #[test]
505    fn backend_args_preserve_invalid_timeout_for_backend_validation() {
506        let backend = BackendServerConfig::new("local", "python", ["server.py", "--timeout", "0"]);
507
508        assert_eq!(backend.args, ["server.py", "--timeout", "0"]);
509        assert_eq!(backend.timeout, None);
510    }
511
512    #[test]
513    fn http_backend_url_preserves_unrecognized_args_for_validation() {
514        let backend = BackendServerConfig::new(
515            "remote",
516            "https://example.test/mcp",
517            ["--unknown", "value", "-H"],
518        );
519
520        assert_eq!(backend.args, ["--unknown", "value", "-H"]);
521        assert!(backend.headers.is_empty());
522    }
523}