Skip to main content

ati/proxy/
client.rs

1/// Proxy client — forwards tool calls to an external ATI proxy server.
2///
3/// When ATI_PROXY_URL is set, `ati run <tool>` sends tool_name + args
4/// to the proxy. Authentication is via JWT in the Authorization header
5/// (ATI_SESSION_TOKEN env var).
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10use std::time::Duration;
11use thiserror::Error;
12
13#[derive(Error, Debug)]
14pub enum ProxyError {
15    #[error("Proxy request failed: {0}")]
16    Request(#[from] reqwest::Error),
17    #[error("Proxy error ({status}): {body}")]
18    ProxyResponse { status: u16, body: String },
19    #[error("Invalid proxy URL: {0}")]
20    InvalidUrl(String),
21    #[error("Proxy returned invalid response: {0}")]
22    InvalidResponse(String),
23}
24
25/// Request payload sent to the proxy server's /call endpoint.
26#[derive(Debug, Serialize)]
27pub struct ProxyCallRequest {
28    pub tool_name: String,
29    /// Tool arguments — JSON object for HTTP/MCP tools, or JSON array for CLI tools.
30    pub args: Value,
31}
32
33/// Response payload from the proxy server.
34#[derive(Debug, Deserialize)]
35pub struct ProxyCallResponse {
36    pub result: Value,
37    #[serde(default)]
38    pub error: Option<String>,
39}
40
41/// Request payload for the proxy's /help endpoint.
42#[derive(Debug, Serialize)]
43pub struct ProxyHelpRequest {
44    pub query: String,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub tool: Option<String>,
47}
48
49/// Response from the proxy's /help endpoint.
50#[derive(Debug, Deserialize)]
51pub struct ProxyHelpResponse {
52    pub content: String,
53    #[serde(default)]
54    pub error: Option<String>,
55}
56
57const PROXY_TIMEOUT_SECS: u64 = 120;
58
59/// Build an HTTP request builder with JWT Bearer auth from ATI_SESSION_TOKEN.
60fn build_proxy_request(
61    client: &Client,
62    method: reqwest::Method,
63    url: &str,
64) -> reqwest::RequestBuilder {
65    let mut req = client.request(method, url);
66    if let Ok(token) = std::env::var("ATI_SESSION_TOKEN") {
67        if !token.is_empty() {
68            req = req.header("Authorization", format!("Bearer {token}"));
69        }
70    }
71    req
72}
73
74/// Execute a tool call via the proxy server.
75///
76/// POST {proxy_url}/call with JSON body: { tool_name, args }
77/// Scopes are carried inside the JWT — not in the request body.
78///
79/// `args` carries key-value pairs for HTTP/MCP tools.
80/// `raw_args`, if provided, is sent as an array in the `args` field for CLI tools.
81pub async fn call_tool(
82    proxy_url: &str,
83    tool_name: &str,
84    args: &HashMap<String, Value>,
85    raw_args: Option<&[String]>,
86) -> Result<Value, ProxyError> {
87    let client = Client::builder()
88        .timeout(Duration::from_secs(PROXY_TIMEOUT_SECS))
89        .build()?;
90
91    let url = format!("{}/call", proxy_url.trim_end_matches('/'));
92
93    // If raw_args are provided (CLI tool), send them as a JSON array in `args`.
94    // Otherwise send the key-value map.
95    let args_value = match raw_args {
96        Some(raw) if !raw.is_empty() => {
97            Value::Array(raw.iter().map(|s| Value::String(s.clone())).collect())
98        }
99        _ => serde_json::to_value(args).unwrap_or(Value::Object(serde_json::Map::new())),
100    };
101
102    let payload = ProxyCallRequest {
103        tool_name: tool_name.to_string(),
104        args: args_value,
105    };
106
107    let response = build_proxy_request(&client, reqwest::Method::POST, &url)
108        .json(&payload)
109        .send()
110        .await?;
111    let status = response.status();
112
113    if !status.is_success() {
114        let body = response.text().await.unwrap_or_else(|_| "empty".into());
115        return Err(ProxyError::ProxyResponse {
116            status: status.as_u16(),
117            body,
118        });
119    }
120
121    let body: ProxyCallResponse = response
122        .json()
123        .await
124        .map_err(|e| ProxyError::InvalidResponse(e.to_string()))?;
125
126    if let Some(err) = body.error {
127        return Err(ProxyError::ProxyResponse {
128            status: 200,
129            body: err,
130        });
131    }
132
133    Ok(body.result)
134}
135
136/// Forward a raw MCP JSON-RPC message via the proxy's /mcp endpoint.
137pub async fn call_mcp(
138    proxy_url: &str,
139    method: &str,
140    params: Option<Value>,
141) -> Result<Value, ProxyError> {
142    use std::sync::atomic::{AtomicU64, Ordering};
143    static MCP_ID: AtomicU64 = AtomicU64::new(1);
144
145    let id = MCP_ID.fetch_add(1, Ordering::SeqCst);
146    let msg = serde_json::json!({
147        "jsonrpc": "2.0",
148        "id": id,
149        "method": method,
150        "params": params,
151    });
152
153    let client = Client::builder()
154        .timeout(Duration::from_secs(PROXY_TIMEOUT_SECS))
155        .build()?;
156
157    let url = format!("{}/mcp", proxy_url.trim_end_matches('/'));
158
159    let response = build_proxy_request(&client, reqwest::Method::POST, &url)
160        .json(&msg)
161        .send()
162        .await?;
163    let status = response.status();
164
165    if status == reqwest::StatusCode::ACCEPTED {
166        return Ok(Value::Null);
167    }
168
169    if !status.is_success() {
170        let body = response.text().await.unwrap_or_else(|_| "empty".into());
171        return Err(ProxyError::ProxyResponse {
172            status: status.as_u16(),
173            body,
174        });
175    }
176
177    let body: Value = response
178        .json()
179        .await
180        .map_err(|e| ProxyError::InvalidResponse(e.to_string()))?;
181
182    if let Some(err) = body.get("error") {
183        let message = err
184            .get("message")
185            .and_then(|m| m.as_str())
186            .unwrap_or("MCP proxy error");
187        return Err(ProxyError::ProxyResponse {
188            status: 200,
189            body: message.to_string(),
190        });
191    }
192
193    Ok(body.get("result").cloned().unwrap_or(Value::Null))
194}
195
196/// Fetch skill list from the proxy server.
197pub async fn list_skills(
198    proxy_url: &str,
199    query_params: &str,
200) -> Result<serde_json::Value, ProxyError> {
201    let client = Client::builder()
202        .timeout(Duration::from_secs(PROXY_TIMEOUT_SECS))
203        .build()?;
204
205    let url = if query_params.is_empty() {
206        format!("{}/skills", proxy_url.trim_end_matches('/'))
207    } else {
208        format!("{}/skills?{query_params}", proxy_url.trim_end_matches('/'))
209    };
210
211    let response = build_proxy_request(&client, reqwest::Method::GET, &url)
212        .send()
213        .await?;
214    let status = response.status();
215
216    if !status.is_success() {
217        let body = response.text().await.unwrap_or_else(|_| "empty".into());
218        return Err(ProxyError::ProxyResponse {
219            status: status.as_u16(),
220            body,
221        });
222    }
223
224    response
225        .json()
226        .await
227        .map_err(|e| ProxyError::InvalidResponse(e.to_string()))
228}
229
230/// Fetch a skill's detail from the proxy server.
231pub async fn get_skill(
232    proxy_url: &str,
233    name: &str,
234    query_params: &str,
235) -> Result<serde_json::Value, ProxyError> {
236    let client = Client::builder()
237        .timeout(Duration::from_secs(PROXY_TIMEOUT_SECS))
238        .build()?;
239
240    let url = if query_params.is_empty() {
241        format!("{}/skills/{name}", proxy_url.trim_end_matches('/'))
242    } else {
243        format!(
244            "{}/skills/{name}?{query_params}",
245            proxy_url.trim_end_matches('/')
246        )
247    };
248
249    let response = build_proxy_request(&client, reqwest::Method::GET, &url)
250        .send()
251        .await?;
252    let status = response.status();
253
254    if !status.is_success() {
255        let body = response.text().await.unwrap_or_else(|_| "empty".into());
256        return Err(ProxyError::ProxyResponse {
257            status: status.as_u16(),
258            body,
259        });
260    }
261
262    response
263        .json()
264        .await
265        .map_err(|e| ProxyError::InvalidResponse(e.to_string()))
266}
267
268/// Resolve skills for given scopes via the proxy.
269pub async fn resolve_skills(
270    proxy_url: &str,
271    scopes: &serde_json::Value,
272) -> Result<serde_json::Value, ProxyError> {
273    let client = Client::builder()
274        .timeout(Duration::from_secs(PROXY_TIMEOUT_SECS))
275        .build()?;
276
277    let url = format!("{}/skills/resolve", proxy_url.trim_end_matches('/'));
278
279    let response = build_proxy_request(&client, reqwest::Method::POST, &url)
280        .json(scopes)
281        .send()
282        .await?;
283    let status = response.status();
284
285    if !status.is_success() {
286        let body = response.text().await.unwrap_or_else(|_| "empty".into());
287        return Err(ProxyError::ProxyResponse {
288            status: status.as_u16(),
289            body,
290        });
291    }
292
293    response
294        .json()
295        .await
296        .map_err(|e| ProxyError::InvalidResponse(e.to_string()))
297}
298
299/// Execute an LLM help query via the proxy server.
300pub async fn call_help(
301    proxy_url: &str,
302    query: &str,
303    tool: Option<&str>,
304) -> Result<String, ProxyError> {
305    let client = Client::builder()
306        .timeout(Duration::from_secs(PROXY_TIMEOUT_SECS))
307        .build()?;
308
309    let url = format!("{}/help", proxy_url.trim_end_matches('/'));
310
311    let payload = ProxyHelpRequest {
312        query: query.to_string(),
313        tool: tool.map(|t| t.to_string()),
314    };
315
316    let response = build_proxy_request(&client, reqwest::Method::POST, &url)
317        .json(&payload)
318        .send()
319        .await?;
320    let status = response.status();
321
322    if !status.is_success() {
323        let body = response.text().await.unwrap_or_else(|_| "empty".into());
324        return Err(ProxyError::ProxyResponse {
325            status: status.as_u16(),
326            body,
327        });
328    }
329
330    let body: ProxyHelpResponse = response
331        .json()
332        .await
333        .map_err(|e| ProxyError::InvalidResponse(e.to_string()))?;
334
335    if let Some(err) = body.error {
336        return Err(ProxyError::ProxyResponse {
337            status: 200,
338            body: err,
339        });
340    }
341
342    Ok(body.content)
343}