1use std::sync::Arc;
12use std::time::Duration;
13
14use envconfig::Envconfig;
15use objectiveai_sdk::mcp::Client;
16
17use crate::session_manager::SessionManager;
18use crate::{AppState, mcp};
19
20#[derive(Envconfig)]
21struct EnvConfigBuilder {
22 #[envconfig(from = "ADDRESS")]
23 address: Option<String>,
24 #[envconfig(from = "PORT")]
25 port: Option<u16>,
26 #[envconfig(from = "USER_AGENT")]
27 user_agent: Option<String>,
28 #[envconfig(from = "HTTP_REFERER")]
29 http_referer: Option<String>,
30 #[envconfig(from = "X_TITLE")]
31 x_title: Option<String>,
32 #[envconfig(from = "MCP_CONNECT_TIMEOUT")]
33 mcp_connect_timeout: Option<u64>,
34 #[envconfig(from = "MCP_CALL_TIMEOUT")]
35 mcp_call_timeout: Option<u64>,
36 #[envconfig(from = "MCP_BACKOFF_CURRENT_INTERVAL")]
37 mcp_backoff_current_interval: Option<u64>,
38 #[envconfig(from = "MCP_BACKOFF_INITIAL_INTERVAL")]
39 mcp_backoff_initial_interval: Option<u64>,
40 #[envconfig(from = "MCP_BACKOFF_RANDOMIZATION_FACTOR")]
41 mcp_backoff_randomization_factor: Option<f64>,
42 #[envconfig(from = "MCP_BACKOFF_MULTIPLIER")]
43 mcp_backoff_multiplier: Option<f64>,
44 #[envconfig(from = "MCP_BACKOFF_MAX_INTERVAL")]
45 mcp_backoff_max_interval: Option<u64>,
46 #[envconfig(from = "MCP_BACKOFF_MAX_ELAPSED_TIME")]
47 mcp_backoff_max_elapsed_time: Option<u64>,
48 #[envconfig(from = "MCP_ENCRYPTION_KEY")]
61 mcp_encryption_key: Option<String>,
62 #[envconfig(from = "SUPPRESS_OUTPUT")]
63 suppress_output: Option<String>,
64 #[envconfig(from = "OBJECTIVEAI_LOGS_DIR")]
68 logs_dir: Option<String>,
69}
70
71impl EnvConfigBuilder {
72 fn build(self) -> ConfigBuilder {
73 ConfigBuilder {
74 address: self.address,
75 port: self.port,
76 user_agent: self.user_agent,
77 http_referer: self.http_referer,
78 x_title: self.x_title,
79 mcp_connect_timeout: self.mcp_connect_timeout,
80 mcp_call_timeout: self.mcp_call_timeout,
81 mcp_backoff_current_interval: self.mcp_backoff_current_interval,
82 mcp_backoff_initial_interval: self.mcp_backoff_initial_interval,
83 mcp_backoff_randomization_factor: self.mcp_backoff_randomization_factor,
84 mcp_backoff_multiplier: self.mcp_backoff_multiplier,
85 mcp_backoff_max_interval: self.mcp_backoff_max_interval,
86 mcp_backoff_max_elapsed_time: self.mcp_backoff_max_elapsed_time,
87 mcp_encryption_key: match self.mcp_encryption_key.as_deref() {
88 Some(s) => match crate::session_manager::parse_key_env(s) {
89 Ok(opt) => opt,
90 Err(e) => {
91 tracing::error!(error = %e, "MCP_ENCRYPTION_KEY parse failed; falling back to ephemeral key");
92 None
93 }
94 },
95 None => None,
96 },
97 suppress_output: self.suppress_output.map(|v| {
98 matches!(v.to_ascii_lowercase().as_str(), "1" | "true" | "yes" | "on")
99 }),
100 logs_dir: self.logs_dir,
101 }
102 }
103}
104
105#[derive(Default)]
106pub struct ConfigBuilder {
107 pub address: Option<String>,
108 pub port: Option<u16>,
109 pub user_agent: Option<String>,
110 pub http_referer: Option<String>,
111 pub x_title: Option<String>,
112 pub mcp_connect_timeout: Option<u64>,
113 pub mcp_call_timeout: Option<u64>,
114 pub mcp_backoff_current_interval: Option<u64>,
115 pub mcp_backoff_initial_interval: Option<u64>,
116 pub mcp_backoff_randomization_factor: Option<f64>,
117 pub mcp_backoff_multiplier: Option<f64>,
118 pub mcp_backoff_max_interval: Option<u64>,
119 pub mcp_backoff_max_elapsed_time: Option<u64>,
120 pub mcp_encryption_key: Option<[u8; 32]>,
124 pub suppress_output: Option<bool>,
125 pub logs_dir: Option<String>,
128}
129
130impl Envconfig for ConfigBuilder {
131 #[allow(deprecated)]
132 fn init() -> Result<Self, envconfig::Error> {
133 EnvConfigBuilder::init().map(|e| e.build())
134 }
135
136 fn init_from_env() -> Result<Self, envconfig::Error> {
137 EnvConfigBuilder::init_from_env().map(|e| e.build())
138 }
139
140 fn init_from_hashmap(
141 hashmap: &std::collections::HashMap<String, String>,
142 ) -> Result<Self, envconfig::Error> {
143 EnvConfigBuilder::init_from_hashmap(hashmap).map(|e| e.build())
144 }
145}
146
147impl ConfigBuilder {
148 pub fn build(self) -> Config {
149 Config {
150 address: self.address.unwrap_or_else(|| "0.0.0.0".to_string()),
151 port: self.port.unwrap_or(3000),
152 user_agent: self
153 .user_agent
154 .unwrap_or_else(|| format!("objectiveai-mcp-proxy/{}", env!("CARGO_PKG_VERSION"))),
155 http_referer: self
156 .http_referer
157 .unwrap_or_else(|| "https://objectiveai.dev".to_string()),
158 x_title: self
159 .x_title
160 .unwrap_or_else(|| "ObjectiveAI MCP Proxy".to_string()),
161 mcp_connect_timeout: self.mcp_connect_timeout.unwrap_or(60000),
165 mcp_call_timeout: self.mcp_call_timeout.unwrap_or(60000),
166 mcp_backoff_current_interval: self.mcp_backoff_current_interval.unwrap_or(100),
167 mcp_backoff_initial_interval: self.mcp_backoff_initial_interval.unwrap_or(100),
168 mcp_backoff_randomization_factor: self.mcp_backoff_randomization_factor.unwrap_or(0.5),
169 mcp_backoff_multiplier: self.mcp_backoff_multiplier.unwrap_or(1.5),
170 mcp_backoff_max_interval: self.mcp_backoff_max_interval.unwrap_or(1000),
171 mcp_backoff_max_elapsed_time: self.mcp_backoff_max_elapsed_time.unwrap_or(40000),
172 mcp_encryption_key: self.mcp_encryption_key,
173 suppress_output: self.suppress_output.unwrap_or(false),
174 logs_dir: self.logs_dir.map(std::path::PathBuf::from),
175 }
176 }
177}
178
179pub struct Config {
180 pub address: String,
181 pub port: u16,
182 pub user_agent: String,
183 pub http_referer: String,
184 pub x_title: String,
185 pub mcp_connect_timeout: u64,
186 pub mcp_call_timeout: u64,
187 pub mcp_backoff_current_interval: u64,
188 pub mcp_backoff_initial_interval: u64,
189 pub mcp_backoff_randomization_factor: f64,
190 pub mcp_backoff_multiplier: f64,
191 pub mcp_backoff_max_interval: u64,
192 pub mcp_backoff_max_elapsed_time: u64,
193 pub mcp_encryption_key: Option<[u8; 32]>,
195 pub suppress_output: bool,
196 pub logs_dir: Option<std::path::PathBuf>,
199}
200
201pub async fn setup(
202 config: Config,
203 queue_delegate: Option<std::sync::Arc<dyn crate::QueueDelegate>>,
204) -> std::io::Result<(tokio::net::TcpListener, axum::Router)> {
205 let Config {
206 address,
207 port,
208 user_agent,
209 http_referer,
210 x_title,
211 mcp_connect_timeout,
212 mcp_call_timeout,
213 mcp_backoff_current_interval,
214 mcp_backoff_initial_interval,
215 mcp_backoff_randomization_factor,
216 mcp_backoff_multiplier,
217 mcp_backoff_max_interval,
218 mcp_backoff_max_elapsed_time,
219 mcp_encryption_key,
220 suppress_output: _,
221 logs_dir,
222 } = config;
223
224 let client = Client::new(
225 reqwest::Client::new(),
226 user_agent,
227 x_title,
228 http_referer,
229 Duration::from_millis(mcp_connect_timeout),
230 Duration::from_millis(mcp_backoff_current_interval),
231 Duration::from_millis(mcp_backoff_initial_interval),
232 mcp_backoff_randomization_factor,
233 mcp_backoff_multiplier,
234 Duration::from_millis(mcp_backoff_max_interval),
235 Duration::from_millis(mcp_backoff_max_elapsed_time),
236 Duration::from_millis(mcp_call_timeout),
237 );
238
239 let sessions = match mcp_encryption_key {
240 Some(key) => SessionManager::new(key),
241 None => SessionManager::with_ephemeral_key(),
242 };
243 let state = AppState {
244 sessions: Arc::new(sessions),
245 client: Arc::new(client),
246 queue_delegate,
247 };
248
249 let router = axum::Router::new()
250 .route(
251 "/",
252 axum::routing::post(mcp::handle_post)
253 .get(mcp::handle_get)
254 .delete(mcp::handle_delete),
255 )
256 .with_state(state);
257
258 let router = match logs_dir {
262 Some(dir) => router.layer(axum::middleware::from_fn_with_state(
263 std::sync::Arc::new(crate::logging::ProxyLogger::new(dir)),
264 crate::logging::log_layer,
265 )),
266 None => router,
267 };
268
269 let listener = tokio::net::TcpListener::bind(format!("{address}:{port}")).await?;
270
271 Ok((listener, router))
272}
273
274pub async fn serve(listener: tokio::net::TcpListener, app: axum::Router) -> std::io::Result<()> {
275 axum::serve(listener, app).await
276}
277
278pub async fn run(config: Config) -> std::io::Result<()> {
279 let suppress_output = config.suppress_output;
280 let (listener, app) = setup(config, None).await?;
282 if !suppress_output {
283 let addr = listener.local_addr()?;
284 eprintln!("listening on {addr}");
285 }
286 serve(listener, app).await
287}