Skip to main content

ati/proxy/
server.rs

1/// ATI proxy server — holds API keys and executes tool calls on behalf of sandbox agents.
2///
3/// Authentication: ES256-signed JWT (or HS256 fallback). The JWT carries identity,
4/// scopes, and expiry. No more static tokens or unsigned scope lists.
5///
6/// Usage: `ati proxy --port 8080 [--ati-dir ~/.ati]`
7use axum::{
8    body::Body,
9    extract::{Extension, Query, State},
10    http::{Request as HttpRequest, StatusCode},
11    middleware::{self, Next},
12    response::{IntoResponse, Response},
13    routing::{get, post},
14    Json, Router,
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::core::auth_generator::{AuthCache, GenContext};
24use crate::core::http;
25use crate::core::jwt::{self, JwtConfig, TokenClaims};
26use crate::core::keyring::Keyring;
27use crate::core::manifest::{ManifestRegistry, Provider, Tool};
28use crate::core::mcp_client;
29use crate::core::response;
30use crate::core::scope::ScopeConfig;
31use crate::core::skill::{self, SkillRegistry};
32use crate::core::skillati::{RemoteSkillMeta, SkillAtiClient, SkillAtiError};
33
34/// Shared state for the proxy server.
35pub struct ProxyState {
36    pub registry: ManifestRegistry,
37    pub skill_registry: SkillRegistry,
38    pub keyring: Keyring,
39    /// JWT validation config (None = auth disabled / dev mode).
40    pub jwt_config: Option<JwtConfig>,
41    /// Pre-computed JWKS JSON for the /.well-known/jwks.json endpoint.
42    pub jwks_json: Option<Value>,
43    /// Shared cache for dynamically generated auth credentials.
44    pub auth_cache: AuthCache,
45}
46
47// --- Request/Response types ---
48
49#[derive(Debug, Deserialize)]
50pub struct CallRequest {
51    pub tool_name: String,
52    /// Tool arguments — accepts a JSON object (key-value pairs) for HTTP/MCP/OpenAPI tools,
53    /// or a JSON array of strings / a single string for CLI tools.
54    /// The proxy auto-detects the handler type and routes accordingly.
55    #[serde(default = "default_args")]
56    pub args: Value,
57    /// Deprecated: use `args` with an array value instead.
58    /// Kept for backward compatibility — if present, takes precedence for CLI tools.
59    #[serde(default)]
60    pub raw_args: Option<Vec<String>>,
61}
62
63fn default_args() -> Value {
64    Value::Object(serde_json::Map::new())
65}
66
67impl CallRequest {
68    /// Extract args as a HashMap for HTTP/MCP/OpenAPI tools.
69    /// If `args` is a JSON object, returns its entries.
70    /// If `args` is something else (array, string), returns an empty map.
71    fn args_as_map(&self) -> HashMap<String, Value> {
72        match &self.args {
73            Value::Object(map) => map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
74            _ => HashMap::new(),
75        }
76    }
77
78    /// Extract positional args for CLI tools.
79    /// Priority: explicit `raw_args` field > `args` array > `args` string > `args._positional` > empty.
80    fn args_as_positional(&self) -> Vec<String> {
81        // Backward compat: explicit raw_args wins
82        if let Some(ref raw) = self.raw_args {
83            return raw.clone();
84        }
85        match &self.args {
86            // ["pr", "list", "--repo", "X"]
87            Value::Array(arr) => arr
88                .iter()
89                .map(|v| match v {
90                    Value::String(s) => s.clone(),
91                    other => other.to_string(),
92                })
93                .collect(),
94            // "pr list --repo X"
95            Value::String(s) => s.split_whitespace().map(String::from).collect(),
96            // {"_positional": ["pr", "list"]} or {"--key": "value"} converted to CLI flags
97            Value::Object(map) => {
98                if let Some(Value::Array(pos)) = map.get("_positional") {
99                    return pos
100                        .iter()
101                        .map(|v| match v {
102                            Value::String(s) => s.clone(),
103                            other => other.to_string(),
104                        })
105                        .collect();
106                }
107                // Convert map entries to --key value pairs
108                let mut result = Vec::new();
109                for (k, v) in map {
110                    result.push(format!("--{k}"));
111                    match v {
112                        Value::String(s) => result.push(s.clone()),
113                        Value::Bool(true) => {} // flag, no value needed
114                        other => result.push(other.to_string()),
115                    }
116                }
117                result
118            }
119            _ => Vec::new(),
120        }
121    }
122}
123
124#[derive(Debug, Serialize)]
125pub struct CallResponse {
126    pub result: Value,
127    #[serde(skip_serializing_if = "Option::is_none")]
128    pub error: Option<String>,
129}
130
131#[derive(Debug, Deserialize)]
132pub struct HelpRequest {
133    pub query: String,
134    #[serde(default)]
135    pub tool: Option<String>,
136}
137
138#[derive(Debug, Serialize)]
139pub struct HelpResponse {
140    pub content: String,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub error: Option<String>,
143}
144
145#[derive(Debug, Serialize)]
146pub struct HealthResponse {
147    pub status: String,
148    pub version: String,
149    pub tools: usize,
150    pub providers: usize,
151    pub skills: usize,
152    pub auth: String,
153}
154
155// --- Skill endpoint types ---
156
157#[derive(Debug, Deserialize)]
158pub struct SkillsQuery {
159    #[serde(default)]
160    pub category: Option<String>,
161    #[serde(default)]
162    pub provider: Option<String>,
163    #[serde(default)]
164    pub tool: Option<String>,
165    #[serde(default)]
166    pub search: Option<String>,
167}
168
169#[derive(Debug, Deserialize)]
170pub struct SkillDetailQuery {
171    #[serde(default)]
172    pub meta: Option<bool>,
173    #[serde(default)]
174    pub refs: Option<bool>,
175}
176
177#[derive(Debug, Deserialize)]
178pub struct SkillResolveRequest {
179    pub scopes: Vec<String>,
180    /// When true, include SKILL.md content in each resolved skill.
181    #[serde(default)]
182    pub include_content: bool,
183}
184
185#[derive(Debug, Deserialize)]
186pub struct SkillBundleBatchRequest {
187    pub names: Vec<String>,
188}
189
190#[derive(Debug, Deserialize, Default)]
191pub struct SkillAtiCatalogQuery {
192    #[serde(default)]
193    pub search: Option<String>,
194}
195
196#[derive(Debug, Deserialize, Default)]
197pub struct SkillAtiResourcesQuery {
198    #[serde(default)]
199    pub prefix: Option<String>,
200}
201
202#[derive(Debug, Deserialize)]
203pub struct SkillAtiFileQuery {
204    pub path: String,
205}
206
207// --- Tool endpoint types ---
208
209#[derive(Debug, Deserialize)]
210pub struct ToolsQuery {
211    #[serde(default)]
212    pub provider: Option<String>,
213    #[serde(default)]
214    pub search: Option<String>,
215}
216
217// --- Handlers ---
218
219fn scopes_for_request(claims: Option<&TokenClaims>, state: &ProxyState) -> ScopeConfig {
220    match claims {
221        Some(claims) => ScopeConfig::from_jwt(claims),
222        None if state.jwt_config.is_none() => ScopeConfig::unrestricted(),
223        None => ScopeConfig {
224            scopes: Vec::new(),
225            sub: String::new(),
226            expires_at: 0,
227            rate_config: None,
228        },
229    }
230}
231
232fn visible_tools_for_scopes<'a>(
233    state: &'a ProxyState,
234    scopes: &ScopeConfig,
235) -> Vec<(&'a Provider, &'a Tool)> {
236    crate::core::scope::filter_tools_by_scope(state.registry.list_public_tools(), scopes)
237}
238
239fn visible_skill_names(
240    state: &ProxyState,
241    scopes: &ScopeConfig,
242) -> std::collections::HashSet<String> {
243    skill::visible_skills(&state.skill_registry, &state.registry, scopes)
244        .into_iter()
245        .map(|skill| skill.name.clone())
246        .collect()
247}
248
249/// Compute the set of remote (SkillATI-registry) skill names that the caller's
250/// scopes grant access to.
251///
252/// Mirrors the scope cascade in `skill::resolve_skills` — explicit `skill:X`
253/// scopes, `tool:Y` scopes resolved to the tool's covering skills (including
254/// provider/category bindings) — but against a remote catalog whose skills
255/// are **not** present in the local filesystem `SkillRegistry`.
256///
257/// Without this, proxies running `ATI_SKILL_REGISTRY=gcs://...` with an empty
258/// local skills directory return 404 for every remote skill, because the
259/// visibility gate only consults `state.skill_registry` (see issue #59).
260fn visible_remote_skill_names(
261    state: &ProxyState,
262    scopes: &ScopeConfig,
263    catalog: &[RemoteSkillMeta],
264) -> std::collections::HashSet<String> {
265    let mut visible: std::collections::HashSet<String> = std::collections::HashSet::new();
266    if catalog.is_empty() {
267        return visible;
268    }
269    if scopes.is_wildcard() {
270        for entry in catalog {
271            visible.insert(entry.name.clone());
272        }
273        return visible;
274    }
275
276    // Collect allowed tool/provider/category identifiers from the caller's scopes.
277    // 1. Direct `tool:X` scopes (including wildcards) → walk against the public
278    //    tool registry to collect concrete (provider, tool) pairs.
279    let allowed_tool_pairs: Vec<(String, String)> =
280        crate::core::scope::filter_tools_by_scope(state.registry.list_public_tools(), scopes)
281            .into_iter()
282            .map(|(p, t)| (p.name.clone(), t.name.clone()))
283            .collect();
284    let allowed_tool_names: std::collections::HashSet<&str> =
285        allowed_tool_pairs.iter().map(|(_, t)| t.as_str()).collect();
286    let allowed_provider_names: std::collections::HashSet<&str> =
287        allowed_tool_pairs.iter().map(|(p, _)| p.as_str()).collect();
288    let allowed_categories: std::collections::HashSet<String> = state
289        .registry
290        .list_providers()
291        .into_iter()
292        .filter(|p| allowed_provider_names.contains(p.name.as_str()))
293        .filter_map(|p| p.category.clone())
294        .collect();
295
296    // Explicit `skill:X` scopes → include X if present in the remote catalog.
297    for scope in &scopes.scopes {
298        if let Some(skill_name) = scope.strip_prefix("skill:") {
299            if catalog.iter().any(|e| e.name == skill_name) {
300                visible.insert(skill_name.to_string());
301            }
302        }
303    }
304
305    // Tool/provider/category cascade → include a remote skill if any of its
306    // `tools`, `providers`, or `categories` bindings match a scope-allowed
307    // tool/provider/category.
308    for entry in catalog {
309        if entry
310            .tools
311            .iter()
312            .any(|t| allowed_tool_names.contains(t.as_str()))
313            || entry
314                .providers
315                .iter()
316                .any(|p| allowed_provider_names.contains(p.as_str()))
317            || entry
318                .categories
319                .iter()
320                .any(|c| allowed_categories.contains(c))
321        {
322            visible.insert(entry.name.clone());
323        }
324    }
325
326    visible
327}
328
329/// Union of local + remote visible skill names, computed on demand. The
330/// remote catalog is fetched lazily (and is cached inside `SkillAtiClient`
331/// after the first call on the hot path).
332async fn visible_skill_names_with_remote(
333    state: &ProxyState,
334    scopes: &ScopeConfig,
335    client: &SkillAtiClient,
336) -> Result<std::collections::HashSet<String>, SkillAtiError> {
337    let mut names = visible_skill_names(state, scopes);
338    let catalog = client.catalog().await?;
339    let remote = visible_remote_skill_names(state, scopes, &catalog);
340    names.extend(remote);
341    Ok(names)
342}
343
344async fn handle_call(
345    State(state): State<Arc<ProxyState>>,
346    req: HttpRequest<Body>,
347) -> impl IntoResponse {
348    // Extract JWT claims from request extensions (set by auth middleware)
349    let claims = req.extensions().get::<TokenClaims>().cloned();
350
351    // Parse request body. The ceiling must accommodate the worst-case upload
352    // payload: `file_manager::MAX_UPLOAD_BYTES` of raw bytes, base64-inflated
353    // (~1.34×), plus a few KB of JSON framing. Anti-abuse is enforced
354    // downstream by per-tool limits (`max_bytes` on downloads, `MAX_UPLOAD_BYTES`
355    // on uploads) and by JWT scope + rate limits — this is just the outer
356    // wire cap.
357    let body_bytes = match axum::body::to_bytes(req.into_body(), max_call_body_bytes()).await {
358        Ok(b) => b,
359        Err(e) => {
360            return (
361                StatusCode::BAD_REQUEST,
362                Json(CallResponse {
363                    result: Value::Null,
364                    error: Some(format!("Failed to read request body: {e}")),
365                }),
366            );
367        }
368    };
369
370    let call_req: CallRequest = match serde_json::from_slice(&body_bytes) {
371        Ok(r) => r,
372        Err(e) => {
373            return (
374                StatusCode::UNPROCESSABLE_ENTITY,
375                Json(CallResponse {
376                    result: Value::Null,
377                    error: Some(format!("Invalid request: {e}")),
378                }),
379            );
380        }
381    };
382
383    tracing::debug!(
384        tool = %call_req.tool_name,
385        args = ?call_req.args,
386        "POST /call"
387    );
388
389    // Look up tool in registry.
390    // If not found, try converting underscore format (finnhub_quote) to colon (finnhub:quote).
391    let (provider, tool) = match state.registry.get_tool(&call_req.tool_name) {
392        Some(pt) => pt,
393        None => {
394            // Try underscore → colon conversion at each underscore position.
395            // "finnhub_quote" → try "finnhub:quote"
396            // "test_api_get_data" → try "test:api_get_data", "test_api:get_data"
397            let mut resolved = None;
398            for (idx, _) in call_req.tool_name.match_indices('_') {
399                let candidate = format!(
400                    "{}:{}",
401                    &call_req.tool_name[..idx],
402                    &call_req.tool_name[idx + 1..]
403                );
404                if let Some(pt) = state.registry.get_tool(&candidate) {
405                    tracing::debug!(
406                        original = %call_req.tool_name,
407                        resolved = %candidate,
408                        "resolved underscore tool name to colon format"
409                    );
410                    resolved = Some(pt);
411                    break;
412                }
413            }
414
415            match resolved {
416                Some(pt) => pt,
417                None => {
418                    return (
419                        StatusCode::NOT_FOUND,
420                        Json(CallResponse {
421                            result: Value::Null,
422                            error: Some(format!("Unknown tool: '{}'", call_req.tool_name)),
423                        }),
424                    );
425                }
426            }
427        }
428    };
429
430    // Scope enforcement from JWT claims
431    if let Some(tool_scope) = &tool.scope {
432        let scopes = match &claims {
433            Some(c) => ScopeConfig::from_jwt(c),
434            None if state.jwt_config.is_none() => ScopeConfig::unrestricted(), // Dev mode
435            None => {
436                return (
437                    StatusCode::FORBIDDEN,
438                    Json(CallResponse {
439                        result: Value::Null,
440                        error: Some("Authentication required — no JWT provided".into()),
441                    }),
442                );
443            }
444        };
445
446        if !scopes.is_allowed(tool_scope) {
447            return (
448                StatusCode::FORBIDDEN,
449                Json(CallResponse {
450                    result: Value::Null,
451                    error: Some(format!(
452                        "Access denied: '{}' is not in your scopes",
453                        tool.name
454                    )),
455                }),
456            );
457        }
458    }
459
460    // Rate limit check
461    {
462        let scopes = match &claims {
463            Some(c) => ScopeConfig::from_jwt(c),
464            None => ScopeConfig::unrestricted(),
465        };
466        if let Some(ref rate_config) = scopes.rate_config {
467            if let Err(e) = crate::core::rate::check_and_record(&call_req.tool_name, rate_config) {
468                return (
469                    StatusCode::TOO_MANY_REQUESTS,
470                    Json(CallResponse {
471                        result: Value::Null,
472                        error: Some(format!("{e}")),
473                    }),
474                );
475            }
476        }
477    }
478
479    // Build auth generator context from JWT claims
480    let gen_ctx = GenContext {
481        jwt_sub: claims
482            .as_ref()
483            .map(|c| c.sub.clone())
484            .unwrap_or_else(|| "dev".into()),
485        jwt_scope: claims
486            .as_ref()
487            .map(|c| c.scope.clone())
488            .unwrap_or_else(|| "*".into()),
489        tool_name: call_req.tool_name.clone(),
490        timestamp: crate::core::jwt::now_secs(),
491    };
492
493    // Execute tool call — dispatch based on handler type, with timing for audit
494    let agent_sub = claims.as_ref().map(|c| c.sub.clone()).unwrap_or_default();
495    let job_id = claims
496        .as_ref()
497        .and_then(|c| c.job_id.clone())
498        .unwrap_or_default();
499    let sandbox_id = claims
500        .as_ref()
501        .and_then(|c| c.sandbox_id.clone())
502        .unwrap_or_default();
503    tracing::info!(
504        tool = %call_req.tool_name,
505        agent = %agent_sub,
506        job_id = %job_id,
507        sandbox_id = %sandbox_id,
508        "tool call"
509    );
510    let start = std::time::Instant::now();
511
512    let response = match provider.handler.as_str() {
513        "mcp" => {
514            let args_map = call_req.args_as_map();
515            match mcp_client::execute_with_gen(
516                provider,
517                &call_req.tool_name,
518                &args_map,
519                &state.keyring,
520                Some(&gen_ctx),
521                Some(&state.auth_cache),
522            )
523            .await
524            {
525                Ok(result) => (
526                    StatusCode::OK,
527                    Json(CallResponse {
528                        result,
529                        error: None,
530                    }),
531                ),
532                Err(e) => (
533                    StatusCode::BAD_GATEWAY,
534                    Json(CallResponse {
535                        result: Value::Null,
536                        error: Some(format!("MCP error: {e}")),
537                    }),
538                ),
539            }
540        }
541        "cli" => {
542            let positional = call_req.args_as_positional();
543            match crate::core::cli_executor::execute_with_gen(
544                provider,
545                &positional,
546                &state.keyring,
547                Some(&gen_ctx),
548                Some(&state.auth_cache),
549            )
550            .await
551            {
552                Ok(result) => (
553                    StatusCode::OK,
554                    Json(CallResponse {
555                        result,
556                        error: None,
557                    }),
558                ),
559                Err(e) => (
560                    StatusCode::BAD_GATEWAY,
561                    Json(CallResponse {
562                        result: Value::Null,
563                        error: Some(format!("CLI error: {e}")),
564                    }),
565                ),
566            }
567        }
568        "file_manager" => {
569            let args_map = call_req.args_as_map();
570            match dispatch_file_manager(&call_req.tool_name, &args_map, provider, &state.keyring)
571                .await
572            {
573                Ok(result) => (
574                    StatusCode::OK,
575                    Json(CallResponse {
576                        result,
577                        error: None,
578                    }),
579                ),
580                Err((status, msg)) => (
581                    status,
582                    Json(CallResponse {
583                        result: Value::Null,
584                        error: Some(msg),
585                    }),
586                ),
587            }
588        }
589        _ => {
590            let args_map = call_req.args_as_map();
591            let raw_response = match http::execute_tool_with_gen(
592                provider,
593                tool,
594                &args_map,
595                &state.keyring,
596                Some(&gen_ctx),
597                Some(&state.auth_cache),
598            )
599            .await
600            {
601                Ok(resp) => resp,
602                Err(e) => {
603                    let duration = start.elapsed();
604                    write_proxy_audit(
605                        &call_req,
606                        &agent_sub,
607                        claims.as_ref(),
608                        duration,
609                        Some(&e.to_string()),
610                    );
611                    return (
612                        StatusCode::BAD_GATEWAY,
613                        Json(CallResponse {
614                            result: Value::Null,
615                            error: Some(format!("Upstream API error: {e}")),
616                        }),
617                    );
618                }
619            };
620
621            let processed = match response::process_response(&raw_response, tool.response.as_ref())
622            {
623                Ok(p) => p,
624                Err(e) => {
625                    let duration = start.elapsed();
626                    write_proxy_audit(
627                        &call_req,
628                        &agent_sub,
629                        claims.as_ref(),
630                        duration,
631                        Some(&e.to_string()),
632                    );
633                    return (
634                        StatusCode::INTERNAL_SERVER_ERROR,
635                        Json(CallResponse {
636                            result: raw_response,
637                            error: Some(format!("Response processing error: {e}")),
638                        }),
639                    );
640                }
641            };
642
643            (
644                StatusCode::OK,
645                Json(CallResponse {
646                    result: processed,
647                    error: None,
648                }),
649            )
650        }
651    };
652
653    let duration = start.elapsed();
654    let error_msg = response.1.error.as_deref();
655    write_proxy_audit(&call_req, &agent_sub, claims.as_ref(), duration, error_msg);
656
657    response
658}
659
660async fn handle_help(
661    State(state): State<Arc<ProxyState>>,
662    claims: Option<Extension<TokenClaims>>,
663    Json(req): Json<HelpRequest>,
664) -> impl IntoResponse {
665    tracing::debug!(query = %req.query, tool = ?req.tool, "POST /help");
666
667    let claims = claims.map(|Extension(claims)| claims);
668    let scopes = scopes_for_request(claims.as_ref(), &state);
669
670    let (llm_provider, llm_tool) = match state.registry.get_tool("_chat_completion") {
671        Some(pt) => pt,
672        None => {
673            return (
674                StatusCode::SERVICE_UNAVAILABLE,
675                Json(HelpResponse {
676                    content: String::new(),
677                    error: Some("No _llm.toml manifest found. Proxy help requires a configured LLM provider.".into()),
678                }),
679            );
680        }
681    };
682
683    let api_key = match llm_provider
684        .auth_key_name
685        .as_deref()
686        .and_then(|k| state.keyring.get(k))
687    {
688        Some(key) => key.to_string(),
689        None => {
690            return (
691                StatusCode::SERVICE_UNAVAILABLE,
692                Json(HelpResponse {
693                    content: String::new(),
694                    error: Some("LLM API key not found in keyring".into()),
695                }),
696            );
697        }
698    };
699
700    let resolved_skills = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
701    let local_skills_section = if resolved_skills.is_empty() {
702        String::new()
703    } else {
704        format!(
705            "## Available Skills (methodology guides)\n{}",
706            skill::build_skill_context(&resolved_skills)
707        )
708    };
709    let remote_query = req
710        .tool
711        .as_ref()
712        .map(|tool| format!("{tool} {}", req.query))
713        .unwrap_or_else(|| req.query.clone());
714    let remote_skills_section =
715        build_remote_skillati_section(&state.keyring, &remote_query, 12).await;
716    let skills_section = merge_help_skill_sections(&[local_skills_section, remote_skills_section]);
717
718    // Build system prompt — scoped or unscoped
719    let visible_tools = visible_tools_for_scopes(&state, &scopes);
720    let system_prompt = if let Some(ref tool_name) = req.tool {
721        // Scoped mode: narrow tools to the specified tool or provider
722        match build_scoped_prompt(tool_name, &visible_tools, &skills_section) {
723            Some(prompt) => prompt,
724            None => {
725                return (
726                    StatusCode::FORBIDDEN,
727                    Json(HelpResponse {
728                        content: String::new(),
729                        error: Some(format!(
730                            "Scope '{tool_name}' is not visible in your current scopes."
731                        )),
732                    }),
733                );
734            }
735        }
736    } else {
737        let tools_context = build_tool_context(&visible_tools);
738        HELP_SYSTEM_PROMPT
739            .replace("{tools}", &tools_context)
740            .replace("{skills_section}", &skills_section)
741    };
742
743    let request_body = serde_json::json!({
744        "model": "zai-glm-4.7",
745        "messages": [
746            {"role": "system", "content": system_prompt},
747            {"role": "user", "content": req.query}
748        ],
749        "max_completion_tokens": 1536,
750        "temperature": 0.3
751    });
752
753    let client = reqwest::Client::new();
754    let url = format!(
755        "{}{}",
756        llm_provider.base_url.trim_end_matches('/'),
757        llm_tool.endpoint
758    );
759
760    let response = match client
761        .post(&url)
762        .bearer_auth(&api_key)
763        .json(&request_body)
764        .send()
765        .await
766    {
767        Ok(r) => r,
768        Err(e) => {
769            return (
770                StatusCode::BAD_GATEWAY,
771                Json(HelpResponse {
772                    content: String::new(),
773                    error: Some(format!("LLM request failed: {e}")),
774                }),
775            );
776        }
777    };
778
779    if !response.status().is_success() {
780        let status = response.status();
781        let body = response.text().await.unwrap_or_default();
782        return (
783            StatusCode::BAD_GATEWAY,
784            Json(HelpResponse {
785                content: String::new(),
786                error: Some(format!("LLM API error ({status}): {body}")),
787            }),
788        );
789    }
790
791    let body: Value = match response.json().await {
792        Ok(b) => b,
793        Err(e) => {
794            return (
795                StatusCode::INTERNAL_SERVER_ERROR,
796                Json(HelpResponse {
797                    content: String::new(),
798                    error: Some(format!("Failed to parse LLM response: {e}")),
799                }),
800            );
801        }
802    };
803
804    let content = body
805        .pointer("/choices/0/message/content")
806        .and_then(|c| c.as_str())
807        .unwrap_or("No response from LLM")
808        .to_string();
809
810    (
811        StatusCode::OK,
812        Json(HelpResponse {
813            content,
814            error: None,
815        }),
816    )
817}
818
819async fn handle_health(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
820    let auth = if state.jwt_config.is_some() {
821        "jwt"
822    } else {
823        "disabled"
824    };
825
826    Json(HealthResponse {
827        status: "ok".into(),
828        version: env!("CARGO_PKG_VERSION").into(),
829        tools: state.registry.list_public_tools().len(),
830        providers: state.registry.list_providers().len(),
831        skills: state.skill_registry.skill_count(),
832        auth: auth.into(),
833    })
834}
835
836/// GET /.well-known/jwks.json — serves the public key for JWT validation.
837async fn handle_jwks(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
838    match &state.jwks_json {
839        Some(jwks) => (StatusCode::OK, Json(jwks.clone())),
840        None => (
841            StatusCode::NOT_FOUND,
842            Json(serde_json::json!({"error": "JWKS not configured"})),
843        ),
844    }
845}
846
847// ---------------------------------------------------------------------------
848// POST /mcp — MCP JSON-RPC proxy endpoint
849// ---------------------------------------------------------------------------
850
851async fn handle_mcp(
852    State(state): State<Arc<ProxyState>>,
853    claims: Option<Extension<TokenClaims>>,
854    Json(msg): Json<Value>,
855) -> impl IntoResponse {
856    let claims = claims.map(|Extension(claims)| claims);
857    let scopes = scopes_for_request(claims.as_ref(), &state);
858    let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
859    let id = msg.get("id").cloned();
860    tracing::info!(
861        %method,
862        agent = claims.as_ref().map(|c| c.sub.as_str()).unwrap_or(""),
863        job_id = claims.as_ref().and_then(|c| c.job_id.as_deref()).unwrap_or(""),
864        sandbox_id = claims.as_ref().and_then(|c| c.sandbox_id.as_deref()).unwrap_or(""),
865        "mcp call"
866    );
867
868    match method {
869        "initialize" => {
870            let result = serde_json::json!({
871                "protocolVersion": "2025-03-26",
872                "capabilities": {
873                    "tools": { "listChanged": false }
874                },
875                "serverInfo": {
876                    "name": "ati-proxy",
877                    "version": env!("CARGO_PKG_VERSION")
878                }
879            });
880            jsonrpc_success(id, result)
881        }
882
883        "notifications/initialized" => (StatusCode::ACCEPTED, Json(Value::Null)),
884
885        "tools/list" => {
886            let visible_tools = visible_tools_for_scopes(&state, &scopes);
887            let mcp_tools: Vec<Value> = visible_tools
888                .iter()
889                .map(|(_provider, tool)| {
890                    serde_json::json!({
891                        "name": tool.name,
892                        "description": tool.description,
893                        "inputSchema": tool.input_schema.clone().unwrap_or(serde_json::json!({
894                            "type": "object",
895                            "properties": {}
896                        }))
897                    })
898                })
899                .collect();
900
901            let result = serde_json::json!({
902                "tools": mcp_tools,
903            });
904            jsonrpc_success(id, result)
905        }
906
907        "tools/call" => {
908            let params = msg.get("params").cloned().unwrap_or(Value::Null);
909            let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
910            let arguments: HashMap<String, Value> = params
911                .get("arguments")
912                .and_then(|a| serde_json::from_value(a.clone()).ok())
913                .unwrap_or_default();
914
915            if tool_name.is_empty() {
916                return jsonrpc_error(id, -32602, "Missing tool name in params.name");
917            }
918
919            let (provider, _tool) = match state.registry.get_tool(tool_name) {
920                Some(pt) => pt,
921                None => {
922                    return jsonrpc_error(id, -32602, &format!("Unknown tool: '{tool_name}'"));
923                }
924            };
925
926            if let Some(tool_scope) = &_tool.scope {
927                if !scopes.is_allowed(tool_scope) {
928                    return jsonrpc_error(
929                        id,
930                        -32001,
931                        &format!("Access denied: '{}' is not in your scopes", _tool.name),
932                    );
933                }
934            }
935
936            tracing::debug!(%tool_name, provider = %provider.name, "MCP tools/call");
937
938            let mcp_gen_ctx = GenContext {
939                jwt_sub: claims
940                    .as_ref()
941                    .map(|claims| claims.sub.clone())
942                    .unwrap_or_else(|| "dev".into()),
943                jwt_scope: claims
944                    .as_ref()
945                    .map(|claims| claims.scope.clone())
946                    .unwrap_or_else(|| "*".into()),
947                tool_name: tool_name.to_string(),
948                timestamp: crate::core::jwt::now_secs(),
949            };
950
951            let result = if provider.is_mcp() {
952                mcp_client::execute_with_gen(
953                    provider,
954                    tool_name,
955                    &arguments,
956                    &state.keyring,
957                    Some(&mcp_gen_ctx),
958                    Some(&state.auth_cache),
959                )
960                .await
961            } else if provider.is_cli() {
962                // Convert arguments map to CLI-style args for MCP passthrough
963                let raw: Vec<String> = arguments
964                    .iter()
965                    .flat_map(|(k, v)| {
966                        let val = match v {
967                            Value::String(s) => s.clone(),
968                            other => other.to_string(),
969                        };
970                        vec![format!("--{k}"), val]
971                    })
972                    .collect();
973                crate::core::cli_executor::execute_with_gen(
974                    provider,
975                    &raw,
976                    &state.keyring,
977                    Some(&mcp_gen_ctx),
978                    Some(&state.auth_cache),
979                )
980                .await
981                .map_err(|e| mcp_client::McpError::Transport(e.to_string()))
982            } else {
983                match http::execute_tool_with_gen(
984                    provider,
985                    _tool,
986                    &arguments,
987                    &state.keyring,
988                    Some(&mcp_gen_ctx),
989                    Some(&state.auth_cache),
990                )
991                .await
992                {
993                    Ok(val) => Ok(val),
994                    Err(e) => Err(mcp_client::McpError::Transport(e.to_string())),
995                }
996            };
997
998            match result {
999                Ok(value) => {
1000                    let text = match &value {
1001                        Value::String(s) => s.clone(),
1002                        other => serde_json::to_string_pretty(other).unwrap_or_default(),
1003                    };
1004                    let mcp_result = serde_json::json!({
1005                        "content": [{"type": "text", "text": text}],
1006                        "isError": false,
1007                    });
1008                    jsonrpc_success(id, mcp_result)
1009                }
1010                Err(e) => {
1011                    let mcp_result = serde_json::json!({
1012                        "content": [{"type": "text", "text": format!("Error: {e}")}],
1013                        "isError": true,
1014                    });
1015                    jsonrpc_success(id, mcp_result)
1016                }
1017            }
1018        }
1019
1020        _ => jsonrpc_error(id, -32601, &format!("Method not found: '{method}'")),
1021    }
1022}
1023
1024fn jsonrpc_success(id: Option<Value>, result: Value) -> (StatusCode, Json<Value>) {
1025    (
1026        StatusCode::OK,
1027        Json(serde_json::json!({
1028            "jsonrpc": "2.0",
1029            "id": id,
1030            "result": result,
1031        })),
1032    )
1033}
1034
1035fn jsonrpc_error(id: Option<Value>, code: i64, message: &str) -> (StatusCode, Json<Value>) {
1036    (
1037        StatusCode::OK,
1038        Json(serde_json::json!({
1039            "jsonrpc": "2.0",
1040            "id": id,
1041            "error": {
1042                "code": code,
1043                "message": message,
1044            }
1045        })),
1046    )
1047}
1048
1049// ---------------------------------------------------------------------------
1050// Tool endpoints
1051// ---------------------------------------------------------------------------
1052
1053/// GET /tools — list available tools with optional filters.
1054async fn handle_tools_list(
1055    State(state): State<Arc<ProxyState>>,
1056    claims: Option<Extension<TokenClaims>>,
1057    axum::extract::Query(query): axum::extract::Query<ToolsQuery>,
1058) -> impl IntoResponse {
1059    tracing::debug!(
1060        provider = ?query.provider,
1061        search = ?query.search,
1062        "GET /tools"
1063    );
1064
1065    let claims = claims.map(|Extension(claims)| claims);
1066    let scopes = scopes_for_request(claims.as_ref(), &state);
1067    let all_tools = visible_tools_for_scopes(&state, &scopes);
1068
1069    let tools: Vec<Value> = all_tools
1070        .iter()
1071        .filter(|(provider, tool)| {
1072            if let Some(ref p) = query.provider {
1073                if provider.name != *p {
1074                    return false;
1075                }
1076            }
1077            if let Some(ref q) = query.search {
1078                let q = q.to_lowercase();
1079                let name_match = tool.name.to_lowercase().contains(&q);
1080                let desc_match = tool.description.to_lowercase().contains(&q);
1081                let tag_match = tool.tags.iter().any(|t| t.to_lowercase().contains(&q));
1082                if !name_match && !desc_match && !tag_match {
1083                    return false;
1084                }
1085            }
1086            true
1087        })
1088        .map(|(provider, tool)| {
1089            serde_json::json!({
1090                "name": tool.name,
1091                "description": tool.description,
1092                "provider": provider.name,
1093                "method": format!("{:?}", tool.method),
1094                "tags": tool.tags,
1095                "skills": provider.skills,
1096                "input_schema": tool.input_schema,
1097            })
1098        })
1099        .collect();
1100
1101    (StatusCode::OK, Json(Value::Array(tools)))
1102}
1103
1104/// GET /tools/:name — get detailed info about a specific tool.
1105async fn handle_tool_info(
1106    State(state): State<Arc<ProxyState>>,
1107    claims: Option<Extension<TokenClaims>>,
1108    axum::extract::Path(name): axum::extract::Path<String>,
1109) -> impl IntoResponse {
1110    tracing::debug!(tool = %name, "GET /tools/:name");
1111
1112    let claims = claims.map(|Extension(claims)| claims);
1113    let scopes = scopes_for_request(claims.as_ref(), &state);
1114
1115    match state
1116        .registry
1117        .get_tool(&name)
1118        .filter(|(_, tool)| match &tool.scope {
1119            Some(scope) => scopes.is_allowed(scope),
1120            None => true,
1121        }) {
1122        Some((provider, tool)) => {
1123            // Merge skills from manifest + SkillRegistry (tool binding + provider binding)
1124            let mut skills: Vec<String> = provider.skills.clone();
1125            for s in state.skill_registry.skills_for_tool(&tool.name) {
1126                if !skills.contains(&s.name) {
1127                    skills.push(s.name.clone());
1128                }
1129            }
1130            for s in state.skill_registry.skills_for_provider(&provider.name) {
1131                if !skills.contains(&s.name) {
1132                    skills.push(s.name.clone());
1133                }
1134            }
1135
1136            (
1137                StatusCode::OK,
1138                Json(serde_json::json!({
1139                    "name": tool.name,
1140                    "description": tool.description,
1141                    "provider": provider.name,
1142                    "method": format!("{:?}", tool.method),
1143                    "endpoint": tool.endpoint,
1144                    "tags": tool.tags,
1145                    "hint": tool.hint,
1146                    "skills": skills,
1147                    "input_schema": tool.input_schema,
1148                    "scope": tool.scope,
1149                })),
1150            )
1151        }
1152        None => (
1153            StatusCode::NOT_FOUND,
1154            Json(serde_json::json!({"error": format!("Tool '{name}' not found")})),
1155        ),
1156    }
1157}
1158
1159// ---------------------------------------------------------------------------
1160// Skill endpoints
1161// ---------------------------------------------------------------------------
1162
1163async fn handle_skills_list(
1164    State(state): State<Arc<ProxyState>>,
1165    claims: Option<Extension<TokenClaims>>,
1166    axum::extract::Query(query): axum::extract::Query<SkillsQuery>,
1167) -> impl IntoResponse {
1168    tracing::debug!(
1169        category = ?query.category,
1170        provider = ?query.provider,
1171        tool = ?query.tool,
1172        search = ?query.search,
1173        "GET /skills"
1174    );
1175
1176    let claims = claims.map(|Extension(claims)| claims);
1177    let scopes = scopes_for_request(claims.as_ref(), &state);
1178    let visible_names = visible_skill_names(&state, &scopes);
1179
1180    let skills: Vec<&skill::SkillMeta> = if let Some(search_query) = &query.search {
1181        state
1182            .skill_registry
1183            .search(search_query)
1184            .into_iter()
1185            .filter(|skill| visible_names.contains(&skill.name))
1186            .collect()
1187    } else if let Some(cat) = &query.category {
1188        state
1189            .skill_registry
1190            .skills_for_category(cat)
1191            .into_iter()
1192            .filter(|skill| visible_names.contains(&skill.name))
1193            .collect()
1194    } else if let Some(prov) = &query.provider {
1195        state
1196            .skill_registry
1197            .skills_for_provider(prov)
1198            .into_iter()
1199            .filter(|skill| visible_names.contains(&skill.name))
1200            .collect()
1201    } else if let Some(t) = &query.tool {
1202        state
1203            .skill_registry
1204            .skills_for_tool(t)
1205            .into_iter()
1206            .filter(|skill| visible_names.contains(&skill.name))
1207            .collect()
1208    } else {
1209        state
1210            .skill_registry
1211            .list_skills()
1212            .iter()
1213            .filter(|skill| visible_names.contains(&skill.name))
1214            .collect()
1215    };
1216
1217    let json: Vec<Value> = skills
1218        .iter()
1219        .map(|s| {
1220            serde_json::json!({
1221                "name": s.name,
1222                "version": s.version,
1223                "description": s.description,
1224                "tools": s.tools,
1225                "providers": s.providers,
1226                "categories": s.categories,
1227                "hint": s.hint,
1228            })
1229        })
1230        .collect();
1231
1232    (StatusCode::OK, Json(Value::Array(json)))
1233}
1234
1235async fn handle_skill_detail(
1236    State(state): State<Arc<ProxyState>>,
1237    claims: Option<Extension<TokenClaims>>,
1238    axum::extract::Path(name): axum::extract::Path<String>,
1239    axum::extract::Query(query): axum::extract::Query<SkillDetailQuery>,
1240) -> impl IntoResponse {
1241    tracing::debug!(%name, meta = ?query.meta, refs = ?query.refs, "GET /skills/:name");
1242
1243    let claims = claims.map(|Extension(claims)| claims);
1244    let scopes = scopes_for_request(claims.as_ref(), &state);
1245    let visible_names = visible_skill_names(&state, &scopes);
1246
1247    let skill_meta = match state
1248        .skill_registry
1249        .get_skill(&name)
1250        .filter(|skill| visible_names.contains(&skill.name))
1251    {
1252        Some(s) => s,
1253        None => {
1254            return (
1255                StatusCode::NOT_FOUND,
1256                Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
1257            );
1258        }
1259    };
1260
1261    if query.meta.unwrap_or(false) {
1262        return (
1263            StatusCode::OK,
1264            Json(serde_json::json!({
1265                "name": skill_meta.name,
1266                "version": skill_meta.version,
1267                "description": skill_meta.description,
1268                "author": skill_meta.author,
1269                "tools": skill_meta.tools,
1270                "providers": skill_meta.providers,
1271                "categories": skill_meta.categories,
1272                "keywords": skill_meta.keywords,
1273                "hint": skill_meta.hint,
1274                "depends_on": skill_meta.depends_on,
1275                "suggests": skill_meta.suggests,
1276                "license": skill_meta.license,
1277                "compatibility": skill_meta.compatibility,
1278                "allowed_tools": skill_meta.allowed_tools,
1279                "format": skill_meta.format,
1280            })),
1281        );
1282    }
1283
1284    let content = match state.skill_registry.read_content(&name) {
1285        Ok(c) => c,
1286        Err(e) => {
1287            return (
1288                StatusCode::INTERNAL_SERVER_ERROR,
1289                Json(serde_json::json!({"error": format!("Failed to read skill: {e}")})),
1290            );
1291        }
1292    };
1293
1294    let mut response = serde_json::json!({
1295        "name": skill_meta.name,
1296        "version": skill_meta.version,
1297        "description": skill_meta.description,
1298        "content": content,
1299    });
1300
1301    if query.refs.unwrap_or(false) {
1302        if let Ok(refs) = state.skill_registry.list_references(&name) {
1303            response["references"] = serde_json::json!(refs);
1304        }
1305    }
1306
1307    (StatusCode::OK, Json(response))
1308}
1309
1310/// GET /skills/:name/bundle — return all files in a skill directory.
1311/// Response: `{"name": "...", "files": {"SKILL.md": "...", "scripts/generate.sh": "...", ...}}`
1312/// Binary files are base64-encoded; text files are returned as-is.
1313async fn handle_skill_bundle(
1314    State(state): State<Arc<ProxyState>>,
1315    claims: Option<Extension<TokenClaims>>,
1316    axum::extract::Path(name): axum::extract::Path<String>,
1317) -> impl IntoResponse {
1318    tracing::debug!(skill = %name, "GET /skills/:name/bundle");
1319
1320    let claims = claims.map(|Extension(claims)| claims);
1321    let scopes = scopes_for_request(claims.as_ref(), &state);
1322    let visible_names = visible_skill_names(&state, &scopes);
1323    if !visible_names.contains(&name) {
1324        return (
1325            StatusCode::NOT_FOUND,
1326            Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
1327        );
1328    }
1329
1330    let files = match state.skill_registry.bundle_files(&name) {
1331        Ok(f) => f,
1332        Err(_) => {
1333            return (
1334                StatusCode::NOT_FOUND,
1335                Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
1336            );
1337        }
1338    };
1339
1340    // Convert bytes to strings (UTF-8 text) or base64 for binary files
1341    let mut file_map = serde_json::Map::new();
1342    for (path, data) in &files {
1343        match std::str::from_utf8(data) {
1344            Ok(text) => {
1345                file_map.insert(path.clone(), Value::String(text.to_string()));
1346            }
1347            Err(_) => {
1348                // Binary file — base64 encode
1349                use base64::Engine;
1350                let encoded = base64::engine::general_purpose::STANDARD.encode(data);
1351                file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
1352            }
1353        }
1354    }
1355
1356    (
1357        StatusCode::OK,
1358        Json(serde_json::json!({
1359            "name": name,
1360            "files": file_map,
1361        })),
1362    )
1363}
1364
1365/// POST /skills/bundle — return all files for multiple skills in one response.
1366/// Request: `{"names": ["fal-generate", "compliance-screening"]}`
1367/// Response: `{"skills": {...}, "missing": [...]}`
1368async fn handle_skills_bundle_batch(
1369    State(state): State<Arc<ProxyState>>,
1370    claims: Option<Extension<TokenClaims>>,
1371    Json(req): Json<SkillBundleBatchRequest>,
1372) -> impl IntoResponse {
1373    const MAX_BATCH: usize = 50;
1374    if req.names.len() > MAX_BATCH {
1375        return (
1376            StatusCode::BAD_REQUEST,
1377            Json(
1378                serde_json::json!({"error": format!("batch size {} exceeds limit of {MAX_BATCH}", req.names.len())}),
1379            ),
1380        );
1381    }
1382
1383    tracing::debug!(names = ?req.names, "POST /skills/bundle");
1384
1385    let claims = claims.map(|Extension(claims)| claims);
1386    let scopes = scopes_for_request(claims.as_ref(), &state);
1387    let visible_names = visible_skill_names(&state, &scopes);
1388
1389    let mut result = serde_json::Map::new();
1390    let mut missing: Vec<String> = Vec::new();
1391
1392    for name in &req.names {
1393        if !visible_names.contains(name) {
1394            missing.push(name.clone());
1395            continue;
1396        }
1397        let files = match state.skill_registry.bundle_files(name) {
1398            Ok(f) => f,
1399            Err(_) => {
1400                missing.push(name.clone());
1401                continue;
1402            }
1403        };
1404
1405        let mut file_map = serde_json::Map::new();
1406        for (path, data) in &files {
1407            match std::str::from_utf8(data) {
1408                Ok(text) => {
1409                    file_map.insert(path.clone(), Value::String(text.to_string()));
1410                }
1411                Err(_) => {
1412                    use base64::Engine;
1413                    let encoded = base64::engine::general_purpose::STANDARD.encode(data);
1414                    file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
1415                }
1416            }
1417        }
1418
1419        result.insert(name.clone(), serde_json::json!({ "files": file_map }));
1420    }
1421
1422    (
1423        StatusCode::OK,
1424        Json(serde_json::json!({ "skills": result, "missing": missing })),
1425    )
1426}
1427
1428async fn handle_skills_resolve(
1429    State(state): State<Arc<ProxyState>>,
1430    claims: Option<Extension<TokenClaims>>,
1431    Json(req): Json<SkillResolveRequest>,
1432) -> impl IntoResponse {
1433    tracing::debug!(scopes = ?req.scopes, include_content = req.include_content, "POST /skills/resolve");
1434
1435    let include_content = req.include_content;
1436    let request_scopes = ScopeConfig {
1437        scopes: req.scopes,
1438        sub: String::new(),
1439        expires_at: 0,
1440        rate_config: None,
1441    };
1442    let claims = claims.map(|Extension(claims)| claims);
1443    let caller_scopes = scopes_for_request(claims.as_ref(), &state);
1444    let visible_names = visible_skill_names(&state, &caller_scopes);
1445
1446    let resolved: Vec<&skill::SkillMeta> =
1447        skill::resolve_skills(&state.skill_registry, &state.registry, &request_scopes)
1448            .into_iter()
1449            .filter(|skill| visible_names.contains(&skill.name))
1450            .collect();
1451
1452    let json: Vec<Value> = resolved
1453        .iter()
1454        .map(|s| {
1455            let mut entry = serde_json::json!({
1456                "name": s.name,
1457                "version": s.version,
1458                "description": s.description,
1459                "tools": s.tools,
1460                "providers": s.providers,
1461                "categories": s.categories,
1462            });
1463            if include_content {
1464                if let Ok(content) = state.skill_registry.read_content(&s.name) {
1465                    entry["content"] = Value::String(content);
1466                }
1467            }
1468            entry
1469        })
1470        .collect();
1471
1472    (StatusCode::OK, Json(Value::Array(json)))
1473}
1474
1475fn skillati_client(keyring: &Keyring) -> Result<SkillAtiClient, SkillAtiError> {
1476    match SkillAtiClient::from_env(keyring)? {
1477        Some(client) => Ok(client),
1478        None => Err(SkillAtiError::NotConfigured),
1479    }
1480}
1481
1482async fn handle_skillati_catalog(
1483    State(state): State<Arc<ProxyState>>,
1484    claims: Option<Extension<TokenClaims>>,
1485    Query(query): Query<SkillAtiCatalogQuery>,
1486) -> impl IntoResponse {
1487    tracing::debug!(search = ?query.search, "GET /skillati/catalog");
1488
1489    let client = match skillati_client(&state.keyring) {
1490        Ok(client) => client,
1491        Err(err) => return skillati_error_response(err),
1492    };
1493
1494    let claims = claims.map(|Extension(c)| c);
1495    let scopes = scopes_for_request(claims.as_ref(), &state);
1496
1497    match client.catalog().await {
1498        Ok(catalog) => {
1499            // Union of local + remote visibility. Merging here (instead of
1500            // calling visible_skill_names_with_remote, which would re-fetch)
1501            // avoids a redundant catalog request on the hot path.
1502            let mut visible_names = visible_skill_names(&state, &scopes);
1503            visible_names.extend(visible_remote_skill_names(&state, &scopes, &catalog));
1504
1505            let mut skills: Vec<_> = catalog
1506                .into_iter()
1507                .filter(|s| visible_names.contains(&s.name))
1508                .collect();
1509            if let Some(search) = query.search.as_deref() {
1510                skills = SkillAtiClient::filter_catalog(&skills, search, 25);
1511            }
1512            (
1513                StatusCode::OK,
1514                Json(serde_json::json!({
1515                    "skills": skills,
1516                })),
1517            )
1518        }
1519        Err(err) => skillati_error_response(err),
1520    }
1521}
1522
1523async fn handle_skillati_read(
1524    State(state): State<Arc<ProxyState>>,
1525    claims: Option<Extension<TokenClaims>>,
1526    axum::extract::Path(name): axum::extract::Path<String>,
1527) -> impl IntoResponse {
1528    tracing::debug!(%name, "GET /skillati/:name");
1529
1530    let client = match skillati_client(&state.keyring) {
1531        Ok(client) => client,
1532        Err(err) => return skillati_error_response(err),
1533    };
1534
1535    let claims = claims.map(|Extension(c)| c);
1536    let scopes = scopes_for_request(claims.as_ref(), &state);
1537    let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
1538        Ok(v) => v,
1539        Err(err) => return skillati_error_response(err),
1540    };
1541    if !visible_names.contains(&name) {
1542        return skillati_error_response(SkillAtiError::SkillNotFound(name));
1543    }
1544
1545    match client.read_skill(&name).await {
1546        Ok(activation) => (StatusCode::OK, Json(serde_json::json!(activation))),
1547        Err(err) => skillati_error_response(err),
1548    }
1549}
1550
1551async fn handle_skillati_resources(
1552    State(state): State<Arc<ProxyState>>,
1553    claims: Option<Extension<TokenClaims>>,
1554    axum::extract::Path(name): axum::extract::Path<String>,
1555    Query(query): Query<SkillAtiResourcesQuery>,
1556) -> impl IntoResponse {
1557    tracing::debug!(%name, prefix = ?query.prefix, "GET /skillati/:name/resources");
1558
1559    let client = match skillati_client(&state.keyring) {
1560        Ok(client) => client,
1561        Err(err) => return skillati_error_response(err),
1562    };
1563
1564    let claims = claims.map(|Extension(c)| c);
1565    let scopes = scopes_for_request(claims.as_ref(), &state);
1566    let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
1567        Ok(v) => v,
1568        Err(err) => return skillati_error_response(err),
1569    };
1570    if !visible_names.contains(&name) {
1571        return skillati_error_response(SkillAtiError::SkillNotFound(name));
1572    }
1573
1574    match client.list_resources(&name, query.prefix.as_deref()).await {
1575        Ok(resources) => (
1576            StatusCode::OK,
1577            Json(serde_json::json!({
1578                "name": name,
1579                "prefix": query.prefix,
1580                "resources": resources,
1581            })),
1582        ),
1583        Err(err) => skillati_error_response(err),
1584    }
1585}
1586
1587async fn handle_skillati_file(
1588    State(state): State<Arc<ProxyState>>,
1589    claims: Option<Extension<TokenClaims>>,
1590    axum::extract::Path(name): axum::extract::Path<String>,
1591    Query(query): Query<SkillAtiFileQuery>,
1592) -> impl IntoResponse {
1593    tracing::debug!(%name, path = %query.path, "GET /skillati/:name/file");
1594
1595    let client = match skillati_client(&state.keyring) {
1596        Ok(client) => client,
1597        Err(err) => return skillati_error_response(err),
1598    };
1599
1600    let claims = claims.map(|Extension(c)| c);
1601    let scopes = scopes_for_request(claims.as_ref(), &state);
1602    let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
1603        Ok(v) => v,
1604        Err(err) => return skillati_error_response(err),
1605    };
1606    if !visible_names.contains(&name) {
1607        return skillati_error_response(SkillAtiError::SkillNotFound(name));
1608    }
1609
1610    match client.read_path(&name, &query.path).await {
1611        Ok(file) => (StatusCode::OK, Json(serde_json::json!(file))),
1612        Err(err) => skillati_error_response(err),
1613    }
1614}
1615
1616async fn handle_skillati_refs(
1617    State(state): State<Arc<ProxyState>>,
1618    claims: Option<Extension<TokenClaims>>,
1619    axum::extract::Path(name): axum::extract::Path<String>,
1620) -> impl IntoResponse {
1621    tracing::debug!(%name, "GET /skillati/:name/refs");
1622
1623    let client = match skillati_client(&state.keyring) {
1624        Ok(client) => client,
1625        Err(err) => return skillati_error_response(err),
1626    };
1627
1628    let claims = claims.map(|Extension(c)| c);
1629    let scopes = scopes_for_request(claims.as_ref(), &state);
1630    let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
1631        Ok(v) => v,
1632        Err(err) => return skillati_error_response(err),
1633    };
1634    if !visible_names.contains(&name) {
1635        return skillati_error_response(SkillAtiError::SkillNotFound(name));
1636    }
1637
1638    match client.list_references(&name).await {
1639        Ok(references) => (
1640            StatusCode::OK,
1641            Json(serde_json::json!({
1642                "name": name,
1643                "references": references,
1644            })),
1645        ),
1646        Err(err) => skillati_error_response(err),
1647    }
1648}
1649
1650async fn handle_skillati_ref(
1651    State(state): State<Arc<ProxyState>>,
1652    claims: Option<Extension<TokenClaims>>,
1653    axum::extract::Path((name, reference)): axum::extract::Path<(String, String)>,
1654) -> impl IntoResponse {
1655    tracing::debug!(%name, %reference, "GET /skillati/:name/ref/:reference");
1656
1657    let client = match skillati_client(&state.keyring) {
1658        Ok(client) => client,
1659        Err(err) => return skillati_error_response(err),
1660    };
1661
1662    let claims = claims.map(|Extension(c)| c);
1663    let scopes = scopes_for_request(claims.as_ref(), &state);
1664    let visible_names = match visible_skill_names_with_remote(&state, &scopes, &client).await {
1665        Ok(v) => v,
1666        Err(err) => return skillati_error_response(err),
1667    };
1668    if !visible_names.contains(&name) {
1669        return skillati_error_response(SkillAtiError::SkillNotFound(name));
1670    }
1671
1672    match client.read_reference(&name, &reference).await {
1673        Ok(content) => (
1674            StatusCode::OK,
1675            Json(serde_json::json!({
1676                "name": name,
1677                "reference": reference,
1678                "content": content,
1679            })),
1680        ),
1681        Err(err) => skillati_error_response(err),
1682    }
1683}
1684
1685fn skillati_error_response(err: SkillAtiError) -> (StatusCode, Json<Value>) {
1686    let status = match &err {
1687        SkillAtiError::NotConfigured
1688        | SkillAtiError::UnsupportedRegistry(_)
1689        | SkillAtiError::MissingCredentials(_)
1690        | SkillAtiError::ProxyUrlRequired => StatusCode::SERVICE_UNAVAILABLE,
1691        SkillAtiError::SkillNotFound(_) | SkillAtiError::PathNotFound { .. } => {
1692            StatusCode::NOT_FOUND
1693        }
1694        SkillAtiError::InvalidPath(_) => StatusCode::BAD_REQUEST,
1695        SkillAtiError::Gcs(_)
1696        | SkillAtiError::ProxyRequest(_)
1697        | SkillAtiError::ProxyResponse(_) => StatusCode::BAD_GATEWAY,
1698    };
1699
1700    (
1701        status,
1702        Json(serde_json::json!({
1703            "error": err.to_string(),
1704        })),
1705    )
1706}
1707
1708// --- Auth middleware ---
1709
1710/// JWT authentication middleware.
1711///
1712/// - /health and /.well-known/jwks.json → skip auth
1713/// - JWT configured → validate Bearer token, attach claims to request extensions
1714/// - No JWT configured → allow all (dev mode)
1715async fn auth_middleware(
1716    State(state): State<Arc<ProxyState>>,
1717    mut req: HttpRequest<Body>,
1718    next: Next,
1719) -> Result<Response, StatusCode> {
1720    let path = req.uri().path();
1721
1722    // Skip auth for public endpoints
1723    if path == "/health" || path == "/.well-known/jwks.json" {
1724        return Ok(next.run(req).await);
1725    }
1726
1727    // If no JWT configured, allow all (dev mode)
1728    let jwt_config = match &state.jwt_config {
1729        Some(c) => c,
1730        None => return Ok(next.run(req).await),
1731    };
1732
1733    // Extract Authorization: Bearer <token>
1734    let auth_header = req
1735        .headers()
1736        .get("authorization")
1737        .and_then(|v| v.to_str().ok());
1738
1739    let token = match auth_header {
1740        Some(header) if header.starts_with("Bearer ") => &header[7..],
1741        _ => return Err(StatusCode::UNAUTHORIZED),
1742    };
1743
1744    // Validate JWT
1745    match jwt::validate(token, jwt_config) {
1746        Ok(claims) => {
1747            tracing::debug!(sub = %claims.sub, scopes = %claims.scope, "JWT validated");
1748            req.extensions_mut().insert(claims);
1749            Ok(next.run(req).await)
1750        }
1751        Err(e) => {
1752            tracing::debug!(error = %e, "JWT validation failed");
1753            Err(StatusCode::UNAUTHORIZED)
1754        }
1755    }
1756}
1757
1758// --- Router builder ---
1759
1760/// Build the axum Router from a pre-constructed ProxyState.
1761/// Outer body-size ceiling for `POST /call`. Large enough to carry the worst
1762/// case `file_manager:upload` payload (`MAX_UPLOAD_BYTES` of raw bytes,
1763/// base64-inflated ~4/3×, plus a few KB of JSON framing).
1764///
1765/// Per-tool limits (`max_bytes`, `MAX_UPLOAD_BYTES`) plus JWT scopes + rate
1766/// limits are the real gates — this is just the outermost wrapper check.
1767fn max_call_body_bytes() -> usize {
1768    (crate::core::file_manager::MAX_UPLOAD_BYTES as usize)
1769        .saturating_mul(4)
1770        .saturating_div(3)
1771        .saturating_add(8 * 1024)
1772}
1773
1774pub fn build_router(state: Arc<ProxyState>) -> Router {
1775    use axum::extract::DefaultBodyLimit;
1776
1777    Router::new()
1778        .route("/call", post(handle_call))
1779        .route("/help", post(handle_help))
1780        .route("/mcp", post(handle_mcp))
1781        .route("/tools", get(handle_tools_list))
1782        .route("/tools/{name}", get(handle_tool_info))
1783        .route("/skills", get(handle_skills_list))
1784        .route("/skills/resolve", post(handle_skills_resolve))
1785        .route("/skills/bundle", post(handle_skills_bundle_batch))
1786        .route("/skills/{name}", get(handle_skill_detail))
1787        .route("/skills/{name}/bundle", get(handle_skill_bundle))
1788        .route("/skillati/catalog", get(handle_skillati_catalog))
1789        .route("/skillati/{name}", get(handle_skillati_read))
1790        .route("/skillati/{name}/resources", get(handle_skillati_resources))
1791        .route("/skillati/{name}/file", get(handle_skillati_file))
1792        .route("/skillati/{name}/refs", get(handle_skillati_refs))
1793        .route("/skillati/{name}/ref/{reference}", get(handle_skillati_ref))
1794        .route("/health", get(handle_health))
1795        .route("/.well-known/jwks.json", get(handle_jwks))
1796        // Raise axum's default 2 MB body-extractor limit so request bodies
1797        // carrying base64-encoded upload payloads aren't rejected before the
1798        // handler runs. `handle_call` still enforces its own
1799        // `max_call_body_bytes()` cap when streaming the body to bytes.
1800        .layer(DefaultBodyLimit::max(max_call_body_bytes()))
1801        .layer(middleware::from_fn_with_state(
1802            state.clone(),
1803            auth_middleware,
1804        ))
1805        .with_state(state)
1806}
1807
1808// --- Server startup ---
1809
1810/// Start the proxy server.
1811pub async fn run(
1812    port: u16,
1813    bind_addr: Option<String>,
1814    ati_dir: PathBuf,
1815    _verbose: bool,
1816    env_keys: bool,
1817) -> Result<(), Box<dyn std::error::Error>> {
1818    // Load manifests
1819    let manifests_dir = ati_dir.join("manifests");
1820    let mut registry = ManifestRegistry::load(&manifests_dir)?;
1821    let provider_count = registry.list_providers().len();
1822
1823    // Load keyring
1824    let keyring_source;
1825    let keyring = if env_keys {
1826        // --env-keys: scan ATI_KEY_* environment variables
1827        let kr = Keyring::from_env();
1828        let key_names = kr.key_names();
1829        tracing::info!(
1830            count = key_names.len(),
1831            "loaded API keys from ATI_KEY_* env vars"
1832        );
1833        for name in &key_names {
1834            tracing::debug!(key = %name, "env key loaded");
1835        }
1836        keyring_source = "env-vars (ATI_KEY_*)";
1837        kr
1838    } else {
1839        // Cascade: keyring.enc (sealed) → keyring.enc (persistent) → credentials → empty
1840        let keyring_path = ati_dir.join("keyring.enc");
1841        if keyring_path.exists() {
1842            if let Ok(kr) = Keyring::load(&keyring_path) {
1843                keyring_source = "keyring.enc (sealed key)";
1844                kr
1845            } else if let Ok(kr) = Keyring::load_local(&keyring_path, &ati_dir) {
1846                keyring_source = "keyring.enc (persistent key)";
1847                kr
1848            } else {
1849                tracing::warn!("keyring.enc exists but could not be decrypted");
1850                keyring_source = "empty (decryption failed)";
1851                Keyring::empty()
1852            }
1853        } else {
1854            let creds_path = ati_dir.join("credentials");
1855            if creds_path.exists() {
1856                match Keyring::load_credentials(&creds_path) {
1857                    Ok(kr) => {
1858                        keyring_source = "credentials (plaintext)";
1859                        kr
1860                    }
1861                    Err(e) => {
1862                        tracing::warn!(error = %e, "failed to load credentials");
1863                        keyring_source = "empty (credentials error)";
1864                        Keyring::empty()
1865                    }
1866                }
1867            } else {
1868                tracing::warn!("no keyring.enc or credentials found — running without API keys");
1869                tracing::warn!("tools requiring authentication will fail");
1870                keyring_source = "empty (no auth)";
1871                Keyring::empty()
1872            }
1873        }
1874    };
1875
1876    // Discover MCP tools at startup so they appear in GET /tools.
1877    // Runs concurrently across providers with 30s per-provider timeout.
1878    mcp_client::discover_all_mcp_tools(&mut registry, &keyring).await;
1879
1880    let tool_count = registry.list_public_tools().len();
1881
1882    // Log MCP and OpenAPI providers
1883    let mcp_providers: Vec<(String, String)> = registry
1884        .list_mcp_providers()
1885        .iter()
1886        .map(|p| (p.name.clone(), p.mcp_transport_type().to_string()))
1887        .collect();
1888    let mcp_count = mcp_providers.len();
1889    let openapi_providers: Vec<String> = registry
1890        .list_openapi_providers()
1891        .iter()
1892        .map(|p| p.name.clone())
1893        .collect();
1894    let openapi_count = openapi_providers.len();
1895
1896    // Load installed/local skill registry only.
1897    let skills_dir = ati_dir.join("skills");
1898    let skill_registry = SkillRegistry::load(&skills_dir).unwrap_or_else(|e| {
1899        tracing::warn!(error = %e, "failed to load skills");
1900        SkillRegistry::load(std::path::Path::new("/nonexistent-fallback")).unwrap()
1901    });
1902
1903    if let Ok(registry_url) = std::env::var("ATI_SKILL_REGISTRY") {
1904        if registry_url.strip_prefix("gcs://").is_some() {
1905            tracing::info!(
1906                registry = %registry_url,
1907                "SkillATI remote registry configured for lazy reads"
1908            );
1909        } else {
1910            tracing::warn!(url = %registry_url, "SkillATI only supports gcs:// registries");
1911        }
1912    }
1913
1914    let skill_count = skill_registry.skill_count();
1915
1916    // Load JWT config from environment
1917    let jwt_config = match jwt::config_from_env() {
1918        Ok(config) => config,
1919        Err(e) => {
1920            tracing::warn!(error = %e, "JWT config error");
1921            None
1922        }
1923    };
1924
1925    let auth_status = if jwt_config.is_some() {
1926        "JWT enabled"
1927    } else {
1928        "DISABLED (no JWT keys configured)"
1929    };
1930
1931    // Build JWKS for the endpoint
1932    let jwks_json = jwt_config.as_ref().and_then(|config| {
1933        config
1934            .public_key_pem
1935            .as_ref()
1936            .and_then(|pem| jwt::public_key_to_jwks(pem, config.algorithm, "ati-proxy-1").ok())
1937    });
1938
1939    let state = Arc::new(ProxyState {
1940        registry,
1941        skill_registry,
1942        keyring,
1943        jwt_config,
1944        jwks_json,
1945        auth_cache: AuthCache::new(),
1946    });
1947
1948    let app = build_router(state);
1949
1950    let addr: SocketAddr = if let Some(ref bind) = bind_addr {
1951        format!("{bind}:{port}").parse()?
1952    } else {
1953        SocketAddr::from(([127, 0, 0, 1], port))
1954    };
1955
1956    tracing::info!(
1957        version = env!("CARGO_PKG_VERSION"),
1958        %addr,
1959        auth = auth_status,
1960        ati_dir = %ati_dir.display(),
1961        tools = tool_count,
1962        providers = provider_count,
1963        mcp = mcp_count,
1964        openapi = openapi_count,
1965        skills = skill_count,
1966        keyring = keyring_source,
1967        "ATI proxy server starting"
1968    );
1969    for (name, transport) in &mcp_providers {
1970        tracing::info!(provider = %name, transport = %transport, "MCP provider");
1971    }
1972    for name in &openapi_providers {
1973        tracing::info!(provider = %name, "OpenAPI provider");
1974    }
1975
1976    let listener = tokio::net::TcpListener::bind(addr).await?;
1977    axum::serve(listener, app).await?;
1978
1979    Ok(())
1980}
1981
1982/// Dispatch a `file_manager:*` tool call. Returns either a JSON payload or an
1983/// (HTTP status, message) error for the caller to forward.
1984async fn dispatch_file_manager(
1985    tool_name: &str,
1986    args: &HashMap<String, Value>,
1987    provider: &Provider,
1988    keyring: &Keyring,
1989) -> Result<Value, (StatusCode, String)> {
1990    use crate::core::file_manager::{self, DownloadArgs, FileManagerError, UploadArgs};
1991
1992    // One mapping, derived from FileManagerError::http_status, so adding an
1993    // error variant can't silently regress one handler while the other updates.
1994    let to_resp = |e: FileManagerError| {
1995        let status =
1996            StatusCode::from_u16(e.http_status()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
1997        (status, e.to_string())
1998    };
1999
2000    match tool_name {
2001        "file_manager:download" => {
2002            let parsed = DownloadArgs::from_value(args).map_err(to_resp)?;
2003            let result = file_manager::fetch_bytes(&parsed).await.map_err(to_resp)?;
2004            Ok(file_manager::build_download_response(&result))
2005        }
2006        "file_manager:upload" => {
2007            let parsed = UploadArgs::from_wire(args).map_err(to_resp)?;
2008            file_manager::upload_to_destination(
2009                parsed,
2010                &provider.upload_destinations,
2011                provider.upload_default_destination.as_deref(),
2012                keyring,
2013            )
2014            .await
2015            .map_err(to_resp)
2016        }
2017        other => Err((
2018            StatusCode::NOT_FOUND,
2019            format!("Unknown file_manager tool: '{other}'"),
2020        )),
2021    }
2022}
2023
2024fn write_proxy_audit(
2025    call_req: &CallRequest,
2026    agent_sub: &str,
2027    claims: Option<&TokenClaims>,
2028    duration: std::time::Duration,
2029    error: Option<&str>,
2030) {
2031    let entry = crate::core::audit::AuditEntry {
2032        ts: chrono::Utc::now().to_rfc3339(),
2033        tool: call_req.tool_name.clone(),
2034        args: crate::core::audit::sanitize_args(&call_req.args),
2035        status: if error.is_some() {
2036            crate::core::audit::AuditStatus::Error
2037        } else {
2038            crate::core::audit::AuditStatus::Ok
2039        },
2040        duration_ms: duration.as_millis() as u64,
2041        agent_sub: agent_sub.to_string(),
2042        job_id: claims.and_then(|c| c.job_id.clone()),
2043        sandbox_id: claims.and_then(|c| c.sandbox_id.clone()),
2044        error: error.map(|s| s.to_string()),
2045        exit_code: None,
2046    };
2047    let _ = crate::core::audit::append(&entry);
2048}
2049
2050// --- Helpers ---
2051
2052const HELP_SYSTEM_PROMPT: &str = r#"You are a helpful assistant for an AI agent that uses external tools via the `ati` CLI.
2053
2054## Available Tools
2055{tools}
2056
2057{skills_section}
2058
2059Answer the agent's question naturally, like a knowledgeable colleague would. Keep it short but useful:
2060
2061- Explain which tools to use and why, with `ati run` commands showing realistic parameter values
2062- If multiple steps are needed, walk through them briefly in order
2063- Mention important gotchas or parameter choices that matter
2064- If skills are relevant, tell the agent to load them using the Skill tool (e.g., `skill: "research-financial-data"`)
2065
2066Keep your answer concise — a few short paragraphs with embedded code blocks. Only recommend tools from the list above."#;
2067
2068async fn build_remote_skillati_section(keyring: &Keyring, query: &str, limit: usize) -> String {
2069    let client = match SkillAtiClient::from_env(keyring) {
2070        Ok(Some(client)) => client,
2071        Ok(None) => return String::new(),
2072        Err(err) => {
2073            tracing::warn!(error = %err, "failed to initialize SkillATI catalog for proxy help");
2074            return String::new();
2075        }
2076    };
2077
2078    let catalog = match client.catalog().await {
2079        Ok(catalog) => catalog,
2080        Err(err) => {
2081            tracing::warn!(error = %err, "failed to load SkillATI catalog for proxy help");
2082            return String::new();
2083        }
2084    };
2085
2086    let matched = SkillAtiClient::filter_catalog(&catalog, query, limit);
2087    if matched.is_empty() {
2088        return String::new();
2089    }
2090
2091    render_remote_skillati_section(&matched, catalog.len())
2092}
2093
2094fn render_remote_skillati_section(skills: &[RemoteSkillMeta], total_catalog: usize) -> String {
2095    let mut section = String::from("## Remote Skills Available Via SkillATI\n\n");
2096    section.push_str(
2097        "These skills are available. Load them using the Skill tool (e.g., `skill: \"skill-name\"`).\n\n",
2098    );
2099
2100    for skill in skills {
2101        section.push_str(&format!("- **{}**: {}\n", skill.name, skill.description));
2102    }
2103
2104    if total_catalog > skills.len() {
2105        section.push_str(&format!(
2106            "\nOnly the most relevant {} remote skills are shown here.\n",
2107            skills.len()
2108        ));
2109    }
2110
2111    section
2112}
2113
2114fn merge_help_skill_sections(sections: &[String]) -> String {
2115    sections
2116        .iter()
2117        .filter_map(|section| {
2118            let trimmed = section.trim();
2119            if trimmed.is_empty() {
2120                None
2121            } else {
2122                Some(trimmed.to_string())
2123            }
2124        })
2125        .collect::<Vec<_>>()
2126        .join("\n\n")
2127}
2128
2129fn build_tool_context(
2130    tools: &[(
2131        &crate::core::manifest::Provider,
2132        &crate::core::manifest::Tool,
2133    )],
2134) -> String {
2135    let mut summaries = Vec::new();
2136    for (provider, tool) in tools {
2137        let mut summary = if let Some(cat) = &provider.category {
2138            format!(
2139                "- **{}** (provider: {}, category: {}): {}",
2140                tool.name, provider.name, cat, tool.description
2141            )
2142        } else {
2143            format!(
2144                "- **{}** (provider: {}): {}",
2145                tool.name, provider.name, tool.description
2146            )
2147        };
2148        if !tool.tags.is_empty() {
2149            summary.push_str(&format!("\n  Tags: {}", tool.tags.join(", ")));
2150        }
2151        // CLI tools: show passthrough usage
2152        if provider.is_cli() && tool.input_schema.is_none() {
2153            let cmd = provider.cli_command.as_deref().unwrap_or("?");
2154            summary.push_str(&format!(
2155                "\n  Usage: `ati run {} -- <args>`  (passthrough to `{}`)",
2156                tool.name, cmd
2157            ));
2158        } else if let Some(schema) = &tool.input_schema {
2159            if let Some(props) = schema.get("properties") {
2160                if let Some(obj) = props.as_object() {
2161                    let params: Vec<String> = obj
2162                        .iter()
2163                        .filter(|(_, v)| {
2164                            v.get("x-ati-param-location").is_none()
2165                                || v.get("description").is_some()
2166                        })
2167                        .map(|(k, v)| {
2168                            let type_str =
2169                                v.get("type").and_then(|t| t.as_str()).unwrap_or("string");
2170                            let desc = v.get("description").and_then(|d| d.as_str()).unwrap_or("");
2171                            format!("    --{k} ({type_str}): {desc}")
2172                        })
2173                        .collect();
2174                    if !params.is_empty() {
2175                        summary.push_str("\n  Parameters:\n");
2176                        summary.push_str(&params.join("\n"));
2177                    }
2178                }
2179            }
2180        }
2181        summaries.push(summary);
2182    }
2183    summaries.join("\n\n")
2184}
2185
2186/// Build a scoped system prompt for a specific tool or provider.
2187///
2188/// Returns None if the scope_name doesn't match any tool or provider.
2189fn build_scoped_prompt(
2190    scope_name: &str,
2191    visible_tools: &[(&Provider, &Tool)],
2192    skills_section: &str,
2193) -> Option<String> {
2194    // Check if scope_name is a tool
2195    if let Some((provider, tool)) = visible_tools
2196        .iter()
2197        .find(|(_, tool)| tool.name == scope_name)
2198    {
2199        let mut details = format!(
2200            "**Name**: `{}`\n**Provider**: {} (handler: {})\n**Description**: {}\n",
2201            tool.name, provider.name, provider.handler, tool.description
2202        );
2203        if let Some(cat) = &provider.category {
2204            details.push_str(&format!("**Category**: {}\n", cat));
2205        }
2206        if provider.is_cli() {
2207            let cmd = provider.cli_command.as_deref().unwrap_or("?");
2208            details.push_str(&format!(
2209                "\n**Usage**: `ati run {} -- <args>`  (passthrough to `{}`)\n",
2210                tool.name, cmd
2211            ));
2212        } else if let Some(schema) = &tool.input_schema {
2213            if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
2214                let required: Vec<String> = schema
2215                    .get("required")
2216                    .and_then(|r| r.as_array())
2217                    .map(|arr| {
2218                        arr.iter()
2219                            .filter_map(|v| v.as_str().map(|s| s.to_string()))
2220                            .collect()
2221                    })
2222                    .unwrap_or_default();
2223                details.push_str("\n**Parameters**:\n");
2224                for (key, val) in props {
2225                    let type_str = val.get("type").and_then(|t| t.as_str()).unwrap_or("string");
2226                    let desc = val
2227                        .get("description")
2228                        .and_then(|d| d.as_str())
2229                        .unwrap_or("");
2230                    let req = if required.contains(key) {
2231                        " **(required)**"
2232                    } else {
2233                        ""
2234                    };
2235                    details.push_str(&format!("- `--{key}` ({type_str}{req}): {desc}\n"));
2236                }
2237            }
2238        }
2239
2240        let prompt = format!(
2241            "You are an expert assistant for the `{}` tool, accessed via the `ati` CLI.\n\n\
2242            ## Tool Details\n{}\n\n{}\n\n\
2243            Answer the agent's question about this specific tool. Provide exact commands, explain flags and options, and give practical examples. Be concise and actionable.",
2244            tool.name, details, skills_section
2245        );
2246        return Some(prompt);
2247    }
2248
2249    // Check if scope_name is a provider
2250    let tools: Vec<(&Provider, &Tool)> = visible_tools
2251        .iter()
2252        .copied()
2253        .filter(|(provider, _)| provider.name == scope_name)
2254        .collect();
2255    if !tools.is_empty() {
2256        let tools_context = build_tool_context(&tools);
2257        let prompt = format!(
2258            "You are an expert assistant for the `{}` provider's tools, accessed via the `ati` CLI.\n\n\
2259            ## Tools in provider `{}`\n{}\n\n{}\n\n\
2260            Answer the agent's question about these tools. Provide exact `ati run` commands, explain parameters, and give practical examples. Be concise and actionable.",
2261            scope_name, scope_name, tools_context, skills_section
2262        );
2263        return Some(prompt);
2264    }
2265
2266    None
2267}