1use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10use std::time::Duration;
11use thiserror::Error;
12
13#[derive(Error, Debug)]
14pub enum ProxyError {
15 #[error("Proxy request failed: {0}")]
16 Request(#[from] reqwest::Error),
17 #[error("Proxy error ({status}): {body}")]
18 ProxyResponse { status: u16, body: String },
19 #[error("Invalid proxy URL: {0}")]
20 InvalidUrl(String),
21 #[error("Proxy returned invalid response: {0}")]
22 InvalidResponse(String),
23}
24
25#[derive(Debug, Serialize)]
27pub struct ProxyCallRequest {
28 pub tool_name: String,
29 pub args: Value,
31}
32
33#[derive(Debug, Deserialize)]
35pub struct ProxyCallResponse {
36 pub result: Value,
37 #[serde(default)]
38 pub error: Option<String>,
39}
40
41#[derive(Debug, Serialize)]
43pub struct ProxyHelpRequest {
44 pub query: String,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 pub tool: Option<String>,
47}
48
49#[derive(Debug, Deserialize)]
51pub struct ProxyHelpResponse {
52 pub content: String,
53 #[serde(default)]
54 pub error: Option<String>,
55}
56
57const PROXY_TIMEOUT_SECS: u64 = 120;
58
59fn build_proxy_request(
61 client: &Client,
62 method: reqwest::Method,
63 url: &str,
64) -> reqwest::RequestBuilder {
65 let mut req = client.request(method, url);
66 if let Ok(token) = std::env::var("ATI_SESSION_TOKEN") {
67 if !token.is_empty() {
68 req = req.header("Authorization", format!("Bearer {token}"));
69 }
70 }
71 req
72}
73
74pub async fn call_tool(
82 proxy_url: &str,
83 tool_name: &str,
84 args: &HashMap<String, Value>,
85 raw_args: Option<&[String]>,
86) -> Result<Value, ProxyError> {
87 let client = Client::builder()
88 .timeout(Duration::from_secs(PROXY_TIMEOUT_SECS))
89 .build()?;
90
91 let url = format!("{}/call", proxy_url.trim_end_matches('/'));
92
93 let args_value = match raw_args {
96 Some(raw) if !raw.is_empty() => {
97 Value::Array(raw.iter().map(|s| Value::String(s.clone())).collect())
98 }
99 _ => serde_json::to_value(args).unwrap_or(Value::Object(serde_json::Map::new())),
100 };
101
102 let payload = ProxyCallRequest {
103 tool_name: tool_name.to_string(),
104 args: args_value,
105 };
106
107 let response = build_proxy_request(&client, reqwest::Method::POST, &url)
108 .json(&payload)
109 .send()
110 .await?;
111 let status = response.status();
112
113 if !status.is_success() {
114 let body = response.text().await.unwrap_or_else(|_| "empty".into());
115 return Err(ProxyError::ProxyResponse {
116 status: status.as_u16(),
117 body,
118 });
119 }
120
121 let body: ProxyCallResponse = response
122 .json()
123 .await
124 .map_err(|e| ProxyError::InvalidResponse(e.to_string()))?;
125
126 if let Some(err) = body.error {
127 return Err(ProxyError::ProxyResponse {
128 status: 200,
129 body: err,
130 });
131 }
132
133 Ok(body.result)
134}
135
136pub async fn call_mcp(
138 proxy_url: &str,
139 method: &str,
140 params: Option<Value>,
141) -> Result<Value, ProxyError> {
142 use std::sync::atomic::{AtomicU64, Ordering};
143 static MCP_ID: AtomicU64 = AtomicU64::new(1);
144
145 let id = MCP_ID.fetch_add(1, Ordering::SeqCst);
146 let msg = serde_json::json!({
147 "jsonrpc": "2.0",
148 "id": id,
149 "method": method,
150 "params": params,
151 });
152
153 let client = Client::builder()
154 .timeout(Duration::from_secs(PROXY_TIMEOUT_SECS))
155 .build()?;
156
157 let url = format!("{}/mcp", proxy_url.trim_end_matches('/'));
158
159 let response = build_proxy_request(&client, reqwest::Method::POST, &url)
160 .json(&msg)
161 .send()
162 .await?;
163 let status = response.status();
164
165 if status == reqwest::StatusCode::ACCEPTED {
166 return Ok(Value::Null);
167 }
168
169 if !status.is_success() {
170 let body = response.text().await.unwrap_or_else(|_| "empty".into());
171 return Err(ProxyError::ProxyResponse {
172 status: status.as_u16(),
173 body,
174 });
175 }
176
177 let body: Value = response
178 .json()
179 .await
180 .map_err(|e| ProxyError::InvalidResponse(e.to_string()))?;
181
182 if let Some(err) = body.get("error") {
183 let message = err
184 .get("message")
185 .and_then(|m| m.as_str())
186 .unwrap_or("MCP proxy error");
187 return Err(ProxyError::ProxyResponse {
188 status: 200,
189 body: message.to_string(),
190 });
191 }
192
193 Ok(body.get("result").cloned().unwrap_or(Value::Null))
194}
195
196pub async fn list_skills(
198 proxy_url: &str,
199 query_params: &str,
200) -> Result<serde_json::Value, ProxyError> {
201 let client = Client::builder()
202 .timeout(Duration::from_secs(PROXY_TIMEOUT_SECS))
203 .build()?;
204
205 let url = if query_params.is_empty() {
206 format!("{}/skills", proxy_url.trim_end_matches('/'))
207 } else {
208 format!("{}/skills?{query_params}", proxy_url.trim_end_matches('/'))
209 };
210
211 let response = build_proxy_request(&client, reqwest::Method::GET, &url)
212 .send()
213 .await?;
214 let status = response.status();
215
216 if !status.is_success() {
217 let body = response.text().await.unwrap_or_else(|_| "empty".into());
218 return Err(ProxyError::ProxyResponse {
219 status: status.as_u16(),
220 body,
221 });
222 }
223
224 response
225 .json()
226 .await
227 .map_err(|e| ProxyError::InvalidResponse(e.to_string()))
228}
229
230pub async fn get_skill(
232 proxy_url: &str,
233 name: &str,
234 query_params: &str,
235) -> Result<serde_json::Value, ProxyError> {
236 let client = Client::builder()
237 .timeout(Duration::from_secs(PROXY_TIMEOUT_SECS))
238 .build()?;
239
240 let url = if query_params.is_empty() {
241 format!("{}/skills/{name}", proxy_url.trim_end_matches('/'))
242 } else {
243 format!(
244 "{}/skills/{name}?{query_params}",
245 proxy_url.trim_end_matches('/')
246 )
247 };
248
249 let response = build_proxy_request(&client, reqwest::Method::GET, &url)
250 .send()
251 .await?;
252 let status = response.status();
253
254 if !status.is_success() {
255 let body = response.text().await.unwrap_or_else(|_| "empty".into());
256 return Err(ProxyError::ProxyResponse {
257 status: status.as_u16(),
258 body,
259 });
260 }
261
262 response
263 .json()
264 .await
265 .map_err(|e| ProxyError::InvalidResponse(e.to_string()))
266}
267
268pub async fn resolve_skills(
270 proxy_url: &str,
271 scopes: &serde_json::Value,
272) -> Result<serde_json::Value, ProxyError> {
273 let client = Client::builder()
274 .timeout(Duration::from_secs(PROXY_TIMEOUT_SECS))
275 .build()?;
276
277 let url = format!("{}/skills/resolve", proxy_url.trim_end_matches('/'));
278
279 let response = build_proxy_request(&client, reqwest::Method::POST, &url)
280 .json(scopes)
281 .send()
282 .await?;
283 let status = response.status();
284
285 if !status.is_success() {
286 let body = response.text().await.unwrap_or_else(|_| "empty".into());
287 return Err(ProxyError::ProxyResponse {
288 status: status.as_u16(),
289 body,
290 });
291 }
292
293 response
294 .json()
295 .await
296 .map_err(|e| ProxyError::InvalidResponse(e.to_string()))
297}
298
299pub async fn call_help(
301 proxy_url: &str,
302 query: &str,
303 tool: Option<&str>,
304) -> Result<String, ProxyError> {
305 let client = Client::builder()
306 .timeout(Duration::from_secs(PROXY_TIMEOUT_SECS))
307 .build()?;
308
309 let url = format!("{}/help", proxy_url.trim_end_matches('/'));
310
311 let payload = ProxyHelpRequest {
312 query: query.to_string(),
313 tool: tool.map(|t| t.to_string()),
314 };
315
316 let response = build_proxy_request(&client, reqwest::Method::POST, &url)
317 .json(&payload)
318 .send()
319 .await?;
320 let status = response.status();
321
322 if !status.is_success() {
323 let body = response.text().await.unwrap_or_else(|_| "empty".into());
324 return Err(ProxyError::ProxyResponse {
325 status: status.as_u16(),
326 body,
327 });
328 }
329
330 let body: ProxyHelpResponse = response
331 .json()
332 .await
333 .map_err(|e| ProxyError::InvalidResponse(e.to_string()))?;
334
335 if let Some(err) = body.error {
336 return Err(ProxyError::ProxyResponse {
337 status: 200,
338 body: err,
339 });
340 }
341
342 Ok(body.content)
343}