Skip to main content

ati/proxy/
server.rs

1/// ATI proxy server — holds API keys and executes tool calls on behalf of sandbox agents.
2///
3/// Authentication: ES256-signed JWT (or HS256 fallback). The JWT carries identity,
4/// scopes, and expiry. No more static tokens or unsigned scope lists.
5///
6/// Usage: `ati proxy --port 8080 [--ati-dir ~/.ati]`
7use axum::{
8    body::Body,
9    extract::{Extension, Query, State},
10    http::{Request as HttpRequest, StatusCode},
11    middleware::{self, Next},
12    response::{IntoResponse, Response},
13    routing::{get, post},
14    Json, Router,
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::core::auth_generator::{AuthCache, GenContext};
24use crate::core::http;
25use crate::core::jwt::{self, JwtConfig, TokenClaims};
26use crate::core::keyring::Keyring;
27use crate::core::manifest::{ManifestRegistry, Provider, Tool};
28use crate::core::mcp_client;
29use crate::core::response;
30use crate::core::scope::ScopeConfig;
31use crate::core::sentry_scope;
32use crate::core::skill::{self, SkillRegistry};
33use crate::core::skillati::{RemoteSkillMeta, SkillAtiClient, SkillAtiError};
34
35/// Shared state for the proxy server.
36pub struct ProxyState {
37    pub registry: ManifestRegistry,
38    pub skill_registry: SkillRegistry,
39    pub keyring: Keyring,
40    /// JWT validation config (None = auth disabled / dev mode).
41    pub jwt_config: Option<JwtConfig>,
42    /// Pre-computed JWKS JSON for the /.well-known/jwks.json endpoint.
43    pub jwks_json: Option<Value>,
44    /// Shared cache for dynamically generated auth credentials.
45    pub auth_cache: AuthCache,
46    /// Per-provider upstream-URL allowlists, compiled lazily on first
47    /// per-request validation and cached for the process lifetime. Keyed
48    /// by provider name. `None` value = keyring entry missing; we cache
49    /// negatives too to avoid repeating the lookup on every request.
50    /// Operator hot-reload of an allowlist requires proxy restart (same
51    /// constraint as every other keyring entry today). See issue #124.
52    ///
53    /// Entries are pre-parsed `UpstreamAllowEntry` structs (scheme + host
54    /// glob + canonical path) rather than raw URL glob patterns. Glob over
55    /// raw URL strings is unsafe — `*` can cross `.`, `#`, `?`, `:` and other
56    /// URL delimiters, letting a sandbox-crafted URL satisfy a pattern while
57    /// `reqwest` connects to a different host. Greptile P0/P1 on #124.
58    pub upstream_url_allowlists:
59        Arc<std::sync::Mutex<std::collections::HashMap<String, Option<Vec<UpstreamAllowEntry>>>>>,
60    /// Lazy-discovered MCP `inputSchema` cache for sandbox-supplied-URL
61    /// providers (issue #135).
62    ///
63    /// For providers with `handler = "mcp"` + `mcp_url_env` set, the URL
64    /// isn't proxy-resident so we cannot discover tools at boot; operators
65    /// must declare tool *names* statically in the manifest, but the
66    /// *schemas* are gaps unless they mirror them by hand. This cache
67    /// holds schemas fetched on demand via `tools/list` against the
68    /// sandbox-supplied upstream when an `ati tool info` / `tools/list`
69    /// request asks for one and the static entry has `input_schema: None`.
70    ///
71    /// Keyed by `(provider_name, upstream_url)` so different sandboxes
72    /// pointed at different upstreams each get their own cache entry.
73    /// Value is `Option<HashMap<tool_name, inputSchema>>`:
74    ///   - `Some(map)` — fetched successfully; serve `map[tool_name]`.
75    ///   - `None` — fetched and the dial / list failed; remember the
76    ///     negative result so we don't re-attempt on every subsequent
77    ///     request. Cache lives for the process lifetime (matching the
78    ///     allowlist cache; operator hot-reload requires restart).
79    ///
80    /// Lookup never blocks a response: a discovery miss falls back to
81    /// the static `Tool.input_schema` (which is `None` in the exact case
82    /// this cache exists to handle), preserving today's behavior.
83    pub lazy_schema_cache: LazySchemaCache,
84}
85
86/// Type alias for the issue #135 lazy-schema cache. Keyed by
87/// `(provider_name, upstream_url)`. The value is `Option<map>` so the
88/// negative result (dial failed / list failed) can be cached too —
89/// avoids retrying a known-bad upstream on every subsequent request.
90/// Named so clippy's `type_complexity` lint doesn't fire on the field.
91pub type LazySchemaCache = Arc<
92    std::sync::Mutex<std::collections::HashMap<(String, String), Option<HashMap<String, Value>>>>,
93>;
94
95/// One parsed allowlist entry. Allowlist CSV entries are operator-authored
96/// URL templates like `https://parcha-tools-*.grep.ai/mcp`. We parse them
97/// into (scheme, host_label_patterns, path) at load time so per-request
98/// validation can match each component separately:
99/// - **scheme**: exact string match (case-insensitive per RFC 3986)
100/// - **host_label_patterns**: per-DNS-label globs; the runtime URL's host
101///   is split on `.` and each label matched against the corresponding
102///   pattern entry. `*` can never bridge labels because labels are matched
103///   independently.
104/// - **path**: exact match against the URL's path; query/fragment ignored
105///
106/// This is the secure shape; raw URL globs let `*` cross `.`/`#`/`?`/`:`
107/// (Greptile P0/P1 on #124). Per-label matching closes those bypasses.
108#[derive(Debug, Clone)]
109pub struct UpstreamAllowEntry {
110    /// `"https"` or `"http"` — operator-declared scheme. Exact match required.
111    pub scheme: String,
112    /// One glob per DNS label, e.g. `parcha-tools-*.grep.ai` →
113    /// `["parcha-tools-*", "grep", "ai"]`. The runtime URL's lowercased
114    /// host is split on `.` and label-by-label matched against this Vec.
115    /// Length must match exactly (no `**`-style multi-label wildcards).
116    pub host_label_patterns: Vec<glob::Pattern>,
117    /// Exact path match required. Empty operator path becomes `"/"`.
118    pub path: String,
119}
120
121/// Outcome of `X-Ati-Upstream-Url` validation for a single request. The
122/// proxy handler turns `Reject` into an HTTP response and `Allow` into a
123/// param threaded down to `mcp_client::execute_with_gen`. See issue #124.
124#[derive(Debug)]
125enum UpstreamOverride {
126    /// No header sent, no override needed. The MCP client will use
127    /// `provider.mcp_url` as today.
128    None,
129    /// Header sent, validated against the operator allowlist. The MCP
130    /// client should dial this URL instead of `provider.mcp_url`.
131    Allow(String),
132    /// Header sent, but the request must be rejected. Carries the HTTP
133    /// status and a human-readable message for the response body.
134    Reject(StatusCode, String),
135}
136
137/// Resolve the optional sandbox-supplied upstream URL for a single
138/// `/call` request. Six cases per the issue #124 spec:
139///   1. Header absent + `mcp_url_env` absent → `None`.
140///   2. Header present + `mcp_url_env` absent → 400 reject (fail loud).
141///   3. Header absent + `mcp_url_env` present → `None` (mcp_client falls
142///      back to `provider.mcp_url`).
143///   4. Header present + `mcp_url_env` present + keyring has no allowlist
144///      → 403 reject (fail closed; operator must opt in).
145///   5. Header present + URL doesn't match any glob → 403 reject.
146///   6. Header present + URL matches → `Allow(url)`.
147fn resolve_upstream_override(
148    state: &ProxyState,
149    provider: &crate::core::manifest::Provider,
150    header_value: Option<&str>,
151) -> UpstreamOverride {
152    let provider_accepts_override = provider.mcp_url_env.is_some();
153    match (header_value, provider_accepts_override) {
154        (None, _) => UpstreamOverride::None,
155        (Some(_), false) => UpstreamOverride::Reject(
156            StatusCode::BAD_REQUEST,
157            format!(
158                "X-Ati-Upstream-Url sent for provider '{}' which does not declare mcp_url_env",
159                provider.name
160            ),
161        ),
162        (Some(url), true) => {
163            let allowlist_key = format!("{}_allowed_urls", provider.name);
164            // Lazy-build the per-provider patterns on first use; cache for
165            // process lifetime (operator hot-reload requires restart, same
166            // as every other keyring entry today).
167            let entries = {
168                let mut cache = state.upstream_url_allowlists.lock().unwrap();
169                if !cache.contains_key(&provider.name) {
170                    let compiled: Option<Vec<UpstreamAllowEntry>> = state
171                        .keyring
172                        .get(&allowlist_key)
173                        .and_then(|csv| match build_url_allowlist(csv) {
174                            Ok(set) => set,
175                            Err(e) => {
176                                tracing::warn!(
177                                    provider = %provider.name,
178                                    error = %e,
179                                    "failed to compile upstream URL allowlist; treating as missing"
180                                );
181                                None
182                            }
183                        });
184                    cache.insert(provider.name.clone(), compiled);
185                }
186                cache.get(&provider.name).cloned().flatten()
187            };
188            let Some(entries) = entries else {
189                return UpstreamOverride::Reject(
190                    StatusCode::FORBIDDEN,
191                    format!(
192                        "Provider '{}' has no upstream URL allowlist configured (set ATI_KEY_{}_ALLOWED_URLS on the proxy)",
193                        provider.name,
194                        provider.name.to_uppercase()
195                    ),
196                );
197            };
198            // Parse the sandbox-supplied URL with `url::Url` (which canonicalizes
199            // scheme/host/path and surfaces fragment/query separately). Glob over
200            // the canonical host only — never the raw URL string — so attackers
201            // can't smuggle a different host via `.`, `#`, `?`, `:`, etc.
202            let parsed = match url::Url::parse(url) {
203                Ok(u) => u,
204                Err(e) => {
205                    return UpstreamOverride::Reject(
206                        StatusCode::BAD_REQUEST,
207                        format!("X-Ati-Upstream-Url '{url}' is not a valid URL: {e}"),
208                    );
209                }
210            };
211            if matches_allowlist(&parsed, &entries) {
212                UpstreamOverride::Allow(url.to_string())
213            } else {
214                UpstreamOverride::Reject(
215                    StatusCode::FORBIDDEN,
216                    format!(
217                        "Upstream URL '{url}' not in provider '{}'s allowlist",
218                        provider.name
219                    ),
220                )
221            }
222        }
223    }
224}
225
226/// Issue #135 helper: extract a validated sandbox-supplied upstream URL
227/// from request headers for a given provider, or `None` if absent/rejected.
228///
229/// Wraps `resolve_upstream_override` for the REST tool-info path. REST
230/// callers can't surface a structured rejection like `/call` does
231/// (those are JSON-RPC errors); they just need "do we have a URL to
232/// discover from, yes or no". A rejected URL silently becomes `None`
233/// here — the handler will fall back to the static schema (today's
234/// behavior), which is exactly the graceful-degradation contract the
235/// issue acceptance criteria asks for.
236fn extract_upstream_url(
237    state: &ProxyState,
238    provider: &Provider,
239    headers: &axum::http::HeaderMap,
240) -> Option<String> {
241    let header_value = headers
242        .get("x-ati-upstream-url")
243        .and_then(|v| v.to_str().ok())
244        .map(|s| s.trim())
245        .filter(|s| !s.is_empty());
246
247    match resolve_upstream_override(state, provider, header_value) {
248        UpstreamOverride::Allow(url) => Some(url),
249        UpstreamOverride::None => None,
250        UpstreamOverride::Reject(status, msg) => {
251            tracing::debug!(
252                provider = %provider.name,
253                status = status.as_u16(),
254                reason = %msg,
255                "lazy schema discovery: upstream URL rejected, falling back to static schema"
256            );
257            None
258        }
259    }
260}
261
262/// Issue #135 helper: lazily fetch the per-tool `inputSchema` map for a
263/// sandbox-supplied-URL MCP provider whose static manifest entries declare
264/// tool names but no schemas. Results are cached per `(provider, upstream_url)`
265/// for the process lifetime — both positive (the schema map) and negative
266/// (the upstream was unreachable / the dial failed) — so a steady-state of
267/// `ati tool info` calls only pays the discovery cost once per sandbox URL.
268///
269/// Returns `None` when:
270///   - the provider isn't `handler = "mcp"` with `mcp_url_env` set (we have
271///     nothing to discover from);
272///   - no validated upstream URL is in scope for this request;
273///   - or the discovery attempt failed and the cache has the negative result.
274///
275/// Never propagates an error to the caller: a discovery miss falls back to
276/// today's behavior (return `Tool.input_schema`, which is `None` in the case
277/// this helper exists to handle), so a flaky upstream cannot turn a working
278/// `tool info` endpoint into a 5xx.
279async fn lazy_fetch_schemas(
280    state: &ProxyState,
281    provider: &Provider,
282    upstream_url: &str,
283) -> Option<HashMap<String, Value>> {
284    let key = (provider.name.clone(), upstream_url.to_string());
285
286    // Cache hit (positive or negative) — return without dialing.
287    {
288        let cache = state.lazy_schema_cache.lock().ok()?;
289        if let Some(cached) = cache.get(&key) {
290            return cached.clone();
291        }
292    }
293
294    // Cache miss: dial the upstream, run tools/list, build the schema map.
295    // We deliberately use the same `connect_with_gen` path the /call handler
296    // takes (issue #124), so allowlist semantics and transport selection
297    // stay identical — no second code path to keep in sync.
298    let discovered: Option<HashMap<String, Value>> = match mcp_client::McpClient::connect_with_gen(
299        provider,
300        &state.keyring,
301        None,
302        Some(&state.auth_cache),
303        Some(upstream_url),
304    )
305    .await
306    {
307        Ok(client) => {
308            let result = client.list_tools().await;
309            client.disconnect().await;
310            match result {
311                Ok(tools) => {
312                    let mut map = HashMap::new();
313                    for t in tools {
314                        if let Some(schema) = t.input_schema {
315                            map.insert(t.name, schema);
316                        }
317                    }
318                    Some(map)
319                }
320                Err(e) => {
321                    tracing::warn!(
322                        provider = %provider.name,
323                        upstream = %upstream_url,
324                        error = %e,
325                        "lazy schema discovery: tools/list failed"
326                    );
327                    None
328                }
329            }
330        }
331        Err(e) => {
332            tracing::warn!(
333                provider = %provider.name,
334                upstream = %upstream_url,
335                error = %e,
336                "lazy schema discovery: MCP connect failed"
337            );
338            None
339        }
340    };
341
342    // Cache both positive and negative results to avoid re-dialing on every
343    // subsequent tool-info request — operators who fix a misconfigured
344    // upstream URL will need to restart the proxy (same constraint as the
345    // allowlist cache).
346    if let Ok(mut cache) = state.lazy_schema_cache.lock() {
347        cache.insert(key, discovered.clone());
348    }
349    discovered
350}
351
352/// Returns true iff `url`'s scheme + host + path matches at least one
353/// allowlist entry. Each component checked separately:
354/// - **scheme**: exact string match (case-insensitive per RFC 3986)
355/// - **host**: split on `.` into DNS labels; matched label-by-label against
356///   the entry's `host_label_patterns`. Label count must match exactly so
357///   `*.grep.ai` can't be satisfied by `evil.com.grep.ai` (3 labels vs 4).
358/// - **path**: exact string match (no globs). Query + fragment ignored.
359///
360/// Hostnames are lowercased before comparison (DNS case-insensitive). Userinfo
361/// (`user:pass@`) is rejected outright; allowlist patterns never carry it.
362/// Default ports (`:443` for https, `:80` for http) are accepted but only
363/// when the operator pattern itself doesn't pin a port; otherwise the entry
364/// would reject any URL because `url::Url` strips default ports.
365fn matches_allowlist(url: &url::Url, entries: &[UpstreamAllowEntry]) -> bool {
366    if !url.username().is_empty() || url.password().is_some() {
367        return false;
368    }
369    let Some(host) = url.host_str() else {
370        return false;
371    };
372    let host_lower = host.to_ascii_lowercase();
373    let host_labels: Vec<&str> = host_lower.split('.').collect();
374    let path = url.path();
375    for entry in entries {
376        if !entry.scheme.eq_ignore_ascii_case(url.scheme()) {
377            continue;
378        }
379        // Label-count check is the per-DNS-label-boundary constraint:
380        // pattern `*.grep.ai` (3 labels) cannot match `evil.com.grep.ai`
381        // (4 labels). This is the core security primitive that closes the
382        // `*`-crosses-`.` bypass Greptile flagged on #124.
383        if entry.host_label_patterns.len() != host_labels.len() {
384            continue;
385        }
386        if !entry
387            .host_label_patterns
388            .iter()
389            .zip(host_labels.iter())
390            .all(|(pat, label)| pat.matches(label))
391        {
392            continue;
393        }
394        if entry.path != path {
395            continue;
396        }
397        return true;
398    }
399    false
400}
401
402/// Compile a CSV of operator URL patterns into pre-parsed
403/// [`UpstreamAllowEntry`] structs. Each pattern is a URL template like
404/// `https://parcha-tools-*.grep.ai/mcp` — parsed with `url::Url`, then the
405/// host is split into DNS labels and each label compiled as a separate
406/// glob component. Returns `Ok(None)` for empty input / all-whitespace
407/// (caller treats as "no allowlist", same as missing keyring entry).
408///
409/// ## Why per-label host compilation
410///
411/// `glob::Pattern` doesn't natively treat `.` as a separator. Naive
412/// `Pattern::new("parcha-tools-*.grep.ai")` would let `*` swallow
413/// `evil.com` from `parcha-tools-evil.com.grep.ai`. To prevent that we
414/// reassemble the host pattern label-by-label after validating each label
415/// glob contains no embedded `.` itself. The composed `host_pattern` then
416/// uses `*` semantics that stop at label boundaries.
417///
418/// ## Why exact path match
419///
420/// Glob over paths invites the same boundary-crossing problems. MCP
421/// endpoints are conventionally `/mcp` or `/` — operators write the literal
422/// path, sandbox URLs must match exactly. Query strings and URL fragments
423/// are stripped before comparison.
424///
425/// ## Errors
426///
427/// - `pattern not a valid URL` — operator entry isn't parseable as a URL
428/// - `pattern missing host` — e.g., `https:///mcp`
429/// - `pattern host label contains '.'` — would let `*` cross DNS boundaries
430/// - `pattern has userinfo` — credentials in allowlist entries are refused
431fn build_url_allowlist(csv: &str) -> Result<Option<Vec<UpstreamAllowEntry>>, String> {
432    let mut entries = Vec::new();
433    for raw in csv.split(',') {
434        let pat = raw.trim();
435        if pat.is_empty() {
436            continue;
437        }
438        let parsed = url::Url::parse(pat)
439            .map_err(|e| format!("upstream allowlist pattern '{pat}' is not a valid URL: {e}"))?;
440        if !parsed.username().is_empty() || parsed.password().is_some() {
441            return Err(format!(
442                "upstream allowlist pattern '{pat}' must not include userinfo"
443            ));
444        }
445        let host = parsed
446            .host_str()
447            .ok_or_else(|| format!("upstream allowlist pattern '{pat}' has no host"))?;
448        // Reject `**` outright — it's a glob escape hatch.
449        if host.contains("**") {
450            return Err(format!(
451                "upstream allowlist pattern '{pat}' must not contain '**' (use single '*' per DNS label)"
452            ));
453        }
454        let host_lower = host.to_ascii_lowercase();
455        let mut host_label_patterns = Vec::new();
456        for label in host_lower.split('.') {
457            if label.is_empty() {
458                return Err(format!(
459                    "upstream allowlist pattern '{pat}' has empty DNS label"
460                ));
461            }
462            let p = glob::Pattern::new(label).map_err(|e| {
463                format!(
464                    "upstream allowlist pattern '{pat}' has invalid host label glob '{label}': {e}"
465                )
466            })?;
467            host_label_patterns.push(p);
468        }
469        let path = if parsed.path().is_empty() {
470            "/".to_string()
471        } else {
472            parsed.path().to_string()
473        };
474        entries.push(UpstreamAllowEntry {
475            scheme: parsed.scheme().to_string(),
476            host_label_patterns,
477            path,
478        });
479    }
480    if entries.is_empty() {
481        return Ok(None);
482    }
483    Ok(Some(entries))
484}
485
486// --- Request/Response types ---
487
488#[derive(Debug, Deserialize)]
489pub struct CallRequest {
490    pub tool_name: String,
491    /// Tool arguments — accepts a JSON object (key-value pairs) for HTTP/MCP/OpenAPI tools,
492    /// or a JSON array of strings / a single string for CLI tools.
493    /// The proxy auto-detects the handler type and routes accordingly.
494    #[serde(default = "default_args")]
495    pub args: Value,
496    /// Deprecated: use `args` with an array value instead.
497    /// Kept for backward compatibility — if present, takes precedence for CLI tools.
498    #[serde(default)]
499    pub raw_args: Option<Vec<String>>,
500}
501
502fn default_args() -> Value {
503    Value::Object(serde_json::Map::new())
504}
505
506impl CallRequest {
507    /// Extract args as a HashMap for HTTP/MCP/OpenAPI tools.
508    /// If `args` is a JSON object, returns its entries.
509    /// If `args` is something else (array, string), returns an empty map.
510    fn args_as_map(&self) -> HashMap<String, Value> {
511        match &self.args {
512            Value::Object(map) => map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
513            _ => HashMap::new(),
514        }
515    }
516
517    /// Extract positional args for CLI tools.
518    /// Priority: explicit `raw_args` field > `args` array > `args` string > `args._positional` > empty.
519    fn args_as_positional(&self) -> Vec<String> {
520        // Backward compat: explicit raw_args wins
521        if let Some(ref raw) = self.raw_args {
522            return raw.clone();
523        }
524        match &self.args {
525            // ["pr", "list", "--repo", "X"]
526            Value::Array(arr) => arr
527                .iter()
528                .map(|v| match v {
529                    Value::String(s) => s.clone(),
530                    other => other.to_string(),
531                })
532                .collect(),
533            // "pr list --repo X"
534            Value::String(s) => s.split_whitespace().map(String::from).collect(),
535            // {"_positional": ["pr", "list"]} or {"--key": "value"} converted to CLI flags
536            Value::Object(map) => {
537                if let Some(Value::Array(pos)) = map.get("_positional") {
538                    return pos
539                        .iter()
540                        .map(|v| match v {
541                            Value::String(s) => s.clone(),
542                            other => other.to_string(),
543                        })
544                        .collect();
545                }
546                // Convert map entries to --key value pairs
547                let mut result = Vec::new();
548                for (k, v) in map {
549                    result.push(format!("--{k}"));
550                    match v {
551                        Value::String(s) => result.push(s.clone()),
552                        Value::Bool(true) => {} // flag, no value needed
553                        other => result.push(other.to_string()),
554                    }
555                }
556                result
557            }
558            _ => Vec::new(),
559        }
560    }
561}
562
563#[derive(Debug, Serialize)]
564pub struct CallResponse {
565    pub result: Value,
566    #[serde(skip_serializing_if = "Option::is_none")]
567    pub error: Option<String>,
568}
569
570#[derive(Debug, Deserialize)]
571pub struct HelpRequest {
572    pub query: String,
573    #[serde(default)]
574    pub tool: Option<String>,
575}
576
577#[derive(Debug, Serialize)]
578pub struct HelpResponse {
579    pub content: String,
580    #[serde(skip_serializing_if = "Option::is_none")]
581    pub error: Option<String>,
582}
583
584#[derive(Debug, Serialize)]
585pub struct HealthResponse {
586    pub status: String,
587    pub version: String,
588    pub tools: usize,
589    pub providers: usize,
590    pub skills: usize,
591    pub auth: String,
592}
593
594// --- Skill endpoint types ---
595
596#[derive(Debug, Deserialize)]
597pub struct SkillsQuery {
598    #[serde(default)]
599    pub category: Option<String>,
600    #[serde(default)]
601    pub provider: Option<String>,
602    #[serde(default)]
603    pub tool: Option<String>,
604    #[serde(default)]
605    pub search: Option<String>,
606}
607
608#[derive(Debug, Deserialize)]
609pub struct SkillDetailQuery {
610    #[serde(default)]
611    pub meta: Option<bool>,
612    #[serde(default)]
613    pub refs: Option<bool>,
614}
615
616#[derive(Debug, Deserialize)]
617pub struct SkillResolveRequest {
618    pub scopes: Vec<String>,
619    /// When true, include SKILL.md content in each resolved skill.
620    #[serde(default)]
621    pub include_content: bool,
622}
623
624#[derive(Debug, Deserialize)]
625pub struct SkillBundleBatchRequest {
626    pub names: Vec<String>,
627}
628
629#[derive(Debug, Deserialize, Default)]
630pub struct SkillAtiCatalogQuery {
631    #[serde(default)]
632    pub search: Option<String>,
633}
634
635#[derive(Debug, Deserialize, Default)]
636pub struct SkillAtiResourcesQuery {
637    #[serde(default)]
638    pub prefix: Option<String>,
639}
640
641#[derive(Debug, Deserialize)]
642pub struct SkillAtiFileQuery {
643    pub path: String,
644}
645
646// --- Tool endpoint types ---
647
648#[derive(Debug, Deserialize)]
649pub struct ToolsQuery {
650    #[serde(default)]
651    pub provider: Option<String>,
652    #[serde(default)]
653    pub search: Option<String>,
654}
655
656// --- Handlers ---
657
658fn scopes_for_request(claims: Option<&TokenClaims>, state: &ProxyState) -> ScopeConfig {
659    match claims {
660        Some(claims) => ScopeConfig::from_jwt(claims),
661        None if state.jwt_config.is_none() => ScopeConfig::unrestricted(),
662        None => ScopeConfig {
663            scopes: Vec::new(),
664            sub: String::new(),
665            expires_at: 0,
666            rate_config: None,
667        },
668    }
669}
670
671fn visible_tools_for_scopes<'a>(
672    state: &'a ProxyState,
673    scopes: &ScopeConfig,
674) -> Vec<(&'a Provider, &'a Tool)> {
675    crate::core::scope::filter_tools_by_scope(state.registry.list_public_tools(), scopes)
676}
677
678fn visible_skill_names(
679    state: &ProxyState,
680    scopes: &ScopeConfig,
681) -> std::collections::HashSet<String> {
682    skill::visible_skills(&state.skill_registry, &state.registry, scopes)
683        .into_iter()
684        .map(|skill| skill.name.clone())
685        .collect()
686}
687
688/// Compute the set of remote (SkillATI-registry) skill names that the caller's
689/// scopes grant access to.
690///
691/// Mirrors the scope cascade in `skill::resolve_skills` — explicit `skill:X`
692/// scopes, `tool:Y` scopes resolved to the tool's covering skills (including
693/// provider/category bindings) — but against a remote catalog whose skills
694/// are **not** present in the local filesystem `SkillRegistry`.
695///
696/// Without this, proxies running `ATI_SKILL_REGISTRY=gcs://...` with an empty
697/// local skills directory return 404 for every remote skill, because the
698/// visibility gate only consults `state.skill_registry` (see issue #59).
699fn visible_remote_skill_names(
700    state: &ProxyState,
701    scopes: &ScopeConfig,
702    catalog: &[RemoteSkillMeta],
703) -> std::collections::HashSet<String> {
704    let mut visible: std::collections::HashSet<String> = std::collections::HashSet::new();
705    if catalog.is_empty() {
706        return visible;
707    }
708    if scopes.is_wildcard() {
709        for entry in catalog {
710            visible.insert(entry.name.clone());
711        }
712        return visible;
713    }
714
715    // Collect allowed tool/provider/category identifiers from the caller's scopes.
716    // 1. Direct `tool:X` scopes (including wildcards) → walk against the public
717    //    tool registry to collect concrete (provider, tool) pairs.
718    let allowed_tool_pairs: Vec<(String, String)> =
719        crate::core::scope::filter_tools_by_scope(state.registry.list_public_tools(), scopes)
720            .into_iter()
721            .map(|(p, t)| (p.name.clone(), t.name.clone()))
722            .collect();
723    let allowed_tool_names: std::collections::HashSet<&str> =
724        allowed_tool_pairs.iter().map(|(_, t)| t.as_str()).collect();
725    let allowed_provider_names: std::collections::HashSet<&str> =
726        allowed_tool_pairs.iter().map(|(p, _)| p.as_str()).collect();
727    let allowed_categories: std::collections::HashSet<String> = state
728        .registry
729        .list_providers()
730        .into_iter()
731        .filter(|p| allowed_provider_names.contains(p.name.as_str()))
732        .filter_map(|p| p.category.clone())
733        .collect();
734
735    // Explicit `skill:X` scopes → include X if present in the remote catalog.
736    for scope in &scopes.scopes {
737        if let Some(skill_name) = scope.strip_prefix("skill:") {
738            if catalog.iter().any(|e| e.name == skill_name) {
739                visible.insert(skill_name.to_string());
740            }
741        }
742    }
743
744    // Tool/provider/category cascade → include a remote skill if any of its
745    // `tools`, `providers`, or `categories` bindings match a scope-allowed
746    // tool/provider/category.
747    for entry in catalog {
748        if entry
749            .tools
750            .iter()
751            .any(|t| allowed_tool_names.contains(t.as_str()))
752            || entry
753                .providers
754                .iter()
755                .any(|p| allowed_provider_names.contains(p.as_str()))
756            || entry
757                .categories
758                .iter()
759                .any(|c| allowed_categories.contains(c))
760        {
761            visible.insert(entry.name.clone());
762        }
763    }
764
765    visible
766}
767
768/// Union of local + remote visible skill names, computed on demand. The
769/// remote catalog is fetched lazily (and is cached inside `SkillAtiClient`
770/// after the first call on the hot path).
771async fn visible_skill_names_with_remote(
772    state: &ProxyState,
773    scopes: &ScopeConfig,
774    client: &SkillAtiClient,
775) -> Result<std::collections::HashSet<String>, SkillAtiError> {
776    let mut names = visible_skill_names(state, scopes);
777    let catalog = client.catalog().await?;
778    let remote = visible_remote_skill_names(state, scopes, &catalog);
779    names.extend(remote);
780    Ok(names)
781}
782
783async fn handle_call(
784    State(state): State<Arc<ProxyState>>,
785    req: HttpRequest<Body>,
786) -> impl IntoResponse {
787    // Extract JWT claims from request extensions (set by auth middleware)
788    let claims = req.extensions().get::<TokenClaims>().cloned();
789    // Stamp Sentry's scope with the request identity so every event raised
790    // during this request — exceptions, the dynamic-title upstream-error
791    // bucket, panics — is searchable by sub / sandbox_id / job_id. The
792    // scope is per-thread and gets cleared by subsequent `with_scope`
793    // pushes inside the report helpers; the tags we set here ride along
794    // as the bottom layer.
795    if let Some(ref c) = claims {
796        sentry_scope::set_jwt_sentry_scope(c);
797    }
798    // Raw inbound bearer (Bearer-token path only — dev/no-JWT paths leave
799    // this empty). Used by `GenContext.jwt_token` so auth_generator scripts
800    // can forward the sandbox's identity to an upstream MCP. See issue #115.
801    let bearer_token: String = req
802        .extensions()
803        .get::<BearerToken>()
804        .map(|b| b.0.clone())
805        .unwrap_or_default();
806
807    // Pull the sandbox-supplied upstream URL out of the inbound headers
808    // before `into_body()` consumes them. Validation against the operator
809    // allowlist happens after tool resolution — we need the provider to
810    // know which keyring allowlist entry to consult. See issue #124.
811    let upstream_url_header: Option<String> = req
812        .headers()
813        .get("x-ati-upstream-url")
814        .and_then(|v| v.to_str().ok())
815        .map(|s| s.trim().to_string())
816        .filter(|s| !s.is_empty());
817
818    // Parse request body. The ceiling must accommodate the worst-case upload
819    // payload: `file_manager::MAX_UPLOAD_BYTES` of raw bytes, base64-inflated
820    // (~1.34×), plus a few KB of JSON framing. Anti-abuse is enforced
821    // downstream by per-tool limits (`max_bytes` on downloads, `MAX_UPLOAD_BYTES`
822    // on uploads) and by JWT scope + rate limits — this is just the outer
823    // wire cap.
824    let body_bytes = match axum::body::to_bytes(req.into_body(), max_call_body_bytes()).await {
825        Ok(b) => b,
826        Err(e) => {
827            return (
828                StatusCode::BAD_REQUEST,
829                Json(CallResponse {
830                    result: Value::Null,
831                    error: Some(format!("Failed to read request body: {e}")),
832                }),
833            );
834        }
835    };
836
837    let call_req: CallRequest = match serde_json::from_slice(&body_bytes) {
838        Ok(r) => r,
839        Err(e) => {
840            return (
841                StatusCode::UNPROCESSABLE_ENTITY,
842                Json(CallResponse {
843                    result: Value::Null,
844                    error: Some(format!("Invalid request: {e}")),
845                }),
846            );
847        }
848    };
849
850    tracing::debug!(
851        tool = %call_req.tool_name,
852        args = ?call_req.args,
853        "POST /call"
854    );
855
856    // Look up tool in registry.
857    // If not found, try converting underscore format (finnhub_quote) to colon (finnhub:quote).
858    let (provider, tool) = match state.registry.get_tool(&call_req.tool_name) {
859        Some(pt) => pt,
860        None => {
861            // Try underscore → colon conversion at each underscore position.
862            // "finnhub_quote" → try "finnhub:quote"
863            // "test_api_get_data" → try "test:api_get_data", "test_api:get_data"
864            let mut resolved = None;
865            for (idx, _) in call_req.tool_name.match_indices('_') {
866                let candidate = format!(
867                    "{}:{}",
868                    &call_req.tool_name[..idx],
869                    &call_req.tool_name[idx + 1..]
870                );
871                if let Some(pt) = state.registry.get_tool(&candidate) {
872                    tracing::debug!(
873                        original = %call_req.tool_name,
874                        resolved = %candidate,
875                        "resolved underscore tool name to colon format"
876                    );
877                    resolved = Some(pt);
878                    break;
879                }
880            }
881
882            match resolved {
883                Some(pt) => pt,
884                None => {
885                    return (
886                        StatusCode::NOT_FOUND,
887                        Json(CallResponse {
888                            result: Value::Null,
889                            error: Some(format!("Unknown tool: '{}'", call_req.tool_name)),
890                        }),
891                    );
892                }
893            }
894        }
895    };
896
897    // Scope enforcement from JWT claims
898    if let Some(tool_scope) = &tool.scope {
899        let scopes = match &claims {
900            Some(c) => ScopeConfig::from_jwt(c),
901            None if state.jwt_config.is_none() => ScopeConfig::unrestricted(), // Dev mode
902            None => {
903                return (
904                    StatusCode::FORBIDDEN,
905                    Json(CallResponse {
906                        result: Value::Null,
907                        error: Some("Authentication required — no JWT provided".into()),
908                    }),
909                );
910            }
911        };
912
913        if !scopes.is_allowed(tool_scope) {
914            return (
915                StatusCode::FORBIDDEN,
916                Json(CallResponse {
917                    result: Value::Null,
918                    error: Some(format!(
919                        "Access denied: '{}' is not in your scopes",
920                        tool.name
921                    )),
922                }),
923            );
924        }
925    }
926
927    // Rate limit check
928    {
929        let scopes = match &claims {
930            Some(c) => ScopeConfig::from_jwt(c),
931            None => ScopeConfig::unrestricted(),
932        };
933        if let Some(ref rate_config) = scopes.rate_config {
934            if let Err(e) = crate::core::rate::check_and_record(&call_req.tool_name, rate_config) {
935                return (
936                    StatusCode::TOO_MANY_REQUESTS,
937                    Json(CallResponse {
938                        result: Value::Null,
939                        error: Some(format!("{e}")),
940                    }),
941                );
942            }
943        }
944    }
945
946    // Build auth generator context from JWT claims
947    let gen_ctx = GenContext {
948        jwt_sub: claims
949            .as_ref()
950            .map(|c| c.sub.clone())
951            .unwrap_or_else(|| "dev".into()),
952        jwt_scope: claims
953            .as_ref()
954            .map(|c| c.scope.clone())
955            .unwrap_or_else(|| "*".into()),
956        tool_name: call_req.tool_name.clone(),
957        timestamp: crate::core::jwt::now_secs(),
958        jwt_token: bearer_token.clone(),
959    };
960
961    // Validate any sandbox-supplied X-Ati-Upstream-Url against the operator's
962    // per-provider allowlist (issue #124). Reject paths exit early with a
963    // clean status; the Allow path threads through to mcp_client below.
964    let override_mcp_url: Option<String> =
965        match resolve_upstream_override(&state, provider, upstream_url_header.as_deref()) {
966            UpstreamOverride::None => None,
967            UpstreamOverride::Allow(url) => Some(url),
968            UpstreamOverride::Reject(status, msg) => {
969                tracing::warn!(
970                    provider = %provider.name,
971                    tool = %call_req.tool_name,
972                    status = status.as_u16(),
973                    reason = %msg,
974                    "rejecting sandbox-supplied upstream URL"
975                );
976                return (
977                    status,
978                    Json(CallResponse {
979                        result: Value::Null,
980                        error: Some(msg),
981                    }),
982                );
983            }
984        };
985
986    // Execute tool call — dispatch based on handler type, with timing for audit
987    let agent_sub = claims.as_ref().map(|c| c.sub.clone()).unwrap_or_default();
988    let job_id = claims
989        .as_ref()
990        .and_then(|c| c.job_id.clone())
991        .unwrap_or_default();
992    let sandbox_id = claims
993        .as_ref()
994        .and_then(|c| c.sandbox_id.clone())
995        .unwrap_or_default();
996    tracing::info!(
997        tool = %call_req.tool_name,
998        agent = %agent_sub,
999        job_id = %job_id,
1000        sandbox_id = %sandbox_id,
1001        "tool call"
1002    );
1003    let start = std::time::Instant::now();
1004
1005    let response = match provider.handler.as_str() {
1006        "mcp" => {
1007            let args_map = call_req.args_as_map();
1008            match mcp_client::execute_with_gen(
1009                provider,
1010                &call_req.tool_name,
1011                &args_map,
1012                &state.keyring,
1013                Some(&gen_ctx),
1014                Some(&state.auth_cache),
1015                override_mcp_url.as_deref(),
1016            )
1017            .await
1018            {
1019                Ok(result) => (
1020                    StatusCode::OK,
1021                    Json(CallResponse {
1022                        result,
1023                        error: None,
1024                    }),
1025                ),
1026                Err(e) => {
1027                    // Use the registry-resolved provider name (not the
1028                    // tool_name) — the old `split_tool_name` lied when the
1029                    // tool name had no `:` separator. The error_message
1030                    // path also gets the underlying error's `to_string`
1031                    // for the Sentry "extra"; the full source chain is
1032                    // attached separately via `capture_error_with_scope`
1033                    // so the Sentry issue page shows real exception data.
1034                    let (provider_name, operation_id) =
1035                        sentry_scope::provider_and_op(&provider.name, &call_req.tool_name);
1036                    let msg = e.to_string();
1037                    sentry_scope::report_upstream_error(
1038                        &provider_name,
1039                        &operation_id,
1040                        0,
1041                        502,
1042                        None,
1043                        Some(&msg),
1044                    );
1045                    sentry_scope::capture_error_with_scope(
1046                        &e,
1047                        &provider_name,
1048                        &operation_id,
1049                        0,
1050                        502,
1051                        None,
1052                        Some(&msg),
1053                    );
1054                    (
1055                        StatusCode::BAD_GATEWAY,
1056                        Json(CallResponse {
1057                            result: Value::Null,
1058                            error: Some(format!("MCP error: {e}")),
1059                        }),
1060                    )
1061                }
1062            }
1063        }
1064        "cli" => {
1065            let positional = call_req.args_as_positional();
1066            match crate::core::cli_executor::execute_with_gen(
1067                provider,
1068                &positional,
1069                &state.keyring,
1070                Some(&gen_ctx),
1071                Some(&state.auth_cache),
1072            )
1073            .await
1074            {
1075                Ok(result) => (
1076                    StatusCode::OK,
1077                    Json(CallResponse {
1078                        result,
1079                        error: None,
1080                    }),
1081                ),
1082                Err(e) => {
1083                    // Same registry-resolved provider treatment as the
1084                    // MCP arm — and the structured error attached via
1085                    // `capture_error_with_scope` so Sentry sees the real
1086                    // exception chain instead of just the flat title.
1087                    let (provider_name, operation_id) =
1088                        sentry_scope::provider_and_op(&provider.name, &call_req.tool_name);
1089                    let msg = e.to_string();
1090                    sentry_scope::report_upstream_error(
1091                        &provider_name,
1092                        &operation_id,
1093                        0,
1094                        502,
1095                        None,
1096                        Some(&msg),
1097                    );
1098                    sentry_scope::capture_error_with_scope(
1099                        &e,
1100                        &provider_name,
1101                        &operation_id,
1102                        0,
1103                        502,
1104                        None,
1105                        Some(&msg),
1106                    );
1107                    (
1108                        StatusCode::BAD_GATEWAY,
1109                        Json(CallResponse {
1110                            result: Value::Null,
1111                            error: Some(format!("CLI error: {e}")),
1112                        }),
1113                    )
1114                }
1115            }
1116        }
1117        "file_manager" => {
1118            let args_map = call_req.args_as_map();
1119            match dispatch_file_manager(&call_req.tool_name, &args_map, provider, &state.keyring)
1120                .await
1121            {
1122                Ok(result) => (
1123                    StatusCode::OK,
1124                    Json(CallResponse {
1125                        result,
1126                        error: None,
1127                    }),
1128                ),
1129                Err((status, msg)) => (
1130                    status,
1131                    Json(CallResponse {
1132                        result: Value::Null,
1133                        error: Some(msg),
1134                    }),
1135                ),
1136            }
1137        }
1138        _ => {
1139            let args_map = call_req.args_as_map();
1140            let raw_response = match http::execute_tool_with_gen(
1141                provider,
1142                tool,
1143                &args_map,
1144                &state.keyring,
1145                Some(&gen_ctx),
1146                Some(&state.auth_cache),
1147            )
1148            .await
1149            {
1150                Ok(resp) => resp,
1151                Err(http::HttpError::NoRecordsFound { status }) => {
1152                    // Legit empty upstream result — not an error. Return a
1153                    // clean empty object so callers can distinguish from a
1154                    // failed call and move on without paging Sentry.
1155                    let duration = start.elapsed();
1156                    tracing::info!(
1157                        tool = %call_req.tool_name,
1158                        upstream_status = status,
1159                        "upstream returned no records"
1160                    );
1161                    write_proxy_audit(&call_req, &agent_sub, claims.as_ref(), duration, None);
1162                    return (
1163                        StatusCode::OK,
1164                        Json(CallResponse {
1165                            result: serde_json::json!({ "records": [] }),
1166                            error: None,
1167                        }),
1168                    );
1169                }
1170                Err(e) => {
1171                    let duration = start.elapsed();
1172                    // Use the registry-resolved provider name so the
1173                    // Sentry tags actually identify the provider — the
1174                    // old `split_tool_name` returned the whole flat tool
1175                    // name as the "provider" and "unknown" as the
1176                    // operation for tools like `pdl_person_enrichment`.
1177                    let (provider_name, operation_id) =
1178                        sentry_scope::provider_and_op(&provider.name, &call_req.tool_name);
1179                    let (upstream_status, error_type, error_message) = match &e {
1180                        http::HttpError::ApiError {
1181                            status,
1182                            error_type,
1183                            error_message,
1184                            ..
1185                        } => (*status, error_type.clone(), error_message.clone()),
1186                        _ => (0u16, None, Some(e.to_string())),
1187                    };
1188                    sentry_scope::report_upstream_error(
1189                        &provider_name,
1190                        &operation_id,
1191                        upstream_status,
1192                        502,
1193                        error_type.as_deref(),
1194                        error_message.as_deref(),
1195                    );
1196                    // Attach the structured error (with source chain) as
1197                    // a separate exception event so the Sentry issue
1198                    // page shows real Rust frames + Display, not just
1199                    // the scrubbed body. Quota / rate-limited classes
1200                    // short-circuit inside the helper.
1201                    sentry_scope::capture_error_with_scope(
1202                        &e,
1203                        &provider_name,
1204                        &operation_id,
1205                        upstream_status,
1206                        502,
1207                        error_type.as_deref(),
1208                        error_message.as_deref(),
1209                    );
1210                    write_proxy_audit(
1211                        &call_req,
1212                        &agent_sub,
1213                        claims.as_ref(),
1214                        duration,
1215                        Some(&e.to_string()),
1216                    );
1217                    return (
1218                        StatusCode::BAD_GATEWAY,
1219                        Json(CallResponse {
1220                            result: Value::Null,
1221                            error: Some(format!("Upstream API error: {e}")),
1222                        }),
1223                    );
1224                }
1225            };
1226
1227            let processed = match response::process_response(&raw_response, tool.response.as_ref())
1228            {
1229                Ok(p) => p,
1230                Err(e) => {
1231                    let duration = start.elapsed();
1232                    write_proxy_audit(
1233                        &call_req,
1234                        &agent_sub,
1235                        claims.as_ref(),
1236                        duration,
1237                        Some(&e.to_string()),
1238                    );
1239                    return (
1240                        StatusCode::INTERNAL_SERVER_ERROR,
1241                        Json(CallResponse {
1242                            result: raw_response,
1243                            error: Some(format!("Response processing error: {e}")),
1244                        }),
1245                    );
1246                }
1247            };
1248
1249            (
1250                StatusCode::OK,
1251                Json(CallResponse {
1252                    result: processed,
1253                    error: None,
1254                }),
1255            )
1256        }
1257    };
1258
1259    let duration = start.elapsed();
1260    let error_msg = response.1.error.as_deref();
1261    write_proxy_audit(&call_req, &agent_sub, claims.as_ref(), duration, error_msg);
1262
1263    response
1264}
1265
1266async fn handle_help(
1267    State(state): State<Arc<ProxyState>>,
1268    claims: Option<Extension<TokenClaims>>,
1269    Json(req): Json<HelpRequest>,
1270) -> impl IntoResponse {
1271    tracing::debug!(query = %req.query, tool = ?req.tool, "POST /help");
1272
1273    let claims = claims.map(|Extension(claims)| claims);
1274    let scopes = scopes_for_request(claims.as_ref(), &state);
1275
1276    let (llm_provider, llm_tool) = match state.registry.get_tool("_chat_completion") {
1277        Some(pt) => pt,
1278        None => {
1279            return (
1280                StatusCode::SERVICE_UNAVAILABLE,
1281                Json(HelpResponse {
1282                    content: String::new(),
1283                    error: Some("No _llm.toml manifest found. Proxy help requires a configured LLM provider.".into()),
1284                }),
1285            );
1286        }
1287    };
1288
1289    let api_key = match llm_provider
1290        .auth_key_name
1291        .as_deref()
1292        .and_then(|k| state.keyring.get(k))
1293    {
1294        Some(key) => key.to_string(),
1295        None => {
1296            return (
1297                StatusCode::SERVICE_UNAVAILABLE,
1298                Json(HelpResponse {
1299                    content: String::new(),
1300                    error: Some("LLM API key not found in keyring".into()),
1301                }),
1302            );
1303        }
1304    };
1305
1306    let resolved_skills = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
1307    let local_skills_section = if resolved_skills.is_empty() {
1308        String::new()
1309    } else {
1310        format!(
1311            "## Available Skills (methodology guides)\n{}",
1312            skill::build_skill_context(&resolved_skills)
1313        )
1314    };
1315    let remote_query = req
1316        .tool
1317        .as_ref()
1318        .map(|tool| format!("{tool} {}", req.query))
1319        .unwrap_or_else(|| req.query.clone());
1320    let remote_skills_section =
1321        build_remote_skillati_section(&state.keyring, &remote_query, 12).await;
1322    let skills_section = merge_help_skill_sections(&[local_skills_section, remote_skills_section]);
1323
1324    // Build system prompt — scoped or unscoped
1325    let visible_tools = visible_tools_for_scopes(&state, &scopes);
1326    let system_prompt = if let Some(ref tool_name) = req.tool {
1327        // Scoped mode: narrow tools to the specified tool or provider
1328        match build_scoped_prompt(tool_name, &visible_tools, &skills_section) {
1329            Some(prompt) => prompt,
1330            None => {
1331                return (
1332                    StatusCode::FORBIDDEN,
1333                    Json(HelpResponse {
1334                        content: String::new(),
1335                        error: Some(format!(
1336                            "Scope '{tool_name}' is not visible in your current scopes."
1337                        )),
1338                    }),
1339                );
1340            }
1341        }
1342    } else {
1343        let tools_context = build_tool_context(&visible_tools);
1344        HELP_SYSTEM_PROMPT
1345            .replace("{tools}", &tools_context)
1346            .replace("{skills_section}", &skills_section)
1347    };
1348
1349    let request_body = serde_json::json!({
1350        "model": "zai-glm-4.7",
1351        "messages": [
1352            {"role": "system", "content": system_prompt},
1353            {"role": "user", "content": req.query}
1354        ],
1355        "max_completion_tokens": 1536,
1356        "temperature": 0.3
1357    });
1358
1359    let client = reqwest::Client::new();
1360    let url = format!(
1361        "{}{}",
1362        llm_provider.base_url.trim_end_matches('/'),
1363        llm_tool.endpoint
1364    );
1365
1366    let response = match client
1367        .post(&url)
1368        .bearer_auth(&api_key)
1369        .json(&request_body)
1370        .send()
1371        .await
1372    {
1373        Ok(r) => r,
1374        Err(e) => {
1375            return (
1376                StatusCode::BAD_GATEWAY,
1377                Json(HelpResponse {
1378                    content: String::new(),
1379                    error: Some(format!("LLM request failed: {e}")),
1380                }),
1381            );
1382        }
1383    };
1384
1385    if !response.status().is_success() {
1386        let status = response.status();
1387        let body = response.text().await.unwrap_or_default();
1388        return (
1389            StatusCode::BAD_GATEWAY,
1390            Json(HelpResponse {
1391                content: String::new(),
1392                error: Some(format!("LLM API error ({status}): {body}")),
1393            }),
1394        );
1395    }
1396
1397    let body: Value = match response.json().await {
1398        Ok(b) => b,
1399        Err(e) => {
1400            return (
1401                StatusCode::INTERNAL_SERVER_ERROR,
1402                Json(HelpResponse {
1403                    content: String::new(),
1404                    error: Some(format!("Failed to parse LLM response: {e}")),
1405                }),
1406            );
1407        }
1408    };
1409
1410    let content = body
1411        .pointer("/choices/0/message/content")
1412        .and_then(|c| c.as_str())
1413        .unwrap_or("No response from LLM")
1414        .to_string();
1415
1416    (
1417        StatusCode::OK,
1418        Json(HelpResponse {
1419            content,
1420            error: None,
1421        }),
1422    )
1423}
1424
1425async fn handle_health(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
1426    let auth = if state.jwt_config.is_some() {
1427        "jwt"
1428    } else {
1429        "disabled"
1430    };
1431
1432    Json(HealthResponse {
1433        status: "ok".into(),
1434        version: env!("CARGO_PKG_VERSION").into(),
1435        tools: state.registry.list_public_tools().len(),
1436        providers: state.registry.list_providers().len(),
1437        skills: state.skill_registry.skill_count(),
1438        auth: auth.into(),
1439    })
1440}
1441
1442/// GET /.well-known/jwks.json — serves the public key for JWT validation.
1443async fn handle_jwks(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
1444    match &state.jwks_json {
1445        Some(jwks) => (StatusCode::OK, Json(jwks.clone())),
1446        None => (
1447            StatusCode::NOT_FOUND,
1448            Json(serde_json::json!({"error": "JWKS not configured"})),
1449        ),
1450    }
1451}
1452
1453// ---------------------------------------------------------------------------
1454// POST /mcp — MCP JSON-RPC proxy endpoint
1455// ---------------------------------------------------------------------------
1456
1457async fn handle_mcp(
1458    State(state): State<Arc<ProxyState>>,
1459    claims: Option<Extension<TokenClaims>>,
1460    bearer: Option<Extension<BearerToken>>,
1461    headers: axum::http::HeaderMap,
1462    Json(msg): Json<Value>,
1463) -> impl IntoResponse {
1464    let claims = claims.map(|Extension(claims)| claims);
1465    // Raw inbound bearer for `GenContext.jwt_token` (issue #115). The
1466    // Bearer-token path always populates this when a JWT is configured;
1467    // dev mode and Ati-Key paths leave it None and we fall back to "".
1468    let bearer_token: String = bearer.map(|Extension(b)| b.0).unwrap_or_default();
1469    // Sandbox-supplied upstream URL for MCP tools/call (issue #124). The
1470    // proxy validates against the per-provider allowlist before honouring.
1471    let upstream_url_header: Option<String> = headers
1472        .get("x-ati-upstream-url")
1473        .and_then(|v| v.to_str().ok())
1474        .map(|s| s.trim().to_string())
1475        .filter(|s| !s.is_empty());
1476    let scopes = scopes_for_request(claims.as_ref(), &state);
1477    let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
1478    let id = msg.get("id").cloned();
1479    tracing::info!(
1480        %method,
1481        agent = claims.as_ref().map(|c| c.sub.as_str()).unwrap_or(""),
1482        job_id = claims.as_ref().and_then(|c| c.job_id.as_deref()).unwrap_or(""),
1483        sandbox_id = claims.as_ref().and_then(|c| c.sandbox_id.as_deref()).unwrap_or(""),
1484        "mcp call"
1485    );
1486
1487    match method {
1488        "initialize" => {
1489            let result = serde_json::json!({
1490                "protocolVersion": "2025-03-26",
1491                "capabilities": {
1492                    "tools": { "listChanged": false }
1493                },
1494                "serverInfo": {
1495                    "name": "ati-proxy",
1496                    "version": env!("CARGO_PKG_VERSION")
1497                }
1498            });
1499            jsonrpc_success(id, result)
1500        }
1501
1502        "notifications/initialized" => (StatusCode::ACCEPTED, Json(Value::Null)),
1503
1504        "tools/list" => {
1505            let visible_tools = visible_tools_for_scopes(&state, &scopes);
1506
1507            // Issue #135: lazily discover schemas for sandbox-supplied-URL
1508            // MCP providers whose static manifest entries lack input_schema.
1509            // Same per-provider batching as GET /tools so a multi-tool
1510            // provider only triggers one upstream tools/list per request.
1511            // An entry is recorded for the provider on every decision-path
1512            // (including URL-missing and discovery-failed) so the
1513            // `contains_key` short-circuit fires for sibling tools.
1514            // Greptile #136 finding.
1515            let mut schemas_by_provider: HashMap<String, HashMap<String, Value>> = HashMap::new();
1516            for (provider, tool) in visible_tools.iter() {
1517                if tool.input_schema.is_some() {
1518                    continue;
1519                }
1520                if provider.handler != "mcp" || provider.mcp_url_env.is_none() {
1521                    continue;
1522                }
1523                if schemas_by_provider.contains_key(&provider.name) {
1524                    continue;
1525                }
1526                let map = match extract_upstream_url(&state, provider, &headers) {
1527                    Some(url) => lazy_fetch_schemas(&state, provider, &url)
1528                        .await
1529                        .unwrap_or_default(),
1530                    None => HashMap::new(),
1531                };
1532                schemas_by_provider.insert(provider.name.clone(), map);
1533            }
1534
1535            let mcp_tools: Vec<Value> = visible_tools
1536                .iter()
1537                .map(|(provider, tool)| {
1538                    let schema = tool.input_schema.clone().or_else(|| {
1539                        schemas_by_provider
1540                            .get(&provider.name)
1541                            .and_then(|m| m.get(&tool.name).cloned())
1542                    });
1543                    serde_json::json!({
1544                        "name": tool.name,
1545                        "description": tool.description,
1546                        "inputSchema": schema.unwrap_or(serde_json::json!({
1547                            "type": "object",
1548                            "properties": {}
1549                        }))
1550                    })
1551                })
1552                .collect();
1553
1554            let result = serde_json::json!({
1555                "tools": mcp_tools,
1556            });
1557            jsonrpc_success(id, result)
1558        }
1559
1560        "tools/call" => {
1561            let params = msg.get("params").cloned().unwrap_or(Value::Null);
1562            let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
1563            let arguments: HashMap<String, Value> = params
1564                .get("arguments")
1565                .and_then(|a| serde_json::from_value(a.clone()).ok())
1566                .unwrap_or_default();
1567
1568            if tool_name.is_empty() {
1569                return jsonrpc_error(id, -32602, "Missing tool name in params.name");
1570            }
1571
1572            let (provider, _tool) = match state.registry.get_tool(tool_name) {
1573                Some(pt) => pt,
1574                None => {
1575                    return jsonrpc_error(id, -32602, &format!("Unknown tool: '{tool_name}'"));
1576                }
1577            };
1578
1579            if let Some(tool_scope) = &_tool.scope {
1580                if !scopes.is_allowed(tool_scope) {
1581                    return jsonrpc_error(
1582                        id,
1583                        -32001,
1584                        &format!("Access denied: '{}' is not in your scopes", _tool.name),
1585                    );
1586                }
1587            }
1588
1589            tracing::debug!(%tool_name, provider = %provider.name, "MCP tools/call");
1590
1591            // Validate any sandbox-supplied X-Ati-Upstream-Url against the
1592            // operator's per-provider allowlist (issue #124). For MCP JSON-RPC
1593            // we encode rejections as JSON-RPC error codes:
1594            //   - -32602 (invalid params) for header-without-mcp_url_env
1595            //   - -32001 (access denied) for allowlist misses
1596            let override_mcp_url: Option<String> =
1597                match resolve_upstream_override(&state, provider, upstream_url_header.as_deref()) {
1598                    UpstreamOverride::None => None,
1599                    UpstreamOverride::Allow(url) => Some(url),
1600                    UpstreamOverride::Reject(status, msg) => {
1601                        let code = if status == StatusCode::BAD_REQUEST {
1602                            -32602
1603                        } else {
1604                            -32001
1605                        };
1606                        tracing::warn!(
1607                            provider = %provider.name,
1608                            tool = %tool_name,
1609                            status = status.as_u16(),
1610                            reason = %msg,
1611                            "rejecting sandbox-supplied upstream URL on /mcp"
1612                        );
1613                        return jsonrpc_error(id, code, &msg);
1614                    }
1615                };
1616
1617            let mcp_gen_ctx = GenContext {
1618                jwt_sub: claims
1619                    .as_ref()
1620                    .map(|claims| claims.sub.clone())
1621                    .unwrap_or_else(|| "dev".into()),
1622                jwt_scope: claims
1623                    .as_ref()
1624                    .map(|claims| claims.scope.clone())
1625                    .unwrap_or_else(|| "*".into()),
1626                tool_name: tool_name.to_string(),
1627                timestamp: crate::core::jwt::now_secs(),
1628                jwt_token: bearer_token.clone(),
1629            };
1630
1631            let result = if provider.is_mcp() {
1632                mcp_client::execute_with_gen(
1633                    provider,
1634                    tool_name,
1635                    &arguments,
1636                    &state.keyring,
1637                    Some(&mcp_gen_ctx),
1638                    Some(&state.auth_cache),
1639                    override_mcp_url.as_deref(),
1640                )
1641                .await
1642            } else if provider.is_cli() {
1643                // Convert arguments map to CLI-style args for MCP passthrough
1644                let raw: Vec<String> = arguments
1645                    .iter()
1646                    .flat_map(|(k, v)| {
1647                        let val = match v {
1648                            Value::String(s) => s.clone(),
1649                            other => other.to_string(),
1650                        };
1651                        vec![format!("--{k}"), val]
1652                    })
1653                    .collect();
1654                crate::core::cli_executor::execute_with_gen(
1655                    provider,
1656                    &raw,
1657                    &state.keyring,
1658                    Some(&mcp_gen_ctx),
1659                    Some(&state.auth_cache),
1660                )
1661                .await
1662                .map_err(|e| mcp_client::McpError::Transport(e.to_string()))
1663            } else {
1664                match http::execute_tool_with_gen(
1665                    provider,
1666                    _tool,
1667                    &arguments,
1668                    &state.keyring,
1669                    Some(&mcp_gen_ctx),
1670                    Some(&state.auth_cache),
1671                )
1672                .await
1673                {
1674                    Ok(val) => Ok(val),
1675                    Err(e) => Err(mcp_client::McpError::Transport(e.to_string())),
1676                }
1677            };
1678
1679            match result {
1680                Ok(value) => {
1681                    let text = match &value {
1682                        Value::String(s) => s.clone(),
1683                        other => serde_json::to_string_pretty(other).unwrap_or_default(),
1684                    };
1685                    let mcp_result = serde_json::json!({
1686                        "content": [{"type": "text", "text": text}],
1687                        "isError": false,
1688                    });
1689                    jsonrpc_success(id, mcp_result)
1690                }
1691                Err(e) => {
1692                    let mcp_result = serde_json::json!({
1693                        "content": [{"type": "text", "text": format!("Error: {e}")}],
1694                        "isError": true,
1695                    });
1696                    jsonrpc_success(id, mcp_result)
1697                }
1698            }
1699        }
1700
1701        _ => jsonrpc_error(id, -32601, &format!("Method not found: '{method}'")),
1702    }
1703}
1704
1705fn jsonrpc_success(id: Option<Value>, result: Value) -> (StatusCode, Json<Value>) {
1706    (
1707        StatusCode::OK,
1708        Json(serde_json::json!({
1709            "jsonrpc": "2.0",
1710            "id": id,
1711            "result": result,
1712        })),
1713    )
1714}
1715
1716fn jsonrpc_error(id: Option<Value>, code: i64, message: &str) -> (StatusCode, Json<Value>) {
1717    (
1718        StatusCode::OK,
1719        Json(serde_json::json!({
1720            "jsonrpc": "2.0",
1721            "id": id,
1722            "error": {
1723                "code": code,
1724                "message": message,
1725            }
1726        })),
1727    )
1728}
1729
1730// ---------------------------------------------------------------------------
1731// Tool endpoints
1732// ---------------------------------------------------------------------------
1733
1734/// GET /tools — list available tools with optional filters.
1735async fn handle_tools_list(
1736    State(state): State<Arc<ProxyState>>,
1737    claims: Option<Extension<TokenClaims>>,
1738    headers: axum::http::HeaderMap,
1739    axum::extract::Query(query): axum::extract::Query<ToolsQuery>,
1740) -> impl IntoResponse {
1741    tracing::debug!(
1742        provider = ?query.provider,
1743        search = ?query.search,
1744        "GET /tools"
1745    );
1746
1747    let claims = claims.map(|Extension(claims)| claims);
1748    let scopes = scopes_for_request(claims.as_ref(), &state);
1749    let all_tools = visible_tools_for_scopes(&state, &scopes);
1750
1751    let filtered: Vec<&(&Provider, &Tool)> = all_tools
1752        .iter()
1753        .filter(|(provider, tool)| {
1754            if let Some(ref p) = query.provider {
1755                if provider.name != *p {
1756                    return false;
1757                }
1758            }
1759            if let Some(ref q) = query.search {
1760                let q = q.to_lowercase();
1761                let name_match = tool.name.to_lowercase().contains(&q);
1762                let desc_match = tool.description.to_lowercase().contains(&q);
1763                let tag_match = tool.tags.iter().any(|t| t.to_lowercase().contains(&q));
1764                if !name_match && !desc_match && !tag_match {
1765                    return false;
1766                }
1767            }
1768            true
1769        })
1770        .collect();
1771
1772    // Issue #135: group filtered tools by provider so we run at most one
1773    // `tools/list` dial per provider per request, then serve schemas from
1774    // the resulting map. Without this, a 4-tool MCP provider would trigger
1775    // 4 redundant discoveries on the first request. Caching helps after
1776    // the first hit, but the in-request batching avoids 4 lock-and-dial
1777    // cycles on a cold cache.
1778    //
1779    // We record an entry for the provider on EVERY path through the per-
1780    // provider decision (including the URL-missing and discovery-failed
1781    // paths) so the `contains_key` short-circuit fires for every sibling
1782    // tool. An empty `HashMap` value is the sentinel for "tried and got
1783    // nothing"; the lookup loop below treats it the same as a true miss.
1784    // Greptile #136 finding.
1785    let mut schemas_by_provider: HashMap<String, HashMap<String, Value>> = HashMap::new();
1786    for (provider, tool) in filtered.iter().copied() {
1787        if tool.input_schema.is_some() {
1788            continue;
1789        }
1790        if provider.handler != "mcp" || provider.mcp_url_env.is_none() {
1791            continue;
1792        }
1793        if schemas_by_provider.contains_key(&provider.name) {
1794            continue;
1795        }
1796        let map = match extract_upstream_url(&state, provider, &headers) {
1797            Some(upstream_url) => lazy_fetch_schemas(&state, provider, &upstream_url)
1798                .await
1799                .unwrap_or_default(),
1800            None => HashMap::new(),
1801        };
1802        schemas_by_provider.insert(provider.name.clone(), map);
1803    }
1804
1805    let tools: Vec<Value> = filtered
1806        .into_iter()
1807        .map(|(provider, tool)| {
1808            let input_schema = tool.input_schema.clone().or_else(|| {
1809                schemas_by_provider
1810                    .get(&provider.name)
1811                    .and_then(|m| m.get(&tool.name).cloned())
1812            });
1813            serde_json::json!({
1814                "name": tool.name,
1815                "description": tool.description,
1816                "provider": provider.name,
1817                "method": format!("{:?}", tool.method),
1818                "tags": tool.tags,
1819                "skills": provider.skills,
1820                "input_schema": input_schema,
1821            })
1822        })
1823        .collect();
1824
1825    (StatusCode::OK, Json(Value::Array(tools)))
1826}
1827
1828/// GET /tools/:name — get detailed info about a specific tool.
1829async fn handle_tool_info(
1830    State(state): State<Arc<ProxyState>>,
1831    claims: Option<Extension<TokenClaims>>,
1832    headers: axum::http::HeaderMap,
1833    axum::extract::Path(name): axum::extract::Path<String>,
1834) -> impl IntoResponse {
1835    tracing::debug!(tool = %name, "GET /tools/:name");
1836
1837    let claims = claims.map(|Extension(claims)| claims);
1838    let scopes = scopes_for_request(claims.as_ref(), &state);
1839
1840    match state
1841        .registry
1842        .get_tool(&name)
1843        .filter(|(_, tool)| match &tool.scope {
1844            Some(scope) => scopes.is_allowed(scope),
1845            None => true,
1846        }) {
1847        Some((provider, tool)) => {
1848            // Merge skills from manifest + SkillRegistry (tool binding + provider binding)
1849            let mut skills: Vec<String> = provider.skills.clone();
1850            for s in state.skill_registry.skills_for_tool(&tool.name) {
1851                if !skills.contains(&s.name) {
1852                    skills.push(s.name.clone());
1853                }
1854            }
1855            for s in state.skill_registry.skills_for_provider(&provider.name) {
1856                if !skills.contains(&s.name) {
1857                    skills.push(s.name.clone());
1858                }
1859            }
1860
1861            // Issue #135: when the static manifest entry has no
1862            // input_schema AND this is a sandbox-supplied-URL MCP
1863            // provider AND the request carried a valid upstream URL,
1864            // lazily discover the schema via tools/list against that
1865            // upstream. Cached per (provider, url) for the process.
1866            let input_schema = resolve_lazy_input_schema(&state, provider, tool, &headers).await;
1867
1868            (
1869                StatusCode::OK,
1870                Json(serde_json::json!({
1871                    "name": tool.name,
1872                    "description": tool.description,
1873                    "provider": provider.name,
1874                    "method": format!("{:?}", tool.method),
1875                    "endpoint": tool.endpoint,
1876                    "tags": tool.tags,
1877                    "hint": tool.hint,
1878                    "skills": skills,
1879                    "input_schema": input_schema,
1880                    "scope": tool.scope,
1881                })),
1882            )
1883        }
1884        None => (
1885            StatusCode::NOT_FOUND,
1886            Json(serde_json::json!({"error": format!("Tool '{name}' not found")})),
1887        ),
1888    }
1889}
1890
1891/// Resolve the effective `input_schema` for a tool response, with the
1892/// issue #135 lazy-discovery overlay. Returns the static schema when
1893/// present, otherwise tries one MCP `tools/list` against the sandbox-
1894/// supplied upstream (if the request carries one for this provider),
1895/// otherwise falls back to `None` (today's behavior).
1896async fn resolve_lazy_input_schema(
1897    state: &ProxyState,
1898    provider: &Provider,
1899    tool: &Tool,
1900    headers: &axum::http::HeaderMap,
1901) -> Option<Value> {
1902    // Fast path — static schema present, no discovery needed. This is the
1903    // overwhelming majority of calls (every OpenAPI tool, every hand-
1904    // written HTTP tool, every MCP tool that bothered to mirror the
1905    // schema). The expensive path runs only for the issue #135 case:
1906    // mcp_url_env-driven providers whose operators chose to NOT mirror
1907    // the schema by hand.
1908    if let Some(ref schema) = tool.input_schema {
1909        return Some(schema.clone());
1910    }
1911    // Only mcp providers with mcp_url_env have a chance of discovery.
1912    if provider.handler != "mcp" || provider.mcp_url_env.is_none() {
1913        return None;
1914    }
1915    let upstream_url = extract_upstream_url(state, provider, headers)?;
1916    let schemas = lazy_fetch_schemas(state, provider, &upstream_url).await?;
1917    schemas.get(&tool.name).cloned()
1918}
1919
1920// ---------------------------------------------------------------------------
1921// Skill endpoints
1922// ---------------------------------------------------------------------------
1923
1924async fn handle_skills_list(
1925    State(state): State<Arc<ProxyState>>,
1926    claims: Option<Extension<TokenClaims>>,
1927    axum::extract::Query(query): axum::extract::Query<SkillsQuery>,
1928) -> impl IntoResponse {
1929    tracing::debug!(
1930        category = ?query.category,
1931        provider = ?query.provider,
1932        tool = ?query.tool,
1933        search = ?query.search,
1934        "GET /skills"
1935    );
1936
1937    let claims = claims.map(|Extension(claims)| claims);
1938    let scopes = scopes_for_request(claims.as_ref(), &state);
1939    let visible_names = visible_skill_names(&state, &scopes);
1940
1941    let skills: Vec<&skill::SkillMeta> = if let Some(search_query) = &query.search {
1942        state
1943            .skill_registry
1944            .search(search_query)
1945            .into_iter()
1946            .filter(|skill| visible_names.contains(&skill.name))
1947            .collect()
1948    } else if let Some(cat) = &query.category {
1949        state
1950            .skill_registry
1951            .skills_for_category(cat)
1952            .into_iter()
1953            .filter(|skill| visible_names.contains(&skill.name))
1954            .collect()
1955    } else if let Some(prov) = &query.provider {
1956        state
1957            .skill_registry
1958            .skills_for_provider(prov)
1959            .into_iter()
1960            .filter(|skill| visible_names.contains(&skill.name))
1961            .collect()
1962    } else if let Some(t) = &query.tool {
1963        state
1964            .skill_registry
1965            .skills_for_tool(t)
1966            .into_iter()
1967            .filter(|skill| visible_names.contains(&skill.name))
1968            .collect()
1969    } else {
1970        state
1971            .skill_registry
1972            .list_skills()
1973            .iter()
1974            .filter(|skill| visible_names.contains(&skill.name))
1975            .collect()
1976    };
1977
1978    let json: Vec<Value> = skills
1979        .iter()
1980        .map(|s| {
1981            serde_json::json!({
1982                "name": s.name,
1983                "version": s.version,
1984                "description": s.description,
1985                "tools": s.tools,
1986                "providers": s.providers,
1987                "categories": s.categories,
1988                "hint": s.hint,
1989            })
1990        })
1991        .collect();
1992
1993    (StatusCode::OK, Json(Value::Array(json)))
1994}
1995
1996async fn handle_skill_detail(
1997    State(state): State<Arc<ProxyState>>,
1998    claims: Option<Extension<TokenClaims>>,
1999    axum::extract::Path(name): axum::extract::Path<String>,
2000    axum::extract::Query(query): axum::extract::Query<SkillDetailQuery>,
2001) -> impl IntoResponse {
2002    tracing::debug!(%name, meta = ?query.meta, refs = ?query.refs, "GET /skills/:name");
2003
2004    let claims = claims.map(|Extension(claims)| claims);
2005    let scopes = scopes_for_request(claims.as_ref(), &state);
2006    let visible_names = visible_skill_names(&state, &scopes);
2007
2008    let skill_meta = match state
2009        .skill_registry
2010        .get_skill(&name)
2011        .filter(|skill| visible_names.contains(&skill.name))
2012    {
2013        Some(s) => s,
2014        None => {
2015            return (
2016                StatusCode::NOT_FOUND,
2017                Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
2018            );
2019        }
2020    };
2021
2022    if query.meta.unwrap_or(false) {
2023        return (
2024            StatusCode::OK,
2025            Json(serde_json::json!({
2026                "name": skill_meta.name,
2027                "version": skill_meta.version,
2028                "description": skill_meta.description,
2029                "author": skill_meta.author,
2030                "tools": skill_meta.tools,
2031                "providers": skill_meta.providers,
2032                "categories": skill_meta.categories,
2033                "keywords": skill_meta.keywords,
2034                "hint": skill_meta.hint,
2035                "depends_on": skill_meta.depends_on,
2036                "suggests": skill_meta.suggests,
2037                "license": skill_meta.license,
2038                "compatibility": skill_meta.compatibility,
2039                "allowed_tools": skill_meta.allowed_tools,
2040                "format": skill_meta.format,
2041            })),
2042        );
2043    }
2044
2045    let content = match state.skill_registry.read_content(&name) {
2046        Ok(c) => c,
2047        Err(e) => {
2048            return (
2049                StatusCode::INTERNAL_SERVER_ERROR,
2050                Json(serde_json::json!({"error": format!("Failed to read skill: {e}")})),
2051            );
2052        }
2053    };
2054
2055    let mut response = serde_json::json!({
2056        "name": skill_meta.name,
2057        "version": skill_meta.version,
2058        "description": skill_meta.description,
2059        "content": content,
2060    });
2061
2062    if query.refs.unwrap_or(false) {
2063        if let Ok(refs) = state.skill_registry.list_references(&name) {
2064            response["references"] = serde_json::json!(refs);
2065        }
2066    }
2067
2068    (StatusCode::OK, Json(response))
2069}
2070
2071/// GET /skills/:name/bundle — return all files in a skill directory.
2072/// Response: `{"name": "...", "files": {"SKILL.md": "...", "scripts/generate.sh": "...", ...}}`
2073/// Binary files are base64-encoded; text files are returned as-is.
2074async fn handle_skill_bundle(
2075    State(state): State<Arc<ProxyState>>,
2076    claims: Option<Extension<TokenClaims>>,
2077    axum::extract::Path(name): axum::extract::Path<String>,
2078) -> impl IntoResponse {
2079    tracing::debug!(skill = %name, "GET /skills/:name/bundle");
2080
2081    let claims = claims.map(|Extension(claims)| claims);
2082    let scopes = scopes_for_request(claims.as_ref(), &state);
2083    let visible_names = visible_skill_names(&state, &scopes);
2084    if !visible_names.contains(&name) {
2085        return (
2086            StatusCode::NOT_FOUND,
2087            Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
2088        );
2089    }
2090
2091    let files = match state.skill_registry.bundle_files(&name) {
2092        Ok(f) => f,
2093        Err(_) => {
2094            return (
2095                StatusCode::NOT_FOUND,
2096                Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
2097            );
2098        }
2099    };
2100
2101    // Convert bytes to strings (UTF-8 text) or base64 for binary files
2102    let mut file_map = serde_json::Map::new();
2103    for (path, data) in &files {
2104        match std::str::from_utf8(data) {
2105            Ok(text) => {
2106                file_map.insert(path.clone(), Value::String(text.to_string()));
2107            }
2108            Err(_) => {
2109                // Binary file — base64 encode
2110                use base64::Engine;
2111                let encoded = base64::engine::general_purpose::STANDARD.encode(data);
2112                file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
2113            }
2114        }
2115    }
2116
2117    (
2118        StatusCode::OK,
2119        Json(serde_json::json!({
2120            "name": name,
2121            "files": file_map,
2122        })),
2123    )
2124}
2125
2126/// POST /skills/bundle — return all files for multiple skills in one response.
2127/// Request: `{"names": ["fal-generate", "compliance-screening"]}`
2128/// Response: `{"skills": {...}, "missing": [...]}`
2129async fn handle_skills_bundle_batch(
2130    State(state): State<Arc<ProxyState>>,
2131    claims: Option<Extension<TokenClaims>>,
2132    Json(req): Json<SkillBundleBatchRequest>,
2133) -> impl IntoResponse {
2134    const MAX_BATCH: usize = 50;
2135    if req.names.len() > MAX_BATCH {
2136        return (
2137            StatusCode::BAD_REQUEST,
2138            Json(
2139                serde_json::json!({"error": format!("batch size {} exceeds limit of {MAX_BATCH}", req.names.len())}),
2140            ),
2141        );
2142    }
2143
2144    tracing::debug!(names = ?req.names, "POST /skills/bundle");
2145
2146    let claims = claims.map(|Extension(claims)| claims);
2147    let scopes = scopes_for_request(claims.as_ref(), &state);
2148    let visible_names = visible_skill_names(&state, &scopes);
2149
2150    let mut result = serde_json::Map::new();
2151    let mut missing: Vec<String> = Vec::new();
2152
2153    for name in &req.names {
2154        if !visible_names.contains(name) {
2155            missing.push(name.clone());
2156            continue;
2157        }
2158        let files = match state.skill_registry.bundle_files(name) {
2159            Ok(f) => f,
2160            Err(_) => {
2161                missing.push(name.clone());
2162                continue;
2163            }
2164        };
2165
2166        let mut file_map = serde_json::Map::new();
2167        for (path, data) in &files {
2168            match std::str::from_utf8(data) {
2169                Ok(text) => {
2170                    file_map.insert(path.clone(), Value::String(text.to_string()));
2171                }
2172                Err(_) => {
2173                    use base64::Engine;
2174                    let encoded = base64::engine::general_purpose::STANDARD.encode(data);
2175                    file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
2176                }
2177            }
2178        }
2179
2180        result.insert(name.clone(), serde_json::json!({ "files": file_map }));
2181    }
2182
2183    (
2184        StatusCode::OK,
2185        Json(serde_json::json!({ "skills": result, "missing": missing })),
2186    )
2187}
2188
2189async fn handle_skills_resolve(
2190    State(state): State<Arc<ProxyState>>,
2191    claims: Option<Extension<TokenClaims>>,
2192    Json(req): Json<SkillResolveRequest>,
2193) -> impl IntoResponse {
2194    tracing::debug!(scopes = ?req.scopes, include_content = req.include_content, "POST /skills/resolve");
2195
2196    let include_content = req.include_content;
2197    let request_scopes = ScopeConfig {
2198        scopes: req.scopes,
2199        sub: String::new(),
2200        expires_at: 0,
2201        rate_config: None,
2202    };
2203    let claims = claims.map(|Extension(claims)| claims);
2204    let caller_scopes = scopes_for_request(claims.as_ref(), &state);
2205    let visible_names = visible_skill_names(&state, &caller_scopes);
2206
2207    let resolved: Vec<&skill::SkillMeta> =
2208        skill::resolve_skills(&state.skill_registry, &state.registry, &request_scopes)
2209            .into_iter()
2210            .filter(|skill| visible_names.contains(&skill.name))
2211            .collect();
2212
2213    let json: Vec<Value> = resolved
2214        .iter()
2215        .map(|s| {
2216            let mut entry = serde_json::json!({
2217                "name": s.name,
2218                "version": s.version,
2219                "description": s.description,
2220                "tools": s.tools,
2221                "providers": s.providers,
2222                "categories": s.categories,
2223            });
2224            if include_content {
2225                if let Ok(content) = state.skill_registry.read_content(&s.name) {
2226                    entry["content"] = Value::String(content);
2227                }
2228            }
2229            entry
2230        })
2231        .collect();
2232
2233    (StatusCode::OK, Json(Value::Array(json)))
2234}
2235
2236fn skillati_client(keyring: &Keyring) -> Result<SkillAtiClient, SkillAtiError> {
2237    match SkillAtiClient::from_env(keyring)? {
2238        Some(client) => Ok(client),
2239        None => Err(SkillAtiError::NotConfigured),
2240    }
2241}
2242
2243async fn handle_skillati_catalog(
2244    State(state): State<Arc<ProxyState>>,
2245    claims: Option<Extension<TokenClaims>>,
2246    Query(query): Query<SkillAtiCatalogQuery>,
2247) -> impl IntoResponse {
2248    tracing::debug!(search = ?query.search, "GET /skillati/catalog");
2249
2250    let client = match skillati_client(&state.keyring) {
2251        Ok(client) => client,
2252        Err(err) => return skillati_error_response(err),
2253    };
2254
2255    let claims = claims.map(|Extension(c)| c);
2256    let scopes = scopes_for_request(claims.as_ref(), &state);
2257
2258    match client.catalog().await {
2259        Ok(catalog) => {
2260            // Union of local + remote visibility. Merging here (instead of
2261            // calling visible_skill_names_with_remote, which would re-fetch)
2262            // avoids a redundant catalog request on the hot path.
2263            let mut visible_names = visible_skill_names(&state, &scopes);
2264            visible_names.extend(visible_remote_skill_names(&state, &scopes, &catalog));
2265
2266            let mut skills: Vec<_> = catalog
2267                .into_iter()
2268                .filter(|s| visible_names.contains(&s.name))
2269                .collect();
2270            if let Some(search) = query.search.as_deref() {
2271                skills = SkillAtiClient::filter_catalog(&skills, search, 25);
2272            }
2273            (
2274                StatusCode::OK,
2275                Json(serde_json::json!({
2276                    "skills": skills,
2277                })),
2278            )
2279        }
2280        Err(err) => skillati_error_response(err),
2281    }
2282}
2283
2284async fn handle_skillati_read(
2285    State(state): State<Arc<ProxyState>>,
2286    claims: Option<Extension<TokenClaims>>,
2287    axum::extract::Path(name): axum::extract::Path<String>,
2288) -> impl IntoResponse {
2289    tracing::debug!(%name, "GET /skillati/:name");
2290
2291    let client = match skillati_client(&state.keyring) {
2292        Ok(client) => client,
2293        Err(err) => return skillati_error_response(err),
2294    };
2295
2296    let claims = claims.map(|Extension(c)| c);
2297    let scopes = scopes_for_request(claims.as_ref(), &state);
2298    let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
2299        Ok(v) => v,
2300        Err(err) => return skillati_error_response(err),
2301    };
2302    if !visible_names.contains(&name) {
2303        return skillati_error_response(SkillAtiError::SkillNotFound(name));
2304    }
2305
2306    match client.read_skill(&name).await {
2307        Ok(activation) => (StatusCode::OK, Json(serde_json::json!(activation))),
2308        Err(err) => skillati_error_response(err),
2309    }
2310}
2311
2312async fn handle_skillati_resources(
2313    State(state): State<Arc<ProxyState>>,
2314    claims: Option<Extension<TokenClaims>>,
2315    axum::extract::Path(name): axum::extract::Path<String>,
2316    Query(query): Query<SkillAtiResourcesQuery>,
2317) -> impl IntoResponse {
2318    tracing::debug!(%name, prefix = ?query.prefix, "GET /skillati/:name/resources");
2319
2320    let client = match skillati_client(&state.keyring) {
2321        Ok(client) => client,
2322        Err(err) => return skillati_error_response(err),
2323    };
2324
2325    let claims = claims.map(|Extension(c)| c);
2326    let scopes = scopes_for_request(claims.as_ref(), &state);
2327    let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
2328        Ok(v) => v,
2329        Err(err) => return skillati_error_response(err),
2330    };
2331    if !visible_names.contains(&name) {
2332        return skillati_error_response(SkillAtiError::SkillNotFound(name));
2333    }
2334
2335    match client.list_resources(&name, query.prefix.as_deref()).await {
2336        Ok(resources) => (
2337            StatusCode::OK,
2338            Json(serde_json::json!({
2339                "name": name,
2340                "prefix": query.prefix,
2341                "resources": resources,
2342            })),
2343        ),
2344        Err(err) => skillati_error_response(err),
2345    }
2346}
2347
2348async fn handle_skillati_file(
2349    State(state): State<Arc<ProxyState>>,
2350    claims: Option<Extension<TokenClaims>>,
2351    axum::extract::Path(name): axum::extract::Path<String>,
2352    Query(query): Query<SkillAtiFileQuery>,
2353) -> impl IntoResponse {
2354    tracing::debug!(%name, path = %query.path, "GET /skillati/:name/file");
2355
2356    let client = match skillati_client(&state.keyring) {
2357        Ok(client) => client,
2358        Err(err) => return skillati_error_response(err),
2359    };
2360
2361    let claims = claims.map(|Extension(c)| c);
2362    let scopes = scopes_for_request(claims.as_ref(), &state);
2363    let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
2364        Ok(v) => v,
2365        Err(err) => return skillati_error_response(err),
2366    };
2367    if !visible_names.contains(&name) {
2368        return skillati_error_response(SkillAtiError::SkillNotFound(name));
2369    }
2370
2371    match client.read_path(&name, &query.path).await {
2372        Ok(file) => (StatusCode::OK, Json(serde_json::json!(file))),
2373        Err(err) => skillati_error_response(err),
2374    }
2375}
2376
2377async fn handle_skillati_refs(
2378    State(state): State<Arc<ProxyState>>,
2379    claims: Option<Extension<TokenClaims>>,
2380    axum::extract::Path(name): axum::extract::Path<String>,
2381) -> impl IntoResponse {
2382    tracing::debug!(%name, "GET /skillati/:name/refs");
2383
2384    let client = match skillati_client(&state.keyring) {
2385        Ok(client) => client,
2386        Err(err) => return skillati_error_response(err),
2387    };
2388
2389    let claims = claims.map(|Extension(c)| c);
2390    let scopes = scopes_for_request(claims.as_ref(), &state);
2391    let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
2392        Ok(v) => v,
2393        Err(err) => return skillati_error_response(err),
2394    };
2395    if !visible_names.contains(&name) {
2396        return skillati_error_response(SkillAtiError::SkillNotFound(name));
2397    }
2398
2399    match client.list_references(&name).await {
2400        Ok(references) => (
2401            StatusCode::OK,
2402            Json(serde_json::json!({
2403                "name": name,
2404                "references": references,
2405            })),
2406        ),
2407        Err(err) => skillati_error_response(err),
2408    }
2409}
2410
2411async fn handle_skillati_ref(
2412    State(state): State<Arc<ProxyState>>,
2413    claims: Option<Extension<TokenClaims>>,
2414    axum::extract::Path((name, reference)): axum::extract::Path<(String, String)>,
2415) -> impl IntoResponse {
2416    tracing::debug!(%name, %reference, "GET /skillati/:name/ref/:reference");
2417
2418    let client = match skillati_client(&state.keyring) {
2419        Ok(client) => client,
2420        Err(err) => return skillati_error_response(err),
2421    };
2422
2423    let claims = claims.map(|Extension(c)| c);
2424    let scopes = scopes_for_request(claims.as_ref(), &state);
2425    let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
2426        Ok(v) => v,
2427        Err(err) => return skillati_error_response(err),
2428    };
2429    if !visible_names.contains(&name) {
2430        return skillati_error_response(SkillAtiError::SkillNotFound(name));
2431    }
2432
2433    match client.read_reference(&name, &reference).await {
2434        Ok(content) => (
2435            StatusCode::OK,
2436            Json(serde_json::json!({
2437                "name": name,
2438                "reference": reference,
2439                "content": content,
2440            })),
2441        ),
2442        Err(err) => skillati_error_response(err),
2443    }
2444}
2445
2446fn skillati_error_response(err: SkillAtiError) -> (StatusCode, Json<Value>) {
2447    let status = match &err {
2448        SkillAtiError::NotConfigured
2449        | SkillAtiError::UnsupportedRegistry(_)
2450        | SkillAtiError::MissingCredentials(_)
2451        | SkillAtiError::ProxyUrlRequired => StatusCode::SERVICE_UNAVAILABLE,
2452        SkillAtiError::SkillNotFound(_) | SkillAtiError::PathNotFound { .. } => {
2453            StatusCode::NOT_FOUND
2454        }
2455        SkillAtiError::InvalidPath(_) => StatusCode::BAD_REQUEST,
2456        SkillAtiError::Gcs(_)
2457        | SkillAtiError::ProxyRequest(_)
2458        | SkillAtiError::ProxyResponse(_) => StatusCode::BAD_GATEWAY,
2459    };
2460
2461    (
2462        status,
2463        Json(serde_json::json!({
2464            "error": err.to_string(),
2465        })),
2466    )
2467}
2468
2469// --- Auth middleware ---
2470
2471/// JWT authentication middleware.
2472///
2473/// - /health and /.well-known/jwks.json → skip auth
2474/// - JWT configured → validate Bearer token, attach claims to request extensions
2475/// - No JWT configured → allow all (dev mode)
2476async fn auth_middleware(
2477    State(state): State<Arc<ProxyState>>,
2478    mut req: HttpRequest<Body>,
2479    next: Next,
2480) -> Result<Response, StatusCode> {
2481    let path = req.uri().path();
2482
2483    // Skip auth for public endpoints
2484    if path == "/health" || path == "/.well-known/jwks.json" {
2485        return Ok(next.run(req).await);
2486    }
2487
2488    // If no JWT configured, allow all (dev mode)
2489    let jwt_config = match &state.jwt_config {
2490        Some(c) => c,
2491        None => return Ok(next.run(req).await),
2492    };
2493
2494    // Extract Authorization: Bearer <token>. We own the String so the
2495    // `req.headers()` immutable borrow ends before we touch
2496    // `req.extensions_mut()` below — needed because we now stash the raw
2497    // token in the BearerToken extension (issue #115).
2498    let token_owned: String = match req
2499        .headers()
2500        .get("authorization")
2501        .and_then(|v| v.to_str().ok())
2502    {
2503        Some(header) if header.starts_with("Bearer ") => header[7..].to_string(),
2504        _ => return Err(StatusCode::UNAUTHORIZED),
2505    };
2506
2507    // Validate JWT
2508    match jwt::validate(&token_owned, jwt_config) {
2509        Ok(claims) => {
2510            tracing::debug!(sub = %claims.sub, scopes = %claims.scope, "JWT validated");
2511            // Stash the *raw* bearer alongside the parsed claims so handlers
2512            // can forward it to auth_generator scripts via `${JWT_TOKEN}`.
2513            // We deliberately store this AFTER `validate()` succeeds — an
2514            // invalid/expired token never lands in extensions and so can
2515            // never reach a generator. See issue #115.
2516            req.extensions_mut().insert(BearerToken(token_owned));
2517            req.extensions_mut().insert(claims);
2518            Ok(next.run(req).await)
2519        }
2520        Err(e) => {
2521            tracing::debug!(error = %e, "JWT validation failed");
2522            Err(StatusCode::UNAUTHORIZED)
2523        }
2524    }
2525}
2526
2527/// Per-request extension carrying the *raw* inbound bearer (the `<token>`
2528/// from `Authorization: Bearer <token>`). Inserted by `auth_middleware` only
2529/// after `jwt::validate` succeeds, so an expired/forged token never leaks
2530/// into this extension. Consumed by handlers building a `GenContext` so an
2531/// auth_generator can forward the calling sandbox's identity to an upstream
2532/// MCP via `${JWT_TOKEN}` (issue #115).
2533#[derive(Debug, Clone)]
2534pub struct BearerToken(pub String);
2535
2536// --- Router builder ---
2537
2538/// Build the axum Router from a pre-constructed ProxyState.
2539/// Outer body-size ceiling for `POST /call`. Large enough to carry the worst
2540/// case `file_manager:upload` payload (`MAX_UPLOAD_BYTES` of raw bytes,
2541/// base64-inflated ~4/3×, plus a few KB of JSON framing).
2542///
2543/// Per-tool limits (`max_bytes`, `MAX_UPLOAD_BYTES`) plus JWT scopes + rate
2544/// limits are the real gates — this is just the outermost wrapper check.
2545fn max_call_body_bytes() -> usize {
2546    (crate::core::file_manager::MAX_UPLOAD_BYTES as usize)
2547        .saturating_mul(4)
2548        .saturating_div(3)
2549        .saturating_add(8 * 1024)
2550}
2551
2552pub fn build_router(state: Arc<ProxyState>) -> Router {
2553    use axum::extract::DefaultBodyLimit;
2554
2555    Router::new()
2556        .route("/call", post(handle_call))
2557        .route("/help", post(handle_help))
2558        .route("/mcp", post(handle_mcp))
2559        .route("/tools", get(handle_tools_list))
2560        .route("/tools/{name}", get(handle_tool_info))
2561        .route("/skills", get(handle_skills_list))
2562        .route("/skills/resolve", post(handle_skills_resolve))
2563        .route("/skills/bundle", post(handle_skills_bundle_batch))
2564        .route("/skills/{name}", get(handle_skill_detail))
2565        .route("/skills/{name}/bundle", get(handle_skill_bundle))
2566        .route("/skillati/catalog", get(handle_skillati_catalog))
2567        .route("/skillati/{name}", get(handle_skillati_read))
2568        .route("/skillati/{name}/resources", get(handle_skillati_resources))
2569        .route("/skillati/{name}/file", get(handle_skillati_file))
2570        .route("/skillati/{name}/refs", get(handle_skillati_refs))
2571        .route("/skillati/{name}/ref/{reference}", get(handle_skillati_ref))
2572        .route("/health", get(handle_health))
2573        .route("/.well-known/jwks.json", get(handle_jwks))
2574        // Raise axum's default 2 MB body-extractor limit so request bodies
2575        // carrying base64-encoded upload payloads aren't rejected before the
2576        // handler runs. `handle_call` still enforces its own
2577        // `max_call_body_bytes()` cap when streaming the body to bytes.
2578        .layer(DefaultBodyLimit::max(max_call_body_bytes()))
2579        .layer(middleware::from_fn_with_state(
2580            state.clone(),
2581            auth_middleware,
2582        ))
2583        .with_state(state)
2584}
2585
2586// --- Server startup ---
2587
2588/// Start the proxy server.
2589pub async fn run(
2590    port: u16,
2591    bind_addr: Option<String>,
2592    ati_dir: PathBuf,
2593    _verbose: bool,
2594    env_keys: bool,
2595) -> Result<(), Box<dyn std::error::Error>> {
2596    // Load manifests
2597    let manifests_dir = ati_dir.join("manifests");
2598    let mut registry = ManifestRegistry::load(&manifests_dir)?;
2599    let provider_count = registry.list_providers().len();
2600
2601    // Load keyring
2602    let keyring_source;
2603    let keyring = if env_keys {
2604        // --env-keys: scan ATI_KEY_* environment variables
2605        let kr = Keyring::from_env();
2606        let key_names = kr.key_names();
2607        tracing::info!(
2608            count = key_names.len(),
2609            "loaded API keys from ATI_KEY_* env vars"
2610        );
2611        for name in &key_names {
2612            tracing::debug!(key = %name, "env key loaded");
2613        }
2614        keyring_source = "env-vars (ATI_KEY_*)";
2615        kr
2616    } else {
2617        // Cascade: keyring.enc (sealed) → keyring.enc (persistent) → credentials → empty
2618        let keyring_path = ati_dir.join("keyring.enc");
2619        if keyring_path.exists() {
2620            if let Ok(kr) = Keyring::load(&keyring_path) {
2621                keyring_source = "keyring.enc (sealed key)";
2622                kr
2623            } else if let Ok(kr) = Keyring::load_local(&keyring_path, &ati_dir) {
2624                keyring_source = "keyring.enc (persistent key)";
2625                kr
2626            } else {
2627                tracing::warn!("keyring.enc exists but could not be decrypted");
2628                keyring_source = "empty (decryption failed)";
2629                Keyring::empty()
2630            }
2631        } else {
2632            let creds_path = ati_dir.join("credentials");
2633            if creds_path.exists() {
2634                match Keyring::load_credentials(&creds_path) {
2635                    Ok(kr) => {
2636                        keyring_source = "credentials (plaintext)";
2637                        kr
2638                    }
2639                    Err(e) => {
2640                        tracing::warn!(error = %e, "failed to load credentials");
2641                        keyring_source = "empty (credentials error)";
2642                        Keyring::empty()
2643                    }
2644                }
2645            } else {
2646                tracing::warn!("no keyring.enc or credentials found — running without API keys");
2647                tracing::warn!("tools requiring authentication will fail");
2648                keyring_source = "empty (no auth)";
2649                Keyring::empty()
2650            }
2651        }
2652    };
2653
2654    // Discover MCP tools at startup so they appear in GET /tools.
2655    // Runs concurrently across providers with 30s per-provider timeout.
2656    mcp_client::discover_all_mcp_tools(&mut registry, &keyring).await;
2657
2658    let tool_count = registry.list_public_tools().len();
2659
2660    // Log MCP and OpenAPI providers
2661    let mcp_providers: Vec<(String, String)> = registry
2662        .list_mcp_providers()
2663        .iter()
2664        .map(|p| (p.name.clone(), p.mcp_transport_type().to_string()))
2665        .collect();
2666    let mcp_count = mcp_providers.len();
2667    let openapi_providers: Vec<String> = registry
2668        .list_openapi_providers()
2669        .iter()
2670        .map(|p| p.name.clone())
2671        .collect();
2672    let openapi_count = openapi_providers.len();
2673
2674    // Load installed/local skill registry only.
2675    let skills_dir = ati_dir.join("skills");
2676    let skill_registry = SkillRegistry::load(&skills_dir).unwrap_or_else(|e| {
2677        tracing::warn!(error = %e, "failed to load skills");
2678        SkillRegistry::load(std::path::Path::new("/nonexistent-fallback")).unwrap()
2679    });
2680
2681    if let Ok(registry_url) = std::env::var("ATI_SKILL_REGISTRY") {
2682        if registry_url.strip_prefix("gcs://").is_some() {
2683            tracing::info!(
2684                registry = %registry_url,
2685                "SkillATI remote registry configured for lazy reads"
2686            );
2687        } else {
2688            tracing::warn!(url = %registry_url, "SkillATI only supports gcs:// registries");
2689        }
2690    }
2691
2692    let skill_count = skill_registry.skill_count();
2693
2694    // Load JWT config from environment
2695    let jwt_config = match jwt::config_from_env() {
2696        Ok(config) => config,
2697        Err(e) => {
2698            tracing::warn!(error = %e, "JWT config error");
2699            None
2700        }
2701    };
2702
2703    let auth_status = if jwt_config.is_some() {
2704        "JWT enabled"
2705    } else {
2706        "DISABLED (no JWT keys configured)"
2707    };
2708
2709    // Build JWKS for the endpoint
2710    let jwks_json = jwt_config.as_ref().and_then(|config| {
2711        config
2712            .public_key_pem
2713            .as_ref()
2714            .and_then(|pem| jwt::public_key_to_jwks(pem, config.algorithm, "ati-proxy-1").ok())
2715    });
2716
2717    let state = Arc::new(ProxyState {
2718        registry,
2719        skill_registry,
2720        keyring,
2721        jwt_config,
2722        jwks_json,
2723        auth_cache: AuthCache::new(),
2724        upstream_url_allowlists: std::sync::Arc::new(std::sync::Mutex::new(
2725            std::collections::HashMap::new(),
2726        )),
2727        lazy_schema_cache: std::sync::Arc::new(std::sync::Mutex::new(
2728            std::collections::HashMap::new(),
2729        )),
2730    });
2731
2732    let app = build_router(state);
2733
2734    let addr: SocketAddr = if let Some(ref bind) = bind_addr {
2735        format!("{bind}:{port}").parse()?
2736    } else {
2737        SocketAddr::from(([127, 0, 0, 1], port))
2738    };
2739
2740    tracing::info!(
2741        version = env!("CARGO_PKG_VERSION"),
2742        %addr,
2743        auth = auth_status,
2744        ati_dir = %ati_dir.display(),
2745        tools = tool_count,
2746        providers = provider_count,
2747        mcp = mcp_count,
2748        openapi = openapi_count,
2749        skills = skill_count,
2750        keyring = keyring_source,
2751        "ATI proxy server starting"
2752    );
2753    for (name, transport) in &mcp_providers {
2754        tracing::info!(provider = %name, transport = %transport, "MCP provider");
2755    }
2756    for name in &openapi_providers {
2757        tracing::info!(provider = %name, "OpenAPI provider");
2758    }
2759
2760    let listener = tokio::net::TcpListener::bind(addr).await?;
2761    axum::serve(listener, app).await?;
2762
2763    Ok(())
2764}
2765
2766/// Dispatch a `file_manager:*` tool call. Returns either a JSON payload or an
2767/// (HTTP status, message) error for the caller to forward.
2768async fn dispatch_file_manager(
2769    tool_name: &str,
2770    args: &HashMap<String, Value>,
2771    provider: &Provider,
2772    keyring: &Keyring,
2773) -> Result<Value, (StatusCode, String)> {
2774    use crate::core::file_manager::{self, DownloadArgs, FileManagerError, UploadArgs};
2775
2776    // One mapping, derived from FileManagerError::http_status, so adding an
2777    // error variant can't silently regress one handler while the other updates.
2778    let to_resp = |e: FileManagerError| {
2779        let status =
2780            StatusCode::from_u16(e.http_status()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
2781        (status, e.to_string())
2782    };
2783
2784    match tool_name {
2785        "file_manager:download" => {
2786            let parsed = DownloadArgs::from_value(args).map_err(to_resp)?;
2787            let result = file_manager::fetch_bytes(&parsed).await.map_err(to_resp)?;
2788            Ok(file_manager::build_download_response(&result))
2789        }
2790        "file_manager:upload" => {
2791            let parsed = UploadArgs::from_wire(args).map_err(to_resp)?;
2792            file_manager::upload_to_destination(
2793                parsed,
2794                &provider.upload_destinations,
2795                provider.upload_default_destination.as_deref(),
2796                keyring,
2797            )
2798            .await
2799            .map_err(to_resp)
2800        }
2801        other => Err((
2802            StatusCode::NOT_FOUND,
2803            format!("Unknown file_manager tool: '{other}'"),
2804        )),
2805    }
2806}
2807
2808fn write_proxy_audit(
2809    call_req: &CallRequest,
2810    agent_sub: &str,
2811    claims: Option<&TokenClaims>,
2812    duration: std::time::Duration,
2813    error: Option<&str>,
2814) {
2815    let entry = crate::core::audit::AuditEntry {
2816        ts: chrono::Utc::now().to_rfc3339(),
2817        tool: call_req.tool_name.clone(),
2818        args: crate::core::audit::sanitize_args(&call_req.args),
2819        status: if error.is_some() {
2820            crate::core::audit::AuditStatus::Error
2821        } else {
2822            crate::core::audit::AuditStatus::Ok
2823        },
2824        duration_ms: duration.as_millis() as u64,
2825        agent_sub: agent_sub.to_string(),
2826        job_id: claims.and_then(|c| c.job_id.clone()),
2827        sandbox_id: claims.and_then(|c| c.sandbox_id.clone()),
2828        error: error.map(|s| s.to_string()),
2829        exit_code: None,
2830    };
2831    let _ = crate::core::audit::append(&entry);
2832}
2833
2834// --- Helpers ---
2835
2836const HELP_SYSTEM_PROMPT: &str = r#"You are a helpful assistant for an AI agent that uses external tools via the `ati` CLI.
2837
2838## Available Tools
2839{tools}
2840
2841{skills_section}
2842
2843Answer the agent's question naturally, like a knowledgeable colleague would. Keep it short but useful:
2844
2845- Explain which tools to use and why, with `ati run` commands showing realistic parameter values
2846- If multiple steps are needed, walk through them briefly in order
2847- Mention important gotchas or parameter choices that matter
2848- If skills are relevant, tell the agent to load them using the Skill tool (e.g., `skill: "research-financial-data"`)
2849
2850Keep your answer concise — a few short paragraphs with embedded code blocks. Only recommend tools from the list above."#;
2851
2852async fn build_remote_skillati_section(keyring: &Keyring, query: &str, limit: usize) -> String {
2853    let client = match SkillAtiClient::from_env(keyring) {
2854        Ok(Some(client)) => client,
2855        Ok(None) => return String::new(),
2856        Err(err) => {
2857            tracing::warn!(error = %err, "failed to initialize SkillATI catalog for proxy help");
2858            return String::new();
2859        }
2860    };
2861
2862    let catalog = match client.catalog().await {
2863        Ok(catalog) => catalog,
2864        Err(err) => {
2865            tracing::warn!(error = %err, "failed to load SkillATI catalog for proxy help");
2866            return String::new();
2867        }
2868    };
2869
2870    let matched = SkillAtiClient::filter_catalog(&catalog, query, limit);
2871    if matched.is_empty() {
2872        return String::new();
2873    }
2874
2875    render_remote_skillati_section(&matched, catalog.len())
2876}
2877
2878fn render_remote_skillati_section(skills: &[RemoteSkillMeta], total_catalog: usize) -> String {
2879    let mut section = String::from("## Remote Skills Available Via SkillATI\n\n");
2880    section.push_str(
2881        "These skills are available. Load them using the Skill tool (e.g., `skill: \"skill-name\"`).\n\n",
2882    );
2883
2884    for skill in skills {
2885        section.push_str(&format!("- **{}**: {}\n", skill.name, skill.description));
2886    }
2887
2888    if total_catalog > skills.len() {
2889        section.push_str(&format!(
2890            "\nOnly the most relevant {} remote skills are shown here.\n",
2891            skills.len()
2892        ));
2893    }
2894
2895    section
2896}
2897
2898fn merge_help_skill_sections(sections: &[String]) -> String {
2899    sections
2900        .iter()
2901        .filter_map(|section| {
2902            let trimmed = section.trim();
2903            if trimmed.is_empty() {
2904                None
2905            } else {
2906                Some(trimmed.to_string())
2907            }
2908        })
2909        .collect::<Vec<_>>()
2910        .join("\n\n")
2911}
2912
2913fn build_tool_context(
2914    tools: &[(
2915        &crate::core::manifest::Provider,
2916        &crate::core::manifest::Tool,
2917    )],
2918) -> String {
2919    let mut summaries = Vec::new();
2920    for (provider, tool) in tools {
2921        let mut summary = if let Some(cat) = &provider.category {
2922            format!(
2923                "- **{}** (provider: {}, category: {}): {}",
2924                tool.name, provider.name, cat, tool.description
2925            )
2926        } else {
2927            format!(
2928                "- **{}** (provider: {}): {}",
2929                tool.name, provider.name, tool.description
2930            )
2931        };
2932        if !tool.tags.is_empty() {
2933            summary.push_str(&format!("\n  Tags: {}", tool.tags.join(", ")));
2934        }
2935        // CLI tools: show passthrough usage
2936        if provider.is_cli() && tool.input_schema.is_none() {
2937            let cmd = provider.cli_command.as_deref().unwrap_or("?");
2938            summary.push_str(&format!(
2939                "\n  Usage: `ati run {} -- <args>`  (passthrough to `{}`)",
2940                tool.name, cmd
2941            ));
2942        } else if let Some(schema) = &tool.input_schema {
2943            if let Some(props) = schema.get("properties") {
2944                if let Some(obj) = props.as_object() {
2945                    let params: Vec<String> = obj
2946                        .iter()
2947                        .filter(|(_, v)| {
2948                            v.get("x-ati-param-location").is_none()
2949                                || v.get("description").is_some()
2950                        })
2951                        .map(|(k, v)| {
2952                            let type_str =
2953                                v.get("type").and_then(|t| t.as_str()).unwrap_or("string");
2954                            let desc = v.get("description").and_then(|d| d.as_str()).unwrap_or("");
2955                            format!("    --{k} ({type_str}): {desc}")
2956                        })
2957                        .collect();
2958                    if !params.is_empty() {
2959                        summary.push_str("\n  Parameters:\n");
2960                        summary.push_str(&params.join("\n"));
2961                    }
2962                }
2963            }
2964        }
2965        summaries.push(summary);
2966    }
2967    summaries.join("\n\n")
2968}
2969
2970/// Build a scoped system prompt for a specific tool or provider.
2971///
2972/// Returns None if the scope_name doesn't match any tool or provider.
2973fn build_scoped_prompt(
2974    scope_name: &str,
2975    visible_tools: &[(&Provider, &Tool)],
2976    skills_section: &str,
2977) -> Option<String> {
2978    // Check if scope_name is a tool
2979    if let Some((provider, tool)) = visible_tools
2980        .iter()
2981        .find(|(_, tool)| tool.name == scope_name)
2982    {
2983        let mut details = format!(
2984            "**Name**: `{}`\n**Provider**: {} (handler: {})\n**Description**: {}\n",
2985            tool.name, provider.name, provider.handler, tool.description
2986        );
2987        if let Some(cat) = &provider.category {
2988            details.push_str(&format!("**Category**: {}\n", cat));
2989        }
2990        if provider.is_cli() {
2991            let cmd = provider.cli_command.as_deref().unwrap_or("?");
2992            details.push_str(&format!(
2993                "\n**Usage**: `ati run {} -- <args>`  (passthrough to `{}`)\n",
2994                tool.name, cmd
2995            ));
2996        } else if let Some(schema) = &tool.input_schema {
2997            if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
2998                let required: Vec<String> = schema
2999                    .get("required")
3000                    .and_then(|r| r.as_array())
3001                    .map(|arr| {
3002                        arr.iter()
3003                            .filter_map(|v| v.as_str().map(|s| s.to_string()))
3004                            .collect()
3005                    })
3006                    .unwrap_or_default();
3007                details.push_str("\n**Parameters**:\n");
3008                for (key, val) in props {
3009                    let type_str = val.get("type").and_then(|t| t.as_str()).unwrap_or("string");
3010                    let desc = val
3011                        .get("description")
3012                        .and_then(|d| d.as_str())
3013                        .unwrap_or("");
3014                    let req = if required.contains(key) {
3015                        " **(required)**"
3016                    } else {
3017                        ""
3018                    };
3019                    details.push_str(&format!("- `--{key}` ({type_str}{req}): {desc}\n"));
3020                }
3021            }
3022        }
3023
3024        let prompt = format!(
3025            "You are an expert assistant for the `{}` tool, accessed via the `ati` CLI.\n\n\
3026            ## Tool Details\n{}\n\n{}\n\n\
3027            Answer the agent's question about this specific tool. Provide exact commands, explain flags and options, and give practical examples. Be concise and actionable.",
3028            tool.name, details, skills_section
3029        );
3030        return Some(prompt);
3031    }
3032
3033    // Check if scope_name is a provider
3034    let tools: Vec<(&Provider, &Tool)> = visible_tools
3035        .iter()
3036        .copied()
3037        .filter(|(provider, _)| provider.name == scope_name)
3038        .collect();
3039    if !tools.is_empty() {
3040        let tools_context = build_tool_context(&tools);
3041        let prompt = format!(
3042            "You are an expert assistant for the `{}` provider's tools, accessed via the `ati` CLI.\n\n\
3043            ## Tools in provider `{}`\n{}\n\n{}\n\n\
3044            Answer the agent's question about these tools. Provide exact `ati run` commands, explain parameters, and give practical examples. Be concise and actionable.",
3045            scope_name, scope_name, tools_context, skills_section
3046        );
3047        return Some(prompt);
3048    }
3049
3050    None
3051}