Skip to main content

ati/proxy/
server.rs

1/// ATI proxy server — holds API keys and executes tool calls on behalf of sandbox agents.
2///
3/// Authentication: ES256-signed JWT (or HS256 fallback). The JWT carries identity,
4/// scopes, and expiry. No more static tokens or unsigned scope lists.
5///
6/// Usage: `ati proxy --port 8080 [--ati-dir ~/.ati]`
7use axum::{
8    body::Body,
9    extract::{Extension, Query, State},
10    http::{Request as HttpRequest, StatusCode},
11    middleware::{self, Next},
12    response::{IntoResponse, Response},
13    routing::{get, post},
14    Json, Router,
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::core::auth_generator::{AuthCache, GenContext};
24use crate::core::http;
25use crate::core::jwt::{self, JwtConfig, TokenClaims};
26use crate::core::keyring::Keyring;
27use crate::core::manifest::{ManifestRegistry, Provider, Tool};
28use crate::core::mcp_client;
29use crate::core::response;
30use crate::core::scope::ScopeConfig;
31use crate::core::sentry_scope;
32use crate::core::skill::{self, SkillRegistry};
33use crate::core::skillati::{RemoteSkillMeta, SkillAtiClient, SkillAtiError};
34
35/// Shared state for the proxy server.
36pub struct ProxyState {
37    pub registry: ManifestRegistry,
38    pub skill_registry: SkillRegistry,
39    pub keyring: Keyring,
40    /// JWT validation config (None = auth disabled / dev mode).
41    pub jwt_config: Option<JwtConfig>,
42    /// Pre-computed JWKS JSON for the /.well-known/jwks.json endpoint.
43    pub jwks_json: Option<Value>,
44    /// Shared cache for dynamically generated auth credentials.
45    pub auth_cache: AuthCache,
46    /// Per-provider upstream-URL allowlists, compiled lazily on first
47    /// per-request validation and cached for the process lifetime. Keyed
48    /// by provider name. `None` value = keyring entry missing; we cache
49    /// negatives too to avoid repeating the lookup on every request.
50    /// Operator hot-reload of an allowlist requires proxy restart (same
51    /// constraint as every other keyring entry today). See issue #124.
52    ///
53    /// Entries are pre-parsed `UpstreamAllowEntry` structs (scheme + host
54    /// glob + canonical path) rather than raw URL glob patterns. Glob over
55    /// raw URL strings is unsafe — `*` can cross `.`, `#`, `?`, `:` and other
56    /// URL delimiters, letting a sandbox-crafted URL satisfy a pattern while
57    /// `reqwest` connects to a different host. Greptile P0/P1 on #124.
58    pub upstream_url_allowlists:
59        Arc<std::sync::Mutex<std::collections::HashMap<String, Option<Vec<UpstreamAllowEntry>>>>>,
60}
61
62/// One parsed allowlist entry. Allowlist CSV entries are operator-authored
63/// URL templates like `https://parcha-tools-*.grep.ai/mcp`. We parse them
64/// into (scheme, host_label_patterns, path) at load time so per-request
65/// validation can match each component separately:
66/// - **scheme**: exact string match (case-insensitive per RFC 3986)
67/// - **host_label_patterns**: per-DNS-label globs; the runtime URL's host
68///   is split on `.` and each label matched against the corresponding
69///   pattern entry. `*` can never bridge labels because labels are matched
70///   independently.
71/// - **path**: exact match against the URL's path; query/fragment ignored
72///
73/// This is the secure shape; raw URL globs let `*` cross `.`/`#`/`?`/`:`
74/// (Greptile P0/P1 on #124). Per-label matching closes those bypasses.
75#[derive(Debug, Clone)]
76pub struct UpstreamAllowEntry {
77    /// `"https"` or `"http"` — operator-declared scheme. Exact match required.
78    pub scheme: String,
79    /// One glob per DNS label, e.g. `parcha-tools-*.grep.ai` →
80    /// `["parcha-tools-*", "grep", "ai"]`. The runtime URL's lowercased
81    /// host is split on `.` and label-by-label matched against this Vec.
82    /// Length must match exactly (no `**`-style multi-label wildcards).
83    pub host_label_patterns: Vec<glob::Pattern>,
84    /// Exact path match required. Empty operator path becomes `"/"`.
85    pub path: String,
86}
87
88/// Outcome of `X-Ati-Upstream-Url` validation for a single request. The
89/// proxy handler turns `Reject` into an HTTP response and `Allow` into a
90/// param threaded down to `mcp_client::execute_with_gen`. See issue #124.
91#[derive(Debug)]
92enum UpstreamOverride {
93    /// No header sent, no override needed. The MCP client will use
94    /// `provider.mcp_url` as today.
95    None,
96    /// Header sent, validated against the operator allowlist. The MCP
97    /// client should dial this URL instead of `provider.mcp_url`.
98    Allow(String),
99    /// Header sent, but the request must be rejected. Carries the HTTP
100    /// status and a human-readable message for the response body.
101    Reject(StatusCode, String),
102}
103
104/// Resolve the optional sandbox-supplied upstream URL for a single
105/// `/call` request. Six cases per the issue #124 spec:
106///   1. Header absent + `mcp_url_env` absent → `None`.
107///   2. Header present + `mcp_url_env` absent → 400 reject (fail loud).
108///   3. Header absent + `mcp_url_env` present → `None` (mcp_client falls
109///      back to `provider.mcp_url`).
110///   4. Header present + `mcp_url_env` present + keyring has no allowlist
111///      → 403 reject (fail closed; operator must opt in).
112///   5. Header present + URL doesn't match any glob → 403 reject.
113///   6. Header present + URL matches → `Allow(url)`.
114fn resolve_upstream_override(
115    state: &ProxyState,
116    provider: &crate::core::manifest::Provider,
117    header_value: Option<&str>,
118) -> UpstreamOverride {
119    let provider_accepts_override = provider.mcp_url_env.is_some();
120    match (header_value, provider_accepts_override) {
121        (None, _) => UpstreamOverride::None,
122        (Some(_), false) => UpstreamOverride::Reject(
123            StatusCode::BAD_REQUEST,
124            format!(
125                "X-Ati-Upstream-Url sent for provider '{}' which does not declare mcp_url_env",
126                provider.name
127            ),
128        ),
129        (Some(url), true) => {
130            let allowlist_key = format!("{}_allowed_urls", provider.name);
131            // Lazy-build the per-provider patterns on first use; cache for
132            // process lifetime (operator hot-reload requires restart, same
133            // as every other keyring entry today).
134            let entries = {
135                let mut cache = state.upstream_url_allowlists.lock().unwrap();
136                if !cache.contains_key(&provider.name) {
137                    let compiled: Option<Vec<UpstreamAllowEntry>> = state
138                        .keyring
139                        .get(&allowlist_key)
140                        .and_then(|csv| match build_url_allowlist(csv) {
141                            Ok(set) => set,
142                            Err(e) => {
143                                tracing::warn!(
144                                    provider = %provider.name,
145                                    error = %e,
146                                    "failed to compile upstream URL allowlist; treating as missing"
147                                );
148                                None
149                            }
150                        });
151                    cache.insert(provider.name.clone(), compiled);
152                }
153                cache.get(&provider.name).cloned().flatten()
154            };
155            let Some(entries) = entries else {
156                return UpstreamOverride::Reject(
157                    StatusCode::FORBIDDEN,
158                    format!(
159                        "Provider '{}' has no upstream URL allowlist configured (set ATI_KEY_{}_ALLOWED_URLS on the proxy)",
160                        provider.name,
161                        provider.name.to_uppercase()
162                    ),
163                );
164            };
165            // Parse the sandbox-supplied URL with `url::Url` (which canonicalizes
166            // scheme/host/path and surfaces fragment/query separately). Glob over
167            // the canonical host only — never the raw URL string — so attackers
168            // can't smuggle a different host via `.`, `#`, `?`, `:`, etc.
169            let parsed = match url::Url::parse(url) {
170                Ok(u) => u,
171                Err(e) => {
172                    return UpstreamOverride::Reject(
173                        StatusCode::BAD_REQUEST,
174                        format!("X-Ati-Upstream-Url '{url}' is not a valid URL: {e}"),
175                    );
176                }
177            };
178            if matches_allowlist(&parsed, &entries) {
179                UpstreamOverride::Allow(url.to_string())
180            } else {
181                UpstreamOverride::Reject(
182                    StatusCode::FORBIDDEN,
183                    format!(
184                        "Upstream URL '{url}' not in provider '{}'s allowlist",
185                        provider.name
186                    ),
187                )
188            }
189        }
190    }
191}
192
193/// Returns true iff `url`'s scheme + host + path matches at least one
194/// allowlist entry. Each component checked separately:
195/// - **scheme**: exact string match (case-insensitive per RFC 3986)
196/// - **host**: split on `.` into DNS labels; matched label-by-label against
197///   the entry's `host_label_patterns`. Label count must match exactly so
198///   `*.grep.ai` can't be satisfied by `evil.com.grep.ai` (3 labels vs 4).
199/// - **path**: exact string match (no globs). Query + fragment ignored.
200///
201/// Hostnames are lowercased before comparison (DNS case-insensitive). Userinfo
202/// (`user:pass@`) is rejected outright; allowlist patterns never carry it.
203/// Default ports (`:443` for https, `:80` for http) are accepted but only
204/// when the operator pattern itself doesn't pin a port; otherwise the entry
205/// would reject any URL because `url::Url` strips default ports.
206fn matches_allowlist(url: &url::Url, entries: &[UpstreamAllowEntry]) -> bool {
207    if !url.username().is_empty() || url.password().is_some() {
208        return false;
209    }
210    let Some(host) = url.host_str() else {
211        return false;
212    };
213    let host_lower = host.to_ascii_lowercase();
214    let host_labels: Vec<&str> = host_lower.split('.').collect();
215    let path = url.path();
216    for entry in entries {
217        if !entry.scheme.eq_ignore_ascii_case(url.scheme()) {
218            continue;
219        }
220        // Label-count check is the per-DNS-label-boundary constraint:
221        // pattern `*.grep.ai` (3 labels) cannot match `evil.com.grep.ai`
222        // (4 labels). This is the core security primitive that closes the
223        // `*`-crosses-`.` bypass Greptile flagged on #124.
224        if entry.host_label_patterns.len() != host_labels.len() {
225            continue;
226        }
227        if !entry
228            .host_label_patterns
229            .iter()
230            .zip(host_labels.iter())
231            .all(|(pat, label)| pat.matches(label))
232        {
233            continue;
234        }
235        if entry.path != path {
236            continue;
237        }
238        return true;
239    }
240    false
241}
242
243/// Compile a CSV of operator URL patterns into pre-parsed
244/// [`UpstreamAllowEntry`] structs. Each pattern is a URL template like
245/// `https://parcha-tools-*.grep.ai/mcp` — parsed with `url::Url`, then the
246/// host is split into DNS labels and each label compiled as a separate
247/// glob component. Returns `Ok(None)` for empty input / all-whitespace
248/// (caller treats as "no allowlist", same as missing keyring entry).
249///
250/// ## Why per-label host compilation
251///
252/// `glob::Pattern` doesn't natively treat `.` as a separator. Naive
253/// `Pattern::new("parcha-tools-*.grep.ai")` would let `*` swallow
254/// `evil.com` from `parcha-tools-evil.com.grep.ai`. To prevent that we
255/// reassemble the host pattern label-by-label after validating each label
256/// glob contains no embedded `.` itself. The composed `host_pattern` then
257/// uses `*` semantics that stop at label boundaries.
258///
259/// ## Why exact path match
260///
261/// Glob over paths invites the same boundary-crossing problems. MCP
262/// endpoints are conventionally `/mcp` or `/` — operators write the literal
263/// path, sandbox URLs must match exactly. Query strings and URL fragments
264/// are stripped before comparison.
265///
266/// ## Errors
267///
268/// - `pattern not a valid URL` — operator entry isn't parseable as a URL
269/// - `pattern missing host` — e.g., `https:///mcp`
270/// - `pattern host label contains '.'` — would let `*` cross DNS boundaries
271/// - `pattern has userinfo` — credentials in allowlist entries are refused
272fn build_url_allowlist(csv: &str) -> Result<Option<Vec<UpstreamAllowEntry>>, String> {
273    let mut entries = Vec::new();
274    for raw in csv.split(',') {
275        let pat = raw.trim();
276        if pat.is_empty() {
277            continue;
278        }
279        let parsed = url::Url::parse(pat)
280            .map_err(|e| format!("upstream allowlist pattern '{pat}' is not a valid URL: {e}"))?;
281        if !parsed.username().is_empty() || parsed.password().is_some() {
282            return Err(format!(
283                "upstream allowlist pattern '{pat}' must not include userinfo"
284            ));
285        }
286        let host = parsed
287            .host_str()
288            .ok_or_else(|| format!("upstream allowlist pattern '{pat}' has no host"))?;
289        // Reject `**` outright — it's a glob escape hatch.
290        if host.contains("**") {
291            return Err(format!(
292                "upstream allowlist pattern '{pat}' must not contain '**' (use single '*' per DNS label)"
293            ));
294        }
295        let host_lower = host.to_ascii_lowercase();
296        let mut host_label_patterns = Vec::new();
297        for label in host_lower.split('.') {
298            if label.is_empty() {
299                return Err(format!(
300                    "upstream allowlist pattern '{pat}' has empty DNS label"
301                ));
302            }
303            let p = glob::Pattern::new(label).map_err(|e| {
304                format!(
305                    "upstream allowlist pattern '{pat}' has invalid host label glob '{label}': {e}"
306                )
307            })?;
308            host_label_patterns.push(p);
309        }
310        let path = if parsed.path().is_empty() {
311            "/".to_string()
312        } else {
313            parsed.path().to_string()
314        };
315        entries.push(UpstreamAllowEntry {
316            scheme: parsed.scheme().to_string(),
317            host_label_patterns,
318            path,
319        });
320    }
321    if entries.is_empty() {
322        return Ok(None);
323    }
324    Ok(Some(entries))
325}
326
327// --- Request/Response types ---
328
329#[derive(Debug, Deserialize)]
330pub struct CallRequest {
331    pub tool_name: String,
332    /// Tool arguments — accepts a JSON object (key-value pairs) for HTTP/MCP/OpenAPI tools,
333    /// or a JSON array of strings / a single string for CLI tools.
334    /// The proxy auto-detects the handler type and routes accordingly.
335    #[serde(default = "default_args")]
336    pub args: Value,
337    /// Deprecated: use `args` with an array value instead.
338    /// Kept for backward compatibility — if present, takes precedence for CLI tools.
339    #[serde(default)]
340    pub raw_args: Option<Vec<String>>,
341}
342
343fn default_args() -> Value {
344    Value::Object(serde_json::Map::new())
345}
346
347impl CallRequest {
348    /// Extract args as a HashMap for HTTP/MCP/OpenAPI tools.
349    /// If `args` is a JSON object, returns its entries.
350    /// If `args` is something else (array, string), returns an empty map.
351    fn args_as_map(&self) -> HashMap<String, Value> {
352        match &self.args {
353            Value::Object(map) => map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
354            _ => HashMap::new(),
355        }
356    }
357
358    /// Extract positional args for CLI tools.
359    /// Priority: explicit `raw_args` field > `args` array > `args` string > `args._positional` > empty.
360    fn args_as_positional(&self) -> Vec<String> {
361        // Backward compat: explicit raw_args wins
362        if let Some(ref raw) = self.raw_args {
363            return raw.clone();
364        }
365        match &self.args {
366            // ["pr", "list", "--repo", "X"]
367            Value::Array(arr) => arr
368                .iter()
369                .map(|v| match v {
370                    Value::String(s) => s.clone(),
371                    other => other.to_string(),
372                })
373                .collect(),
374            // "pr list --repo X"
375            Value::String(s) => s.split_whitespace().map(String::from).collect(),
376            // {"_positional": ["pr", "list"]} or {"--key": "value"} converted to CLI flags
377            Value::Object(map) => {
378                if let Some(Value::Array(pos)) = map.get("_positional") {
379                    return pos
380                        .iter()
381                        .map(|v| match v {
382                            Value::String(s) => s.clone(),
383                            other => other.to_string(),
384                        })
385                        .collect();
386                }
387                // Convert map entries to --key value pairs
388                let mut result = Vec::new();
389                for (k, v) in map {
390                    result.push(format!("--{k}"));
391                    match v {
392                        Value::String(s) => result.push(s.clone()),
393                        Value::Bool(true) => {} // flag, no value needed
394                        other => result.push(other.to_string()),
395                    }
396                }
397                result
398            }
399            _ => Vec::new(),
400        }
401    }
402}
403
404#[derive(Debug, Serialize)]
405pub struct CallResponse {
406    pub result: Value,
407    #[serde(skip_serializing_if = "Option::is_none")]
408    pub error: Option<String>,
409}
410
411#[derive(Debug, Deserialize)]
412pub struct HelpRequest {
413    pub query: String,
414    #[serde(default)]
415    pub tool: Option<String>,
416}
417
418#[derive(Debug, Serialize)]
419pub struct HelpResponse {
420    pub content: String,
421    #[serde(skip_serializing_if = "Option::is_none")]
422    pub error: Option<String>,
423}
424
425#[derive(Debug, Serialize)]
426pub struct HealthResponse {
427    pub status: String,
428    pub version: String,
429    pub tools: usize,
430    pub providers: usize,
431    pub skills: usize,
432    pub auth: String,
433}
434
435// --- Skill endpoint types ---
436
437#[derive(Debug, Deserialize)]
438pub struct SkillsQuery {
439    #[serde(default)]
440    pub category: Option<String>,
441    #[serde(default)]
442    pub provider: Option<String>,
443    #[serde(default)]
444    pub tool: Option<String>,
445    #[serde(default)]
446    pub search: Option<String>,
447}
448
449#[derive(Debug, Deserialize)]
450pub struct SkillDetailQuery {
451    #[serde(default)]
452    pub meta: Option<bool>,
453    #[serde(default)]
454    pub refs: Option<bool>,
455}
456
457#[derive(Debug, Deserialize)]
458pub struct SkillResolveRequest {
459    pub scopes: Vec<String>,
460    /// When true, include SKILL.md content in each resolved skill.
461    #[serde(default)]
462    pub include_content: bool,
463}
464
465#[derive(Debug, Deserialize)]
466pub struct SkillBundleBatchRequest {
467    pub names: Vec<String>,
468}
469
470#[derive(Debug, Deserialize, Default)]
471pub struct SkillAtiCatalogQuery {
472    #[serde(default)]
473    pub search: Option<String>,
474}
475
476#[derive(Debug, Deserialize, Default)]
477pub struct SkillAtiResourcesQuery {
478    #[serde(default)]
479    pub prefix: Option<String>,
480}
481
482#[derive(Debug, Deserialize)]
483pub struct SkillAtiFileQuery {
484    pub path: String,
485}
486
487// --- Tool endpoint types ---
488
489#[derive(Debug, Deserialize)]
490pub struct ToolsQuery {
491    #[serde(default)]
492    pub provider: Option<String>,
493    #[serde(default)]
494    pub search: Option<String>,
495}
496
497// --- Handlers ---
498
499fn scopes_for_request(claims: Option<&TokenClaims>, state: &ProxyState) -> ScopeConfig {
500    match claims {
501        Some(claims) => ScopeConfig::from_jwt(claims),
502        None if state.jwt_config.is_none() => ScopeConfig::unrestricted(),
503        None => ScopeConfig {
504            scopes: Vec::new(),
505            sub: String::new(),
506            expires_at: 0,
507            rate_config: None,
508        },
509    }
510}
511
512fn visible_tools_for_scopes<'a>(
513    state: &'a ProxyState,
514    scopes: &ScopeConfig,
515) -> Vec<(&'a Provider, &'a Tool)> {
516    crate::core::scope::filter_tools_by_scope(state.registry.list_public_tools(), scopes)
517}
518
519fn visible_skill_names(
520    state: &ProxyState,
521    scopes: &ScopeConfig,
522) -> std::collections::HashSet<String> {
523    skill::visible_skills(&state.skill_registry, &state.registry, scopes)
524        .into_iter()
525        .map(|skill| skill.name.clone())
526        .collect()
527}
528
529/// Compute the set of remote (SkillATI-registry) skill names that the caller's
530/// scopes grant access to.
531///
532/// Mirrors the scope cascade in `skill::resolve_skills` — explicit `skill:X`
533/// scopes, `tool:Y` scopes resolved to the tool's covering skills (including
534/// provider/category bindings) — but against a remote catalog whose skills
535/// are **not** present in the local filesystem `SkillRegistry`.
536///
537/// Without this, proxies running `ATI_SKILL_REGISTRY=gcs://...` with an empty
538/// local skills directory return 404 for every remote skill, because the
539/// visibility gate only consults `state.skill_registry` (see issue #59).
540fn visible_remote_skill_names(
541    state: &ProxyState,
542    scopes: &ScopeConfig,
543    catalog: &[RemoteSkillMeta],
544) -> std::collections::HashSet<String> {
545    let mut visible: std::collections::HashSet<String> = std::collections::HashSet::new();
546    if catalog.is_empty() {
547        return visible;
548    }
549    if scopes.is_wildcard() {
550        for entry in catalog {
551            visible.insert(entry.name.clone());
552        }
553        return visible;
554    }
555
556    // Collect allowed tool/provider/category identifiers from the caller's scopes.
557    // 1. Direct `tool:X` scopes (including wildcards) → walk against the public
558    //    tool registry to collect concrete (provider, tool) pairs.
559    let allowed_tool_pairs: Vec<(String, String)> =
560        crate::core::scope::filter_tools_by_scope(state.registry.list_public_tools(), scopes)
561            .into_iter()
562            .map(|(p, t)| (p.name.clone(), t.name.clone()))
563            .collect();
564    let allowed_tool_names: std::collections::HashSet<&str> =
565        allowed_tool_pairs.iter().map(|(_, t)| t.as_str()).collect();
566    let allowed_provider_names: std::collections::HashSet<&str> =
567        allowed_tool_pairs.iter().map(|(p, _)| p.as_str()).collect();
568    let allowed_categories: std::collections::HashSet<String> = state
569        .registry
570        .list_providers()
571        .into_iter()
572        .filter(|p| allowed_provider_names.contains(p.name.as_str()))
573        .filter_map(|p| p.category.clone())
574        .collect();
575
576    // Explicit `skill:X` scopes → include X if present in the remote catalog.
577    for scope in &scopes.scopes {
578        if let Some(skill_name) = scope.strip_prefix("skill:") {
579            if catalog.iter().any(|e| e.name == skill_name) {
580                visible.insert(skill_name.to_string());
581            }
582        }
583    }
584
585    // Tool/provider/category cascade → include a remote skill if any of its
586    // `tools`, `providers`, or `categories` bindings match a scope-allowed
587    // tool/provider/category.
588    for entry in catalog {
589        if entry
590            .tools
591            .iter()
592            .any(|t| allowed_tool_names.contains(t.as_str()))
593            || entry
594                .providers
595                .iter()
596                .any(|p| allowed_provider_names.contains(p.as_str()))
597            || entry
598                .categories
599                .iter()
600                .any(|c| allowed_categories.contains(c))
601        {
602            visible.insert(entry.name.clone());
603        }
604    }
605
606    visible
607}
608
609/// Union of local + remote visible skill names, computed on demand. The
610/// remote catalog is fetched lazily (and is cached inside `SkillAtiClient`
611/// after the first call on the hot path).
612async fn visible_skill_names_with_remote(
613    state: &ProxyState,
614    scopes: &ScopeConfig,
615    client: &SkillAtiClient,
616) -> Result<std::collections::HashSet<String>, SkillAtiError> {
617    let mut names = visible_skill_names(state, scopes);
618    let catalog = client.catalog().await?;
619    let remote = visible_remote_skill_names(state, scopes, &catalog);
620    names.extend(remote);
621    Ok(names)
622}
623
624async fn handle_call(
625    State(state): State<Arc<ProxyState>>,
626    req: HttpRequest<Body>,
627) -> impl IntoResponse {
628    // Extract JWT claims from request extensions (set by auth middleware)
629    let claims = req.extensions().get::<TokenClaims>().cloned();
630    // Raw inbound bearer (Bearer-token path only — dev/no-JWT paths leave
631    // this empty). Used by `GenContext.jwt_token` so auth_generator scripts
632    // can forward the sandbox's identity to an upstream MCP. See issue #115.
633    let bearer_token: String = req
634        .extensions()
635        .get::<BearerToken>()
636        .map(|b| b.0.clone())
637        .unwrap_or_default();
638
639    // Pull the sandbox-supplied upstream URL out of the inbound headers
640    // before `into_body()` consumes them. Validation against the operator
641    // allowlist happens after tool resolution — we need the provider to
642    // know which keyring allowlist entry to consult. See issue #124.
643    let upstream_url_header: Option<String> = req
644        .headers()
645        .get("x-ati-upstream-url")
646        .and_then(|v| v.to_str().ok())
647        .map(|s| s.trim().to_string())
648        .filter(|s| !s.is_empty());
649
650    // Parse request body. The ceiling must accommodate the worst-case upload
651    // payload: `file_manager::MAX_UPLOAD_BYTES` of raw bytes, base64-inflated
652    // (~1.34×), plus a few KB of JSON framing. Anti-abuse is enforced
653    // downstream by per-tool limits (`max_bytes` on downloads, `MAX_UPLOAD_BYTES`
654    // on uploads) and by JWT scope + rate limits — this is just the outer
655    // wire cap.
656    let body_bytes = match axum::body::to_bytes(req.into_body(), max_call_body_bytes()).await {
657        Ok(b) => b,
658        Err(e) => {
659            return (
660                StatusCode::BAD_REQUEST,
661                Json(CallResponse {
662                    result: Value::Null,
663                    error: Some(format!("Failed to read request body: {e}")),
664                }),
665            );
666        }
667    };
668
669    let call_req: CallRequest = match serde_json::from_slice(&body_bytes) {
670        Ok(r) => r,
671        Err(e) => {
672            return (
673                StatusCode::UNPROCESSABLE_ENTITY,
674                Json(CallResponse {
675                    result: Value::Null,
676                    error: Some(format!("Invalid request: {e}")),
677                }),
678            );
679        }
680    };
681
682    tracing::debug!(
683        tool = %call_req.tool_name,
684        args = ?call_req.args,
685        "POST /call"
686    );
687
688    // Look up tool in registry.
689    // If not found, try converting underscore format (finnhub_quote) to colon (finnhub:quote).
690    let (provider, tool) = match state.registry.get_tool(&call_req.tool_name) {
691        Some(pt) => pt,
692        None => {
693            // Try underscore → colon conversion at each underscore position.
694            // "finnhub_quote" → try "finnhub:quote"
695            // "test_api_get_data" → try "test:api_get_data", "test_api:get_data"
696            let mut resolved = None;
697            for (idx, _) in call_req.tool_name.match_indices('_') {
698                let candidate = format!(
699                    "{}:{}",
700                    &call_req.tool_name[..idx],
701                    &call_req.tool_name[idx + 1..]
702                );
703                if let Some(pt) = state.registry.get_tool(&candidate) {
704                    tracing::debug!(
705                        original = %call_req.tool_name,
706                        resolved = %candidate,
707                        "resolved underscore tool name to colon format"
708                    );
709                    resolved = Some(pt);
710                    break;
711                }
712            }
713
714            match resolved {
715                Some(pt) => pt,
716                None => {
717                    return (
718                        StatusCode::NOT_FOUND,
719                        Json(CallResponse {
720                            result: Value::Null,
721                            error: Some(format!("Unknown tool: '{}'", call_req.tool_name)),
722                        }),
723                    );
724                }
725            }
726        }
727    };
728
729    // Scope enforcement from JWT claims
730    if let Some(tool_scope) = &tool.scope {
731        let scopes = match &claims {
732            Some(c) => ScopeConfig::from_jwt(c),
733            None if state.jwt_config.is_none() => ScopeConfig::unrestricted(), // Dev mode
734            None => {
735                return (
736                    StatusCode::FORBIDDEN,
737                    Json(CallResponse {
738                        result: Value::Null,
739                        error: Some("Authentication required — no JWT provided".into()),
740                    }),
741                );
742            }
743        };
744
745        if !scopes.is_allowed(tool_scope) {
746            return (
747                StatusCode::FORBIDDEN,
748                Json(CallResponse {
749                    result: Value::Null,
750                    error: Some(format!(
751                        "Access denied: '{}' is not in your scopes",
752                        tool.name
753                    )),
754                }),
755            );
756        }
757    }
758
759    // Rate limit check
760    {
761        let scopes = match &claims {
762            Some(c) => ScopeConfig::from_jwt(c),
763            None => ScopeConfig::unrestricted(),
764        };
765        if let Some(ref rate_config) = scopes.rate_config {
766            if let Err(e) = crate::core::rate::check_and_record(&call_req.tool_name, rate_config) {
767                return (
768                    StatusCode::TOO_MANY_REQUESTS,
769                    Json(CallResponse {
770                        result: Value::Null,
771                        error: Some(format!("{e}")),
772                    }),
773                );
774            }
775        }
776    }
777
778    // Build auth generator context from JWT claims
779    let gen_ctx = GenContext {
780        jwt_sub: claims
781            .as_ref()
782            .map(|c| c.sub.clone())
783            .unwrap_or_else(|| "dev".into()),
784        jwt_scope: claims
785            .as_ref()
786            .map(|c| c.scope.clone())
787            .unwrap_or_else(|| "*".into()),
788        tool_name: call_req.tool_name.clone(),
789        timestamp: crate::core::jwt::now_secs(),
790        jwt_token: bearer_token.clone(),
791    };
792
793    // Validate any sandbox-supplied X-Ati-Upstream-Url against the operator's
794    // per-provider allowlist (issue #124). Reject paths exit early with a
795    // clean status; the Allow path threads through to mcp_client below.
796    let override_mcp_url: Option<String> =
797        match resolve_upstream_override(&state, provider, upstream_url_header.as_deref()) {
798            UpstreamOverride::None => None,
799            UpstreamOverride::Allow(url) => Some(url),
800            UpstreamOverride::Reject(status, msg) => {
801                tracing::warn!(
802                    provider = %provider.name,
803                    tool = %call_req.tool_name,
804                    status = status.as_u16(),
805                    reason = %msg,
806                    "rejecting sandbox-supplied upstream URL"
807                );
808                return (
809                    status,
810                    Json(CallResponse {
811                        result: Value::Null,
812                        error: Some(msg),
813                    }),
814                );
815            }
816        };
817
818    // Execute tool call — dispatch based on handler type, with timing for audit
819    let agent_sub = claims.as_ref().map(|c| c.sub.clone()).unwrap_or_default();
820    let job_id = claims
821        .as_ref()
822        .and_then(|c| c.job_id.clone())
823        .unwrap_or_default();
824    let sandbox_id = claims
825        .as_ref()
826        .and_then(|c| c.sandbox_id.clone())
827        .unwrap_or_default();
828    tracing::info!(
829        tool = %call_req.tool_name,
830        agent = %agent_sub,
831        job_id = %job_id,
832        sandbox_id = %sandbox_id,
833        "tool call"
834    );
835    let start = std::time::Instant::now();
836
837    let response = match provider.handler.as_str() {
838        "mcp" => {
839            let args_map = call_req.args_as_map();
840            match mcp_client::execute_with_gen(
841                provider,
842                &call_req.tool_name,
843                &args_map,
844                &state.keyring,
845                Some(&gen_ctx),
846                Some(&state.auth_cache),
847                override_mcp_url.as_deref(),
848            )
849            .await
850            {
851                Ok(result) => (
852                    StatusCode::OK,
853                    Json(CallResponse {
854                        result,
855                        error: None,
856                    }),
857                ),
858                Err(e) => {
859                    let (provider_name, operation_id) =
860                        sentry_scope::split_tool_name(&call_req.tool_name);
861                    sentry_scope::report_upstream_error(
862                        &provider_name,
863                        &operation_id,
864                        0,
865                        502,
866                        None,
867                        Some(&e.to_string()),
868                    );
869                    (
870                        StatusCode::BAD_GATEWAY,
871                        Json(CallResponse {
872                            result: Value::Null,
873                            error: Some(format!("MCP error: {e}")),
874                        }),
875                    )
876                }
877            }
878        }
879        "cli" => {
880            let positional = call_req.args_as_positional();
881            match crate::core::cli_executor::execute_with_gen(
882                provider,
883                &positional,
884                &state.keyring,
885                Some(&gen_ctx),
886                Some(&state.auth_cache),
887            )
888            .await
889            {
890                Ok(result) => (
891                    StatusCode::OK,
892                    Json(CallResponse {
893                        result,
894                        error: None,
895                    }),
896                ),
897                Err(e) => {
898                    let (provider_name, operation_id) =
899                        sentry_scope::split_tool_name(&call_req.tool_name);
900                    sentry_scope::report_upstream_error(
901                        &provider_name,
902                        &operation_id,
903                        0,
904                        502,
905                        None,
906                        Some(&e.to_string()),
907                    );
908                    (
909                        StatusCode::BAD_GATEWAY,
910                        Json(CallResponse {
911                            result: Value::Null,
912                            error: Some(format!("CLI error: {e}")),
913                        }),
914                    )
915                }
916            }
917        }
918        "file_manager" => {
919            let args_map = call_req.args_as_map();
920            match dispatch_file_manager(&call_req.tool_name, &args_map, provider, &state.keyring)
921                .await
922            {
923                Ok(result) => (
924                    StatusCode::OK,
925                    Json(CallResponse {
926                        result,
927                        error: None,
928                    }),
929                ),
930                Err((status, msg)) => (
931                    status,
932                    Json(CallResponse {
933                        result: Value::Null,
934                        error: Some(msg),
935                    }),
936                ),
937            }
938        }
939        _ => {
940            let args_map = call_req.args_as_map();
941            let raw_response = match http::execute_tool_with_gen(
942                provider,
943                tool,
944                &args_map,
945                &state.keyring,
946                Some(&gen_ctx),
947                Some(&state.auth_cache),
948            )
949            .await
950            {
951                Ok(resp) => resp,
952                Err(http::HttpError::NoRecordsFound { status }) => {
953                    // Legit empty upstream result — not an error. Return a
954                    // clean empty object so callers can distinguish from a
955                    // failed call and move on without paging Sentry.
956                    let duration = start.elapsed();
957                    tracing::info!(
958                        tool = %call_req.tool_name,
959                        upstream_status = status,
960                        "upstream returned no records"
961                    );
962                    write_proxy_audit(&call_req, &agent_sub, claims.as_ref(), duration, None);
963                    return (
964                        StatusCode::OK,
965                        Json(CallResponse {
966                            result: serde_json::json!({ "records": [] }),
967                            error: None,
968                        }),
969                    );
970                }
971                Err(e) => {
972                    let duration = start.elapsed();
973                    let (provider_name, operation_id) =
974                        sentry_scope::split_tool_name(&call_req.tool_name);
975                    let (upstream_status, error_type, error_message) = match &e {
976                        http::HttpError::ApiError {
977                            status,
978                            error_type,
979                            error_message,
980                            ..
981                        } => (*status, error_type.clone(), error_message.clone()),
982                        _ => (0u16, None, Some(e.to_string())),
983                    };
984                    sentry_scope::report_upstream_error(
985                        &provider_name,
986                        &operation_id,
987                        upstream_status,
988                        502,
989                        error_type.as_deref(),
990                        error_message.as_deref(),
991                    );
992                    write_proxy_audit(
993                        &call_req,
994                        &agent_sub,
995                        claims.as_ref(),
996                        duration,
997                        Some(&e.to_string()),
998                    );
999                    return (
1000                        StatusCode::BAD_GATEWAY,
1001                        Json(CallResponse {
1002                            result: Value::Null,
1003                            error: Some(format!("Upstream API error: {e}")),
1004                        }),
1005                    );
1006                }
1007            };
1008
1009            let processed = match response::process_response(&raw_response, tool.response.as_ref())
1010            {
1011                Ok(p) => p,
1012                Err(e) => {
1013                    let duration = start.elapsed();
1014                    write_proxy_audit(
1015                        &call_req,
1016                        &agent_sub,
1017                        claims.as_ref(),
1018                        duration,
1019                        Some(&e.to_string()),
1020                    );
1021                    return (
1022                        StatusCode::INTERNAL_SERVER_ERROR,
1023                        Json(CallResponse {
1024                            result: raw_response,
1025                            error: Some(format!("Response processing error: {e}")),
1026                        }),
1027                    );
1028                }
1029            };
1030
1031            (
1032                StatusCode::OK,
1033                Json(CallResponse {
1034                    result: processed,
1035                    error: None,
1036                }),
1037            )
1038        }
1039    };
1040
1041    let duration = start.elapsed();
1042    let error_msg = response.1.error.as_deref();
1043    write_proxy_audit(&call_req, &agent_sub, claims.as_ref(), duration, error_msg);
1044
1045    response
1046}
1047
1048async fn handle_help(
1049    State(state): State<Arc<ProxyState>>,
1050    claims: Option<Extension<TokenClaims>>,
1051    Json(req): Json<HelpRequest>,
1052) -> impl IntoResponse {
1053    tracing::debug!(query = %req.query, tool = ?req.tool, "POST /help");
1054
1055    let claims = claims.map(|Extension(claims)| claims);
1056    let scopes = scopes_for_request(claims.as_ref(), &state);
1057
1058    let (llm_provider, llm_tool) = match state.registry.get_tool("_chat_completion") {
1059        Some(pt) => pt,
1060        None => {
1061            return (
1062                StatusCode::SERVICE_UNAVAILABLE,
1063                Json(HelpResponse {
1064                    content: String::new(),
1065                    error: Some("No _llm.toml manifest found. Proxy help requires a configured LLM provider.".into()),
1066                }),
1067            );
1068        }
1069    };
1070
1071    let api_key = match llm_provider
1072        .auth_key_name
1073        .as_deref()
1074        .and_then(|k| state.keyring.get(k))
1075    {
1076        Some(key) => key.to_string(),
1077        None => {
1078            return (
1079                StatusCode::SERVICE_UNAVAILABLE,
1080                Json(HelpResponse {
1081                    content: String::new(),
1082                    error: Some("LLM API key not found in keyring".into()),
1083                }),
1084            );
1085        }
1086    };
1087
1088    let resolved_skills = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
1089    let local_skills_section = if resolved_skills.is_empty() {
1090        String::new()
1091    } else {
1092        format!(
1093            "## Available Skills (methodology guides)\n{}",
1094            skill::build_skill_context(&resolved_skills)
1095        )
1096    };
1097    let remote_query = req
1098        .tool
1099        .as_ref()
1100        .map(|tool| format!("{tool} {}", req.query))
1101        .unwrap_or_else(|| req.query.clone());
1102    let remote_skills_section =
1103        build_remote_skillati_section(&state.keyring, &remote_query, 12).await;
1104    let skills_section = merge_help_skill_sections(&[local_skills_section, remote_skills_section]);
1105
1106    // Build system prompt — scoped or unscoped
1107    let visible_tools = visible_tools_for_scopes(&state, &scopes);
1108    let system_prompt = if let Some(ref tool_name) = req.tool {
1109        // Scoped mode: narrow tools to the specified tool or provider
1110        match build_scoped_prompt(tool_name, &visible_tools, &skills_section) {
1111            Some(prompt) => prompt,
1112            None => {
1113                return (
1114                    StatusCode::FORBIDDEN,
1115                    Json(HelpResponse {
1116                        content: String::new(),
1117                        error: Some(format!(
1118                            "Scope '{tool_name}' is not visible in your current scopes."
1119                        )),
1120                    }),
1121                );
1122            }
1123        }
1124    } else {
1125        let tools_context = build_tool_context(&visible_tools);
1126        HELP_SYSTEM_PROMPT
1127            .replace("{tools}", &tools_context)
1128            .replace("{skills_section}", &skills_section)
1129    };
1130
1131    let request_body = serde_json::json!({
1132        "model": "zai-glm-4.7",
1133        "messages": [
1134            {"role": "system", "content": system_prompt},
1135            {"role": "user", "content": req.query}
1136        ],
1137        "max_completion_tokens": 1536,
1138        "temperature": 0.3
1139    });
1140
1141    let client = reqwest::Client::new();
1142    let url = format!(
1143        "{}{}",
1144        llm_provider.base_url.trim_end_matches('/'),
1145        llm_tool.endpoint
1146    );
1147
1148    let response = match client
1149        .post(&url)
1150        .bearer_auth(&api_key)
1151        .json(&request_body)
1152        .send()
1153        .await
1154    {
1155        Ok(r) => r,
1156        Err(e) => {
1157            return (
1158                StatusCode::BAD_GATEWAY,
1159                Json(HelpResponse {
1160                    content: String::new(),
1161                    error: Some(format!("LLM request failed: {e}")),
1162                }),
1163            );
1164        }
1165    };
1166
1167    if !response.status().is_success() {
1168        let status = response.status();
1169        let body = response.text().await.unwrap_or_default();
1170        return (
1171            StatusCode::BAD_GATEWAY,
1172            Json(HelpResponse {
1173                content: String::new(),
1174                error: Some(format!("LLM API error ({status}): {body}")),
1175            }),
1176        );
1177    }
1178
1179    let body: Value = match response.json().await {
1180        Ok(b) => b,
1181        Err(e) => {
1182            return (
1183                StatusCode::INTERNAL_SERVER_ERROR,
1184                Json(HelpResponse {
1185                    content: String::new(),
1186                    error: Some(format!("Failed to parse LLM response: {e}")),
1187                }),
1188            );
1189        }
1190    };
1191
1192    let content = body
1193        .pointer("/choices/0/message/content")
1194        .and_then(|c| c.as_str())
1195        .unwrap_or("No response from LLM")
1196        .to_string();
1197
1198    (
1199        StatusCode::OK,
1200        Json(HelpResponse {
1201            content,
1202            error: None,
1203        }),
1204    )
1205}
1206
1207async fn handle_health(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
1208    let auth = if state.jwt_config.is_some() {
1209        "jwt"
1210    } else {
1211        "disabled"
1212    };
1213
1214    Json(HealthResponse {
1215        status: "ok".into(),
1216        version: env!("CARGO_PKG_VERSION").into(),
1217        tools: state.registry.list_public_tools().len(),
1218        providers: state.registry.list_providers().len(),
1219        skills: state.skill_registry.skill_count(),
1220        auth: auth.into(),
1221    })
1222}
1223
1224/// GET /.well-known/jwks.json — serves the public key for JWT validation.
1225async fn handle_jwks(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
1226    match &state.jwks_json {
1227        Some(jwks) => (StatusCode::OK, Json(jwks.clone())),
1228        None => (
1229            StatusCode::NOT_FOUND,
1230            Json(serde_json::json!({"error": "JWKS not configured"})),
1231        ),
1232    }
1233}
1234
1235// ---------------------------------------------------------------------------
1236// POST /mcp — MCP JSON-RPC proxy endpoint
1237// ---------------------------------------------------------------------------
1238
1239async fn handle_mcp(
1240    State(state): State<Arc<ProxyState>>,
1241    claims: Option<Extension<TokenClaims>>,
1242    bearer: Option<Extension<BearerToken>>,
1243    headers: axum::http::HeaderMap,
1244    Json(msg): Json<Value>,
1245) -> impl IntoResponse {
1246    let claims = claims.map(|Extension(claims)| claims);
1247    // Raw inbound bearer for `GenContext.jwt_token` (issue #115). The
1248    // Bearer-token path always populates this when a JWT is configured;
1249    // dev mode and Ati-Key paths leave it None and we fall back to "".
1250    let bearer_token: String = bearer.map(|Extension(b)| b.0).unwrap_or_default();
1251    // Sandbox-supplied upstream URL for MCP tools/call (issue #124). The
1252    // proxy validates against the per-provider allowlist before honouring.
1253    let upstream_url_header: Option<String> = headers
1254        .get("x-ati-upstream-url")
1255        .and_then(|v| v.to_str().ok())
1256        .map(|s| s.trim().to_string())
1257        .filter(|s| !s.is_empty());
1258    let scopes = scopes_for_request(claims.as_ref(), &state);
1259    let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
1260    let id = msg.get("id").cloned();
1261    tracing::info!(
1262        %method,
1263        agent = claims.as_ref().map(|c| c.sub.as_str()).unwrap_or(""),
1264        job_id = claims.as_ref().and_then(|c| c.job_id.as_deref()).unwrap_or(""),
1265        sandbox_id = claims.as_ref().and_then(|c| c.sandbox_id.as_deref()).unwrap_or(""),
1266        "mcp call"
1267    );
1268
1269    match method {
1270        "initialize" => {
1271            let result = serde_json::json!({
1272                "protocolVersion": "2025-03-26",
1273                "capabilities": {
1274                    "tools": { "listChanged": false }
1275                },
1276                "serverInfo": {
1277                    "name": "ati-proxy",
1278                    "version": env!("CARGO_PKG_VERSION")
1279                }
1280            });
1281            jsonrpc_success(id, result)
1282        }
1283
1284        "notifications/initialized" => (StatusCode::ACCEPTED, Json(Value::Null)),
1285
1286        "tools/list" => {
1287            let visible_tools = visible_tools_for_scopes(&state, &scopes);
1288            let mcp_tools: Vec<Value> = visible_tools
1289                .iter()
1290                .map(|(_provider, tool)| {
1291                    serde_json::json!({
1292                        "name": tool.name,
1293                        "description": tool.description,
1294                        "inputSchema": tool.input_schema.clone().unwrap_or(serde_json::json!({
1295                            "type": "object",
1296                            "properties": {}
1297                        }))
1298                    })
1299                })
1300                .collect();
1301
1302            let result = serde_json::json!({
1303                "tools": mcp_tools,
1304            });
1305            jsonrpc_success(id, result)
1306        }
1307
1308        "tools/call" => {
1309            let params = msg.get("params").cloned().unwrap_or(Value::Null);
1310            let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
1311            let arguments: HashMap<String, Value> = params
1312                .get("arguments")
1313                .and_then(|a| serde_json::from_value(a.clone()).ok())
1314                .unwrap_or_default();
1315
1316            if tool_name.is_empty() {
1317                return jsonrpc_error(id, -32602, "Missing tool name in params.name");
1318            }
1319
1320            let (provider, _tool) = match state.registry.get_tool(tool_name) {
1321                Some(pt) => pt,
1322                None => {
1323                    return jsonrpc_error(id, -32602, &format!("Unknown tool: '{tool_name}'"));
1324                }
1325            };
1326
1327            if let Some(tool_scope) = &_tool.scope {
1328                if !scopes.is_allowed(tool_scope) {
1329                    return jsonrpc_error(
1330                        id,
1331                        -32001,
1332                        &format!("Access denied: '{}' is not in your scopes", _tool.name),
1333                    );
1334                }
1335            }
1336
1337            tracing::debug!(%tool_name, provider = %provider.name, "MCP tools/call");
1338
1339            // Validate any sandbox-supplied X-Ati-Upstream-Url against the
1340            // operator's per-provider allowlist (issue #124). For MCP JSON-RPC
1341            // we encode rejections as JSON-RPC error codes:
1342            //   - -32602 (invalid params) for header-without-mcp_url_env
1343            //   - -32001 (access denied) for allowlist misses
1344            let override_mcp_url: Option<String> =
1345                match resolve_upstream_override(&state, provider, upstream_url_header.as_deref()) {
1346                    UpstreamOverride::None => None,
1347                    UpstreamOverride::Allow(url) => Some(url),
1348                    UpstreamOverride::Reject(status, msg) => {
1349                        let code = if status == StatusCode::BAD_REQUEST {
1350                            -32602
1351                        } else {
1352                            -32001
1353                        };
1354                        tracing::warn!(
1355                            provider = %provider.name,
1356                            tool = %tool_name,
1357                            status = status.as_u16(),
1358                            reason = %msg,
1359                            "rejecting sandbox-supplied upstream URL on /mcp"
1360                        );
1361                        return jsonrpc_error(id, code, &msg);
1362                    }
1363                };
1364
1365            let mcp_gen_ctx = GenContext {
1366                jwt_sub: claims
1367                    .as_ref()
1368                    .map(|claims| claims.sub.clone())
1369                    .unwrap_or_else(|| "dev".into()),
1370                jwt_scope: claims
1371                    .as_ref()
1372                    .map(|claims| claims.scope.clone())
1373                    .unwrap_or_else(|| "*".into()),
1374                tool_name: tool_name.to_string(),
1375                timestamp: crate::core::jwt::now_secs(),
1376                jwt_token: bearer_token.clone(),
1377            };
1378
1379            let result = if provider.is_mcp() {
1380                mcp_client::execute_with_gen(
1381                    provider,
1382                    tool_name,
1383                    &arguments,
1384                    &state.keyring,
1385                    Some(&mcp_gen_ctx),
1386                    Some(&state.auth_cache),
1387                    override_mcp_url.as_deref(),
1388                )
1389                .await
1390            } else if provider.is_cli() {
1391                // Convert arguments map to CLI-style args for MCP passthrough
1392                let raw: Vec<String> = arguments
1393                    .iter()
1394                    .flat_map(|(k, v)| {
1395                        let val = match v {
1396                            Value::String(s) => s.clone(),
1397                            other => other.to_string(),
1398                        };
1399                        vec![format!("--{k}"), val]
1400                    })
1401                    .collect();
1402                crate::core::cli_executor::execute_with_gen(
1403                    provider,
1404                    &raw,
1405                    &state.keyring,
1406                    Some(&mcp_gen_ctx),
1407                    Some(&state.auth_cache),
1408                )
1409                .await
1410                .map_err(|e| mcp_client::McpError::Transport(e.to_string()))
1411            } else {
1412                match http::execute_tool_with_gen(
1413                    provider,
1414                    _tool,
1415                    &arguments,
1416                    &state.keyring,
1417                    Some(&mcp_gen_ctx),
1418                    Some(&state.auth_cache),
1419                )
1420                .await
1421                {
1422                    Ok(val) => Ok(val),
1423                    Err(e) => Err(mcp_client::McpError::Transport(e.to_string())),
1424                }
1425            };
1426
1427            match result {
1428                Ok(value) => {
1429                    let text = match &value {
1430                        Value::String(s) => s.clone(),
1431                        other => serde_json::to_string_pretty(other).unwrap_or_default(),
1432                    };
1433                    let mcp_result = serde_json::json!({
1434                        "content": [{"type": "text", "text": text}],
1435                        "isError": false,
1436                    });
1437                    jsonrpc_success(id, mcp_result)
1438                }
1439                Err(e) => {
1440                    let mcp_result = serde_json::json!({
1441                        "content": [{"type": "text", "text": format!("Error: {e}")}],
1442                        "isError": true,
1443                    });
1444                    jsonrpc_success(id, mcp_result)
1445                }
1446            }
1447        }
1448
1449        _ => jsonrpc_error(id, -32601, &format!("Method not found: '{method}'")),
1450    }
1451}
1452
1453fn jsonrpc_success(id: Option<Value>, result: Value) -> (StatusCode, Json<Value>) {
1454    (
1455        StatusCode::OK,
1456        Json(serde_json::json!({
1457            "jsonrpc": "2.0",
1458            "id": id,
1459            "result": result,
1460        })),
1461    )
1462}
1463
1464fn jsonrpc_error(id: Option<Value>, code: i64, message: &str) -> (StatusCode, Json<Value>) {
1465    (
1466        StatusCode::OK,
1467        Json(serde_json::json!({
1468            "jsonrpc": "2.0",
1469            "id": id,
1470            "error": {
1471                "code": code,
1472                "message": message,
1473            }
1474        })),
1475    )
1476}
1477
1478// ---------------------------------------------------------------------------
1479// Tool endpoints
1480// ---------------------------------------------------------------------------
1481
1482/// GET /tools — list available tools with optional filters.
1483async fn handle_tools_list(
1484    State(state): State<Arc<ProxyState>>,
1485    claims: Option<Extension<TokenClaims>>,
1486    axum::extract::Query(query): axum::extract::Query<ToolsQuery>,
1487) -> impl IntoResponse {
1488    tracing::debug!(
1489        provider = ?query.provider,
1490        search = ?query.search,
1491        "GET /tools"
1492    );
1493
1494    let claims = claims.map(|Extension(claims)| claims);
1495    let scopes = scopes_for_request(claims.as_ref(), &state);
1496    let all_tools = visible_tools_for_scopes(&state, &scopes);
1497
1498    let tools: Vec<Value> = all_tools
1499        .iter()
1500        .filter(|(provider, tool)| {
1501            if let Some(ref p) = query.provider {
1502                if provider.name != *p {
1503                    return false;
1504                }
1505            }
1506            if let Some(ref q) = query.search {
1507                let q = q.to_lowercase();
1508                let name_match = tool.name.to_lowercase().contains(&q);
1509                let desc_match = tool.description.to_lowercase().contains(&q);
1510                let tag_match = tool.tags.iter().any(|t| t.to_lowercase().contains(&q));
1511                if !name_match && !desc_match && !tag_match {
1512                    return false;
1513                }
1514            }
1515            true
1516        })
1517        .map(|(provider, tool)| {
1518            serde_json::json!({
1519                "name": tool.name,
1520                "description": tool.description,
1521                "provider": provider.name,
1522                "method": format!("{:?}", tool.method),
1523                "tags": tool.tags,
1524                "skills": provider.skills,
1525                "input_schema": tool.input_schema,
1526            })
1527        })
1528        .collect();
1529
1530    (StatusCode::OK, Json(Value::Array(tools)))
1531}
1532
1533/// GET /tools/:name — get detailed info about a specific tool.
1534async fn handle_tool_info(
1535    State(state): State<Arc<ProxyState>>,
1536    claims: Option<Extension<TokenClaims>>,
1537    axum::extract::Path(name): axum::extract::Path<String>,
1538) -> impl IntoResponse {
1539    tracing::debug!(tool = %name, "GET /tools/:name");
1540
1541    let claims = claims.map(|Extension(claims)| claims);
1542    let scopes = scopes_for_request(claims.as_ref(), &state);
1543
1544    match state
1545        .registry
1546        .get_tool(&name)
1547        .filter(|(_, tool)| match &tool.scope {
1548            Some(scope) => scopes.is_allowed(scope),
1549            None => true,
1550        }) {
1551        Some((provider, tool)) => {
1552            // Merge skills from manifest + SkillRegistry (tool binding + provider binding)
1553            let mut skills: Vec<String> = provider.skills.clone();
1554            for s in state.skill_registry.skills_for_tool(&tool.name) {
1555                if !skills.contains(&s.name) {
1556                    skills.push(s.name.clone());
1557                }
1558            }
1559            for s in state.skill_registry.skills_for_provider(&provider.name) {
1560                if !skills.contains(&s.name) {
1561                    skills.push(s.name.clone());
1562                }
1563            }
1564
1565            (
1566                StatusCode::OK,
1567                Json(serde_json::json!({
1568                    "name": tool.name,
1569                    "description": tool.description,
1570                    "provider": provider.name,
1571                    "method": format!("{:?}", tool.method),
1572                    "endpoint": tool.endpoint,
1573                    "tags": tool.tags,
1574                    "hint": tool.hint,
1575                    "skills": skills,
1576                    "input_schema": tool.input_schema,
1577                    "scope": tool.scope,
1578                })),
1579            )
1580        }
1581        None => (
1582            StatusCode::NOT_FOUND,
1583            Json(serde_json::json!({"error": format!("Tool '{name}' not found")})),
1584        ),
1585    }
1586}
1587
1588// ---------------------------------------------------------------------------
1589// Skill endpoints
1590// ---------------------------------------------------------------------------
1591
1592async fn handle_skills_list(
1593    State(state): State<Arc<ProxyState>>,
1594    claims: Option<Extension<TokenClaims>>,
1595    axum::extract::Query(query): axum::extract::Query<SkillsQuery>,
1596) -> impl IntoResponse {
1597    tracing::debug!(
1598        category = ?query.category,
1599        provider = ?query.provider,
1600        tool = ?query.tool,
1601        search = ?query.search,
1602        "GET /skills"
1603    );
1604
1605    let claims = claims.map(|Extension(claims)| claims);
1606    let scopes = scopes_for_request(claims.as_ref(), &state);
1607    let visible_names = visible_skill_names(&state, &scopes);
1608
1609    let skills: Vec<&skill::SkillMeta> = if let Some(search_query) = &query.search {
1610        state
1611            .skill_registry
1612            .search(search_query)
1613            .into_iter()
1614            .filter(|skill| visible_names.contains(&skill.name))
1615            .collect()
1616    } else if let Some(cat) = &query.category {
1617        state
1618            .skill_registry
1619            .skills_for_category(cat)
1620            .into_iter()
1621            .filter(|skill| visible_names.contains(&skill.name))
1622            .collect()
1623    } else if let Some(prov) = &query.provider {
1624        state
1625            .skill_registry
1626            .skills_for_provider(prov)
1627            .into_iter()
1628            .filter(|skill| visible_names.contains(&skill.name))
1629            .collect()
1630    } else if let Some(t) = &query.tool {
1631        state
1632            .skill_registry
1633            .skills_for_tool(t)
1634            .into_iter()
1635            .filter(|skill| visible_names.contains(&skill.name))
1636            .collect()
1637    } else {
1638        state
1639            .skill_registry
1640            .list_skills()
1641            .iter()
1642            .filter(|skill| visible_names.contains(&skill.name))
1643            .collect()
1644    };
1645
1646    let json: Vec<Value> = skills
1647        .iter()
1648        .map(|s| {
1649            serde_json::json!({
1650                "name": s.name,
1651                "version": s.version,
1652                "description": s.description,
1653                "tools": s.tools,
1654                "providers": s.providers,
1655                "categories": s.categories,
1656                "hint": s.hint,
1657            })
1658        })
1659        .collect();
1660
1661    (StatusCode::OK, Json(Value::Array(json)))
1662}
1663
1664async fn handle_skill_detail(
1665    State(state): State<Arc<ProxyState>>,
1666    claims: Option<Extension<TokenClaims>>,
1667    axum::extract::Path(name): axum::extract::Path<String>,
1668    axum::extract::Query(query): axum::extract::Query<SkillDetailQuery>,
1669) -> impl IntoResponse {
1670    tracing::debug!(%name, meta = ?query.meta, refs = ?query.refs, "GET /skills/:name");
1671
1672    let claims = claims.map(|Extension(claims)| claims);
1673    let scopes = scopes_for_request(claims.as_ref(), &state);
1674    let visible_names = visible_skill_names(&state, &scopes);
1675
1676    let skill_meta = match state
1677        .skill_registry
1678        .get_skill(&name)
1679        .filter(|skill| visible_names.contains(&skill.name))
1680    {
1681        Some(s) => s,
1682        None => {
1683            return (
1684                StatusCode::NOT_FOUND,
1685                Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
1686            );
1687        }
1688    };
1689
1690    if query.meta.unwrap_or(false) {
1691        return (
1692            StatusCode::OK,
1693            Json(serde_json::json!({
1694                "name": skill_meta.name,
1695                "version": skill_meta.version,
1696                "description": skill_meta.description,
1697                "author": skill_meta.author,
1698                "tools": skill_meta.tools,
1699                "providers": skill_meta.providers,
1700                "categories": skill_meta.categories,
1701                "keywords": skill_meta.keywords,
1702                "hint": skill_meta.hint,
1703                "depends_on": skill_meta.depends_on,
1704                "suggests": skill_meta.suggests,
1705                "license": skill_meta.license,
1706                "compatibility": skill_meta.compatibility,
1707                "allowed_tools": skill_meta.allowed_tools,
1708                "format": skill_meta.format,
1709            })),
1710        );
1711    }
1712
1713    let content = match state.skill_registry.read_content(&name) {
1714        Ok(c) => c,
1715        Err(e) => {
1716            return (
1717                StatusCode::INTERNAL_SERVER_ERROR,
1718                Json(serde_json::json!({"error": format!("Failed to read skill: {e}")})),
1719            );
1720        }
1721    };
1722
1723    let mut response = serde_json::json!({
1724        "name": skill_meta.name,
1725        "version": skill_meta.version,
1726        "description": skill_meta.description,
1727        "content": content,
1728    });
1729
1730    if query.refs.unwrap_or(false) {
1731        if let Ok(refs) = state.skill_registry.list_references(&name) {
1732            response["references"] = serde_json::json!(refs);
1733        }
1734    }
1735
1736    (StatusCode::OK, Json(response))
1737}
1738
1739/// GET /skills/:name/bundle — return all files in a skill directory.
1740/// Response: `{"name": "...", "files": {"SKILL.md": "...", "scripts/generate.sh": "...", ...}}`
1741/// Binary files are base64-encoded; text files are returned as-is.
1742async fn handle_skill_bundle(
1743    State(state): State<Arc<ProxyState>>,
1744    claims: Option<Extension<TokenClaims>>,
1745    axum::extract::Path(name): axum::extract::Path<String>,
1746) -> impl IntoResponse {
1747    tracing::debug!(skill = %name, "GET /skills/:name/bundle");
1748
1749    let claims = claims.map(|Extension(claims)| claims);
1750    let scopes = scopes_for_request(claims.as_ref(), &state);
1751    let visible_names = visible_skill_names(&state, &scopes);
1752    if !visible_names.contains(&name) {
1753        return (
1754            StatusCode::NOT_FOUND,
1755            Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
1756        );
1757    }
1758
1759    let files = match state.skill_registry.bundle_files(&name) {
1760        Ok(f) => f,
1761        Err(_) => {
1762            return (
1763                StatusCode::NOT_FOUND,
1764                Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
1765            );
1766        }
1767    };
1768
1769    // Convert bytes to strings (UTF-8 text) or base64 for binary files
1770    let mut file_map = serde_json::Map::new();
1771    for (path, data) in &files {
1772        match std::str::from_utf8(data) {
1773            Ok(text) => {
1774                file_map.insert(path.clone(), Value::String(text.to_string()));
1775            }
1776            Err(_) => {
1777                // Binary file — base64 encode
1778                use base64::Engine;
1779                let encoded = base64::engine::general_purpose::STANDARD.encode(data);
1780                file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
1781            }
1782        }
1783    }
1784
1785    (
1786        StatusCode::OK,
1787        Json(serde_json::json!({
1788            "name": name,
1789            "files": file_map,
1790        })),
1791    )
1792}
1793
1794/// POST /skills/bundle — return all files for multiple skills in one response.
1795/// Request: `{"names": ["fal-generate", "compliance-screening"]}`
1796/// Response: `{"skills": {...}, "missing": [...]}`
1797async fn handle_skills_bundle_batch(
1798    State(state): State<Arc<ProxyState>>,
1799    claims: Option<Extension<TokenClaims>>,
1800    Json(req): Json<SkillBundleBatchRequest>,
1801) -> impl IntoResponse {
1802    const MAX_BATCH: usize = 50;
1803    if req.names.len() > MAX_BATCH {
1804        return (
1805            StatusCode::BAD_REQUEST,
1806            Json(
1807                serde_json::json!({"error": format!("batch size {} exceeds limit of {MAX_BATCH}", req.names.len())}),
1808            ),
1809        );
1810    }
1811
1812    tracing::debug!(names = ?req.names, "POST /skills/bundle");
1813
1814    let claims = claims.map(|Extension(claims)| claims);
1815    let scopes = scopes_for_request(claims.as_ref(), &state);
1816    let visible_names = visible_skill_names(&state, &scopes);
1817
1818    let mut result = serde_json::Map::new();
1819    let mut missing: Vec<String> = Vec::new();
1820
1821    for name in &req.names {
1822        if !visible_names.contains(name) {
1823            missing.push(name.clone());
1824            continue;
1825        }
1826        let files = match state.skill_registry.bundle_files(name) {
1827            Ok(f) => f,
1828            Err(_) => {
1829                missing.push(name.clone());
1830                continue;
1831            }
1832        };
1833
1834        let mut file_map = serde_json::Map::new();
1835        for (path, data) in &files {
1836            match std::str::from_utf8(data) {
1837                Ok(text) => {
1838                    file_map.insert(path.clone(), Value::String(text.to_string()));
1839                }
1840                Err(_) => {
1841                    use base64::Engine;
1842                    let encoded = base64::engine::general_purpose::STANDARD.encode(data);
1843                    file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
1844                }
1845            }
1846        }
1847
1848        result.insert(name.clone(), serde_json::json!({ "files": file_map }));
1849    }
1850
1851    (
1852        StatusCode::OK,
1853        Json(serde_json::json!({ "skills": result, "missing": missing })),
1854    )
1855}
1856
1857async fn handle_skills_resolve(
1858    State(state): State<Arc<ProxyState>>,
1859    claims: Option<Extension<TokenClaims>>,
1860    Json(req): Json<SkillResolveRequest>,
1861) -> impl IntoResponse {
1862    tracing::debug!(scopes = ?req.scopes, include_content = req.include_content, "POST /skills/resolve");
1863
1864    let include_content = req.include_content;
1865    let request_scopes = ScopeConfig {
1866        scopes: req.scopes,
1867        sub: String::new(),
1868        expires_at: 0,
1869        rate_config: None,
1870    };
1871    let claims = claims.map(|Extension(claims)| claims);
1872    let caller_scopes = scopes_for_request(claims.as_ref(), &state);
1873    let visible_names = visible_skill_names(&state, &caller_scopes);
1874
1875    let resolved: Vec<&skill::SkillMeta> =
1876        skill::resolve_skills(&state.skill_registry, &state.registry, &request_scopes)
1877            .into_iter()
1878            .filter(|skill| visible_names.contains(&skill.name))
1879            .collect();
1880
1881    let json: Vec<Value> = resolved
1882        .iter()
1883        .map(|s| {
1884            let mut entry = serde_json::json!({
1885                "name": s.name,
1886                "version": s.version,
1887                "description": s.description,
1888                "tools": s.tools,
1889                "providers": s.providers,
1890                "categories": s.categories,
1891            });
1892            if include_content {
1893                if let Ok(content) = state.skill_registry.read_content(&s.name) {
1894                    entry["content"] = Value::String(content);
1895                }
1896            }
1897            entry
1898        })
1899        .collect();
1900
1901    (StatusCode::OK, Json(Value::Array(json)))
1902}
1903
1904fn skillati_client(keyring: &Keyring) -> Result<SkillAtiClient, SkillAtiError> {
1905    match SkillAtiClient::from_env(keyring)? {
1906        Some(client) => Ok(client),
1907        None => Err(SkillAtiError::NotConfigured),
1908    }
1909}
1910
1911async fn handle_skillati_catalog(
1912    State(state): State<Arc<ProxyState>>,
1913    claims: Option<Extension<TokenClaims>>,
1914    Query(query): Query<SkillAtiCatalogQuery>,
1915) -> impl IntoResponse {
1916    tracing::debug!(search = ?query.search, "GET /skillati/catalog");
1917
1918    let client = match skillati_client(&state.keyring) {
1919        Ok(client) => client,
1920        Err(err) => return skillati_error_response(err),
1921    };
1922
1923    let claims = claims.map(|Extension(c)| c);
1924    let scopes = scopes_for_request(claims.as_ref(), &state);
1925
1926    match client.catalog().await {
1927        Ok(catalog) => {
1928            // Union of local + remote visibility. Merging here (instead of
1929            // calling visible_skill_names_with_remote, which would re-fetch)
1930            // avoids a redundant catalog request on the hot path.
1931            let mut visible_names = visible_skill_names(&state, &scopes);
1932            visible_names.extend(visible_remote_skill_names(&state, &scopes, &catalog));
1933
1934            let mut skills: Vec<_> = catalog
1935                .into_iter()
1936                .filter(|s| visible_names.contains(&s.name))
1937                .collect();
1938            if let Some(search) = query.search.as_deref() {
1939                skills = SkillAtiClient::filter_catalog(&skills, search, 25);
1940            }
1941            (
1942                StatusCode::OK,
1943                Json(serde_json::json!({
1944                    "skills": skills,
1945                })),
1946            )
1947        }
1948        Err(err) => skillati_error_response(err),
1949    }
1950}
1951
1952async fn handle_skillati_read(
1953    State(state): State<Arc<ProxyState>>,
1954    claims: Option<Extension<TokenClaims>>,
1955    axum::extract::Path(name): axum::extract::Path<String>,
1956) -> impl IntoResponse {
1957    tracing::debug!(%name, "GET /skillati/:name");
1958
1959    let client = match skillati_client(&state.keyring) {
1960        Ok(client) => client,
1961        Err(err) => return skillati_error_response(err),
1962    };
1963
1964    let claims = claims.map(|Extension(c)| c);
1965    let scopes = scopes_for_request(claims.as_ref(), &state);
1966    let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
1967        Ok(v) => v,
1968        Err(err) => return skillati_error_response(err),
1969    };
1970    if !visible_names.contains(&name) {
1971        return skillati_error_response(SkillAtiError::SkillNotFound(name));
1972    }
1973
1974    match client.read_skill(&name).await {
1975        Ok(activation) => (StatusCode::OK, Json(serde_json::json!(activation))),
1976        Err(err) => skillati_error_response(err),
1977    }
1978}
1979
1980async fn handle_skillati_resources(
1981    State(state): State<Arc<ProxyState>>,
1982    claims: Option<Extension<TokenClaims>>,
1983    axum::extract::Path(name): axum::extract::Path<String>,
1984    Query(query): Query<SkillAtiResourcesQuery>,
1985) -> impl IntoResponse {
1986    tracing::debug!(%name, prefix = ?query.prefix, "GET /skillati/:name/resources");
1987
1988    let client = match skillati_client(&state.keyring) {
1989        Ok(client) => client,
1990        Err(err) => return skillati_error_response(err),
1991    };
1992
1993    let claims = claims.map(|Extension(c)| c);
1994    let scopes = scopes_for_request(claims.as_ref(), &state);
1995    let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
1996        Ok(v) => v,
1997        Err(err) => return skillati_error_response(err),
1998    };
1999    if !visible_names.contains(&name) {
2000        return skillati_error_response(SkillAtiError::SkillNotFound(name));
2001    }
2002
2003    match client.list_resources(&name, query.prefix.as_deref()).await {
2004        Ok(resources) => (
2005            StatusCode::OK,
2006            Json(serde_json::json!({
2007                "name": name,
2008                "prefix": query.prefix,
2009                "resources": resources,
2010            })),
2011        ),
2012        Err(err) => skillati_error_response(err),
2013    }
2014}
2015
2016async fn handle_skillati_file(
2017    State(state): State<Arc<ProxyState>>,
2018    claims: Option<Extension<TokenClaims>>,
2019    axum::extract::Path(name): axum::extract::Path<String>,
2020    Query(query): Query<SkillAtiFileQuery>,
2021) -> impl IntoResponse {
2022    tracing::debug!(%name, path = %query.path, "GET /skillati/:name/file");
2023
2024    let client = match skillati_client(&state.keyring) {
2025        Ok(client) => client,
2026        Err(err) => return skillati_error_response(err),
2027    };
2028
2029    let claims = claims.map(|Extension(c)| c);
2030    let scopes = scopes_for_request(claims.as_ref(), &state);
2031    let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
2032        Ok(v) => v,
2033        Err(err) => return skillati_error_response(err),
2034    };
2035    if !visible_names.contains(&name) {
2036        return skillati_error_response(SkillAtiError::SkillNotFound(name));
2037    }
2038
2039    match client.read_path(&name, &query.path).await {
2040        Ok(file) => (StatusCode::OK, Json(serde_json::json!(file))),
2041        Err(err) => skillati_error_response(err),
2042    }
2043}
2044
2045async fn handle_skillati_refs(
2046    State(state): State<Arc<ProxyState>>,
2047    claims: Option<Extension<TokenClaims>>,
2048    axum::extract::Path(name): axum::extract::Path<String>,
2049) -> impl IntoResponse {
2050    tracing::debug!(%name, "GET /skillati/:name/refs");
2051
2052    let client = match skillati_client(&state.keyring) {
2053        Ok(client) => client,
2054        Err(err) => return skillati_error_response(err),
2055    };
2056
2057    let claims = claims.map(|Extension(c)| c);
2058    let scopes = scopes_for_request(claims.as_ref(), &state);
2059    let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
2060        Ok(v) => v,
2061        Err(err) => return skillati_error_response(err),
2062    };
2063    if !visible_names.contains(&name) {
2064        return skillati_error_response(SkillAtiError::SkillNotFound(name));
2065    }
2066
2067    match client.list_references(&name).await {
2068        Ok(references) => (
2069            StatusCode::OK,
2070            Json(serde_json::json!({
2071                "name": name,
2072                "references": references,
2073            })),
2074        ),
2075        Err(err) => skillati_error_response(err),
2076    }
2077}
2078
2079async fn handle_skillati_ref(
2080    State(state): State<Arc<ProxyState>>,
2081    claims: Option<Extension<TokenClaims>>,
2082    axum::extract::Path((name, reference)): axum::extract::Path<(String, String)>,
2083) -> impl IntoResponse {
2084    tracing::debug!(%name, %reference, "GET /skillati/:name/ref/:reference");
2085
2086    let client = match skillati_client(&state.keyring) {
2087        Ok(client) => client,
2088        Err(err) => return skillati_error_response(err),
2089    };
2090
2091    let claims = claims.map(|Extension(c)| c);
2092    let scopes = scopes_for_request(claims.as_ref(), &state);
2093    let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
2094        Ok(v) => v,
2095        Err(err) => return skillati_error_response(err),
2096    };
2097    if !visible_names.contains(&name) {
2098        return skillati_error_response(SkillAtiError::SkillNotFound(name));
2099    }
2100
2101    match client.read_reference(&name, &reference).await {
2102        Ok(content) => (
2103            StatusCode::OK,
2104            Json(serde_json::json!({
2105                "name": name,
2106                "reference": reference,
2107                "content": content,
2108            })),
2109        ),
2110        Err(err) => skillati_error_response(err),
2111    }
2112}
2113
2114fn skillati_error_response(err: SkillAtiError) -> (StatusCode, Json<Value>) {
2115    let status = match &err {
2116        SkillAtiError::NotConfigured
2117        | SkillAtiError::UnsupportedRegistry(_)
2118        | SkillAtiError::MissingCredentials(_)
2119        | SkillAtiError::ProxyUrlRequired => StatusCode::SERVICE_UNAVAILABLE,
2120        SkillAtiError::SkillNotFound(_) | SkillAtiError::PathNotFound { .. } => {
2121            StatusCode::NOT_FOUND
2122        }
2123        SkillAtiError::InvalidPath(_) => StatusCode::BAD_REQUEST,
2124        SkillAtiError::Gcs(_)
2125        | SkillAtiError::ProxyRequest(_)
2126        | SkillAtiError::ProxyResponse(_) => StatusCode::BAD_GATEWAY,
2127    };
2128
2129    (
2130        status,
2131        Json(serde_json::json!({
2132            "error": err.to_string(),
2133        })),
2134    )
2135}
2136
2137// --- Auth middleware ---
2138
2139/// JWT authentication middleware.
2140///
2141/// - /health and /.well-known/jwks.json → skip auth
2142/// - JWT configured → validate Bearer token, attach claims to request extensions
2143/// - No JWT configured → allow all (dev mode)
2144async fn auth_middleware(
2145    State(state): State<Arc<ProxyState>>,
2146    mut req: HttpRequest<Body>,
2147    next: Next,
2148) -> Result<Response, StatusCode> {
2149    let path = req.uri().path();
2150
2151    // Skip auth for public endpoints
2152    if path == "/health" || path == "/.well-known/jwks.json" {
2153        return Ok(next.run(req).await);
2154    }
2155
2156    // If no JWT configured, allow all (dev mode)
2157    let jwt_config = match &state.jwt_config {
2158        Some(c) => c,
2159        None => return Ok(next.run(req).await),
2160    };
2161
2162    // Extract Authorization: Bearer <token>. We own the String so the
2163    // `req.headers()` immutable borrow ends before we touch
2164    // `req.extensions_mut()` below — needed because we now stash the raw
2165    // token in the BearerToken extension (issue #115).
2166    let token_owned: String = match req
2167        .headers()
2168        .get("authorization")
2169        .and_then(|v| v.to_str().ok())
2170    {
2171        Some(header) if header.starts_with("Bearer ") => header[7..].to_string(),
2172        _ => return Err(StatusCode::UNAUTHORIZED),
2173    };
2174
2175    // Validate JWT
2176    match jwt::validate(&token_owned, jwt_config) {
2177        Ok(claims) => {
2178            tracing::debug!(sub = %claims.sub, scopes = %claims.scope, "JWT validated");
2179            // Stash the *raw* bearer alongside the parsed claims so handlers
2180            // can forward it to auth_generator scripts via `${JWT_TOKEN}`.
2181            // We deliberately store this AFTER `validate()` succeeds — an
2182            // invalid/expired token never lands in extensions and so can
2183            // never reach a generator. See issue #115.
2184            req.extensions_mut().insert(BearerToken(token_owned));
2185            req.extensions_mut().insert(claims);
2186            Ok(next.run(req).await)
2187        }
2188        Err(e) => {
2189            tracing::debug!(error = %e, "JWT validation failed");
2190            Err(StatusCode::UNAUTHORIZED)
2191        }
2192    }
2193}
2194
2195/// Per-request extension carrying the *raw* inbound bearer (the `<token>`
2196/// from `Authorization: Bearer <token>`). Inserted by `auth_middleware` only
2197/// after `jwt::validate` succeeds, so an expired/forged token never leaks
2198/// into this extension. Consumed by handlers building a `GenContext` so an
2199/// auth_generator can forward the calling sandbox's identity to an upstream
2200/// MCP via `${JWT_TOKEN}` (issue #115).
2201#[derive(Debug, Clone)]
2202pub struct BearerToken(pub String);
2203
2204// --- Router builder ---
2205
2206/// Build the axum Router from a pre-constructed ProxyState.
2207/// Outer body-size ceiling for `POST /call`. Large enough to carry the worst
2208/// case `file_manager:upload` payload (`MAX_UPLOAD_BYTES` of raw bytes,
2209/// base64-inflated ~4/3×, plus a few KB of JSON framing).
2210///
2211/// Per-tool limits (`max_bytes`, `MAX_UPLOAD_BYTES`) plus JWT scopes + rate
2212/// limits are the real gates — this is just the outermost wrapper check.
2213fn max_call_body_bytes() -> usize {
2214    (crate::core::file_manager::MAX_UPLOAD_BYTES as usize)
2215        .saturating_mul(4)
2216        .saturating_div(3)
2217        .saturating_add(8 * 1024)
2218}
2219
2220pub fn build_router(state: Arc<ProxyState>) -> Router {
2221    use axum::extract::DefaultBodyLimit;
2222
2223    Router::new()
2224        .route("/call", post(handle_call))
2225        .route("/help", post(handle_help))
2226        .route("/mcp", post(handle_mcp))
2227        .route("/tools", get(handle_tools_list))
2228        .route("/tools/{name}", get(handle_tool_info))
2229        .route("/skills", get(handle_skills_list))
2230        .route("/skills/resolve", post(handle_skills_resolve))
2231        .route("/skills/bundle", post(handle_skills_bundle_batch))
2232        .route("/skills/{name}", get(handle_skill_detail))
2233        .route("/skills/{name}/bundle", get(handle_skill_bundle))
2234        .route("/skillati/catalog", get(handle_skillati_catalog))
2235        .route("/skillati/{name}", get(handle_skillati_read))
2236        .route("/skillati/{name}/resources", get(handle_skillati_resources))
2237        .route("/skillati/{name}/file", get(handle_skillati_file))
2238        .route("/skillati/{name}/refs", get(handle_skillati_refs))
2239        .route("/skillati/{name}/ref/{reference}", get(handle_skillati_ref))
2240        .route("/health", get(handle_health))
2241        .route("/.well-known/jwks.json", get(handle_jwks))
2242        // Raise axum's default 2 MB body-extractor limit so request bodies
2243        // carrying base64-encoded upload payloads aren't rejected before the
2244        // handler runs. `handle_call` still enforces its own
2245        // `max_call_body_bytes()` cap when streaming the body to bytes.
2246        .layer(DefaultBodyLimit::max(max_call_body_bytes()))
2247        .layer(middleware::from_fn_with_state(
2248            state.clone(),
2249            auth_middleware,
2250        ))
2251        .with_state(state)
2252}
2253
2254// --- Server startup ---
2255
2256/// Start the proxy server.
2257pub async fn run(
2258    port: u16,
2259    bind_addr: Option<String>,
2260    ati_dir: PathBuf,
2261    _verbose: bool,
2262    env_keys: bool,
2263) -> Result<(), Box<dyn std::error::Error>> {
2264    // Load manifests
2265    let manifests_dir = ati_dir.join("manifests");
2266    let mut registry = ManifestRegistry::load(&manifests_dir)?;
2267    let provider_count = registry.list_providers().len();
2268
2269    // Load keyring
2270    let keyring_source;
2271    let keyring = if env_keys {
2272        // --env-keys: scan ATI_KEY_* environment variables
2273        let kr = Keyring::from_env();
2274        let key_names = kr.key_names();
2275        tracing::info!(
2276            count = key_names.len(),
2277            "loaded API keys from ATI_KEY_* env vars"
2278        );
2279        for name in &key_names {
2280            tracing::debug!(key = %name, "env key loaded");
2281        }
2282        keyring_source = "env-vars (ATI_KEY_*)";
2283        kr
2284    } else {
2285        // Cascade: keyring.enc (sealed) → keyring.enc (persistent) → credentials → empty
2286        let keyring_path = ati_dir.join("keyring.enc");
2287        if keyring_path.exists() {
2288            if let Ok(kr) = Keyring::load(&keyring_path) {
2289                keyring_source = "keyring.enc (sealed key)";
2290                kr
2291            } else if let Ok(kr) = Keyring::load_local(&keyring_path, &ati_dir) {
2292                keyring_source = "keyring.enc (persistent key)";
2293                kr
2294            } else {
2295                tracing::warn!("keyring.enc exists but could not be decrypted");
2296                keyring_source = "empty (decryption failed)";
2297                Keyring::empty()
2298            }
2299        } else {
2300            let creds_path = ati_dir.join("credentials");
2301            if creds_path.exists() {
2302                match Keyring::load_credentials(&creds_path) {
2303                    Ok(kr) => {
2304                        keyring_source = "credentials (plaintext)";
2305                        kr
2306                    }
2307                    Err(e) => {
2308                        tracing::warn!(error = %e, "failed to load credentials");
2309                        keyring_source = "empty (credentials error)";
2310                        Keyring::empty()
2311                    }
2312                }
2313            } else {
2314                tracing::warn!("no keyring.enc or credentials found — running without API keys");
2315                tracing::warn!("tools requiring authentication will fail");
2316                keyring_source = "empty (no auth)";
2317                Keyring::empty()
2318            }
2319        }
2320    };
2321
2322    // Discover MCP tools at startup so they appear in GET /tools.
2323    // Runs concurrently across providers with 30s per-provider timeout.
2324    mcp_client::discover_all_mcp_tools(&mut registry, &keyring).await;
2325
2326    let tool_count = registry.list_public_tools().len();
2327
2328    // Log MCP and OpenAPI providers
2329    let mcp_providers: Vec<(String, String)> = registry
2330        .list_mcp_providers()
2331        .iter()
2332        .map(|p| (p.name.clone(), p.mcp_transport_type().to_string()))
2333        .collect();
2334    let mcp_count = mcp_providers.len();
2335    let openapi_providers: Vec<String> = registry
2336        .list_openapi_providers()
2337        .iter()
2338        .map(|p| p.name.clone())
2339        .collect();
2340    let openapi_count = openapi_providers.len();
2341
2342    // Load installed/local skill registry only.
2343    let skills_dir = ati_dir.join("skills");
2344    let skill_registry = SkillRegistry::load(&skills_dir).unwrap_or_else(|e| {
2345        tracing::warn!(error = %e, "failed to load skills");
2346        SkillRegistry::load(std::path::Path::new("/nonexistent-fallback")).unwrap()
2347    });
2348
2349    if let Ok(registry_url) = std::env::var("ATI_SKILL_REGISTRY") {
2350        if registry_url.strip_prefix("gcs://").is_some() {
2351            tracing::info!(
2352                registry = %registry_url,
2353                "SkillATI remote registry configured for lazy reads"
2354            );
2355        } else {
2356            tracing::warn!(url = %registry_url, "SkillATI only supports gcs:// registries");
2357        }
2358    }
2359
2360    let skill_count = skill_registry.skill_count();
2361
2362    // Load JWT config from environment
2363    let jwt_config = match jwt::config_from_env() {
2364        Ok(config) => config,
2365        Err(e) => {
2366            tracing::warn!(error = %e, "JWT config error");
2367            None
2368        }
2369    };
2370
2371    let auth_status = if jwt_config.is_some() {
2372        "JWT enabled"
2373    } else {
2374        "DISABLED (no JWT keys configured)"
2375    };
2376
2377    // Build JWKS for the endpoint
2378    let jwks_json = jwt_config.as_ref().and_then(|config| {
2379        config
2380            .public_key_pem
2381            .as_ref()
2382            .and_then(|pem| jwt::public_key_to_jwks(pem, config.algorithm, "ati-proxy-1").ok())
2383    });
2384
2385    let state = Arc::new(ProxyState {
2386        registry,
2387        skill_registry,
2388        keyring,
2389        jwt_config,
2390        jwks_json,
2391        auth_cache: AuthCache::new(),
2392        upstream_url_allowlists: std::sync::Arc::new(std::sync::Mutex::new(
2393            std::collections::HashMap::new(),
2394        )),
2395    });
2396
2397    let app = build_router(state);
2398
2399    let addr: SocketAddr = if let Some(ref bind) = bind_addr {
2400        format!("{bind}:{port}").parse()?
2401    } else {
2402        SocketAddr::from(([127, 0, 0, 1], port))
2403    };
2404
2405    tracing::info!(
2406        version = env!("CARGO_PKG_VERSION"),
2407        %addr,
2408        auth = auth_status,
2409        ati_dir = %ati_dir.display(),
2410        tools = tool_count,
2411        providers = provider_count,
2412        mcp = mcp_count,
2413        openapi = openapi_count,
2414        skills = skill_count,
2415        keyring = keyring_source,
2416        "ATI proxy server starting"
2417    );
2418    for (name, transport) in &mcp_providers {
2419        tracing::info!(provider = %name, transport = %transport, "MCP provider");
2420    }
2421    for name in &openapi_providers {
2422        tracing::info!(provider = %name, "OpenAPI provider");
2423    }
2424
2425    let listener = tokio::net::TcpListener::bind(addr).await?;
2426    axum::serve(listener, app).await?;
2427
2428    Ok(())
2429}
2430
2431/// Dispatch a `file_manager:*` tool call. Returns either a JSON payload or an
2432/// (HTTP status, message) error for the caller to forward.
2433async fn dispatch_file_manager(
2434    tool_name: &str,
2435    args: &HashMap<String, Value>,
2436    provider: &Provider,
2437    keyring: &Keyring,
2438) -> Result<Value, (StatusCode, String)> {
2439    use crate::core::file_manager::{self, DownloadArgs, FileManagerError, UploadArgs};
2440
2441    // One mapping, derived from FileManagerError::http_status, so adding an
2442    // error variant can't silently regress one handler while the other updates.
2443    let to_resp = |e: FileManagerError| {
2444        let status =
2445            StatusCode::from_u16(e.http_status()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
2446        (status, e.to_string())
2447    };
2448
2449    match tool_name {
2450        "file_manager:download" => {
2451            let parsed = DownloadArgs::from_value(args).map_err(to_resp)?;
2452            let result = file_manager::fetch_bytes(&parsed).await.map_err(to_resp)?;
2453            Ok(file_manager::build_download_response(&result))
2454        }
2455        "file_manager:upload" => {
2456            let parsed = UploadArgs::from_wire(args).map_err(to_resp)?;
2457            file_manager::upload_to_destination(
2458                parsed,
2459                &provider.upload_destinations,
2460                provider.upload_default_destination.as_deref(),
2461                keyring,
2462            )
2463            .await
2464            .map_err(to_resp)
2465        }
2466        other => Err((
2467            StatusCode::NOT_FOUND,
2468            format!("Unknown file_manager tool: '{other}'"),
2469        )),
2470    }
2471}
2472
2473fn write_proxy_audit(
2474    call_req: &CallRequest,
2475    agent_sub: &str,
2476    claims: Option<&TokenClaims>,
2477    duration: std::time::Duration,
2478    error: Option<&str>,
2479) {
2480    let entry = crate::core::audit::AuditEntry {
2481        ts: chrono::Utc::now().to_rfc3339(),
2482        tool: call_req.tool_name.clone(),
2483        args: crate::core::audit::sanitize_args(&call_req.args),
2484        status: if error.is_some() {
2485            crate::core::audit::AuditStatus::Error
2486        } else {
2487            crate::core::audit::AuditStatus::Ok
2488        },
2489        duration_ms: duration.as_millis() as u64,
2490        agent_sub: agent_sub.to_string(),
2491        job_id: claims.and_then(|c| c.job_id.clone()),
2492        sandbox_id: claims.and_then(|c| c.sandbox_id.clone()),
2493        error: error.map(|s| s.to_string()),
2494        exit_code: None,
2495    };
2496    let _ = crate::core::audit::append(&entry);
2497}
2498
2499// --- Helpers ---
2500
2501const HELP_SYSTEM_PROMPT: &str = r#"You are a helpful assistant for an AI agent that uses external tools via the `ati` CLI.
2502
2503## Available Tools
2504{tools}
2505
2506{skills_section}
2507
2508Answer the agent's question naturally, like a knowledgeable colleague would. Keep it short but useful:
2509
2510- Explain which tools to use and why, with `ati run` commands showing realistic parameter values
2511- If multiple steps are needed, walk through them briefly in order
2512- Mention important gotchas or parameter choices that matter
2513- If skills are relevant, tell the agent to load them using the Skill tool (e.g., `skill: "research-financial-data"`)
2514
2515Keep your answer concise — a few short paragraphs with embedded code blocks. Only recommend tools from the list above."#;
2516
2517async fn build_remote_skillati_section(keyring: &Keyring, query: &str, limit: usize) -> String {
2518    let client = match SkillAtiClient::from_env(keyring) {
2519        Ok(Some(client)) => client,
2520        Ok(None) => return String::new(),
2521        Err(err) => {
2522            tracing::warn!(error = %err, "failed to initialize SkillATI catalog for proxy help");
2523            return String::new();
2524        }
2525    };
2526
2527    let catalog = match client.catalog().await {
2528        Ok(catalog) => catalog,
2529        Err(err) => {
2530            tracing::warn!(error = %err, "failed to load SkillATI catalog for proxy help");
2531            return String::new();
2532        }
2533    };
2534
2535    let matched = SkillAtiClient::filter_catalog(&catalog, query, limit);
2536    if matched.is_empty() {
2537        return String::new();
2538    }
2539
2540    render_remote_skillati_section(&matched, catalog.len())
2541}
2542
2543fn render_remote_skillati_section(skills: &[RemoteSkillMeta], total_catalog: usize) -> String {
2544    let mut section = String::from("## Remote Skills Available Via SkillATI\n\n");
2545    section.push_str(
2546        "These skills are available. Load them using the Skill tool (e.g., `skill: \"skill-name\"`).\n\n",
2547    );
2548
2549    for skill in skills {
2550        section.push_str(&format!("- **{}**: {}\n", skill.name, skill.description));
2551    }
2552
2553    if total_catalog > skills.len() {
2554        section.push_str(&format!(
2555            "\nOnly the most relevant {} remote skills are shown here.\n",
2556            skills.len()
2557        ));
2558    }
2559
2560    section
2561}
2562
2563fn merge_help_skill_sections(sections: &[String]) -> String {
2564    sections
2565        .iter()
2566        .filter_map(|section| {
2567            let trimmed = section.trim();
2568            if trimmed.is_empty() {
2569                None
2570            } else {
2571                Some(trimmed.to_string())
2572            }
2573        })
2574        .collect::<Vec<_>>()
2575        .join("\n\n")
2576}
2577
2578fn build_tool_context(
2579    tools: &[(
2580        &crate::core::manifest::Provider,
2581        &crate::core::manifest::Tool,
2582    )],
2583) -> String {
2584    let mut summaries = Vec::new();
2585    for (provider, tool) in tools {
2586        let mut summary = if let Some(cat) = &provider.category {
2587            format!(
2588                "- **{}** (provider: {}, category: {}): {}",
2589                tool.name, provider.name, cat, tool.description
2590            )
2591        } else {
2592            format!(
2593                "- **{}** (provider: {}): {}",
2594                tool.name, provider.name, tool.description
2595            )
2596        };
2597        if !tool.tags.is_empty() {
2598            summary.push_str(&format!("\n  Tags: {}", tool.tags.join(", ")));
2599        }
2600        // CLI tools: show passthrough usage
2601        if provider.is_cli() && tool.input_schema.is_none() {
2602            let cmd = provider.cli_command.as_deref().unwrap_or("?");
2603            summary.push_str(&format!(
2604                "\n  Usage: `ati run {} -- <args>`  (passthrough to `{}`)",
2605                tool.name, cmd
2606            ));
2607        } else if let Some(schema) = &tool.input_schema {
2608            if let Some(props) = schema.get("properties") {
2609                if let Some(obj) = props.as_object() {
2610                    let params: Vec<String> = obj
2611                        .iter()
2612                        .filter(|(_, v)| {
2613                            v.get("x-ati-param-location").is_none()
2614                                || v.get("description").is_some()
2615                        })
2616                        .map(|(k, v)| {
2617                            let type_str =
2618                                v.get("type").and_then(|t| t.as_str()).unwrap_or("string");
2619                            let desc = v.get("description").and_then(|d| d.as_str()).unwrap_or("");
2620                            format!("    --{k} ({type_str}): {desc}")
2621                        })
2622                        .collect();
2623                    if !params.is_empty() {
2624                        summary.push_str("\n  Parameters:\n");
2625                        summary.push_str(&params.join("\n"));
2626                    }
2627                }
2628            }
2629        }
2630        summaries.push(summary);
2631    }
2632    summaries.join("\n\n")
2633}
2634
2635/// Build a scoped system prompt for a specific tool or provider.
2636///
2637/// Returns None if the scope_name doesn't match any tool or provider.
2638fn build_scoped_prompt(
2639    scope_name: &str,
2640    visible_tools: &[(&Provider, &Tool)],
2641    skills_section: &str,
2642) -> Option<String> {
2643    // Check if scope_name is a tool
2644    if let Some((provider, tool)) = visible_tools
2645        .iter()
2646        .find(|(_, tool)| tool.name == scope_name)
2647    {
2648        let mut details = format!(
2649            "**Name**: `{}`\n**Provider**: {} (handler: {})\n**Description**: {}\n",
2650            tool.name, provider.name, provider.handler, tool.description
2651        );
2652        if let Some(cat) = &provider.category {
2653            details.push_str(&format!("**Category**: {}\n", cat));
2654        }
2655        if provider.is_cli() {
2656            let cmd = provider.cli_command.as_deref().unwrap_or("?");
2657            details.push_str(&format!(
2658                "\n**Usage**: `ati run {} -- <args>`  (passthrough to `{}`)\n",
2659                tool.name, cmd
2660            ));
2661        } else if let Some(schema) = &tool.input_schema {
2662            if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
2663                let required: Vec<String> = schema
2664                    .get("required")
2665                    .and_then(|r| r.as_array())
2666                    .map(|arr| {
2667                        arr.iter()
2668                            .filter_map(|v| v.as_str().map(|s| s.to_string()))
2669                            .collect()
2670                    })
2671                    .unwrap_or_default();
2672                details.push_str("\n**Parameters**:\n");
2673                for (key, val) in props {
2674                    let type_str = val.get("type").and_then(|t| t.as_str()).unwrap_or("string");
2675                    let desc = val
2676                        .get("description")
2677                        .and_then(|d| d.as_str())
2678                        .unwrap_or("");
2679                    let req = if required.contains(key) {
2680                        " **(required)**"
2681                    } else {
2682                        ""
2683                    };
2684                    details.push_str(&format!("- `--{key}` ({type_str}{req}): {desc}\n"));
2685                }
2686            }
2687        }
2688
2689        let prompt = format!(
2690            "You are an expert assistant for the `{}` tool, accessed via the `ati` CLI.\n\n\
2691            ## Tool Details\n{}\n\n{}\n\n\
2692            Answer the agent's question about this specific tool. Provide exact commands, explain flags and options, and give practical examples. Be concise and actionable.",
2693            tool.name, details, skills_section
2694        );
2695        return Some(prompt);
2696    }
2697
2698    // Check if scope_name is a provider
2699    let tools: Vec<(&Provider, &Tool)> = visible_tools
2700        .iter()
2701        .copied()
2702        .filter(|(provider, _)| provider.name == scope_name)
2703        .collect();
2704    if !tools.is_empty() {
2705        let tools_context = build_tool_context(&tools);
2706        let prompt = format!(
2707            "You are an expert assistant for the `{}` provider's tools, accessed via the `ati` CLI.\n\n\
2708            ## Tools in provider `{}`\n{}\n\n{}\n\n\
2709            Answer the agent's question about these tools. Provide exact `ati run` commands, explain parameters, and give practical examples. Be concise and actionable.",
2710            scope_name, scope_name, tools_context, skills_section
2711        );
2712        return Some(prompt);
2713    }
2714
2715    None
2716}