Skip to main content

objectiveai_mcp_proxy/
run.rs

1//! ObjectiveAI MCP proxy server.
2//!
3//! Multiplexes a downstream MCP client across one or more upstream MCP
4//! servers selected per-request via `X-MCP-Servers` /
5//! `X-MCP-Headers`.
6//!
7//! Mirrors the `objectiveai-api` `run.rs` shape so other crates can
8//! `use objectiveai_mcp_proxy::{ConfigBuilder, run}` and spawn the
9//! server in-process without going through the binary.
10
11use std::sync::Arc;
12use std::time::Duration;
13
14use envconfig::Envconfig;
15use objectiveai_sdk::mcp::Client;
16
17use crate::session_manager::SessionManager;
18use crate::{AppState, mcp};
19
20#[derive(Envconfig)]
21struct EnvConfigBuilder {
22    #[envconfig(from = "ADDRESS")]
23    address: Option<String>,
24    #[envconfig(from = "PORT")]
25    port: Option<u16>,
26    #[envconfig(from = "USER_AGENT")]
27    user_agent: Option<String>,
28    #[envconfig(from = "HTTP_REFERER")]
29    http_referer: Option<String>,
30    #[envconfig(from = "X_TITLE")]
31    x_title: Option<String>,
32    #[envconfig(from = "MCP_CONNECT_TIMEOUT")]
33    mcp_connect_timeout: Option<u64>,
34    #[envconfig(from = "MCP_CALL_TIMEOUT")]
35    mcp_call_timeout: Option<u64>,
36    #[envconfig(from = "MCP_BACKOFF_CURRENT_INTERVAL")]
37    mcp_backoff_current_interval: Option<u64>,
38    #[envconfig(from = "MCP_BACKOFF_INITIAL_INTERVAL")]
39    mcp_backoff_initial_interval: Option<u64>,
40    #[envconfig(from = "MCP_BACKOFF_RANDOMIZATION_FACTOR")]
41    mcp_backoff_randomization_factor: Option<f64>,
42    #[envconfig(from = "MCP_BACKOFF_MULTIPLIER")]
43    mcp_backoff_multiplier: Option<f64>,
44    #[envconfig(from = "MCP_BACKOFF_MAX_INTERVAL")]
45    mcp_backoff_max_interval: Option<u64>,
46    #[envconfig(from = "MCP_BACKOFF_MAX_ELAPSED_TIME")]
47    mcp_backoff_max_elapsed_time: Option<u64>,
48    /// Base64-encoded 32-byte key. Used to AEAD-encrypt the proxy
49    /// session id payload (per-upstream `Mcp-Session-Id` +
50    /// `Authorization` + custom headers).
51    ///
52    /// Rotation: set a new key, restart the proxy. All outstanding
53    /// session ids minted under the old key become 401s; clients
54    /// re-initialize.
55    ///
56    /// Unset → the proxy generates one ephemeral 32-byte key on
57    /// startup. Sessions minted by such a process can't be decoded by
58    /// any other process or after a restart — which is fine for tests
59    /// and dev but bad for production.
60    #[envconfig(from = "MCP_ENCRYPTION_KEY")]
61    mcp_encryption_key: Option<String>,
62    #[envconfig(from = "SUPPRESS_OUTPUT")]
63    suppress_output: Option<String>,
64}
65
66impl EnvConfigBuilder {
67    fn build(self) -> ConfigBuilder {
68        ConfigBuilder {
69            address: self.address,
70            port: self.port,
71            user_agent: self.user_agent,
72            http_referer: self.http_referer,
73            x_title: self.x_title,
74            mcp_connect_timeout: self.mcp_connect_timeout,
75            mcp_call_timeout: self.mcp_call_timeout,
76            mcp_backoff_current_interval: self.mcp_backoff_current_interval,
77            mcp_backoff_initial_interval: self.mcp_backoff_initial_interval,
78            mcp_backoff_randomization_factor: self.mcp_backoff_randomization_factor,
79            mcp_backoff_multiplier: self.mcp_backoff_multiplier,
80            mcp_backoff_max_interval: self.mcp_backoff_max_interval,
81            mcp_backoff_max_elapsed_time: self.mcp_backoff_max_elapsed_time,
82            mcp_encryption_key: match self.mcp_encryption_key.as_deref() {
83                Some(s) => match crate::session_manager::parse_key_env(s) {
84                    Ok(opt) => opt,
85                    Err(e) => {
86                        tracing::error!(error = %e, "MCP_ENCRYPTION_KEY parse failed; falling back to ephemeral key");
87                        None
88                    }
89                },
90                None => None,
91            },
92            suppress_output: self.suppress_output.map(|v| {
93                matches!(v.to_ascii_lowercase().as_str(), "1" | "true" | "yes" | "on")
94            }),
95        }
96    }
97}
98
99#[derive(Default)]
100pub struct ConfigBuilder {
101    pub address: Option<String>,
102    pub port: Option<u16>,
103    pub user_agent: Option<String>,
104    pub http_referer: Option<String>,
105    pub x_title: Option<String>,
106    pub mcp_connect_timeout: Option<u64>,
107    pub mcp_call_timeout: Option<u64>,
108    pub mcp_backoff_current_interval: Option<u64>,
109    pub mcp_backoff_initial_interval: Option<u64>,
110    pub mcp_backoff_randomization_factor: Option<f64>,
111    pub mcp_backoff_multiplier: Option<f64>,
112    pub mcp_backoff_max_interval: Option<u64>,
113    pub mcp_backoff_max_elapsed_time: Option<u64>,
114    /// 256-bit AEAD key. `None` → the proxy generates one ephemeral
115    /// key per process. See [`EnvConfigBuilder`]'s `mcp_encryption_key`
116    /// doc.
117    pub mcp_encryption_key: Option<[u8; 32]>,
118    pub suppress_output: Option<bool>,
119}
120
121impl Envconfig for ConfigBuilder {
122    #[allow(deprecated)]
123    fn init() -> Result<Self, envconfig::Error> {
124        EnvConfigBuilder::init().map(|e| e.build())
125    }
126
127    fn init_from_env() -> Result<Self, envconfig::Error> {
128        EnvConfigBuilder::init_from_env().map(|e| e.build())
129    }
130
131    fn init_from_hashmap(
132        hashmap: &std::collections::HashMap<String, String>,
133    ) -> Result<Self, envconfig::Error> {
134        EnvConfigBuilder::init_from_hashmap(hashmap).map(|e| e.build())
135    }
136}
137
138impl ConfigBuilder {
139    pub fn build(self) -> Config {
140        Config {
141            address: self.address.unwrap_or_else(|| "0.0.0.0".to_string()),
142            port: self.port.unwrap_or(3000),
143            user_agent: self
144                .user_agent
145                .unwrap_or_else(|| format!("objectiveai-mcp-proxy/{}", env!("CARGO_PKG_VERSION"))),
146            http_referer: self
147                .http_referer
148                .unwrap_or_else(|| "https://objectiveai.dev".to_string()),
149            x_title: self
150                .x_title
151                .unwrap_or_else(|| "ObjectiveAI MCP Proxy".to_string()),
152            // Defaults match `objectiveai-api/src/run.rs` so the same
153            // env vars produce the same effective config when read by
154            // either binary independently.
155            mcp_connect_timeout: self.mcp_connect_timeout.unwrap_or(30000),
156            mcp_call_timeout: self.mcp_call_timeout.unwrap_or(30000),
157            mcp_backoff_current_interval: self.mcp_backoff_current_interval.unwrap_or(100),
158            mcp_backoff_initial_interval: self.mcp_backoff_initial_interval.unwrap_or(100),
159            mcp_backoff_randomization_factor: self.mcp_backoff_randomization_factor.unwrap_or(0.5),
160            mcp_backoff_multiplier: self.mcp_backoff_multiplier.unwrap_or(1.5),
161            mcp_backoff_max_interval: self.mcp_backoff_max_interval.unwrap_or(1000),
162            mcp_backoff_max_elapsed_time: self.mcp_backoff_max_elapsed_time.unwrap_or(40000),
163            mcp_encryption_key: self.mcp_encryption_key,
164            suppress_output: self.suppress_output.unwrap_or(false),
165        }
166    }
167}
168
169pub struct Config {
170    pub address: String,
171    pub port: u16,
172    pub user_agent: String,
173    pub http_referer: String,
174    pub x_title: String,
175    pub mcp_connect_timeout: u64,
176    pub mcp_call_timeout: u64,
177    pub mcp_backoff_current_interval: u64,
178    pub mcp_backoff_initial_interval: u64,
179    pub mcp_backoff_randomization_factor: f64,
180    pub mcp_backoff_multiplier: f64,
181    pub mcp_backoff_max_interval: u64,
182    pub mcp_backoff_max_elapsed_time: u64,
183    /// `None` → caller / proxy will generate one ephemeral key.
184    pub mcp_encryption_key: Option<[u8; 32]>,
185    pub suppress_output: bool,
186}
187
188pub async fn setup(config: Config) -> std::io::Result<(tokio::net::TcpListener, axum::Router)> {
189    let Config {
190        address,
191        port,
192        user_agent,
193        http_referer,
194        x_title,
195        mcp_connect_timeout,
196        mcp_call_timeout,
197        mcp_backoff_current_interval,
198        mcp_backoff_initial_interval,
199        mcp_backoff_randomization_factor,
200        mcp_backoff_multiplier,
201        mcp_backoff_max_interval,
202        mcp_backoff_max_elapsed_time,
203        mcp_encryption_key,
204        suppress_output: _,
205    } = config;
206
207    let client = Client::new(
208        reqwest::Client::new(),
209        user_agent,
210        x_title,
211        http_referer,
212        Duration::from_millis(mcp_connect_timeout),
213        Duration::from_millis(mcp_backoff_current_interval),
214        Duration::from_millis(mcp_backoff_initial_interval),
215        mcp_backoff_randomization_factor,
216        mcp_backoff_multiplier,
217        Duration::from_millis(mcp_backoff_max_interval),
218        Duration::from_millis(mcp_backoff_max_elapsed_time),
219        Duration::from_millis(mcp_call_timeout),
220    );
221
222    let sessions = match mcp_encryption_key {
223        Some(key) => SessionManager::new(key),
224        None => SessionManager::with_ephemeral_key(),
225    };
226    let state = AppState {
227        sessions: Arc::new(sessions),
228        client: Arc::new(client),
229    };
230
231    let router = axum::Router::new()
232        .route(
233            "/",
234            axum::routing::post(mcp::handle_post)
235                .get(mcp::handle_get)
236                .delete(mcp::handle_delete),
237        )
238        .route("/notify", axum::routing::post(mcp::handle_notify))
239        .with_state(state);
240
241    let listener = tokio::net::TcpListener::bind(format!("{address}:{port}")).await?;
242
243    Ok((listener, router))
244}
245
246pub async fn serve(listener: tokio::net::TcpListener, app: axum::Router) -> std::io::Result<()> {
247    axum::serve(listener, app).await
248}
249
250pub async fn run(config: Config) -> std::io::Result<()> {
251    let suppress_output = config.suppress_output;
252    let (listener, app) = setup(config).await?;
253    if !suppress_output {
254        let addr = listener.local_addr()?;
255        eprintln!("listening on {addr}");
256    }
257    serve(listener, app).await
258}