Skip to main content

ati/proxy/
server.rs

1/// ATI proxy server — holds API keys and executes tool calls on behalf of sandbox agents.
2///
3/// Authentication: ES256-signed JWT (or HS256 fallback). The JWT carries identity,
4/// scopes, and expiry. No more static tokens or unsigned scope lists.
5///
6/// Usage: `ati proxy --port 8080 [--ati-dir ~/.ati]`
7use axum::{
8    body::Body,
9    extract::{Extension, Query, State},
10    http::{Request as HttpRequest, StatusCode},
11    middleware::{self, Next},
12    response::{IntoResponse, Response},
13    routing::{get, post},
14    Json, Router,
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::core::auth_generator::{AuthCache, GenContext};
24use crate::core::http;
25use crate::core::jwt::{self, JwtConfig, TokenClaims};
26use crate::core::keyring::Keyring;
27use crate::core::manifest::{ManifestRegistry, Provider, Tool};
28use crate::core::mcp_client;
29use crate::core::response;
30use crate::core::scope::ScopeConfig;
31use crate::core::skill::{self, SkillRegistry};
32use crate::core::skillati::{RemoteSkillMeta, SkillAtiClient, SkillAtiError};
33use crate::core::xai;
34
35/// Shared state for the proxy server.
36pub struct ProxyState {
37    pub registry: ManifestRegistry,
38    pub skill_registry: SkillRegistry,
39    pub keyring: Keyring,
40    /// JWT validation config (None = auth disabled / dev mode).
41    pub jwt_config: Option<JwtConfig>,
42    /// Pre-computed JWKS JSON for the /.well-known/jwks.json endpoint.
43    pub jwks_json: Option<Value>,
44    /// Shared cache for dynamically generated auth credentials.
45    pub auth_cache: AuthCache,
46}
47
48// --- Request/Response types ---
49
50#[derive(Debug, Deserialize)]
51pub struct CallRequest {
52    pub tool_name: String,
53    /// Tool arguments — accepts a JSON object (key-value pairs) for HTTP/MCP/OpenAPI tools,
54    /// or a JSON array of strings / a single string for CLI tools.
55    /// The proxy auto-detects the handler type and routes accordingly.
56    #[serde(default = "default_args")]
57    pub args: Value,
58    /// Deprecated: use `args` with an array value instead.
59    /// Kept for backward compatibility — if present, takes precedence for CLI tools.
60    #[serde(default)]
61    pub raw_args: Option<Vec<String>>,
62}
63
64fn default_args() -> Value {
65    Value::Object(serde_json::Map::new())
66}
67
68impl CallRequest {
69    /// Extract args as a HashMap for HTTP/MCP/OpenAPI tools.
70    /// If `args` is a JSON object, returns its entries.
71    /// If `args` is something else (array, string), returns an empty map.
72    fn args_as_map(&self) -> HashMap<String, Value> {
73        match &self.args {
74            Value::Object(map) => map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
75            _ => HashMap::new(),
76        }
77    }
78
79    /// Extract positional args for CLI tools.
80    /// Priority: explicit `raw_args` field > `args` array > `args` string > `args._positional` > empty.
81    fn args_as_positional(&self) -> Vec<String> {
82        // Backward compat: explicit raw_args wins
83        if let Some(ref raw) = self.raw_args {
84            return raw.clone();
85        }
86        match &self.args {
87            // ["pr", "list", "--repo", "X"]
88            Value::Array(arr) => arr
89                .iter()
90                .map(|v| match v {
91                    Value::String(s) => s.clone(),
92                    other => other.to_string(),
93                })
94                .collect(),
95            // "pr list --repo X"
96            Value::String(s) => s.split_whitespace().map(String::from).collect(),
97            // {"_positional": ["pr", "list"]} or {"--key": "value"} converted to CLI flags
98            Value::Object(map) => {
99                if let Some(Value::Array(pos)) = map.get("_positional") {
100                    return pos
101                        .iter()
102                        .map(|v| match v {
103                            Value::String(s) => s.clone(),
104                            other => other.to_string(),
105                        })
106                        .collect();
107                }
108                // Convert map entries to --key value pairs
109                let mut result = Vec::new();
110                for (k, v) in map {
111                    result.push(format!("--{k}"));
112                    match v {
113                        Value::String(s) => result.push(s.clone()),
114                        Value::Bool(true) => {} // flag, no value needed
115                        other => result.push(other.to_string()),
116                    }
117                }
118                result
119            }
120            _ => Vec::new(),
121        }
122    }
123}
124
125#[derive(Debug, Serialize)]
126pub struct CallResponse {
127    pub result: Value,
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub error: Option<String>,
130}
131
132#[derive(Debug, Deserialize)]
133pub struct HelpRequest {
134    pub query: String,
135    #[serde(default)]
136    pub tool: Option<String>,
137}
138
139#[derive(Debug, Serialize)]
140pub struct HelpResponse {
141    pub content: String,
142    #[serde(skip_serializing_if = "Option::is_none")]
143    pub error: Option<String>,
144}
145
146#[derive(Debug, Serialize)]
147pub struct HealthResponse {
148    pub status: String,
149    pub version: String,
150    pub tools: usize,
151    pub providers: usize,
152    pub skills: usize,
153    pub auth: String,
154}
155
156// --- Skill endpoint types ---
157
158#[derive(Debug, Deserialize)]
159pub struct SkillsQuery {
160    #[serde(default)]
161    pub category: Option<String>,
162    #[serde(default)]
163    pub provider: Option<String>,
164    #[serde(default)]
165    pub tool: Option<String>,
166    #[serde(default)]
167    pub search: Option<String>,
168}
169
170#[derive(Debug, Deserialize)]
171pub struct SkillDetailQuery {
172    #[serde(default)]
173    pub meta: Option<bool>,
174    #[serde(default)]
175    pub refs: Option<bool>,
176}
177
178#[derive(Debug, Deserialize)]
179pub struct SkillResolveRequest {
180    pub scopes: Vec<String>,
181    /// When true, include SKILL.md content in each resolved skill.
182    #[serde(default)]
183    pub include_content: bool,
184}
185
186#[derive(Debug, Deserialize)]
187pub struct SkillBundleBatchRequest {
188    pub names: Vec<String>,
189}
190
191#[derive(Debug, Deserialize, Default)]
192pub struct SkillAtiCatalogQuery {
193    #[serde(default)]
194    pub search: Option<String>,
195}
196
197#[derive(Debug, Deserialize, Default)]
198pub struct SkillAtiResourcesQuery {
199    #[serde(default)]
200    pub prefix: Option<String>,
201}
202
203#[derive(Debug, Deserialize)]
204pub struct SkillAtiFileQuery {
205    pub path: String,
206}
207
208// --- Tool endpoint types ---
209
210#[derive(Debug, Deserialize)]
211pub struct ToolsQuery {
212    #[serde(default)]
213    pub provider: Option<String>,
214    #[serde(default)]
215    pub search: Option<String>,
216}
217
218// --- Handlers ---
219
220fn scopes_for_request(claims: Option<&TokenClaims>, state: &ProxyState) -> ScopeConfig {
221    match claims {
222        Some(claims) => ScopeConfig::from_jwt(claims),
223        None if state.jwt_config.is_none() => ScopeConfig::unrestricted(),
224        None => ScopeConfig {
225            scopes: Vec::new(),
226            sub: String::new(),
227            expires_at: 0,
228            rate_config: None,
229        },
230    }
231}
232
233fn visible_tools_for_scopes<'a>(
234    state: &'a ProxyState,
235    scopes: &ScopeConfig,
236) -> Vec<(&'a Provider, &'a Tool)> {
237    crate::core::scope::filter_tools_by_scope(state.registry.list_public_tools(), scopes)
238}
239
240fn visible_skill_names(
241    state: &ProxyState,
242    scopes: &ScopeConfig,
243) -> std::collections::HashSet<String> {
244    skill::visible_skills(&state.skill_registry, &state.registry, scopes)
245        .into_iter()
246        .map(|skill| skill.name.clone())
247        .collect()
248}
249
250async fn handle_call(
251    State(state): State<Arc<ProxyState>>,
252    req: HttpRequest<Body>,
253) -> impl IntoResponse {
254    // Extract JWT claims from request extensions (set by auth middleware)
255    let claims = req.extensions().get::<TokenClaims>().cloned();
256
257    // Parse request body
258    let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
259        Ok(b) => b,
260        Err(e) => {
261            return (
262                StatusCode::BAD_REQUEST,
263                Json(CallResponse {
264                    result: Value::Null,
265                    error: Some(format!("Failed to read request body: {e}")),
266                }),
267            );
268        }
269    };
270
271    let call_req: CallRequest = match serde_json::from_slice(&body_bytes) {
272        Ok(r) => r,
273        Err(e) => {
274            return (
275                StatusCode::UNPROCESSABLE_ENTITY,
276                Json(CallResponse {
277                    result: Value::Null,
278                    error: Some(format!("Invalid request: {e}")),
279                }),
280            );
281        }
282    };
283
284    tracing::debug!(
285        tool = %call_req.tool_name,
286        args = ?call_req.args,
287        "POST /call"
288    );
289
290    // Look up tool in registry.
291    // If not found, try converting underscore format (finnhub_quote) to colon (finnhub:quote).
292    let (provider, tool) = match state.registry.get_tool(&call_req.tool_name) {
293        Some(pt) => pt,
294        None => {
295            // Try underscore → colon conversion at each underscore position.
296            // "finnhub_quote" → try "finnhub:quote"
297            // "test_api_get_data" → try "test:api_get_data", "test_api:get_data"
298            let mut resolved = None;
299            for (idx, _) in call_req.tool_name.match_indices('_') {
300                let candidate = format!(
301                    "{}:{}",
302                    &call_req.tool_name[..idx],
303                    &call_req.tool_name[idx + 1..]
304                );
305                if let Some(pt) = state.registry.get_tool(&candidate) {
306                    tracing::debug!(
307                        original = %call_req.tool_name,
308                        resolved = %candidate,
309                        "resolved underscore tool name to colon format"
310                    );
311                    resolved = Some(pt);
312                    break;
313                }
314            }
315
316            match resolved {
317                Some(pt) => pt,
318                None => {
319                    return (
320                        StatusCode::NOT_FOUND,
321                        Json(CallResponse {
322                            result: Value::Null,
323                            error: Some(format!("Unknown tool: '{}'", call_req.tool_name)),
324                        }),
325                    );
326                }
327            }
328        }
329    };
330
331    // Scope enforcement from JWT claims
332    if let Some(tool_scope) = &tool.scope {
333        let scopes = match &claims {
334            Some(c) => ScopeConfig::from_jwt(c),
335            None if state.jwt_config.is_none() => ScopeConfig::unrestricted(), // Dev mode
336            None => {
337                return (
338                    StatusCode::FORBIDDEN,
339                    Json(CallResponse {
340                        result: Value::Null,
341                        error: Some("Authentication required — no JWT provided".into()),
342                    }),
343                );
344            }
345        };
346
347        if !scopes.is_allowed(tool_scope) {
348            return (
349                StatusCode::FORBIDDEN,
350                Json(CallResponse {
351                    result: Value::Null,
352                    error: Some(format!(
353                        "Access denied: '{}' is not in your scopes",
354                        tool.name
355                    )),
356                }),
357            );
358        }
359    }
360
361    // Rate limit check
362    {
363        let scopes = match &claims {
364            Some(c) => ScopeConfig::from_jwt(c),
365            None => ScopeConfig::unrestricted(),
366        };
367        if let Some(ref rate_config) = scopes.rate_config {
368            if let Err(e) = crate::core::rate::check_and_record(&call_req.tool_name, rate_config) {
369                return (
370                    StatusCode::TOO_MANY_REQUESTS,
371                    Json(CallResponse {
372                        result: Value::Null,
373                        error: Some(format!("{e}")),
374                    }),
375                );
376            }
377        }
378    }
379
380    // Build auth generator context from JWT claims
381    let gen_ctx = GenContext {
382        jwt_sub: claims
383            .as_ref()
384            .map(|c| c.sub.clone())
385            .unwrap_or_else(|| "dev".into()),
386        jwt_scope: claims
387            .as_ref()
388            .map(|c| c.scope.clone())
389            .unwrap_or_else(|| "*".into()),
390        tool_name: call_req.tool_name.clone(),
391        timestamp: crate::core::jwt::now_secs(),
392    };
393
394    // Execute tool call — dispatch based on handler type, with timing for audit
395    let agent_sub = claims.as_ref().map(|c| c.sub.clone()).unwrap_or_default();
396    let job_id = claims
397        .as_ref()
398        .and_then(|c| c.job_id.clone())
399        .unwrap_or_default();
400    let sandbox_id = claims
401        .as_ref()
402        .and_then(|c| c.sandbox_id.clone())
403        .unwrap_or_default();
404    tracing::info!(
405        tool = %call_req.tool_name,
406        agent = %agent_sub,
407        job_id = %job_id,
408        sandbox_id = %sandbox_id,
409        "tool call"
410    );
411    let start = std::time::Instant::now();
412
413    let response = match provider.handler.as_str() {
414        "mcp" => {
415            let args_map = call_req.args_as_map();
416            match mcp_client::execute_with_gen(
417                provider,
418                &call_req.tool_name,
419                &args_map,
420                &state.keyring,
421                Some(&gen_ctx),
422                Some(&state.auth_cache),
423            )
424            .await
425            {
426                Ok(result) => (
427                    StatusCode::OK,
428                    Json(CallResponse {
429                        result,
430                        error: None,
431                    }),
432                ),
433                Err(e) => (
434                    StatusCode::BAD_GATEWAY,
435                    Json(CallResponse {
436                        result: Value::Null,
437                        error: Some(format!("MCP error: {e}")),
438                    }),
439                ),
440            }
441        }
442        "cli" => {
443            let positional = call_req.args_as_positional();
444            match crate::core::cli_executor::execute_with_gen(
445                provider,
446                &positional,
447                &state.keyring,
448                Some(&gen_ctx),
449                Some(&state.auth_cache),
450            )
451            .await
452            {
453                Ok(result) => (
454                    StatusCode::OK,
455                    Json(CallResponse {
456                        result,
457                        error: None,
458                    }),
459                ),
460                Err(e) => (
461                    StatusCode::BAD_GATEWAY,
462                    Json(CallResponse {
463                        result: Value::Null,
464                        error: Some(format!("CLI error: {e}")),
465                    }),
466                ),
467            }
468        }
469        _ => {
470            let args_map = call_req.args_as_map();
471            let raw_response = match match provider.handler.as_str() {
472                "xai" => xai::execute_xai_tool(provider, tool, &args_map, &state.keyring).await,
473                _ => {
474                    http::execute_tool_with_gen(
475                        provider,
476                        tool,
477                        &args_map,
478                        &state.keyring,
479                        Some(&gen_ctx),
480                        Some(&state.auth_cache),
481                    )
482                    .await
483                }
484            } {
485                Ok(resp) => resp,
486                Err(e) => {
487                    let duration = start.elapsed();
488                    write_proxy_audit(
489                        &call_req,
490                        &agent_sub,
491                        claims.as_ref(),
492                        duration,
493                        Some(&e.to_string()),
494                    );
495                    return (
496                        StatusCode::BAD_GATEWAY,
497                        Json(CallResponse {
498                            result: Value::Null,
499                            error: Some(format!("Upstream API error: {e}")),
500                        }),
501                    );
502                }
503            };
504
505            let processed = match response::process_response(&raw_response, tool.response.as_ref())
506            {
507                Ok(p) => p,
508                Err(e) => {
509                    let duration = start.elapsed();
510                    write_proxy_audit(
511                        &call_req,
512                        &agent_sub,
513                        claims.as_ref(),
514                        duration,
515                        Some(&e.to_string()),
516                    );
517                    return (
518                        StatusCode::INTERNAL_SERVER_ERROR,
519                        Json(CallResponse {
520                            result: raw_response,
521                            error: Some(format!("Response processing error: {e}")),
522                        }),
523                    );
524                }
525            };
526
527            (
528                StatusCode::OK,
529                Json(CallResponse {
530                    result: processed,
531                    error: None,
532                }),
533            )
534        }
535    };
536
537    let duration = start.elapsed();
538    let error_msg = response.1.error.as_deref();
539    write_proxy_audit(&call_req, &agent_sub, claims.as_ref(), duration, error_msg);
540
541    response
542}
543
544async fn handle_help(
545    State(state): State<Arc<ProxyState>>,
546    claims: Option<Extension<TokenClaims>>,
547    Json(req): Json<HelpRequest>,
548) -> impl IntoResponse {
549    tracing::debug!(query = %req.query, tool = ?req.tool, "POST /help");
550
551    let claims = claims.map(|Extension(claims)| claims);
552    let scopes = scopes_for_request(claims.as_ref(), &state);
553
554    let (llm_provider, llm_tool) = match state.registry.get_tool("_chat_completion") {
555        Some(pt) => pt,
556        None => {
557            return (
558                StatusCode::SERVICE_UNAVAILABLE,
559                Json(HelpResponse {
560                    content: String::new(),
561                    error: Some("No _llm.toml manifest found. Proxy help requires a configured LLM provider.".into()),
562                }),
563            );
564        }
565    };
566
567    let api_key = match llm_provider
568        .auth_key_name
569        .as_deref()
570        .and_then(|k| state.keyring.get(k))
571    {
572        Some(key) => key.to_string(),
573        None => {
574            return (
575                StatusCode::SERVICE_UNAVAILABLE,
576                Json(HelpResponse {
577                    content: String::new(),
578                    error: Some("LLM API key not found in keyring".into()),
579                }),
580            );
581        }
582    };
583
584    let resolved_skills = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
585    let local_skills_section = if resolved_skills.is_empty() {
586        String::new()
587    } else {
588        format!(
589            "## Available Skills (methodology guides)\n{}",
590            skill::build_skill_context(&resolved_skills)
591        )
592    };
593    let remote_query = req
594        .tool
595        .as_ref()
596        .map(|tool| format!("{tool} {}", req.query))
597        .unwrap_or_else(|| req.query.clone());
598    let remote_skills_section =
599        build_remote_skillati_section(&state.keyring, &remote_query, 12).await;
600    let skills_section = merge_help_skill_sections(&[local_skills_section, remote_skills_section]);
601
602    // Build system prompt — scoped or unscoped
603    let visible_tools = visible_tools_for_scopes(&state, &scopes);
604    let system_prompt = if let Some(ref tool_name) = req.tool {
605        // Scoped mode: narrow tools to the specified tool or provider
606        match build_scoped_prompt(tool_name, &visible_tools, &skills_section) {
607            Some(prompt) => prompt,
608            None => {
609                return (
610                    StatusCode::FORBIDDEN,
611                    Json(HelpResponse {
612                        content: String::new(),
613                        error: Some(format!(
614                            "Scope '{tool_name}' is not visible in your current scopes."
615                        )),
616                    }),
617                );
618            }
619        }
620    } else {
621        let tools_context = build_tool_context(&visible_tools);
622        HELP_SYSTEM_PROMPT
623            .replace("{tools}", &tools_context)
624            .replace("{skills_section}", &skills_section)
625    };
626
627    let request_body = serde_json::json!({
628        "model": "zai-glm-4.7",
629        "messages": [
630            {"role": "system", "content": system_prompt},
631            {"role": "user", "content": req.query}
632        ],
633        "max_completion_tokens": 1536,
634        "temperature": 0.3
635    });
636
637    let client = reqwest::Client::new();
638    let url = format!(
639        "{}{}",
640        llm_provider.base_url.trim_end_matches('/'),
641        llm_tool.endpoint
642    );
643
644    let response = match client
645        .post(&url)
646        .bearer_auth(&api_key)
647        .json(&request_body)
648        .send()
649        .await
650    {
651        Ok(r) => r,
652        Err(e) => {
653            return (
654                StatusCode::BAD_GATEWAY,
655                Json(HelpResponse {
656                    content: String::new(),
657                    error: Some(format!("LLM request failed: {e}")),
658                }),
659            );
660        }
661    };
662
663    if !response.status().is_success() {
664        let status = response.status();
665        let body = response.text().await.unwrap_or_default();
666        return (
667            StatusCode::BAD_GATEWAY,
668            Json(HelpResponse {
669                content: String::new(),
670                error: Some(format!("LLM API error ({status}): {body}")),
671            }),
672        );
673    }
674
675    let body: Value = match response.json().await {
676        Ok(b) => b,
677        Err(e) => {
678            return (
679                StatusCode::INTERNAL_SERVER_ERROR,
680                Json(HelpResponse {
681                    content: String::new(),
682                    error: Some(format!("Failed to parse LLM response: {e}")),
683                }),
684            );
685        }
686    };
687
688    let content = body
689        .pointer("/choices/0/message/content")
690        .and_then(|c| c.as_str())
691        .unwrap_or("No response from LLM")
692        .to_string();
693
694    (
695        StatusCode::OK,
696        Json(HelpResponse {
697            content,
698            error: None,
699        }),
700    )
701}
702
703async fn handle_health(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
704    let auth = if state.jwt_config.is_some() {
705        "jwt"
706    } else {
707        "disabled"
708    };
709
710    Json(HealthResponse {
711        status: "ok".into(),
712        version: env!("CARGO_PKG_VERSION").into(),
713        tools: state.registry.list_public_tools().len(),
714        providers: state.registry.list_providers().len(),
715        skills: state.skill_registry.skill_count(),
716        auth: auth.into(),
717    })
718}
719
720/// GET /.well-known/jwks.json — serves the public key for JWT validation.
721async fn handle_jwks(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
722    match &state.jwks_json {
723        Some(jwks) => (StatusCode::OK, Json(jwks.clone())),
724        None => (
725            StatusCode::NOT_FOUND,
726            Json(serde_json::json!({"error": "JWKS not configured"})),
727        ),
728    }
729}
730
731// ---------------------------------------------------------------------------
732// POST /mcp — MCP JSON-RPC proxy endpoint
733// ---------------------------------------------------------------------------
734
735async fn handle_mcp(
736    State(state): State<Arc<ProxyState>>,
737    claims: Option<Extension<TokenClaims>>,
738    Json(msg): Json<Value>,
739) -> impl IntoResponse {
740    let claims = claims.map(|Extension(claims)| claims);
741    let scopes = scopes_for_request(claims.as_ref(), &state);
742    let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
743    let id = msg.get("id").cloned();
744    tracing::info!(
745        %method,
746        agent = claims.as_ref().map(|c| c.sub.as_str()).unwrap_or(""),
747        job_id = claims.as_ref().and_then(|c| c.job_id.as_deref()).unwrap_or(""),
748        sandbox_id = claims.as_ref().and_then(|c| c.sandbox_id.as_deref()).unwrap_or(""),
749        "mcp call"
750    );
751
752    match method {
753        "initialize" => {
754            let result = serde_json::json!({
755                "protocolVersion": "2025-03-26",
756                "capabilities": {
757                    "tools": { "listChanged": false }
758                },
759                "serverInfo": {
760                    "name": "ati-proxy",
761                    "version": env!("CARGO_PKG_VERSION")
762                }
763            });
764            jsonrpc_success(id, result)
765        }
766
767        "notifications/initialized" => (StatusCode::ACCEPTED, Json(Value::Null)),
768
769        "tools/list" => {
770            let visible_tools = visible_tools_for_scopes(&state, &scopes);
771            let mcp_tools: Vec<Value> = visible_tools
772                .iter()
773                .map(|(_provider, tool)| {
774                    serde_json::json!({
775                        "name": tool.name,
776                        "description": tool.description,
777                        "inputSchema": tool.input_schema.clone().unwrap_or(serde_json::json!({
778                            "type": "object",
779                            "properties": {}
780                        }))
781                    })
782                })
783                .collect();
784
785            let result = serde_json::json!({
786                "tools": mcp_tools,
787            });
788            jsonrpc_success(id, result)
789        }
790
791        "tools/call" => {
792            let params = msg.get("params").cloned().unwrap_or(Value::Null);
793            let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
794            let arguments: HashMap<String, Value> = params
795                .get("arguments")
796                .and_then(|a| serde_json::from_value(a.clone()).ok())
797                .unwrap_or_default();
798
799            if tool_name.is_empty() {
800                return jsonrpc_error(id, -32602, "Missing tool name in params.name");
801            }
802
803            let (provider, _tool) = match state.registry.get_tool(tool_name) {
804                Some(pt) => pt,
805                None => {
806                    return jsonrpc_error(id, -32602, &format!("Unknown tool: '{tool_name}'"));
807                }
808            };
809
810            if let Some(tool_scope) = &_tool.scope {
811                if !scopes.is_allowed(tool_scope) {
812                    return jsonrpc_error(
813                        id,
814                        -32001,
815                        &format!("Access denied: '{}' is not in your scopes", _tool.name),
816                    );
817                }
818            }
819
820            tracing::debug!(%tool_name, provider = %provider.name, "MCP tools/call");
821
822            let mcp_gen_ctx = GenContext {
823                jwt_sub: claims
824                    .as_ref()
825                    .map(|claims| claims.sub.clone())
826                    .unwrap_or_else(|| "dev".into()),
827                jwt_scope: claims
828                    .as_ref()
829                    .map(|claims| claims.scope.clone())
830                    .unwrap_or_else(|| "*".into()),
831                tool_name: tool_name.to_string(),
832                timestamp: crate::core::jwt::now_secs(),
833            };
834
835            let result = if provider.is_mcp() {
836                mcp_client::execute_with_gen(
837                    provider,
838                    tool_name,
839                    &arguments,
840                    &state.keyring,
841                    Some(&mcp_gen_ctx),
842                    Some(&state.auth_cache),
843                )
844                .await
845            } else if provider.is_cli() {
846                // Convert arguments map to CLI-style args for MCP passthrough
847                let raw: Vec<String> = arguments
848                    .iter()
849                    .flat_map(|(k, v)| {
850                        let val = match v {
851                            Value::String(s) => s.clone(),
852                            other => other.to_string(),
853                        };
854                        vec![format!("--{k}"), val]
855                    })
856                    .collect();
857                crate::core::cli_executor::execute_with_gen(
858                    provider,
859                    &raw,
860                    &state.keyring,
861                    Some(&mcp_gen_ctx),
862                    Some(&state.auth_cache),
863                )
864                .await
865                .map_err(|e| mcp_client::McpError::Transport(e.to_string()))
866            } else {
867                match match provider.handler.as_str() {
868                    "xai" => {
869                        xai::execute_xai_tool(provider, _tool, &arguments, &state.keyring).await
870                    }
871                    _ => {
872                        http::execute_tool_with_gen(
873                            provider,
874                            _tool,
875                            &arguments,
876                            &state.keyring,
877                            Some(&mcp_gen_ctx),
878                            Some(&state.auth_cache),
879                        )
880                        .await
881                    }
882                } {
883                    Ok(val) => Ok(val),
884                    Err(e) => Err(mcp_client::McpError::Transport(e.to_string())),
885                }
886            };
887
888            match result {
889                Ok(value) => {
890                    let text = match &value {
891                        Value::String(s) => s.clone(),
892                        other => serde_json::to_string_pretty(other).unwrap_or_default(),
893                    };
894                    let mcp_result = serde_json::json!({
895                        "content": [{"type": "text", "text": text}],
896                        "isError": false,
897                    });
898                    jsonrpc_success(id, mcp_result)
899                }
900                Err(e) => {
901                    let mcp_result = serde_json::json!({
902                        "content": [{"type": "text", "text": format!("Error: {e}")}],
903                        "isError": true,
904                    });
905                    jsonrpc_success(id, mcp_result)
906                }
907            }
908        }
909
910        _ => jsonrpc_error(id, -32601, &format!("Method not found: '{method}'")),
911    }
912}
913
914fn jsonrpc_success(id: Option<Value>, result: Value) -> (StatusCode, Json<Value>) {
915    (
916        StatusCode::OK,
917        Json(serde_json::json!({
918            "jsonrpc": "2.0",
919            "id": id,
920            "result": result,
921        })),
922    )
923}
924
925fn jsonrpc_error(id: Option<Value>, code: i64, message: &str) -> (StatusCode, Json<Value>) {
926    (
927        StatusCode::OK,
928        Json(serde_json::json!({
929            "jsonrpc": "2.0",
930            "id": id,
931            "error": {
932                "code": code,
933                "message": message,
934            }
935        })),
936    )
937}
938
939// ---------------------------------------------------------------------------
940// Tool endpoints
941// ---------------------------------------------------------------------------
942
943/// GET /tools — list available tools with optional filters.
944async fn handle_tools_list(
945    State(state): State<Arc<ProxyState>>,
946    claims: Option<Extension<TokenClaims>>,
947    axum::extract::Query(query): axum::extract::Query<ToolsQuery>,
948) -> impl IntoResponse {
949    tracing::debug!(
950        provider = ?query.provider,
951        search = ?query.search,
952        "GET /tools"
953    );
954
955    let claims = claims.map(|Extension(claims)| claims);
956    let scopes = scopes_for_request(claims.as_ref(), &state);
957    let all_tools = visible_tools_for_scopes(&state, &scopes);
958
959    let tools: Vec<Value> = all_tools
960        .iter()
961        .filter(|(provider, tool)| {
962            if let Some(ref p) = query.provider {
963                if provider.name != *p {
964                    return false;
965                }
966            }
967            if let Some(ref q) = query.search {
968                let q = q.to_lowercase();
969                let name_match = tool.name.to_lowercase().contains(&q);
970                let desc_match = tool.description.to_lowercase().contains(&q);
971                let tag_match = tool.tags.iter().any(|t| t.to_lowercase().contains(&q));
972                if !name_match && !desc_match && !tag_match {
973                    return false;
974                }
975            }
976            true
977        })
978        .map(|(provider, tool)| {
979            serde_json::json!({
980                "name": tool.name,
981                "description": tool.description,
982                "provider": provider.name,
983                "method": format!("{:?}", tool.method),
984                "tags": tool.tags,
985                "skills": provider.skills,
986                "input_schema": tool.input_schema,
987            })
988        })
989        .collect();
990
991    (StatusCode::OK, Json(Value::Array(tools)))
992}
993
994/// GET /tools/:name — get detailed info about a specific tool.
995async fn handle_tool_info(
996    State(state): State<Arc<ProxyState>>,
997    claims: Option<Extension<TokenClaims>>,
998    axum::extract::Path(name): axum::extract::Path<String>,
999) -> impl IntoResponse {
1000    tracing::debug!(tool = %name, "GET /tools/:name");
1001
1002    let claims = claims.map(|Extension(claims)| claims);
1003    let scopes = scopes_for_request(claims.as_ref(), &state);
1004
1005    match state
1006        .registry
1007        .get_tool(&name)
1008        .filter(|(_, tool)| match &tool.scope {
1009            Some(scope) => scopes.is_allowed(scope),
1010            None => true,
1011        }) {
1012        Some((provider, tool)) => {
1013            // Merge skills from manifest + SkillRegistry (tool binding + provider binding)
1014            let mut skills: Vec<String> = provider.skills.clone();
1015            for s in state.skill_registry.skills_for_tool(&tool.name) {
1016                if !skills.contains(&s.name) {
1017                    skills.push(s.name.clone());
1018                }
1019            }
1020            for s in state.skill_registry.skills_for_provider(&provider.name) {
1021                if !skills.contains(&s.name) {
1022                    skills.push(s.name.clone());
1023                }
1024            }
1025
1026            (
1027                StatusCode::OK,
1028                Json(serde_json::json!({
1029                    "name": tool.name,
1030                    "description": tool.description,
1031                    "provider": provider.name,
1032                    "method": format!("{:?}", tool.method),
1033                    "endpoint": tool.endpoint,
1034                    "tags": tool.tags,
1035                    "hint": tool.hint,
1036                    "skills": skills,
1037                    "input_schema": tool.input_schema,
1038                    "scope": tool.scope,
1039                })),
1040            )
1041        }
1042        None => (
1043            StatusCode::NOT_FOUND,
1044            Json(serde_json::json!({"error": format!("Tool '{name}' not found")})),
1045        ),
1046    }
1047}
1048
1049// ---------------------------------------------------------------------------
1050// Skill endpoints
1051// ---------------------------------------------------------------------------
1052
1053async fn handle_skills_list(
1054    State(state): State<Arc<ProxyState>>,
1055    claims: Option<Extension<TokenClaims>>,
1056    axum::extract::Query(query): axum::extract::Query<SkillsQuery>,
1057) -> impl IntoResponse {
1058    tracing::debug!(
1059        category = ?query.category,
1060        provider = ?query.provider,
1061        tool = ?query.tool,
1062        search = ?query.search,
1063        "GET /skills"
1064    );
1065
1066    let claims = claims.map(|Extension(claims)| claims);
1067    let scopes = scopes_for_request(claims.as_ref(), &state);
1068    let visible_names = visible_skill_names(&state, &scopes);
1069
1070    let skills: Vec<&skill::SkillMeta> = if let Some(search_query) = &query.search {
1071        state
1072            .skill_registry
1073            .search(search_query)
1074            .into_iter()
1075            .filter(|skill| visible_names.contains(&skill.name))
1076            .collect()
1077    } else if let Some(cat) = &query.category {
1078        state
1079            .skill_registry
1080            .skills_for_category(cat)
1081            .into_iter()
1082            .filter(|skill| visible_names.contains(&skill.name))
1083            .collect()
1084    } else if let Some(prov) = &query.provider {
1085        state
1086            .skill_registry
1087            .skills_for_provider(prov)
1088            .into_iter()
1089            .filter(|skill| visible_names.contains(&skill.name))
1090            .collect()
1091    } else if let Some(t) = &query.tool {
1092        state
1093            .skill_registry
1094            .skills_for_tool(t)
1095            .into_iter()
1096            .filter(|skill| visible_names.contains(&skill.name))
1097            .collect()
1098    } else {
1099        state
1100            .skill_registry
1101            .list_skills()
1102            .iter()
1103            .filter(|skill| visible_names.contains(&skill.name))
1104            .collect()
1105    };
1106
1107    let json: Vec<Value> = skills
1108        .iter()
1109        .map(|s| {
1110            serde_json::json!({
1111                "name": s.name,
1112                "version": s.version,
1113                "description": s.description,
1114                "tools": s.tools,
1115                "providers": s.providers,
1116                "categories": s.categories,
1117                "hint": s.hint,
1118            })
1119        })
1120        .collect();
1121
1122    (StatusCode::OK, Json(Value::Array(json)))
1123}
1124
1125async fn handle_skill_detail(
1126    State(state): State<Arc<ProxyState>>,
1127    claims: Option<Extension<TokenClaims>>,
1128    axum::extract::Path(name): axum::extract::Path<String>,
1129    axum::extract::Query(query): axum::extract::Query<SkillDetailQuery>,
1130) -> impl IntoResponse {
1131    tracing::debug!(%name, meta = ?query.meta, refs = ?query.refs, "GET /skills/:name");
1132
1133    let claims = claims.map(|Extension(claims)| claims);
1134    let scopes = scopes_for_request(claims.as_ref(), &state);
1135    let visible_names = visible_skill_names(&state, &scopes);
1136
1137    let skill_meta = match state
1138        .skill_registry
1139        .get_skill(&name)
1140        .filter(|skill| visible_names.contains(&skill.name))
1141    {
1142        Some(s) => s,
1143        None => {
1144            return (
1145                StatusCode::NOT_FOUND,
1146                Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
1147            );
1148        }
1149    };
1150
1151    if query.meta.unwrap_or(false) {
1152        return (
1153            StatusCode::OK,
1154            Json(serde_json::json!({
1155                "name": skill_meta.name,
1156                "version": skill_meta.version,
1157                "description": skill_meta.description,
1158                "author": skill_meta.author,
1159                "tools": skill_meta.tools,
1160                "providers": skill_meta.providers,
1161                "categories": skill_meta.categories,
1162                "keywords": skill_meta.keywords,
1163                "hint": skill_meta.hint,
1164                "depends_on": skill_meta.depends_on,
1165                "suggests": skill_meta.suggests,
1166                "license": skill_meta.license,
1167                "compatibility": skill_meta.compatibility,
1168                "allowed_tools": skill_meta.allowed_tools,
1169                "format": skill_meta.format,
1170            })),
1171        );
1172    }
1173
1174    let content = match state.skill_registry.read_content(&name) {
1175        Ok(c) => c,
1176        Err(e) => {
1177            return (
1178                StatusCode::INTERNAL_SERVER_ERROR,
1179                Json(serde_json::json!({"error": format!("Failed to read skill: {e}")})),
1180            );
1181        }
1182    };
1183
1184    let mut response = serde_json::json!({
1185        "name": skill_meta.name,
1186        "version": skill_meta.version,
1187        "description": skill_meta.description,
1188        "content": content,
1189    });
1190
1191    if query.refs.unwrap_or(false) {
1192        if let Ok(refs) = state.skill_registry.list_references(&name) {
1193            response["references"] = serde_json::json!(refs);
1194        }
1195    }
1196
1197    (StatusCode::OK, Json(response))
1198}
1199
1200/// GET /skills/:name/bundle — return all files in a skill directory.
1201/// Response: `{"name": "...", "files": {"SKILL.md": "...", "scripts/generate.sh": "...", ...}}`
1202/// Binary files are base64-encoded; text files are returned as-is.
1203async fn handle_skill_bundle(
1204    State(state): State<Arc<ProxyState>>,
1205    claims: Option<Extension<TokenClaims>>,
1206    axum::extract::Path(name): axum::extract::Path<String>,
1207) -> impl IntoResponse {
1208    tracing::debug!(skill = %name, "GET /skills/:name/bundle");
1209
1210    let claims = claims.map(|Extension(claims)| claims);
1211    let scopes = scopes_for_request(claims.as_ref(), &state);
1212    let visible_names = visible_skill_names(&state, &scopes);
1213    if !visible_names.contains(&name) {
1214        return (
1215            StatusCode::NOT_FOUND,
1216            Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
1217        );
1218    }
1219
1220    let files = match state.skill_registry.bundle_files(&name) {
1221        Ok(f) => f,
1222        Err(_) => {
1223            return (
1224                StatusCode::NOT_FOUND,
1225                Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
1226            );
1227        }
1228    };
1229
1230    // Convert bytes to strings (UTF-8 text) or base64 for binary files
1231    let mut file_map = serde_json::Map::new();
1232    for (path, data) in &files {
1233        match std::str::from_utf8(data) {
1234            Ok(text) => {
1235                file_map.insert(path.clone(), Value::String(text.to_string()));
1236            }
1237            Err(_) => {
1238                // Binary file — base64 encode
1239                use base64::Engine;
1240                let encoded = base64::engine::general_purpose::STANDARD.encode(data);
1241                file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
1242            }
1243        }
1244    }
1245
1246    (
1247        StatusCode::OK,
1248        Json(serde_json::json!({
1249            "name": name,
1250            "files": file_map,
1251        })),
1252    )
1253}
1254
1255/// POST /skills/bundle — return all files for multiple skills in one response.
1256/// Request: `{"names": ["fal-generate", "compliance-screening"]}`
1257/// Response: `{"skills": {...}, "missing": [...]}`
1258async fn handle_skills_bundle_batch(
1259    State(state): State<Arc<ProxyState>>,
1260    claims: Option<Extension<TokenClaims>>,
1261    Json(req): Json<SkillBundleBatchRequest>,
1262) -> impl IntoResponse {
1263    const MAX_BATCH: usize = 50;
1264    if req.names.len() > MAX_BATCH {
1265        return (
1266            StatusCode::BAD_REQUEST,
1267            Json(
1268                serde_json::json!({"error": format!("batch size {} exceeds limit of {MAX_BATCH}", req.names.len())}),
1269            ),
1270        );
1271    }
1272
1273    tracing::debug!(names = ?req.names, "POST /skills/bundle");
1274
1275    let claims = claims.map(|Extension(claims)| claims);
1276    let scopes = scopes_for_request(claims.as_ref(), &state);
1277    let visible_names = visible_skill_names(&state, &scopes);
1278
1279    let mut result = serde_json::Map::new();
1280    let mut missing: Vec<String> = Vec::new();
1281
1282    for name in &req.names {
1283        if !visible_names.contains(name) {
1284            missing.push(name.clone());
1285            continue;
1286        }
1287        let files = match state.skill_registry.bundle_files(name) {
1288            Ok(f) => f,
1289            Err(_) => {
1290                missing.push(name.clone());
1291                continue;
1292            }
1293        };
1294
1295        let mut file_map = serde_json::Map::new();
1296        for (path, data) in &files {
1297            match std::str::from_utf8(data) {
1298                Ok(text) => {
1299                    file_map.insert(path.clone(), Value::String(text.to_string()));
1300                }
1301                Err(_) => {
1302                    use base64::Engine;
1303                    let encoded = base64::engine::general_purpose::STANDARD.encode(data);
1304                    file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
1305                }
1306            }
1307        }
1308
1309        result.insert(name.clone(), serde_json::json!({ "files": file_map }));
1310    }
1311
1312    (
1313        StatusCode::OK,
1314        Json(serde_json::json!({ "skills": result, "missing": missing })),
1315    )
1316}
1317
1318async fn handle_skills_resolve(
1319    State(state): State<Arc<ProxyState>>,
1320    claims: Option<Extension<TokenClaims>>,
1321    Json(req): Json<SkillResolveRequest>,
1322) -> impl IntoResponse {
1323    tracing::debug!(scopes = ?req.scopes, include_content = req.include_content, "POST /skills/resolve");
1324
1325    let include_content = req.include_content;
1326    let request_scopes = ScopeConfig {
1327        scopes: req.scopes,
1328        sub: String::new(),
1329        expires_at: 0,
1330        rate_config: None,
1331    };
1332    let claims = claims.map(|Extension(claims)| claims);
1333    let caller_scopes = scopes_for_request(claims.as_ref(), &state);
1334    let visible_names = visible_skill_names(&state, &caller_scopes);
1335
1336    let resolved: Vec<&skill::SkillMeta> =
1337        skill::resolve_skills(&state.skill_registry, &state.registry, &request_scopes)
1338            .into_iter()
1339            .filter(|skill| visible_names.contains(&skill.name))
1340            .collect();
1341
1342    let json: Vec<Value> = resolved
1343        .iter()
1344        .map(|s| {
1345            let mut entry = serde_json::json!({
1346                "name": s.name,
1347                "version": s.version,
1348                "description": s.description,
1349                "tools": s.tools,
1350                "providers": s.providers,
1351                "categories": s.categories,
1352            });
1353            if include_content {
1354                if let Ok(content) = state.skill_registry.read_content(&s.name) {
1355                    entry["content"] = Value::String(content);
1356                }
1357            }
1358            entry
1359        })
1360        .collect();
1361
1362    (StatusCode::OK, Json(Value::Array(json)))
1363}
1364
1365fn skillati_client(keyring: &Keyring) -> Result<SkillAtiClient, SkillAtiError> {
1366    match SkillAtiClient::from_env(keyring)? {
1367        Some(client) => Ok(client),
1368        None => Err(SkillAtiError::NotConfigured),
1369    }
1370}
1371
1372async fn handle_skillati_catalog(
1373    State(state): State<Arc<ProxyState>>,
1374    claims: Option<Extension<TokenClaims>>,
1375    Query(query): Query<SkillAtiCatalogQuery>,
1376) -> impl IntoResponse {
1377    tracing::debug!(search = ?query.search, "GET /skillati/catalog");
1378
1379    let client = match skillati_client(&state.keyring) {
1380        Ok(client) => client,
1381        Err(err) => return skillati_error_response(err),
1382    };
1383
1384    let claims = claims.map(|Extension(c)| c);
1385    let scopes = scopes_for_request(claims.as_ref(), &state);
1386    let visible_names = visible_skill_names(&state, &scopes);
1387
1388    match client.catalog().await {
1389        Ok(catalog) => {
1390            let mut skills: Vec<_> = catalog
1391                .into_iter()
1392                .filter(|s| visible_names.contains(&s.name))
1393                .collect();
1394            if let Some(search) = query.search.as_deref() {
1395                skills = SkillAtiClient::filter_catalog(&skills, search, 25);
1396            }
1397            (
1398                StatusCode::OK,
1399                Json(serde_json::json!({
1400                    "skills": skills,
1401                })),
1402            )
1403        }
1404        Err(err) => skillati_error_response(err),
1405    }
1406}
1407
1408async fn handle_skillati_read(
1409    State(state): State<Arc<ProxyState>>,
1410    claims: Option<Extension<TokenClaims>>,
1411    axum::extract::Path(name): axum::extract::Path<String>,
1412) -> impl IntoResponse {
1413    tracing::debug!(%name, "GET /skillati/:name");
1414
1415    let client = match skillati_client(&state.keyring) {
1416        Ok(client) => client,
1417        Err(err) => return skillati_error_response(err),
1418    };
1419
1420    let claims = claims.map(|Extension(c)| c);
1421    let scopes = scopes_for_request(claims.as_ref(), &state);
1422    let visible_names = visible_skill_names(&state, &scopes);
1423    if !visible_names.contains(&name) {
1424        return skillati_error_response(SkillAtiError::SkillNotFound(name));
1425    }
1426
1427    match client.read_skill(&name).await {
1428        Ok(activation) => (StatusCode::OK, Json(serde_json::json!(activation))),
1429        Err(err) => skillati_error_response(err),
1430    }
1431}
1432
1433async fn handle_skillati_resources(
1434    State(state): State<Arc<ProxyState>>,
1435    claims: Option<Extension<TokenClaims>>,
1436    axum::extract::Path(name): axum::extract::Path<String>,
1437    Query(query): Query<SkillAtiResourcesQuery>,
1438) -> impl IntoResponse {
1439    tracing::debug!(%name, prefix = ?query.prefix, "GET /skillati/:name/resources");
1440
1441    let client = match skillati_client(&state.keyring) {
1442        Ok(client) => client,
1443        Err(err) => return skillati_error_response(err),
1444    };
1445
1446    let claims = claims.map(|Extension(c)| c);
1447    let scopes = scopes_for_request(claims.as_ref(), &state);
1448    let visible_names = visible_skill_names(&state, &scopes);
1449    if !visible_names.contains(&name) {
1450        return skillati_error_response(SkillAtiError::SkillNotFound(name));
1451    }
1452
1453    match client.list_resources(&name, query.prefix.as_deref()).await {
1454        Ok(resources) => (
1455            StatusCode::OK,
1456            Json(serde_json::json!({
1457                "name": name,
1458                "prefix": query.prefix,
1459                "resources": resources,
1460            })),
1461        ),
1462        Err(err) => skillati_error_response(err),
1463    }
1464}
1465
1466async fn handle_skillati_file(
1467    State(state): State<Arc<ProxyState>>,
1468    claims: Option<Extension<TokenClaims>>,
1469    axum::extract::Path(name): axum::extract::Path<String>,
1470    Query(query): Query<SkillAtiFileQuery>,
1471) -> impl IntoResponse {
1472    tracing::debug!(%name, path = %query.path, "GET /skillati/:name/file");
1473
1474    let client = match skillati_client(&state.keyring) {
1475        Ok(client) => client,
1476        Err(err) => return skillati_error_response(err),
1477    };
1478
1479    let claims = claims.map(|Extension(c)| c);
1480    let scopes = scopes_for_request(claims.as_ref(), &state);
1481    let visible_names = visible_skill_names(&state, &scopes);
1482    if !visible_names.contains(&name) {
1483        return skillati_error_response(SkillAtiError::SkillNotFound(name));
1484    }
1485
1486    match client.read_path(&name, &query.path).await {
1487        Ok(file) => (StatusCode::OK, Json(serde_json::json!(file))),
1488        Err(err) => skillati_error_response(err),
1489    }
1490}
1491
1492async fn handle_skillati_refs(
1493    State(state): State<Arc<ProxyState>>,
1494    claims: Option<Extension<TokenClaims>>,
1495    axum::extract::Path(name): axum::extract::Path<String>,
1496) -> impl IntoResponse {
1497    tracing::debug!(%name, "GET /skillati/:name/refs");
1498
1499    let client = match skillati_client(&state.keyring) {
1500        Ok(client) => client,
1501        Err(err) => return skillati_error_response(err),
1502    };
1503
1504    let claims = claims.map(|Extension(c)| c);
1505    let scopes = scopes_for_request(claims.as_ref(), &state);
1506    let visible_names = visible_skill_names(&state, &scopes);
1507    if !visible_names.contains(&name) {
1508        return skillati_error_response(SkillAtiError::SkillNotFound(name));
1509    }
1510
1511    match client.list_references(&name).await {
1512        Ok(references) => (
1513            StatusCode::OK,
1514            Json(serde_json::json!({
1515                "name": name,
1516                "references": references,
1517            })),
1518        ),
1519        Err(err) => skillati_error_response(err),
1520    }
1521}
1522
1523async fn handle_skillati_ref(
1524    State(state): State<Arc<ProxyState>>,
1525    claims: Option<Extension<TokenClaims>>,
1526    axum::extract::Path((name, reference)): axum::extract::Path<(String, String)>,
1527) -> impl IntoResponse {
1528    tracing::debug!(%name, %reference, "GET /skillati/:name/ref/:reference");
1529
1530    let client = match skillati_client(&state.keyring) {
1531        Ok(client) => client,
1532        Err(err) => return skillati_error_response(err),
1533    };
1534
1535    let claims = claims.map(|Extension(c)| c);
1536    let scopes = scopes_for_request(claims.as_ref(), &state);
1537    let visible_names = visible_skill_names(&state, &scopes);
1538    if !visible_names.contains(&name) {
1539        return skillati_error_response(SkillAtiError::SkillNotFound(name));
1540    }
1541
1542    match client.read_reference(&name, &reference).await {
1543        Ok(content) => (
1544            StatusCode::OK,
1545            Json(serde_json::json!({
1546                "name": name,
1547                "reference": reference,
1548                "content": content,
1549            })),
1550        ),
1551        Err(err) => skillati_error_response(err),
1552    }
1553}
1554
1555fn skillati_error_response(err: SkillAtiError) -> (StatusCode, Json<Value>) {
1556    let status = match &err {
1557        SkillAtiError::NotConfigured
1558        | SkillAtiError::UnsupportedRegistry(_)
1559        | SkillAtiError::MissingCredentials(_)
1560        | SkillAtiError::ProxyUrlRequired => StatusCode::SERVICE_UNAVAILABLE,
1561        SkillAtiError::SkillNotFound(_) | SkillAtiError::PathNotFound { .. } => {
1562            StatusCode::NOT_FOUND
1563        }
1564        SkillAtiError::InvalidPath(_) => StatusCode::BAD_REQUEST,
1565        SkillAtiError::Gcs(_)
1566        | SkillAtiError::ProxyRequest(_)
1567        | SkillAtiError::ProxyResponse(_) => StatusCode::BAD_GATEWAY,
1568    };
1569
1570    (
1571        status,
1572        Json(serde_json::json!({
1573            "error": err.to_string(),
1574        })),
1575    )
1576}
1577
1578// --- Auth middleware ---
1579
1580/// JWT authentication middleware.
1581///
1582/// - /health and /.well-known/jwks.json → skip auth
1583/// - JWT configured → validate Bearer token, attach claims to request extensions
1584/// - No JWT configured → allow all (dev mode)
1585async fn auth_middleware(
1586    State(state): State<Arc<ProxyState>>,
1587    mut req: HttpRequest<Body>,
1588    next: Next,
1589) -> Result<Response, StatusCode> {
1590    let path = req.uri().path();
1591
1592    // Skip auth for public endpoints
1593    if path == "/health" || path == "/.well-known/jwks.json" {
1594        return Ok(next.run(req).await);
1595    }
1596
1597    // If no JWT configured, allow all (dev mode)
1598    let jwt_config = match &state.jwt_config {
1599        Some(c) => c,
1600        None => return Ok(next.run(req).await),
1601    };
1602
1603    // Extract Authorization: Bearer <token>
1604    let auth_header = req
1605        .headers()
1606        .get("authorization")
1607        .and_then(|v| v.to_str().ok());
1608
1609    let token = match auth_header {
1610        Some(header) if header.starts_with("Bearer ") => &header[7..],
1611        _ => return Err(StatusCode::UNAUTHORIZED),
1612    };
1613
1614    // Validate JWT
1615    match jwt::validate(token, jwt_config) {
1616        Ok(claims) => {
1617            tracing::debug!(sub = %claims.sub, scopes = %claims.scope, "JWT validated");
1618            req.extensions_mut().insert(claims);
1619            Ok(next.run(req).await)
1620        }
1621        Err(e) => {
1622            tracing::debug!(error = %e, "JWT validation failed");
1623            Err(StatusCode::UNAUTHORIZED)
1624        }
1625    }
1626}
1627
1628// --- Router builder ---
1629
1630/// Build the axum Router from a pre-constructed ProxyState.
1631pub fn build_router(state: Arc<ProxyState>) -> Router {
1632    Router::new()
1633        .route("/call", post(handle_call))
1634        .route("/help", post(handle_help))
1635        .route("/mcp", post(handle_mcp))
1636        .route("/tools", get(handle_tools_list))
1637        .route("/tools/{name}", get(handle_tool_info))
1638        .route("/skills", get(handle_skills_list))
1639        .route("/skills/resolve", post(handle_skills_resolve))
1640        .route("/skills/bundle", post(handle_skills_bundle_batch))
1641        .route("/skills/{name}", get(handle_skill_detail))
1642        .route("/skills/{name}/bundle", get(handle_skill_bundle))
1643        .route("/skillati/catalog", get(handle_skillati_catalog))
1644        .route("/skillati/{name}", get(handle_skillati_read))
1645        .route("/skillati/{name}/resources", get(handle_skillati_resources))
1646        .route("/skillati/{name}/file", get(handle_skillati_file))
1647        .route("/skillati/{name}/refs", get(handle_skillati_refs))
1648        .route("/skillati/{name}/ref/{reference}", get(handle_skillati_ref))
1649        .route("/health", get(handle_health))
1650        .route("/.well-known/jwks.json", get(handle_jwks))
1651        .layer(middleware::from_fn_with_state(
1652            state.clone(),
1653            auth_middleware,
1654        ))
1655        .with_state(state)
1656}
1657
1658// --- Server startup ---
1659
1660/// Start the proxy server.
1661pub async fn run(
1662    port: u16,
1663    bind_addr: Option<String>,
1664    ati_dir: PathBuf,
1665    _verbose: bool,
1666    env_keys: bool,
1667) -> Result<(), Box<dyn std::error::Error>> {
1668    // Load manifests
1669    let manifests_dir = ati_dir.join("manifests");
1670    let mut registry = ManifestRegistry::load(&manifests_dir)?;
1671    let provider_count = registry.list_providers().len();
1672
1673    // Load keyring
1674    let keyring_source;
1675    let keyring = if env_keys {
1676        // --env-keys: scan ATI_KEY_* environment variables
1677        let kr = Keyring::from_env();
1678        let key_names = kr.key_names();
1679        tracing::info!(
1680            count = key_names.len(),
1681            "loaded API keys from ATI_KEY_* env vars"
1682        );
1683        for name in &key_names {
1684            tracing::debug!(key = %name, "env key loaded");
1685        }
1686        keyring_source = "env-vars (ATI_KEY_*)";
1687        kr
1688    } else {
1689        // Cascade: keyring.enc (sealed) → keyring.enc (persistent) → credentials → empty
1690        let keyring_path = ati_dir.join("keyring.enc");
1691        if keyring_path.exists() {
1692            if let Ok(kr) = Keyring::load(&keyring_path) {
1693                keyring_source = "keyring.enc (sealed key)";
1694                kr
1695            } else if let Ok(kr) = Keyring::load_local(&keyring_path, &ati_dir) {
1696                keyring_source = "keyring.enc (persistent key)";
1697                kr
1698            } else {
1699                tracing::warn!("keyring.enc exists but could not be decrypted");
1700                keyring_source = "empty (decryption failed)";
1701                Keyring::empty()
1702            }
1703        } else {
1704            let creds_path = ati_dir.join("credentials");
1705            if creds_path.exists() {
1706                match Keyring::load_credentials(&creds_path) {
1707                    Ok(kr) => {
1708                        keyring_source = "credentials (plaintext)";
1709                        kr
1710                    }
1711                    Err(e) => {
1712                        tracing::warn!(error = %e, "failed to load credentials");
1713                        keyring_source = "empty (credentials error)";
1714                        Keyring::empty()
1715                    }
1716                }
1717            } else {
1718                tracing::warn!("no keyring.enc or credentials found — running without API keys");
1719                tracing::warn!("tools requiring authentication will fail");
1720                keyring_source = "empty (no auth)";
1721                Keyring::empty()
1722            }
1723        }
1724    };
1725
1726    // Discover MCP tools at startup so they appear in GET /tools.
1727    // Runs concurrently across providers with 30s per-provider timeout.
1728    mcp_client::discover_all_mcp_tools(&mut registry, &keyring).await;
1729
1730    let tool_count = registry.list_public_tools().len();
1731
1732    // Log MCP and OpenAPI providers
1733    let mcp_providers: Vec<(String, String)> = registry
1734        .list_mcp_providers()
1735        .iter()
1736        .map(|p| (p.name.clone(), p.mcp_transport_type().to_string()))
1737        .collect();
1738    let mcp_count = mcp_providers.len();
1739    let openapi_providers: Vec<String> = registry
1740        .list_openapi_providers()
1741        .iter()
1742        .map(|p| p.name.clone())
1743        .collect();
1744    let openapi_count = openapi_providers.len();
1745
1746    // Load installed/local skill registry only.
1747    let skills_dir = ati_dir.join("skills");
1748    let skill_registry = SkillRegistry::load(&skills_dir).unwrap_or_else(|e| {
1749        tracing::warn!(error = %e, "failed to load skills");
1750        SkillRegistry::load(std::path::Path::new("/nonexistent-fallback")).unwrap()
1751    });
1752
1753    if let Ok(registry_url) = std::env::var("ATI_SKILL_REGISTRY") {
1754        if registry_url.strip_prefix("gcs://").is_some() {
1755            tracing::info!(
1756                registry = %registry_url,
1757                "SkillATI remote registry configured for lazy reads"
1758            );
1759        } else {
1760            tracing::warn!(url = %registry_url, "SkillATI only supports gcs:// registries");
1761        }
1762    }
1763
1764    let skill_count = skill_registry.skill_count();
1765
1766    // Load JWT config from environment
1767    let jwt_config = match jwt::config_from_env() {
1768        Ok(config) => config,
1769        Err(e) => {
1770            tracing::warn!(error = %e, "JWT config error");
1771            None
1772        }
1773    };
1774
1775    let auth_status = if jwt_config.is_some() {
1776        "JWT enabled"
1777    } else {
1778        "DISABLED (no JWT keys configured)"
1779    };
1780
1781    // Build JWKS for the endpoint
1782    let jwks_json = jwt_config.as_ref().and_then(|config| {
1783        config
1784            .public_key_pem
1785            .as_ref()
1786            .and_then(|pem| jwt::public_key_to_jwks(pem, config.algorithm, "ati-proxy-1").ok())
1787    });
1788
1789    let state = Arc::new(ProxyState {
1790        registry,
1791        skill_registry,
1792        keyring,
1793        jwt_config,
1794        jwks_json,
1795        auth_cache: AuthCache::new(),
1796    });
1797
1798    let app = build_router(state);
1799
1800    let addr: SocketAddr = if let Some(ref bind) = bind_addr {
1801        format!("{bind}:{port}").parse()?
1802    } else {
1803        SocketAddr::from(([127, 0, 0, 1], port))
1804    };
1805
1806    tracing::info!(
1807        version = env!("CARGO_PKG_VERSION"),
1808        %addr,
1809        auth = auth_status,
1810        ati_dir = %ati_dir.display(),
1811        tools = tool_count,
1812        providers = provider_count,
1813        mcp = mcp_count,
1814        openapi = openapi_count,
1815        skills = skill_count,
1816        keyring = keyring_source,
1817        "ATI proxy server starting"
1818    );
1819    for (name, transport) in &mcp_providers {
1820        tracing::info!(provider = %name, transport = %transport, "MCP provider");
1821    }
1822    for name in &openapi_providers {
1823        tracing::info!(provider = %name, "OpenAPI provider");
1824    }
1825
1826    let listener = tokio::net::TcpListener::bind(addr).await?;
1827    axum::serve(listener, app).await?;
1828
1829    Ok(())
1830}
1831
1832/// Write an audit entry from the proxy server. Failures are silently ignored.
1833fn write_proxy_audit(
1834    call_req: &CallRequest,
1835    agent_sub: &str,
1836    claims: Option<&TokenClaims>,
1837    duration: std::time::Duration,
1838    error: Option<&str>,
1839) {
1840    let entry = crate::core::audit::AuditEntry {
1841        ts: chrono::Utc::now().to_rfc3339(),
1842        tool: call_req.tool_name.clone(),
1843        args: crate::core::audit::sanitize_args(&call_req.args),
1844        status: if error.is_some() {
1845            crate::core::audit::AuditStatus::Error
1846        } else {
1847            crate::core::audit::AuditStatus::Ok
1848        },
1849        duration_ms: duration.as_millis() as u64,
1850        agent_sub: agent_sub.to_string(),
1851        job_id: claims.and_then(|c| c.job_id.clone()),
1852        sandbox_id: claims.and_then(|c| c.sandbox_id.clone()),
1853        error: error.map(|s| s.to_string()),
1854        exit_code: None,
1855    };
1856    let _ = crate::core::audit::append(&entry);
1857}
1858
1859// --- Helpers ---
1860
1861const HELP_SYSTEM_PROMPT: &str = r#"You are a helpful assistant for an AI agent that uses external tools via the `ati` CLI.
1862
1863## Available Tools
1864{tools}
1865
1866{skills_section}
1867
1868Answer the agent's question naturally, like a knowledgeable colleague would. Keep it short but useful:
1869
1870- Explain which tools to use and why, with `ati run` commands showing realistic parameter values
1871- If multiple steps are needed, walk through them briefly in order
1872- Mention important gotchas or parameter choices that matter
1873- If skills are relevant, suggest `ati skill show <name>` for the full methodology
1874
1875Keep your answer concise — a few short paragraphs with embedded code blocks. Only recommend tools from the list above."#;
1876
1877async fn build_remote_skillati_section(keyring: &Keyring, query: &str, limit: usize) -> String {
1878    let client = match SkillAtiClient::from_env(keyring) {
1879        Ok(Some(client)) => client,
1880        Ok(None) => return String::new(),
1881        Err(err) => {
1882            tracing::warn!(error = %err, "failed to initialize SkillATI catalog for proxy help");
1883            return String::new();
1884        }
1885    };
1886
1887    let catalog = match client.catalog().await {
1888        Ok(catalog) => catalog,
1889        Err(err) => {
1890            tracing::warn!(error = %err, "failed to load SkillATI catalog for proxy help");
1891            return String::new();
1892        }
1893    };
1894
1895    let matched = SkillAtiClient::filter_catalog(&catalog, query, limit);
1896    if matched.is_empty() {
1897        return String::new();
1898    }
1899
1900    render_remote_skillati_section(&matched, catalog.len())
1901}
1902
1903fn render_remote_skillati_section(skills: &[RemoteSkillMeta], total_catalog: usize) -> String {
1904    let mut section = String::from("## Remote Skills Available Via SkillATI\n\n");
1905    section.push_str(
1906        "These skills are available remotely from the SkillATI registry. They are not installed locally. Activate one on demand with `ati skillati read <name>`, inspect bundled paths with `ati skillati resources <name>`, and fetch specific files with `ati skillati cat <name> <path>`.\n\n",
1907    );
1908
1909    for skill in skills {
1910        section.push_str(&format!("- **{}**: {}\n", skill.name, skill.description));
1911    }
1912
1913    if total_catalog > skills.len() {
1914        section.push_str(&format!(
1915            "\nOnly the most relevant {} remote skills are shown here.\n",
1916            skills.len()
1917        ));
1918    }
1919
1920    section
1921}
1922
1923fn merge_help_skill_sections(sections: &[String]) -> String {
1924    sections
1925        .iter()
1926        .filter_map(|section| {
1927            let trimmed = section.trim();
1928            if trimmed.is_empty() {
1929                None
1930            } else {
1931                Some(trimmed.to_string())
1932            }
1933        })
1934        .collect::<Vec<_>>()
1935        .join("\n\n")
1936}
1937
1938fn build_tool_context(
1939    tools: &[(
1940        &crate::core::manifest::Provider,
1941        &crate::core::manifest::Tool,
1942    )],
1943) -> String {
1944    let mut summaries = Vec::new();
1945    for (provider, tool) in tools {
1946        let mut summary = if let Some(cat) = &provider.category {
1947            format!(
1948                "- **{}** (provider: {}, category: {}): {}",
1949                tool.name, provider.name, cat, tool.description
1950            )
1951        } else {
1952            format!(
1953                "- **{}** (provider: {}): {}",
1954                tool.name, provider.name, tool.description
1955            )
1956        };
1957        if !tool.tags.is_empty() {
1958            summary.push_str(&format!("\n  Tags: {}", tool.tags.join(", ")));
1959        }
1960        // CLI tools: show passthrough usage
1961        if provider.is_cli() && tool.input_schema.is_none() {
1962            let cmd = provider.cli_command.as_deref().unwrap_or("?");
1963            summary.push_str(&format!(
1964                "\n  Usage: `ati run {} -- <args>`  (passthrough to `{}`)",
1965                tool.name, cmd
1966            ));
1967        } else if let Some(schema) = &tool.input_schema {
1968            if let Some(props) = schema.get("properties") {
1969                if let Some(obj) = props.as_object() {
1970                    let params: Vec<String> = obj
1971                        .iter()
1972                        .filter(|(_, v)| {
1973                            v.get("x-ati-param-location").is_none()
1974                                || v.get("description").is_some()
1975                        })
1976                        .map(|(k, v)| {
1977                            let type_str =
1978                                v.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1979                            let desc = v.get("description").and_then(|d| d.as_str()).unwrap_or("");
1980                            format!("    --{k} ({type_str}): {desc}")
1981                        })
1982                        .collect();
1983                    if !params.is_empty() {
1984                        summary.push_str("\n  Parameters:\n");
1985                        summary.push_str(&params.join("\n"));
1986                    }
1987                }
1988            }
1989        }
1990        summaries.push(summary);
1991    }
1992    summaries.join("\n\n")
1993}
1994
1995/// Build a scoped system prompt for a specific tool or provider.
1996///
1997/// Returns None if the scope_name doesn't match any tool or provider.
1998fn build_scoped_prompt(
1999    scope_name: &str,
2000    visible_tools: &[(&Provider, &Tool)],
2001    skills_section: &str,
2002) -> Option<String> {
2003    // Check if scope_name is a tool
2004    if let Some((provider, tool)) = visible_tools
2005        .iter()
2006        .find(|(_, tool)| tool.name == scope_name)
2007    {
2008        let mut details = format!(
2009            "**Name**: `{}`\n**Provider**: {} (handler: {})\n**Description**: {}\n",
2010            tool.name, provider.name, provider.handler, tool.description
2011        );
2012        if let Some(cat) = &provider.category {
2013            details.push_str(&format!("**Category**: {}\n", cat));
2014        }
2015        if provider.is_cli() {
2016            let cmd = provider.cli_command.as_deref().unwrap_or("?");
2017            details.push_str(&format!(
2018                "\n**Usage**: `ati run {} -- <args>`  (passthrough to `{}`)\n",
2019                tool.name, cmd
2020            ));
2021        } else if let Some(schema) = &tool.input_schema {
2022            if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
2023                let required: Vec<String> = schema
2024                    .get("required")
2025                    .and_then(|r| r.as_array())
2026                    .map(|arr| {
2027                        arr.iter()
2028                            .filter_map(|v| v.as_str().map(|s| s.to_string()))
2029                            .collect()
2030                    })
2031                    .unwrap_or_default();
2032                details.push_str("\n**Parameters**:\n");
2033                for (key, val) in props {
2034                    let type_str = val.get("type").and_then(|t| t.as_str()).unwrap_or("string");
2035                    let desc = val
2036                        .get("description")
2037                        .and_then(|d| d.as_str())
2038                        .unwrap_or("");
2039                    let req = if required.contains(key) {
2040                        " **(required)**"
2041                    } else {
2042                        ""
2043                    };
2044                    details.push_str(&format!("- `--{key}` ({type_str}{req}): {desc}\n"));
2045                }
2046            }
2047        }
2048
2049        let prompt = format!(
2050            "You are an expert assistant for the `{}` tool, accessed via the `ati` CLI.\n\n\
2051            ## Tool Details\n{}\n\n{}\n\n\
2052            Answer the agent's question about this specific tool. Provide exact commands, explain flags and options, and give practical examples. Be concise and actionable.",
2053            tool.name, details, skills_section
2054        );
2055        return Some(prompt);
2056    }
2057
2058    // Check if scope_name is a provider
2059    let tools: Vec<(&Provider, &Tool)> = visible_tools
2060        .iter()
2061        .copied()
2062        .filter(|(provider, _)| provider.name == scope_name)
2063        .collect();
2064    if !tools.is_empty() {
2065        let tools_context = build_tool_context(&tools);
2066        let prompt = format!(
2067            "You are an expert assistant for the `{}` provider's tools, accessed via the `ati` CLI.\n\n\
2068            ## Tools in provider `{}`\n{}\n\n{}\n\n\
2069            Answer the agent's question about these tools. Provide exact `ati run` commands, explain parameters, and give practical examples. Be concise and actionable.",
2070            scope_name, scope_name, tools_context, skills_section
2071        );
2072        return Some(prompt);
2073    }
2074
2075    None
2076}