1use std::sync::Arc;
12use std::time::Duration;
13
14use envconfig::Envconfig;
15use objectiveai_sdk::mcp::Client;
16
17use crate::session_manager::SessionManager;
18use crate::{AppState, mcp};
19
20#[derive(Envconfig)]
21struct EnvConfigBuilder {
22 #[envconfig(from = "ADDRESS")]
23 address: Option<String>,
24 #[envconfig(from = "PORT")]
25 port: Option<u16>,
26 #[envconfig(from = "USER_AGENT")]
27 user_agent: Option<String>,
28 #[envconfig(from = "HTTP_REFERER")]
29 http_referer: Option<String>,
30 #[envconfig(from = "X_TITLE")]
31 x_title: Option<String>,
32 #[envconfig(from = "MCP_CONNECT_TIMEOUT")]
33 mcp_connect_timeout: Option<u64>,
34 #[envconfig(from = "MCP_CALL_TIMEOUT")]
35 mcp_call_timeout: Option<u64>,
36 #[envconfig(from = "MCP_BACKOFF_CURRENT_INTERVAL")]
37 mcp_backoff_current_interval: Option<u64>,
38 #[envconfig(from = "MCP_BACKOFF_INITIAL_INTERVAL")]
39 mcp_backoff_initial_interval: Option<u64>,
40 #[envconfig(from = "MCP_BACKOFF_RANDOMIZATION_FACTOR")]
41 mcp_backoff_randomization_factor: Option<f64>,
42 #[envconfig(from = "MCP_BACKOFF_MULTIPLIER")]
43 mcp_backoff_multiplier: Option<f64>,
44 #[envconfig(from = "MCP_BACKOFF_MAX_INTERVAL")]
45 mcp_backoff_max_interval: Option<u64>,
46 #[envconfig(from = "MCP_BACKOFF_MAX_ELAPSED_TIME")]
47 mcp_backoff_max_elapsed_time: Option<u64>,
48 #[envconfig(from = "MCP_ENCRYPTION_KEY")]
61 mcp_encryption_key: Option<String>,
62 #[envconfig(from = "SUPPRESS_OUTPUT")]
63 suppress_output: Option<String>,
64}
65
66impl EnvConfigBuilder {
67 fn build(self) -> ConfigBuilder {
68 ConfigBuilder {
69 address: self.address,
70 port: self.port,
71 user_agent: self.user_agent,
72 http_referer: self.http_referer,
73 x_title: self.x_title,
74 mcp_connect_timeout: self.mcp_connect_timeout,
75 mcp_call_timeout: self.mcp_call_timeout,
76 mcp_backoff_current_interval: self.mcp_backoff_current_interval,
77 mcp_backoff_initial_interval: self.mcp_backoff_initial_interval,
78 mcp_backoff_randomization_factor: self.mcp_backoff_randomization_factor,
79 mcp_backoff_multiplier: self.mcp_backoff_multiplier,
80 mcp_backoff_max_interval: self.mcp_backoff_max_interval,
81 mcp_backoff_max_elapsed_time: self.mcp_backoff_max_elapsed_time,
82 mcp_encryption_key: match self.mcp_encryption_key.as_deref() {
83 Some(s) => match crate::session_manager::parse_key_env(s) {
84 Ok(opt) => opt,
85 Err(e) => {
86 tracing::error!(error = %e, "MCP_ENCRYPTION_KEY parse failed; falling back to ephemeral key");
87 None
88 }
89 },
90 None => None,
91 },
92 suppress_output: self.suppress_output.map(|v| {
93 matches!(v.to_ascii_lowercase().as_str(), "1" | "true" | "yes" | "on")
94 }),
95 }
96 }
97}
98
99#[derive(Default)]
100pub struct ConfigBuilder {
101 pub address: Option<String>,
102 pub port: Option<u16>,
103 pub user_agent: Option<String>,
104 pub http_referer: Option<String>,
105 pub x_title: Option<String>,
106 pub mcp_connect_timeout: Option<u64>,
107 pub mcp_call_timeout: Option<u64>,
108 pub mcp_backoff_current_interval: Option<u64>,
109 pub mcp_backoff_initial_interval: Option<u64>,
110 pub mcp_backoff_randomization_factor: Option<f64>,
111 pub mcp_backoff_multiplier: Option<f64>,
112 pub mcp_backoff_max_interval: Option<u64>,
113 pub mcp_backoff_max_elapsed_time: Option<u64>,
114 pub mcp_encryption_key: Option<[u8; 32]>,
118 pub suppress_output: Option<bool>,
119}
120
121impl Envconfig for ConfigBuilder {
122 #[allow(deprecated)]
123 fn init() -> Result<Self, envconfig::Error> {
124 EnvConfigBuilder::init().map(|e| e.build())
125 }
126
127 fn init_from_env() -> Result<Self, envconfig::Error> {
128 EnvConfigBuilder::init_from_env().map(|e| e.build())
129 }
130
131 fn init_from_hashmap(
132 hashmap: &std::collections::HashMap<String, String>,
133 ) -> Result<Self, envconfig::Error> {
134 EnvConfigBuilder::init_from_hashmap(hashmap).map(|e| e.build())
135 }
136}
137
138impl ConfigBuilder {
139 pub fn build(self) -> Config {
140 Config {
141 address: self.address.unwrap_or_else(|| "0.0.0.0".to_string()),
142 port: self.port.unwrap_or(3000),
143 user_agent: self
144 .user_agent
145 .unwrap_or_else(|| format!("objectiveai-mcp-proxy/{}", env!("CARGO_PKG_VERSION"))),
146 http_referer: self
147 .http_referer
148 .unwrap_or_else(|| "https://objectiveai.dev".to_string()),
149 x_title: self
150 .x_title
151 .unwrap_or_else(|| "ObjectiveAI MCP Proxy".to_string()),
152 mcp_connect_timeout: self.mcp_connect_timeout.unwrap_or(30000),
156 mcp_call_timeout: self.mcp_call_timeout.unwrap_or(30000),
157 mcp_backoff_current_interval: self.mcp_backoff_current_interval.unwrap_or(100),
158 mcp_backoff_initial_interval: self.mcp_backoff_initial_interval.unwrap_or(100),
159 mcp_backoff_randomization_factor: self.mcp_backoff_randomization_factor.unwrap_or(0.5),
160 mcp_backoff_multiplier: self.mcp_backoff_multiplier.unwrap_or(1.5),
161 mcp_backoff_max_interval: self.mcp_backoff_max_interval.unwrap_or(1000),
162 mcp_backoff_max_elapsed_time: self.mcp_backoff_max_elapsed_time.unwrap_or(40000),
163 mcp_encryption_key: self.mcp_encryption_key,
164 suppress_output: self.suppress_output.unwrap_or(false),
165 }
166 }
167}
168
169pub struct Config {
170 pub address: String,
171 pub port: u16,
172 pub user_agent: String,
173 pub http_referer: String,
174 pub x_title: String,
175 pub mcp_connect_timeout: u64,
176 pub mcp_call_timeout: u64,
177 pub mcp_backoff_current_interval: u64,
178 pub mcp_backoff_initial_interval: u64,
179 pub mcp_backoff_randomization_factor: f64,
180 pub mcp_backoff_multiplier: f64,
181 pub mcp_backoff_max_interval: u64,
182 pub mcp_backoff_max_elapsed_time: u64,
183 pub mcp_encryption_key: Option<[u8; 32]>,
185 pub suppress_output: bool,
186}
187
188pub async fn setup(config: Config) -> std::io::Result<(tokio::net::TcpListener, axum::Router)> {
189 let Config {
190 address,
191 port,
192 user_agent,
193 http_referer,
194 x_title,
195 mcp_connect_timeout,
196 mcp_call_timeout,
197 mcp_backoff_current_interval,
198 mcp_backoff_initial_interval,
199 mcp_backoff_randomization_factor,
200 mcp_backoff_multiplier,
201 mcp_backoff_max_interval,
202 mcp_backoff_max_elapsed_time,
203 mcp_encryption_key,
204 suppress_output: _,
205 } = config;
206
207 let client = Client::new(
208 reqwest::Client::new(),
209 user_agent,
210 x_title,
211 http_referer,
212 Duration::from_millis(mcp_connect_timeout),
213 Duration::from_millis(mcp_backoff_current_interval),
214 Duration::from_millis(mcp_backoff_initial_interval),
215 mcp_backoff_randomization_factor,
216 mcp_backoff_multiplier,
217 Duration::from_millis(mcp_backoff_max_interval),
218 Duration::from_millis(mcp_backoff_max_elapsed_time),
219 Duration::from_millis(mcp_call_timeout),
220 );
221
222 let sessions = match mcp_encryption_key {
223 Some(key) => SessionManager::new(key),
224 None => SessionManager::with_ephemeral_key(),
225 };
226 let state = AppState {
227 sessions: Arc::new(sessions),
228 client: Arc::new(client),
229 };
230
231 let router = axum::Router::new()
232 .route(
233 "/",
234 axum::routing::post(mcp::handle_post)
235 .get(mcp::handle_get)
236 .delete(mcp::handle_delete),
237 )
238 .route("/notify", axum::routing::post(mcp::handle_notify))
239 .with_state(state);
240
241 let listener = tokio::net::TcpListener::bind(format!("{address}:{port}")).await?;
242
243 Ok((listener, router))
244}
245
246pub async fn serve(listener: tokio::net::TcpListener, app: axum::Router) -> std::io::Result<()> {
247 axum::serve(listener, app).await
248}
249
250pub async fn run(config: Config) -> std::io::Result<()> {
251 let suppress_output = config.suppress_output;
252 let (listener, app) = setup(config).await?;
253 if !suppress_output {
254 let addr = listener.local_addr()?;
255 eprintln!("listening on {addr}");
256 }
257 serve(listener, app).await
258}