1use std::sync::Arc;
12use std::time::Duration;
13
14use envconfig::Envconfig;
15use objectiveai_sdk::mcp::Client;
16
17use crate::session_manager::SessionManager;
18use crate::{AppState, mcp};
19
20const BACKOFF_INITIAL_INTERVAL_MS: u64 = 100;
25const BACKOFF_RANDOMIZATION_FACTOR: f64 = 0.5;
26const BACKOFF_MULTIPLIER: f64 = 1.5;
27const BACKOFF_MAX_INTERVAL_MS: u64 = 1000;
28const BACKOFF_MAX_ELAPSED_TIME_DEFAULT_MS: u64 = 60000;
29
30#[derive(Envconfig)]
31struct EnvConfigBuilder {
32 #[envconfig(from = "ADDRESS")]
33 address: Option<String>,
34 #[envconfig(from = "PORT")]
35 port: Option<u16>,
36 #[envconfig(from = "USER_AGENT")]
37 user_agent: Option<String>,
38 #[envconfig(from = "HTTP_REFERER")]
39 http_referer: Option<String>,
40 #[envconfig(from = "X_TITLE")]
41 x_title: Option<String>,
42 #[envconfig(from = "MCP_CONNECT_TIMEOUT")]
43 mcp_connect_timeout: Option<u64>,
44 #[envconfig(from = "MCP_CALL_TIMEOUT")]
45 mcp_call_timeout: Option<u64>,
46 #[envconfig(from = "MCP_BACKOFF_MAX_ELAPSED_TIME")]
47 mcp_backoff_max_elapsed_time: Option<u64>,
48 #[envconfig(from = "MCP_ENCRYPTION_KEY")]
61 mcp_encryption_key: Option<String>,
62 #[envconfig(from = "SUPPRESS_OUTPUT")]
63 suppress_output: Option<String>,
64 #[envconfig(from = "OBJECTIVEAI_LOGS_DIR")]
68 logs_dir: Option<String>,
69}
70
71impl EnvConfigBuilder {
72 fn build(self) -> ConfigBuilder {
73 ConfigBuilder {
74 address: self.address,
75 port: self.port,
76 user_agent: self.user_agent,
77 http_referer: self.http_referer,
78 x_title: self.x_title,
79 mcp_connect_timeout: self.mcp_connect_timeout,
80 mcp_call_timeout: self.mcp_call_timeout,
81 mcp_backoff_max_elapsed_time: self.mcp_backoff_max_elapsed_time,
82 mcp_encryption_key: match self.mcp_encryption_key.as_deref() {
83 Some(s) => match crate::session_manager::parse_key_env(s) {
84 Ok(opt) => opt,
85 Err(e) => {
86 tracing::error!(error = %e, "MCP_ENCRYPTION_KEY parse failed; falling back to ephemeral key");
87 None
88 }
89 },
90 None => None,
91 },
92 suppress_output: self.suppress_output.map(|v| {
93 matches!(v.to_ascii_lowercase().as_str(), "1" | "true" | "yes" | "on")
94 }),
95 logs_dir: self.logs_dir,
96 }
97 }
98}
99
100#[derive(Default)]
101pub struct ConfigBuilder {
102 pub address: Option<String>,
103 pub port: Option<u16>,
104 pub user_agent: Option<String>,
105 pub http_referer: Option<String>,
106 pub x_title: Option<String>,
107 pub mcp_connect_timeout: Option<u64>,
108 pub mcp_call_timeout: Option<u64>,
109 pub mcp_backoff_max_elapsed_time: Option<u64>,
110 pub mcp_encryption_key: Option<[u8; 32]>,
114 pub suppress_output: Option<bool>,
115 pub logs_dir: Option<String>,
118}
119
120impl Envconfig for ConfigBuilder {
121 #[allow(deprecated)]
122 fn init() -> Result<Self, envconfig::Error> {
123 EnvConfigBuilder::init().map(|e| e.build())
124 }
125
126 fn init_from_env() -> Result<Self, envconfig::Error> {
127 EnvConfigBuilder::init_from_env().map(|e| e.build())
128 }
129
130 fn init_from_hashmap(
131 hashmap: &std::collections::HashMap<String, String>,
132 ) -> Result<Self, envconfig::Error> {
133 EnvConfigBuilder::init_from_hashmap(hashmap).map(|e| e.build())
134 }
135}
136
137impl ConfigBuilder {
138 pub fn build(self) -> Config {
139 Config {
140 address: self.address.unwrap_or_else(|| "0.0.0.0".to_string()),
141 port: self.port.unwrap_or(3000),
142 user_agent: self
143 .user_agent
144 .unwrap_or_else(|| format!("objectiveai-mcp-proxy/{}", env!("CARGO_PKG_VERSION"))),
145 http_referer: self
146 .http_referer
147 .unwrap_or_else(|| "https://objectiveai.dev".to_string()),
148 x_title: self
149 .x_title
150 .unwrap_or_else(|| "ObjectiveAI MCP Proxy".to_string()),
151 mcp_connect_timeout: self.mcp_connect_timeout.unwrap_or(60000),
155 mcp_call_timeout: self.mcp_call_timeout.unwrap_or(60000),
156 mcp_backoff_max_elapsed_time: self.mcp_backoff_max_elapsed_time.unwrap_or(BACKOFF_MAX_ELAPSED_TIME_DEFAULT_MS),
157 mcp_encryption_key: self.mcp_encryption_key,
158 suppress_output: self.suppress_output.unwrap_or(false),
159 logs_dir: self.logs_dir.map(std::path::PathBuf::from),
160 }
161 }
162}
163
164pub struct Config {
165 pub address: String,
166 pub port: u16,
167 pub user_agent: String,
168 pub http_referer: String,
169 pub x_title: String,
170 pub mcp_connect_timeout: u64,
171 pub mcp_call_timeout: u64,
172 pub mcp_backoff_max_elapsed_time: u64,
173 pub mcp_encryption_key: Option<[u8; 32]>,
175 pub suppress_output: bool,
176 pub logs_dir: Option<std::path::PathBuf>,
179}
180
181pub async fn setup(
182 config: Config,
183 queue_delegate: Option<std::sync::Arc<dyn crate::QueueDelegate>>,
184 reverse_channel: Option<crate::ReverseChannel>,
185) -> std::io::Result<(tokio::net::TcpListener, axum::Router)> {
186 let Config {
187 address,
188 port,
189 user_agent,
190 http_referer,
191 x_title,
192 mcp_connect_timeout,
193 mcp_call_timeout,
194 mcp_backoff_max_elapsed_time,
195 mcp_encryption_key,
196 suppress_output: _,
197 logs_dir,
198 } = config;
199
200 let client = Client::new(
201 reqwest::Client::new(),
202 user_agent,
203 x_title,
204 http_referer,
205 Duration::from_millis(mcp_connect_timeout),
206 Duration::from_millis(BACKOFF_INITIAL_INTERVAL_MS),
207 Duration::from_millis(BACKOFF_INITIAL_INTERVAL_MS),
208 BACKOFF_RANDOMIZATION_FACTOR,
209 BACKOFF_MULTIPLIER,
210 Duration::from_millis(BACKOFF_MAX_INTERVAL_MS),
211 Duration::from_millis(mcp_backoff_max_elapsed_time),
212 Duration::from_millis(mcp_call_timeout),
213 );
214
215 let sessions = match mcp_encryption_key {
216 Some(key) => SessionManager::new(key),
217 None => SessionManager::with_ephemeral_key(),
218 };
219 let state = AppState {
220 sessions: Arc::new(sessions),
221 client: Arc::new(client),
222 queue_delegate,
223 reverse_channel,
224 };
225
226 let router = axum::Router::new()
227 .route(
228 "/",
229 axum::routing::post(mcp::handle_post)
230 .get(mcp::handle_get)
231 .delete(mcp::handle_delete),
232 )
233 .with_state(state);
234
235 let router = match logs_dir {
239 Some(dir) => router.layer(axum::middleware::from_fn_with_state(
240 std::sync::Arc::new(crate::logging::ProxyLogger::new(dir)),
241 crate::logging::log_layer,
242 )),
243 None => router,
244 };
245
246 let listener = tokio::net::TcpListener::bind(format!("{address}:{port}")).await?;
247
248 Ok((listener, router))
249}
250
251pub async fn serve(listener: tokio::net::TcpListener, app: axum::Router) -> std::io::Result<()> {
252 axum::serve(listener, app).await
253}
254
255pub async fn run(config: Config) -> std::io::Result<()> {
256 let suppress_output = config.suppress_output;
257 let (listener, app) = setup(config, None, None).await?;
259 if !suppress_output {
260 let addr = listener.local_addr()?;
261 eprintln!("listening on {addr}");
262 }
263 serve(listener, app).await
264}