Skip to main content

ati/proxy/
server.rs

1/// ATI proxy server — holds API keys and executes tool calls on behalf of sandbox agents.
2///
3/// Authentication: ES256-signed JWT (or HS256 fallback). The JWT carries identity,
4/// scopes, and expiry. No more static tokens or unsigned scope lists.
5///
6/// Usage: `ati proxy --port 8080 [--ati-dir ~/.ati]`
7use axum::{
8    body::Body,
9    extract::{Extension, Query, State},
10    http::{Request as HttpRequest, StatusCode},
11    middleware::{self, Next},
12    response::{IntoResponse, Response},
13    routing::{get, post},
14    Json, Router,
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::core::auth_generator::{AuthCache, GenContext};
24use crate::core::http;
25use crate::core::jwt::{self, JwtConfig, TokenClaims};
26use crate::core::keyring::Keyring;
27use crate::core::manifest::{ManifestRegistry, Provider, Tool};
28use crate::core::mcp_client;
29use crate::core::response;
30use crate::core::scope::ScopeConfig;
31use crate::core::skill::{self, SkillRegistry};
32use crate::core::skillati::{RemoteSkillMeta, SkillAtiClient, SkillAtiError};
33use crate::core::xai;
34
35/// Shared state for the proxy server.
36pub struct ProxyState {
37    pub registry: ManifestRegistry,
38    pub skill_registry: SkillRegistry,
39    pub keyring: Keyring,
40    /// JWT validation config (None = auth disabled / dev mode).
41    pub jwt_config: Option<JwtConfig>,
42    /// Pre-computed JWKS JSON for the /.well-known/jwks.json endpoint.
43    pub jwks_json: Option<Value>,
44    /// Shared cache for dynamically generated auth credentials.
45    pub auth_cache: AuthCache,
46}
47
48// --- Request/Response types ---
49
50#[derive(Debug, Deserialize)]
51pub struct CallRequest {
52    pub tool_name: String,
53    /// Tool arguments — accepts a JSON object (key-value pairs) for HTTP/MCP/OpenAPI tools,
54    /// or a JSON array of strings / a single string for CLI tools.
55    /// The proxy auto-detects the handler type and routes accordingly.
56    #[serde(default = "default_args")]
57    pub args: Value,
58    /// Deprecated: use `args` with an array value instead.
59    /// Kept for backward compatibility — if present, takes precedence for CLI tools.
60    #[serde(default)]
61    pub raw_args: Option<Vec<String>>,
62}
63
64fn default_args() -> Value {
65    Value::Object(serde_json::Map::new())
66}
67
68impl CallRequest {
69    /// Extract args as a HashMap for HTTP/MCP/OpenAPI tools.
70    /// If `args` is a JSON object, returns its entries.
71    /// If `args` is something else (array, string), returns an empty map.
72    fn args_as_map(&self) -> HashMap<String, Value> {
73        match &self.args {
74            Value::Object(map) => map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
75            _ => HashMap::new(),
76        }
77    }
78
79    /// Extract positional args for CLI tools.
80    /// Priority: explicit `raw_args` field > `args` array > `args` string > `args._positional` > empty.
81    fn args_as_positional(&self) -> Vec<String> {
82        // Backward compat: explicit raw_args wins
83        if let Some(ref raw) = self.raw_args {
84            return raw.clone();
85        }
86        match &self.args {
87            // ["pr", "list", "--repo", "X"]
88            Value::Array(arr) => arr
89                .iter()
90                .map(|v| match v {
91                    Value::String(s) => s.clone(),
92                    other => other.to_string(),
93                })
94                .collect(),
95            // "pr list --repo X"
96            Value::String(s) => s.split_whitespace().map(String::from).collect(),
97            // {"_positional": ["pr", "list"]} or {"--key": "value"} converted to CLI flags
98            Value::Object(map) => {
99                if let Some(Value::Array(pos)) = map.get("_positional") {
100                    return pos
101                        .iter()
102                        .map(|v| match v {
103                            Value::String(s) => s.clone(),
104                            other => other.to_string(),
105                        })
106                        .collect();
107                }
108                // Convert map entries to --key value pairs
109                let mut result = Vec::new();
110                for (k, v) in map {
111                    result.push(format!("--{k}"));
112                    match v {
113                        Value::String(s) => result.push(s.clone()),
114                        Value::Bool(true) => {} // flag, no value needed
115                        other => result.push(other.to_string()),
116                    }
117                }
118                result
119            }
120            _ => Vec::new(),
121        }
122    }
123}
124
125#[derive(Debug, Serialize)]
126pub struct CallResponse {
127    pub result: Value,
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub error: Option<String>,
130}
131
132#[derive(Debug, Deserialize)]
133pub struct HelpRequest {
134    pub query: String,
135    #[serde(default)]
136    pub tool: Option<String>,
137}
138
139#[derive(Debug, Serialize)]
140pub struct HelpResponse {
141    pub content: String,
142    #[serde(skip_serializing_if = "Option::is_none")]
143    pub error: Option<String>,
144}
145
146#[derive(Debug, Serialize)]
147pub struct HealthResponse {
148    pub status: String,
149    pub version: String,
150    pub tools: usize,
151    pub providers: usize,
152    pub skills: usize,
153    pub auth: String,
154}
155
156// --- Skill endpoint types ---
157
158#[derive(Debug, Deserialize)]
159pub struct SkillsQuery {
160    #[serde(default)]
161    pub category: Option<String>,
162    #[serde(default)]
163    pub provider: Option<String>,
164    #[serde(default)]
165    pub tool: Option<String>,
166    #[serde(default)]
167    pub search: Option<String>,
168}
169
170#[derive(Debug, Deserialize)]
171pub struct SkillDetailQuery {
172    #[serde(default)]
173    pub meta: Option<bool>,
174    #[serde(default)]
175    pub refs: Option<bool>,
176}
177
178#[derive(Debug, Deserialize)]
179pub struct SkillResolveRequest {
180    pub scopes: Vec<String>,
181    /// When true, include SKILL.md content in each resolved skill.
182    #[serde(default)]
183    pub include_content: bool,
184}
185
186#[derive(Debug, Deserialize)]
187pub struct SkillBundleBatchRequest {
188    pub names: Vec<String>,
189}
190
191#[derive(Debug, Deserialize, Default)]
192pub struct SkillAtiCatalogQuery {
193    #[serde(default)]
194    pub search: Option<String>,
195}
196
197#[derive(Debug, Deserialize, Default)]
198pub struct SkillAtiResourcesQuery {
199    #[serde(default)]
200    pub prefix: Option<String>,
201}
202
203#[derive(Debug, Deserialize)]
204pub struct SkillAtiFileQuery {
205    pub path: String,
206}
207
208// --- Tool endpoint types ---
209
210#[derive(Debug, Deserialize)]
211pub struct ToolsQuery {
212    #[serde(default)]
213    pub provider: Option<String>,
214    #[serde(default)]
215    pub search: Option<String>,
216}
217
218// --- Handlers ---
219
220fn scopes_for_request(claims: Option<&TokenClaims>, state: &ProxyState) -> ScopeConfig {
221    match claims {
222        Some(claims) => ScopeConfig::from_jwt(claims),
223        None if state.jwt_config.is_none() => ScopeConfig::unrestricted(),
224        None => ScopeConfig {
225            scopes: Vec::new(),
226            sub: String::new(),
227            expires_at: 0,
228            rate_config: None,
229        },
230    }
231}
232
233fn visible_tools_for_scopes<'a>(
234    state: &'a ProxyState,
235    scopes: &ScopeConfig,
236) -> Vec<(&'a Provider, &'a Tool)> {
237    crate::core::scope::filter_tools_by_scope(state.registry.list_public_tools(), scopes)
238}
239
240fn visible_skill_names(
241    state: &ProxyState,
242    scopes: &ScopeConfig,
243) -> std::collections::HashSet<String> {
244    skill::visible_skills(&state.skill_registry, &state.registry, scopes)
245        .into_iter()
246        .map(|skill| skill.name.clone())
247        .collect()
248}
249
250async fn handle_call(
251    State(state): State<Arc<ProxyState>>,
252    req: HttpRequest<Body>,
253) -> impl IntoResponse {
254    // Extract JWT claims from request extensions (set by auth middleware)
255    let claims = req.extensions().get::<TokenClaims>().cloned();
256
257    // Parse request body
258    let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
259        Ok(b) => b,
260        Err(e) => {
261            return (
262                StatusCode::BAD_REQUEST,
263                Json(CallResponse {
264                    result: Value::Null,
265                    error: Some(format!("Failed to read request body: {e}")),
266                }),
267            );
268        }
269    };
270
271    let call_req: CallRequest = match serde_json::from_slice(&body_bytes) {
272        Ok(r) => r,
273        Err(e) => {
274            return (
275                StatusCode::UNPROCESSABLE_ENTITY,
276                Json(CallResponse {
277                    result: Value::Null,
278                    error: Some(format!("Invalid request: {e}")),
279                }),
280            );
281        }
282    };
283
284    tracing::debug!(
285        tool = %call_req.tool_name,
286        args = ?call_req.args,
287        "POST /call"
288    );
289
290    // Look up tool in registry.
291    // If not found, try converting underscore format (finnhub_quote) to colon (finnhub:quote).
292    let (provider, tool) = match state.registry.get_tool(&call_req.tool_name) {
293        Some(pt) => pt,
294        None => {
295            // Try underscore → colon conversion at each underscore position.
296            // "finnhub_quote" → try "finnhub:quote"
297            // "test_api_get_data" → try "test:api_get_data", "test_api:get_data"
298            let mut resolved = None;
299            for (idx, _) in call_req.tool_name.match_indices('_') {
300                let candidate = format!(
301                    "{}:{}",
302                    &call_req.tool_name[..idx],
303                    &call_req.tool_name[idx + 1..]
304                );
305                if let Some(pt) = state.registry.get_tool(&candidate) {
306                    tracing::debug!(
307                        original = %call_req.tool_name,
308                        resolved = %candidate,
309                        "resolved underscore tool name to colon format"
310                    );
311                    resolved = Some(pt);
312                    break;
313                }
314            }
315
316            match resolved {
317                Some(pt) => pt,
318                None => {
319                    return (
320                        StatusCode::NOT_FOUND,
321                        Json(CallResponse {
322                            result: Value::Null,
323                            error: Some(format!("Unknown tool: '{}'", call_req.tool_name)),
324                        }),
325                    );
326                }
327            }
328        }
329    };
330
331    // Scope enforcement from JWT claims
332    if let Some(tool_scope) = &tool.scope {
333        let scopes = match &claims {
334            Some(c) => ScopeConfig::from_jwt(c),
335            None if state.jwt_config.is_none() => ScopeConfig::unrestricted(), // Dev mode
336            None => {
337                return (
338                    StatusCode::FORBIDDEN,
339                    Json(CallResponse {
340                        result: Value::Null,
341                        error: Some("Authentication required — no JWT provided".into()),
342                    }),
343                );
344            }
345        };
346
347        if !scopes.is_allowed(tool_scope) {
348            return (
349                StatusCode::FORBIDDEN,
350                Json(CallResponse {
351                    result: Value::Null,
352                    error: Some(format!(
353                        "Access denied: '{}' is not in your scopes",
354                        tool.name
355                    )),
356                }),
357            );
358        }
359    }
360
361    // Rate limit check
362    {
363        let scopes = match &claims {
364            Some(c) => ScopeConfig::from_jwt(c),
365            None => ScopeConfig::unrestricted(),
366        };
367        if let Some(ref rate_config) = scopes.rate_config {
368            if let Err(e) = crate::core::rate::check_and_record(&call_req.tool_name, rate_config) {
369                return (
370                    StatusCode::TOO_MANY_REQUESTS,
371                    Json(CallResponse {
372                        result: Value::Null,
373                        error: Some(format!("{e}")),
374                    }),
375                );
376            }
377        }
378    }
379
380    // Build auth generator context from JWT claims
381    let gen_ctx = GenContext {
382        jwt_sub: claims
383            .as_ref()
384            .map(|c| c.sub.clone())
385            .unwrap_or_else(|| "dev".into()),
386        jwt_scope: claims
387            .as_ref()
388            .map(|c| c.scope.clone())
389            .unwrap_or_else(|| "*".into()),
390        tool_name: call_req.tool_name.clone(),
391        timestamp: crate::core::jwt::now_secs(),
392    };
393
394    // Execute tool call — dispatch based on handler type, with timing for audit
395    let agent_sub = claims.as_ref().map(|c| c.sub.clone()).unwrap_or_default();
396    let start = std::time::Instant::now();
397
398    let response = match provider.handler.as_str() {
399        "mcp" => {
400            let args_map = call_req.args_as_map();
401            match mcp_client::execute_with_gen(
402                provider,
403                &call_req.tool_name,
404                &args_map,
405                &state.keyring,
406                Some(&gen_ctx),
407                Some(&state.auth_cache),
408            )
409            .await
410            {
411                Ok(result) => (
412                    StatusCode::OK,
413                    Json(CallResponse {
414                        result,
415                        error: None,
416                    }),
417                ),
418                Err(e) => (
419                    StatusCode::BAD_GATEWAY,
420                    Json(CallResponse {
421                        result: Value::Null,
422                        error: Some(format!("MCP error: {e}")),
423                    }),
424                ),
425            }
426        }
427        "cli" => {
428            let positional = call_req.args_as_positional();
429            match crate::core::cli_executor::execute_with_gen(
430                provider,
431                &positional,
432                &state.keyring,
433                Some(&gen_ctx),
434                Some(&state.auth_cache),
435            )
436            .await
437            {
438                Ok(result) => (
439                    StatusCode::OK,
440                    Json(CallResponse {
441                        result,
442                        error: None,
443                    }),
444                ),
445                Err(e) => (
446                    StatusCode::BAD_GATEWAY,
447                    Json(CallResponse {
448                        result: Value::Null,
449                        error: Some(format!("CLI error: {e}")),
450                    }),
451                ),
452            }
453        }
454        _ => {
455            let args_map = call_req.args_as_map();
456            let raw_response = match match provider.handler.as_str() {
457                "xai" => xai::execute_xai_tool(provider, tool, &args_map, &state.keyring).await,
458                _ => {
459                    http::execute_tool_with_gen(
460                        provider,
461                        tool,
462                        &args_map,
463                        &state.keyring,
464                        Some(&gen_ctx),
465                        Some(&state.auth_cache),
466                    )
467                    .await
468                }
469            } {
470                Ok(resp) => resp,
471                Err(e) => {
472                    let duration = start.elapsed();
473                    write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
474                    return (
475                        StatusCode::BAD_GATEWAY,
476                        Json(CallResponse {
477                            result: Value::Null,
478                            error: Some(format!("Upstream API error: {e}")),
479                        }),
480                    );
481                }
482            };
483
484            let processed = match response::process_response(&raw_response, tool.response.as_ref())
485            {
486                Ok(p) => p,
487                Err(e) => {
488                    let duration = start.elapsed();
489                    write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
490                    return (
491                        StatusCode::INTERNAL_SERVER_ERROR,
492                        Json(CallResponse {
493                            result: raw_response,
494                            error: Some(format!("Response processing error: {e}")),
495                        }),
496                    );
497                }
498            };
499
500            (
501                StatusCode::OK,
502                Json(CallResponse {
503                    result: processed,
504                    error: None,
505                }),
506            )
507        }
508    };
509
510    let duration = start.elapsed();
511    let error_msg = response.1.error.as_deref();
512    write_proxy_audit(&call_req, &agent_sub, duration, error_msg);
513
514    response
515}
516
517async fn handle_help(
518    State(state): State<Arc<ProxyState>>,
519    claims: Option<Extension<TokenClaims>>,
520    Json(req): Json<HelpRequest>,
521) -> impl IntoResponse {
522    tracing::debug!(query = %req.query, tool = ?req.tool, "POST /help");
523
524    let claims = claims.map(|Extension(claims)| claims);
525    let scopes = scopes_for_request(claims.as_ref(), &state);
526    if !scopes.help_enabled() {
527        return (
528            StatusCode::FORBIDDEN,
529            Json(HelpResponse {
530                content: String::new(),
531                error: Some("Help is not enabled in your scopes.".into()),
532            }),
533        );
534    }
535
536    let (llm_provider, llm_tool) = match state.registry.get_tool("_chat_completion") {
537        Some(pt) => pt,
538        None => {
539            return (
540                StatusCode::SERVICE_UNAVAILABLE,
541                Json(HelpResponse {
542                    content: String::new(),
543                    error: Some("No _llm.toml manifest found. Proxy help requires a configured LLM provider.".into()),
544                }),
545            );
546        }
547    };
548
549    let api_key = match llm_provider
550        .auth_key_name
551        .as_deref()
552        .and_then(|k| state.keyring.get(k))
553    {
554        Some(key) => key.to_string(),
555        None => {
556            return (
557                StatusCode::SERVICE_UNAVAILABLE,
558                Json(HelpResponse {
559                    content: String::new(),
560                    error: Some("LLM API key not found in keyring".into()),
561                }),
562            );
563        }
564    };
565
566    let resolved_skills = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
567    let local_skills_section = if resolved_skills.is_empty() {
568        String::new()
569    } else {
570        format!(
571            "## Available Skills (methodology guides)\n{}",
572            skill::build_skill_context(&resolved_skills)
573        )
574    };
575    let remote_query = req
576        .tool
577        .as_ref()
578        .map(|tool| format!("{tool} {}", req.query))
579        .unwrap_or_else(|| req.query.clone());
580    let remote_skills_section =
581        build_remote_skillati_section(&state.keyring, &remote_query, 12).await;
582    let skills_section = merge_help_skill_sections(&[local_skills_section, remote_skills_section]);
583
584    // Build system prompt — scoped or unscoped
585    let visible_tools = visible_tools_for_scopes(&state, &scopes);
586    let system_prompt = if let Some(ref tool_name) = req.tool {
587        // Scoped mode: narrow tools to the specified tool or provider
588        match build_scoped_prompt(tool_name, &visible_tools, &skills_section) {
589            Some(prompt) => prompt,
590            None => {
591                return (
592                    StatusCode::FORBIDDEN,
593                    Json(HelpResponse {
594                        content: String::new(),
595                        error: Some(format!(
596                            "Scope '{tool_name}' is not visible in your current scopes."
597                        )),
598                    }),
599                );
600            }
601        }
602    } else {
603        let tools_context = build_tool_context(&visible_tools);
604        HELP_SYSTEM_PROMPT
605            .replace("{tools}", &tools_context)
606            .replace("{skills_section}", &skills_section)
607    };
608
609    let request_body = serde_json::json!({
610        "model": "zai-glm-4.7",
611        "messages": [
612            {"role": "system", "content": system_prompt},
613            {"role": "user", "content": req.query}
614        ],
615        "max_completion_tokens": 1536,
616        "temperature": 0.3
617    });
618
619    let client = reqwest::Client::new();
620    let url = format!(
621        "{}{}",
622        llm_provider.base_url.trim_end_matches('/'),
623        llm_tool.endpoint
624    );
625
626    let response = match client
627        .post(&url)
628        .bearer_auth(&api_key)
629        .json(&request_body)
630        .send()
631        .await
632    {
633        Ok(r) => r,
634        Err(e) => {
635            return (
636                StatusCode::BAD_GATEWAY,
637                Json(HelpResponse {
638                    content: String::new(),
639                    error: Some(format!("LLM request failed: {e}")),
640                }),
641            );
642        }
643    };
644
645    if !response.status().is_success() {
646        let status = response.status();
647        let body = response.text().await.unwrap_or_default();
648        return (
649            StatusCode::BAD_GATEWAY,
650            Json(HelpResponse {
651                content: String::new(),
652                error: Some(format!("LLM API error ({status}): {body}")),
653            }),
654        );
655    }
656
657    let body: Value = match response.json().await {
658        Ok(b) => b,
659        Err(e) => {
660            return (
661                StatusCode::INTERNAL_SERVER_ERROR,
662                Json(HelpResponse {
663                    content: String::new(),
664                    error: Some(format!("Failed to parse LLM response: {e}")),
665                }),
666            );
667        }
668    };
669
670    let content = body
671        .pointer("/choices/0/message/content")
672        .and_then(|c| c.as_str())
673        .unwrap_or("No response from LLM")
674        .to_string();
675
676    (
677        StatusCode::OK,
678        Json(HelpResponse {
679            content,
680            error: None,
681        }),
682    )
683}
684
685async fn handle_health(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
686    let auth = if state.jwt_config.is_some() {
687        "jwt"
688    } else {
689        "disabled"
690    };
691
692    Json(HealthResponse {
693        status: "ok".into(),
694        version: env!("CARGO_PKG_VERSION").into(),
695        tools: state.registry.list_public_tools().len(),
696        providers: state.registry.list_providers().len(),
697        skills: state.skill_registry.skill_count(),
698        auth: auth.into(),
699    })
700}
701
702/// GET /.well-known/jwks.json — serves the public key for JWT validation.
703async fn handle_jwks(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
704    match &state.jwks_json {
705        Some(jwks) => (StatusCode::OK, Json(jwks.clone())),
706        None => (
707            StatusCode::NOT_FOUND,
708            Json(serde_json::json!({"error": "JWKS not configured"})),
709        ),
710    }
711}
712
713// ---------------------------------------------------------------------------
714// POST /mcp — MCP JSON-RPC proxy endpoint
715// ---------------------------------------------------------------------------
716
717async fn handle_mcp(
718    State(state): State<Arc<ProxyState>>,
719    claims: Option<Extension<TokenClaims>>,
720    Json(msg): Json<Value>,
721) -> impl IntoResponse {
722    let claims = claims.map(|Extension(claims)| claims);
723    let scopes = scopes_for_request(claims.as_ref(), &state);
724    let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
725    let id = msg.get("id").cloned();
726
727    tracing::debug!(%method, "POST /mcp");
728
729    match method {
730        "initialize" => {
731            let result = serde_json::json!({
732                "protocolVersion": "2025-03-26",
733                "capabilities": {
734                    "tools": { "listChanged": false }
735                },
736                "serverInfo": {
737                    "name": "ati-proxy",
738                    "version": env!("CARGO_PKG_VERSION")
739                }
740            });
741            jsonrpc_success(id, result)
742        }
743
744        "notifications/initialized" => (StatusCode::ACCEPTED, Json(Value::Null)),
745
746        "tools/list" => {
747            let visible_tools = visible_tools_for_scopes(&state, &scopes);
748            let mcp_tools: Vec<Value> = visible_tools
749                .iter()
750                .map(|(_provider, tool)| {
751                    serde_json::json!({
752                        "name": tool.name,
753                        "description": tool.description,
754                        "inputSchema": tool.input_schema.clone().unwrap_or(serde_json::json!({
755                            "type": "object",
756                            "properties": {}
757                        }))
758                    })
759                })
760                .collect();
761
762            let result = serde_json::json!({
763                "tools": mcp_tools,
764            });
765            jsonrpc_success(id, result)
766        }
767
768        "tools/call" => {
769            let params = msg.get("params").cloned().unwrap_or(Value::Null);
770            let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
771            let arguments: HashMap<String, Value> = params
772                .get("arguments")
773                .and_then(|a| serde_json::from_value(a.clone()).ok())
774                .unwrap_or_default();
775
776            if tool_name.is_empty() {
777                return jsonrpc_error(id, -32602, "Missing tool name in params.name");
778            }
779
780            let (provider, _tool) = match state.registry.get_tool(tool_name) {
781                Some(pt) => pt,
782                None => {
783                    return jsonrpc_error(id, -32602, &format!("Unknown tool: '{tool_name}'"));
784                }
785            };
786
787            if let Some(tool_scope) = &_tool.scope {
788                if !scopes.is_allowed(tool_scope) {
789                    return jsonrpc_error(
790                        id,
791                        -32001,
792                        &format!("Access denied: '{}' is not in your scopes", _tool.name),
793                    );
794                }
795            }
796
797            tracing::debug!(%tool_name, provider = %provider.name, "MCP tools/call");
798
799            let mcp_gen_ctx = GenContext {
800                jwt_sub: claims
801                    .as_ref()
802                    .map(|claims| claims.sub.clone())
803                    .unwrap_or_else(|| "dev".into()),
804                jwt_scope: claims
805                    .as_ref()
806                    .map(|claims| claims.scope.clone())
807                    .unwrap_or_else(|| "*".into()),
808                tool_name: tool_name.to_string(),
809                timestamp: crate::core::jwt::now_secs(),
810            };
811
812            let result = if provider.is_mcp() {
813                mcp_client::execute_with_gen(
814                    provider,
815                    tool_name,
816                    &arguments,
817                    &state.keyring,
818                    Some(&mcp_gen_ctx),
819                    Some(&state.auth_cache),
820                )
821                .await
822            } else if provider.is_cli() {
823                // Convert arguments map to CLI-style args for MCP passthrough
824                let raw: Vec<String> = arguments
825                    .iter()
826                    .flat_map(|(k, v)| {
827                        let val = match v {
828                            Value::String(s) => s.clone(),
829                            other => other.to_string(),
830                        };
831                        vec![format!("--{k}"), val]
832                    })
833                    .collect();
834                crate::core::cli_executor::execute_with_gen(
835                    provider,
836                    &raw,
837                    &state.keyring,
838                    Some(&mcp_gen_ctx),
839                    Some(&state.auth_cache),
840                )
841                .await
842                .map_err(|e| mcp_client::McpError::Transport(e.to_string()))
843            } else {
844                match match provider.handler.as_str() {
845                    "xai" => {
846                        xai::execute_xai_tool(provider, _tool, &arguments, &state.keyring).await
847                    }
848                    _ => {
849                        http::execute_tool_with_gen(
850                            provider,
851                            _tool,
852                            &arguments,
853                            &state.keyring,
854                            Some(&mcp_gen_ctx),
855                            Some(&state.auth_cache),
856                        )
857                        .await
858                    }
859                } {
860                    Ok(val) => Ok(val),
861                    Err(e) => Err(mcp_client::McpError::Transport(e.to_string())),
862                }
863            };
864
865            match result {
866                Ok(value) => {
867                    let text = match &value {
868                        Value::String(s) => s.clone(),
869                        other => serde_json::to_string_pretty(other).unwrap_or_default(),
870                    };
871                    let mcp_result = serde_json::json!({
872                        "content": [{"type": "text", "text": text}],
873                        "isError": false,
874                    });
875                    jsonrpc_success(id, mcp_result)
876                }
877                Err(e) => {
878                    let mcp_result = serde_json::json!({
879                        "content": [{"type": "text", "text": format!("Error: {e}")}],
880                        "isError": true,
881                    });
882                    jsonrpc_success(id, mcp_result)
883                }
884            }
885        }
886
887        _ => jsonrpc_error(id, -32601, &format!("Method not found: '{method}'")),
888    }
889}
890
891fn jsonrpc_success(id: Option<Value>, result: Value) -> (StatusCode, Json<Value>) {
892    (
893        StatusCode::OK,
894        Json(serde_json::json!({
895            "jsonrpc": "2.0",
896            "id": id,
897            "result": result,
898        })),
899    )
900}
901
902fn jsonrpc_error(id: Option<Value>, code: i64, message: &str) -> (StatusCode, Json<Value>) {
903    (
904        StatusCode::OK,
905        Json(serde_json::json!({
906            "jsonrpc": "2.0",
907            "id": id,
908            "error": {
909                "code": code,
910                "message": message,
911            }
912        })),
913    )
914}
915
916// ---------------------------------------------------------------------------
917// Tool endpoints
918// ---------------------------------------------------------------------------
919
920/// GET /tools — list available tools with optional filters.
921async fn handle_tools_list(
922    State(state): State<Arc<ProxyState>>,
923    claims: Option<Extension<TokenClaims>>,
924    axum::extract::Query(query): axum::extract::Query<ToolsQuery>,
925) -> impl IntoResponse {
926    tracing::debug!(
927        provider = ?query.provider,
928        search = ?query.search,
929        "GET /tools"
930    );
931
932    let claims = claims.map(|Extension(claims)| claims);
933    let scopes = scopes_for_request(claims.as_ref(), &state);
934    let all_tools = visible_tools_for_scopes(&state, &scopes);
935
936    let tools: Vec<Value> = all_tools
937        .iter()
938        .filter(|(provider, tool)| {
939            if let Some(ref p) = query.provider {
940                if provider.name != *p {
941                    return false;
942                }
943            }
944            if let Some(ref q) = query.search {
945                let q = q.to_lowercase();
946                let name_match = tool.name.to_lowercase().contains(&q);
947                let desc_match = tool.description.to_lowercase().contains(&q);
948                let tag_match = tool.tags.iter().any(|t| t.to_lowercase().contains(&q));
949                if !name_match && !desc_match && !tag_match {
950                    return false;
951                }
952            }
953            true
954        })
955        .map(|(provider, tool)| {
956            serde_json::json!({
957                "name": tool.name,
958                "description": tool.description,
959                "provider": provider.name,
960                "method": format!("{:?}", tool.method),
961                "tags": tool.tags,
962                "input_schema": tool.input_schema,
963            })
964        })
965        .collect();
966
967    (StatusCode::OK, Json(Value::Array(tools)))
968}
969
970/// GET /tools/:name — get detailed info about a specific tool.
971async fn handle_tool_info(
972    State(state): State<Arc<ProxyState>>,
973    claims: Option<Extension<TokenClaims>>,
974    axum::extract::Path(name): axum::extract::Path<String>,
975) -> impl IntoResponse {
976    tracing::debug!(tool = %name, "GET /tools/:name");
977
978    let claims = claims.map(|Extension(claims)| claims);
979    let scopes = scopes_for_request(claims.as_ref(), &state);
980
981    match state
982        .registry
983        .get_tool(&name)
984        .filter(|(_, tool)| match &tool.scope {
985            Some(scope) => scopes.is_allowed(scope),
986            None => true,
987        }) {
988        Some((provider, tool)) => (
989            StatusCode::OK,
990            Json(serde_json::json!({
991                "name": tool.name,
992                "description": tool.description,
993                "provider": provider.name,
994                "method": format!("{:?}", tool.method),
995                "endpoint": tool.endpoint,
996                "tags": tool.tags,
997                "hint": tool.hint,
998                "input_schema": tool.input_schema,
999                "scope": tool.scope,
1000            })),
1001        ),
1002        None => (
1003            StatusCode::NOT_FOUND,
1004            Json(serde_json::json!({"error": format!("Tool '{name}' not found")})),
1005        ),
1006    }
1007}
1008
1009// ---------------------------------------------------------------------------
1010// Skill endpoints
1011// ---------------------------------------------------------------------------
1012
1013async fn handle_skills_list(
1014    State(state): State<Arc<ProxyState>>,
1015    claims: Option<Extension<TokenClaims>>,
1016    axum::extract::Query(query): axum::extract::Query<SkillsQuery>,
1017) -> impl IntoResponse {
1018    tracing::debug!(
1019        category = ?query.category,
1020        provider = ?query.provider,
1021        tool = ?query.tool,
1022        search = ?query.search,
1023        "GET /skills"
1024    );
1025
1026    let claims = claims.map(|Extension(claims)| claims);
1027    let scopes = scopes_for_request(claims.as_ref(), &state);
1028    let visible_names = visible_skill_names(&state, &scopes);
1029
1030    let skills: Vec<&skill::SkillMeta> = if let Some(search_query) = &query.search {
1031        state
1032            .skill_registry
1033            .search(search_query)
1034            .into_iter()
1035            .filter(|skill| visible_names.contains(&skill.name))
1036            .collect()
1037    } else if let Some(cat) = &query.category {
1038        state
1039            .skill_registry
1040            .skills_for_category(cat)
1041            .into_iter()
1042            .filter(|skill| visible_names.contains(&skill.name))
1043            .collect()
1044    } else if let Some(prov) = &query.provider {
1045        state
1046            .skill_registry
1047            .skills_for_provider(prov)
1048            .into_iter()
1049            .filter(|skill| visible_names.contains(&skill.name))
1050            .collect()
1051    } else if let Some(t) = &query.tool {
1052        state
1053            .skill_registry
1054            .skills_for_tool(t)
1055            .into_iter()
1056            .filter(|skill| visible_names.contains(&skill.name))
1057            .collect()
1058    } else {
1059        state
1060            .skill_registry
1061            .list_skills()
1062            .iter()
1063            .filter(|skill| visible_names.contains(&skill.name))
1064            .collect()
1065    };
1066
1067    let json: Vec<Value> = skills
1068        .iter()
1069        .map(|s| {
1070            serde_json::json!({
1071                "name": s.name,
1072                "version": s.version,
1073                "description": s.description,
1074                "tools": s.tools,
1075                "providers": s.providers,
1076                "categories": s.categories,
1077                "hint": s.hint,
1078            })
1079        })
1080        .collect();
1081
1082    (StatusCode::OK, Json(Value::Array(json)))
1083}
1084
1085async fn handle_skill_detail(
1086    State(state): State<Arc<ProxyState>>,
1087    claims: Option<Extension<TokenClaims>>,
1088    axum::extract::Path(name): axum::extract::Path<String>,
1089    axum::extract::Query(query): axum::extract::Query<SkillDetailQuery>,
1090) -> impl IntoResponse {
1091    tracing::debug!(%name, meta = ?query.meta, refs = ?query.refs, "GET /skills/:name");
1092
1093    let claims = claims.map(|Extension(claims)| claims);
1094    let scopes = scopes_for_request(claims.as_ref(), &state);
1095    let visible_names = visible_skill_names(&state, &scopes);
1096
1097    let skill_meta = match state
1098        .skill_registry
1099        .get_skill(&name)
1100        .filter(|skill| visible_names.contains(&skill.name))
1101    {
1102        Some(s) => s,
1103        None => {
1104            return (
1105                StatusCode::NOT_FOUND,
1106                Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
1107            );
1108        }
1109    };
1110
1111    if query.meta.unwrap_or(false) {
1112        return (
1113            StatusCode::OK,
1114            Json(serde_json::json!({
1115                "name": skill_meta.name,
1116                "version": skill_meta.version,
1117                "description": skill_meta.description,
1118                "author": skill_meta.author,
1119                "tools": skill_meta.tools,
1120                "providers": skill_meta.providers,
1121                "categories": skill_meta.categories,
1122                "keywords": skill_meta.keywords,
1123                "hint": skill_meta.hint,
1124                "depends_on": skill_meta.depends_on,
1125                "suggests": skill_meta.suggests,
1126                "license": skill_meta.license,
1127                "compatibility": skill_meta.compatibility,
1128                "allowed_tools": skill_meta.allowed_tools,
1129                "format": skill_meta.format,
1130            })),
1131        );
1132    }
1133
1134    let content = match state.skill_registry.read_content(&name) {
1135        Ok(c) => c,
1136        Err(e) => {
1137            return (
1138                StatusCode::INTERNAL_SERVER_ERROR,
1139                Json(serde_json::json!({"error": format!("Failed to read skill: {e}")})),
1140            );
1141        }
1142    };
1143
1144    let mut response = serde_json::json!({
1145        "name": skill_meta.name,
1146        "version": skill_meta.version,
1147        "description": skill_meta.description,
1148        "content": content,
1149    });
1150
1151    if query.refs.unwrap_or(false) {
1152        if let Ok(refs) = state.skill_registry.list_references(&name) {
1153            response["references"] = serde_json::json!(refs);
1154        }
1155    }
1156
1157    (StatusCode::OK, Json(response))
1158}
1159
1160/// GET /skills/:name/bundle — return all files in a skill directory.
1161/// Response: `{"name": "...", "files": {"SKILL.md": "...", "scripts/generate.sh": "...", ...}}`
1162/// Binary files are base64-encoded; text files are returned as-is.
1163async fn handle_skill_bundle(
1164    State(state): State<Arc<ProxyState>>,
1165    claims: Option<Extension<TokenClaims>>,
1166    axum::extract::Path(name): axum::extract::Path<String>,
1167) -> impl IntoResponse {
1168    tracing::debug!(skill = %name, "GET /skills/:name/bundle");
1169
1170    let claims = claims.map(|Extension(claims)| claims);
1171    let scopes = scopes_for_request(claims.as_ref(), &state);
1172    let visible_names = visible_skill_names(&state, &scopes);
1173    if !visible_names.contains(&name) {
1174        return (
1175            StatusCode::NOT_FOUND,
1176            Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
1177        );
1178    }
1179
1180    let files = match state.skill_registry.bundle_files(&name) {
1181        Ok(f) => f,
1182        Err(_) => {
1183            return (
1184                StatusCode::NOT_FOUND,
1185                Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
1186            );
1187        }
1188    };
1189
1190    // Convert bytes to strings (UTF-8 text) or base64 for binary files
1191    let mut file_map = serde_json::Map::new();
1192    for (path, data) in &files {
1193        match std::str::from_utf8(data) {
1194            Ok(text) => {
1195                file_map.insert(path.clone(), Value::String(text.to_string()));
1196            }
1197            Err(_) => {
1198                // Binary file — base64 encode
1199                use base64::Engine;
1200                let encoded = base64::engine::general_purpose::STANDARD.encode(data);
1201                file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
1202            }
1203        }
1204    }
1205
1206    (
1207        StatusCode::OK,
1208        Json(serde_json::json!({
1209            "name": name,
1210            "files": file_map,
1211        })),
1212    )
1213}
1214
1215/// POST /skills/bundle — return all files for multiple skills in one response.
1216/// Request: `{"names": ["fal-generate", "compliance-screening"]}`
1217/// Response: `{"skills": {...}, "missing": [...]}`
1218async fn handle_skills_bundle_batch(
1219    State(state): State<Arc<ProxyState>>,
1220    claims: Option<Extension<TokenClaims>>,
1221    Json(req): Json<SkillBundleBatchRequest>,
1222) -> impl IntoResponse {
1223    const MAX_BATCH: usize = 50;
1224    if req.names.len() > MAX_BATCH {
1225        return (
1226            StatusCode::BAD_REQUEST,
1227            Json(
1228                serde_json::json!({"error": format!("batch size {} exceeds limit of {MAX_BATCH}", req.names.len())}),
1229            ),
1230        );
1231    }
1232
1233    tracing::debug!(names = ?req.names, "POST /skills/bundle");
1234
1235    let claims = claims.map(|Extension(claims)| claims);
1236    let scopes = scopes_for_request(claims.as_ref(), &state);
1237    let visible_names = visible_skill_names(&state, &scopes);
1238
1239    let mut result = serde_json::Map::new();
1240    let mut missing: Vec<String> = Vec::new();
1241
1242    for name in &req.names {
1243        if !visible_names.contains(name) {
1244            missing.push(name.clone());
1245            continue;
1246        }
1247        let files = match state.skill_registry.bundle_files(name) {
1248            Ok(f) => f,
1249            Err(_) => {
1250                missing.push(name.clone());
1251                continue;
1252            }
1253        };
1254
1255        let mut file_map = serde_json::Map::new();
1256        for (path, data) in &files {
1257            match std::str::from_utf8(data) {
1258                Ok(text) => {
1259                    file_map.insert(path.clone(), Value::String(text.to_string()));
1260                }
1261                Err(_) => {
1262                    use base64::Engine;
1263                    let encoded = base64::engine::general_purpose::STANDARD.encode(data);
1264                    file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
1265                }
1266            }
1267        }
1268
1269        result.insert(name.clone(), serde_json::json!({ "files": file_map }));
1270    }
1271
1272    (
1273        StatusCode::OK,
1274        Json(serde_json::json!({ "skills": result, "missing": missing })),
1275    )
1276}
1277
1278async fn handle_skills_resolve(
1279    State(state): State<Arc<ProxyState>>,
1280    claims: Option<Extension<TokenClaims>>,
1281    Json(req): Json<SkillResolveRequest>,
1282) -> impl IntoResponse {
1283    tracing::debug!(scopes = ?req.scopes, include_content = req.include_content, "POST /skills/resolve");
1284
1285    let include_content = req.include_content;
1286    let request_scopes = ScopeConfig {
1287        scopes: req.scopes,
1288        sub: String::new(),
1289        expires_at: 0,
1290        rate_config: None,
1291    };
1292    let claims = claims.map(|Extension(claims)| claims);
1293    let caller_scopes = scopes_for_request(claims.as_ref(), &state);
1294    let visible_names = visible_skill_names(&state, &caller_scopes);
1295
1296    let resolved: Vec<&skill::SkillMeta> =
1297        skill::resolve_skills(&state.skill_registry, &state.registry, &request_scopes)
1298            .into_iter()
1299            .filter(|skill| visible_names.contains(&skill.name))
1300            .collect();
1301
1302    let json: Vec<Value> = resolved
1303        .iter()
1304        .map(|s| {
1305            let mut entry = serde_json::json!({
1306                "name": s.name,
1307                "version": s.version,
1308                "description": s.description,
1309                "tools": s.tools,
1310                "providers": s.providers,
1311                "categories": s.categories,
1312            });
1313            if include_content {
1314                if let Ok(content) = state.skill_registry.read_content(&s.name) {
1315                    entry["content"] = Value::String(content);
1316                }
1317            }
1318            entry
1319        })
1320        .collect();
1321
1322    (StatusCode::OK, Json(Value::Array(json)))
1323}
1324
1325fn skillati_client(keyring: &Keyring) -> Result<SkillAtiClient, SkillAtiError> {
1326    match SkillAtiClient::from_env(keyring)? {
1327        Some(client) => Ok(client),
1328        None => Err(SkillAtiError::NotConfigured),
1329    }
1330}
1331
1332async fn handle_skillati_catalog(
1333    State(state): State<Arc<ProxyState>>,
1334    claims: Option<Extension<TokenClaims>>,
1335    Query(query): Query<SkillAtiCatalogQuery>,
1336) -> impl IntoResponse {
1337    tracing::debug!(search = ?query.search, "GET /skillati/catalog");
1338
1339    let client = match skillati_client(&state.keyring) {
1340        Ok(client) => client,
1341        Err(err) => return skillati_error_response(err),
1342    };
1343
1344    let claims = claims.map(|Extension(c)| c);
1345    let scopes = scopes_for_request(claims.as_ref(), &state);
1346    let visible_names = visible_skill_names(&state, &scopes);
1347
1348    match client.catalog().await {
1349        Ok(catalog) => {
1350            let mut skills: Vec<_> = catalog
1351                .into_iter()
1352                .filter(|s| visible_names.contains(&s.name))
1353                .collect();
1354            if let Some(search) = query.search.as_deref() {
1355                skills = SkillAtiClient::filter_catalog(&skills, search, 25);
1356            }
1357            (
1358                StatusCode::OK,
1359                Json(serde_json::json!({
1360                    "skills": skills,
1361                })),
1362            )
1363        }
1364        Err(err) => skillati_error_response(err),
1365    }
1366}
1367
1368async fn handle_skillati_read(
1369    State(state): State<Arc<ProxyState>>,
1370    claims: Option<Extension<TokenClaims>>,
1371    axum::extract::Path(name): axum::extract::Path<String>,
1372) -> impl IntoResponse {
1373    tracing::debug!(%name, "GET /skillati/:name");
1374
1375    let client = match skillati_client(&state.keyring) {
1376        Ok(client) => client,
1377        Err(err) => return skillati_error_response(err),
1378    };
1379
1380    let claims = claims.map(|Extension(c)| c);
1381    let scopes = scopes_for_request(claims.as_ref(), &state);
1382    let visible_names = visible_skill_names(&state, &scopes);
1383    if !visible_names.contains(&name) {
1384        return skillati_error_response(SkillAtiError::SkillNotFound(name));
1385    }
1386
1387    match client.read_skill(&name).await {
1388        Ok(activation) => (StatusCode::OK, Json(serde_json::json!(activation))),
1389        Err(err) => skillati_error_response(err),
1390    }
1391}
1392
1393async fn handle_skillati_resources(
1394    State(state): State<Arc<ProxyState>>,
1395    claims: Option<Extension<TokenClaims>>,
1396    axum::extract::Path(name): axum::extract::Path<String>,
1397    Query(query): Query<SkillAtiResourcesQuery>,
1398) -> impl IntoResponse {
1399    tracing::debug!(%name, prefix = ?query.prefix, "GET /skillati/:name/resources");
1400
1401    let client = match skillati_client(&state.keyring) {
1402        Ok(client) => client,
1403        Err(err) => return skillati_error_response(err),
1404    };
1405
1406    let claims = claims.map(|Extension(c)| c);
1407    let scopes = scopes_for_request(claims.as_ref(), &state);
1408    let visible_names = visible_skill_names(&state, &scopes);
1409    if !visible_names.contains(&name) {
1410        return skillati_error_response(SkillAtiError::SkillNotFound(name));
1411    }
1412
1413    match client.list_resources(&name, query.prefix.as_deref()).await {
1414        Ok(resources) => (
1415            StatusCode::OK,
1416            Json(serde_json::json!({
1417                "name": name,
1418                "prefix": query.prefix,
1419                "resources": resources,
1420            })),
1421        ),
1422        Err(err) => skillati_error_response(err),
1423    }
1424}
1425
1426async fn handle_skillati_file(
1427    State(state): State<Arc<ProxyState>>,
1428    claims: Option<Extension<TokenClaims>>,
1429    axum::extract::Path(name): axum::extract::Path<String>,
1430    Query(query): Query<SkillAtiFileQuery>,
1431) -> impl IntoResponse {
1432    tracing::debug!(%name, path = %query.path, "GET /skillati/:name/file");
1433
1434    let client = match skillati_client(&state.keyring) {
1435        Ok(client) => client,
1436        Err(err) => return skillati_error_response(err),
1437    };
1438
1439    let claims = claims.map(|Extension(c)| c);
1440    let scopes = scopes_for_request(claims.as_ref(), &state);
1441    let visible_names = visible_skill_names(&state, &scopes);
1442    if !visible_names.contains(&name) {
1443        return skillati_error_response(SkillAtiError::SkillNotFound(name));
1444    }
1445
1446    match client.read_path(&name, &query.path).await {
1447        Ok(file) => (StatusCode::OK, Json(serde_json::json!(file))),
1448        Err(err) => skillati_error_response(err),
1449    }
1450}
1451
1452async fn handle_skillati_refs(
1453    State(state): State<Arc<ProxyState>>,
1454    claims: Option<Extension<TokenClaims>>,
1455    axum::extract::Path(name): axum::extract::Path<String>,
1456) -> impl IntoResponse {
1457    tracing::debug!(%name, "GET /skillati/:name/refs");
1458
1459    let client = match skillati_client(&state.keyring) {
1460        Ok(client) => client,
1461        Err(err) => return skillati_error_response(err),
1462    };
1463
1464    let claims = claims.map(|Extension(c)| c);
1465    let scopes = scopes_for_request(claims.as_ref(), &state);
1466    let visible_names = visible_skill_names(&state, &scopes);
1467    if !visible_names.contains(&name) {
1468        return skillati_error_response(SkillAtiError::SkillNotFound(name));
1469    }
1470
1471    match client.list_references(&name).await {
1472        Ok(references) => (
1473            StatusCode::OK,
1474            Json(serde_json::json!({
1475                "name": name,
1476                "references": references,
1477            })),
1478        ),
1479        Err(err) => skillati_error_response(err),
1480    }
1481}
1482
1483async fn handle_skillati_ref(
1484    State(state): State<Arc<ProxyState>>,
1485    claims: Option<Extension<TokenClaims>>,
1486    axum::extract::Path((name, reference)): axum::extract::Path<(String, String)>,
1487) -> impl IntoResponse {
1488    tracing::debug!(%name, %reference, "GET /skillati/:name/ref/:reference");
1489
1490    let client = match skillati_client(&state.keyring) {
1491        Ok(client) => client,
1492        Err(err) => return skillati_error_response(err),
1493    };
1494
1495    let claims = claims.map(|Extension(c)| c);
1496    let scopes = scopes_for_request(claims.as_ref(), &state);
1497    let visible_names = visible_skill_names(&state, &scopes);
1498    if !visible_names.contains(&name) {
1499        return skillati_error_response(SkillAtiError::SkillNotFound(name));
1500    }
1501
1502    match client.read_reference(&name, &reference).await {
1503        Ok(content) => (
1504            StatusCode::OK,
1505            Json(serde_json::json!({
1506                "name": name,
1507                "reference": reference,
1508                "content": content,
1509            })),
1510        ),
1511        Err(err) => skillati_error_response(err),
1512    }
1513}
1514
1515fn skillati_error_response(err: SkillAtiError) -> (StatusCode, Json<Value>) {
1516    let status = match &err {
1517        SkillAtiError::NotConfigured
1518        | SkillAtiError::UnsupportedRegistry(_)
1519        | SkillAtiError::MissingCredentials(_)
1520        | SkillAtiError::ProxyUrlRequired => StatusCode::SERVICE_UNAVAILABLE,
1521        SkillAtiError::SkillNotFound(_) | SkillAtiError::PathNotFound { .. } => {
1522            StatusCode::NOT_FOUND
1523        }
1524        SkillAtiError::InvalidPath(_) => StatusCode::BAD_REQUEST,
1525        SkillAtiError::Gcs(_)
1526        | SkillAtiError::ProxyRequest(_)
1527        | SkillAtiError::ProxyResponse(_) => StatusCode::BAD_GATEWAY,
1528    };
1529
1530    (
1531        status,
1532        Json(serde_json::json!({
1533            "error": err.to_string(),
1534        })),
1535    )
1536}
1537
1538// --- Auth middleware ---
1539
1540/// JWT authentication middleware.
1541///
1542/// - /health and /.well-known/jwks.json → skip auth
1543/// - JWT configured → validate Bearer token, attach claims to request extensions
1544/// - No JWT configured → allow all (dev mode)
1545async fn auth_middleware(
1546    State(state): State<Arc<ProxyState>>,
1547    mut req: HttpRequest<Body>,
1548    next: Next,
1549) -> Result<Response, StatusCode> {
1550    let path = req.uri().path();
1551
1552    // Skip auth for public endpoints
1553    if path == "/health" || path == "/.well-known/jwks.json" {
1554        return Ok(next.run(req).await);
1555    }
1556
1557    // If no JWT configured, allow all (dev mode)
1558    let jwt_config = match &state.jwt_config {
1559        Some(c) => c,
1560        None => return Ok(next.run(req).await),
1561    };
1562
1563    // Extract Authorization: Bearer <token>
1564    let auth_header = req
1565        .headers()
1566        .get("authorization")
1567        .and_then(|v| v.to_str().ok());
1568
1569    let token = match auth_header {
1570        Some(header) if header.starts_with("Bearer ") => &header[7..],
1571        _ => return Err(StatusCode::UNAUTHORIZED),
1572    };
1573
1574    // Validate JWT
1575    match jwt::validate(token, jwt_config) {
1576        Ok(claims) => {
1577            tracing::debug!(sub = %claims.sub, scopes = %claims.scope, "JWT validated");
1578            req.extensions_mut().insert(claims);
1579            Ok(next.run(req).await)
1580        }
1581        Err(e) => {
1582            tracing::debug!(error = %e, "JWT validation failed");
1583            Err(StatusCode::UNAUTHORIZED)
1584        }
1585    }
1586}
1587
1588// --- Router builder ---
1589
1590/// Build the axum Router from a pre-constructed ProxyState.
1591pub fn build_router(state: Arc<ProxyState>) -> Router {
1592    Router::new()
1593        .route("/call", post(handle_call))
1594        .route("/help", post(handle_help))
1595        .route("/mcp", post(handle_mcp))
1596        .route("/tools", get(handle_tools_list))
1597        .route("/tools/{name}", get(handle_tool_info))
1598        .route("/skills", get(handle_skills_list))
1599        .route("/skills/resolve", post(handle_skills_resolve))
1600        .route("/skills/bundle", post(handle_skills_bundle_batch))
1601        .route("/skills/{name}", get(handle_skill_detail))
1602        .route("/skills/{name}/bundle", get(handle_skill_bundle))
1603        .route("/skillati/catalog", get(handle_skillati_catalog))
1604        .route("/skillati/{name}", get(handle_skillati_read))
1605        .route("/skillati/{name}/resources", get(handle_skillati_resources))
1606        .route("/skillati/{name}/file", get(handle_skillati_file))
1607        .route("/skillati/{name}/refs", get(handle_skillati_refs))
1608        .route("/skillati/{name}/ref/{reference}", get(handle_skillati_ref))
1609        .route("/health", get(handle_health))
1610        .route("/.well-known/jwks.json", get(handle_jwks))
1611        .layer(middleware::from_fn_with_state(
1612            state.clone(),
1613            auth_middleware,
1614        ))
1615        .with_state(state)
1616}
1617
1618// --- Server startup ---
1619
1620/// Start the proxy server.
1621pub async fn run(
1622    port: u16,
1623    bind_addr: Option<String>,
1624    ati_dir: PathBuf,
1625    _verbose: bool,
1626    env_keys: bool,
1627) -> Result<(), Box<dyn std::error::Error>> {
1628    // Load manifests
1629    let manifests_dir = ati_dir.join("manifests");
1630    let mut registry = ManifestRegistry::load(&manifests_dir)?;
1631    let provider_count = registry.list_providers().len();
1632
1633    // Load keyring
1634    let keyring_source;
1635    let keyring = if env_keys {
1636        // --env-keys: scan ATI_KEY_* environment variables
1637        let kr = Keyring::from_env();
1638        let key_names = kr.key_names();
1639        tracing::info!(
1640            count = key_names.len(),
1641            "loaded API keys from ATI_KEY_* env vars"
1642        );
1643        for name in &key_names {
1644            tracing::debug!(key = %name, "env key loaded");
1645        }
1646        keyring_source = "env-vars (ATI_KEY_*)";
1647        kr
1648    } else {
1649        // Cascade: keyring.enc (sealed) → keyring.enc (persistent) → credentials → empty
1650        let keyring_path = ati_dir.join("keyring.enc");
1651        if keyring_path.exists() {
1652            if let Ok(kr) = Keyring::load(&keyring_path) {
1653                keyring_source = "keyring.enc (sealed key)";
1654                kr
1655            } else if let Ok(kr) = Keyring::load_local(&keyring_path, &ati_dir) {
1656                keyring_source = "keyring.enc (persistent key)";
1657                kr
1658            } else {
1659                tracing::warn!("keyring.enc exists but could not be decrypted");
1660                keyring_source = "empty (decryption failed)";
1661                Keyring::empty()
1662            }
1663        } else {
1664            let creds_path = ati_dir.join("credentials");
1665            if creds_path.exists() {
1666                match Keyring::load_credentials(&creds_path) {
1667                    Ok(kr) => {
1668                        keyring_source = "credentials (plaintext)";
1669                        kr
1670                    }
1671                    Err(e) => {
1672                        tracing::warn!(error = %e, "failed to load credentials");
1673                        keyring_source = "empty (credentials error)";
1674                        Keyring::empty()
1675                    }
1676                }
1677            } else {
1678                tracing::warn!("no keyring.enc or credentials found — running without API keys");
1679                tracing::warn!("tools requiring authentication will fail");
1680                keyring_source = "empty (no auth)";
1681                Keyring::empty()
1682            }
1683        }
1684    };
1685
1686    // Discover MCP tools at startup so they appear in GET /tools.
1687    // Runs concurrently across providers with 30s per-provider timeout.
1688    mcp_client::discover_all_mcp_tools(&mut registry, &keyring).await;
1689
1690    let tool_count = registry.list_public_tools().len();
1691
1692    // Log MCP and OpenAPI providers
1693    let mcp_providers: Vec<(String, String)> = registry
1694        .list_mcp_providers()
1695        .iter()
1696        .map(|p| (p.name.clone(), p.mcp_transport_type().to_string()))
1697        .collect();
1698    let mcp_count = mcp_providers.len();
1699    let openapi_providers: Vec<String> = registry
1700        .list_openapi_providers()
1701        .iter()
1702        .map(|p| p.name.clone())
1703        .collect();
1704    let openapi_count = openapi_providers.len();
1705
1706    // Load installed/local skill registry only.
1707    let skills_dir = ati_dir.join("skills");
1708    let skill_registry = SkillRegistry::load(&skills_dir).unwrap_or_else(|e| {
1709        tracing::warn!(error = %e, "failed to load skills");
1710        SkillRegistry::load(std::path::Path::new("/nonexistent-fallback")).unwrap()
1711    });
1712
1713    if let Ok(registry_url) = std::env::var("ATI_SKILL_REGISTRY") {
1714        if registry_url.strip_prefix("gcs://").is_some() {
1715            tracing::info!(
1716                registry = %registry_url,
1717                "SkillATI remote registry configured for lazy reads"
1718            );
1719        } else {
1720            tracing::warn!(url = %registry_url, "SkillATI only supports gcs:// registries");
1721        }
1722    }
1723
1724    let skill_count = skill_registry.skill_count();
1725
1726    // Load JWT config from environment
1727    let jwt_config = match jwt::config_from_env() {
1728        Ok(config) => config,
1729        Err(e) => {
1730            tracing::warn!(error = %e, "JWT config error");
1731            None
1732        }
1733    };
1734
1735    let auth_status = if jwt_config.is_some() {
1736        "JWT enabled"
1737    } else {
1738        "DISABLED (no JWT keys configured)"
1739    };
1740
1741    // Build JWKS for the endpoint
1742    let jwks_json = jwt_config.as_ref().and_then(|config| {
1743        config
1744            .public_key_pem
1745            .as_ref()
1746            .and_then(|pem| jwt::public_key_to_jwks(pem, config.algorithm, "ati-proxy-1").ok())
1747    });
1748
1749    let state = Arc::new(ProxyState {
1750        registry,
1751        skill_registry,
1752        keyring,
1753        jwt_config,
1754        jwks_json,
1755        auth_cache: AuthCache::new(),
1756    });
1757
1758    let app = build_router(state);
1759
1760    let addr: SocketAddr = if let Some(ref bind) = bind_addr {
1761        format!("{bind}:{port}").parse()?
1762    } else {
1763        SocketAddr::from(([127, 0, 0, 1], port))
1764    };
1765
1766    tracing::info!(
1767        version = env!("CARGO_PKG_VERSION"),
1768        %addr,
1769        auth = auth_status,
1770        ati_dir = %ati_dir.display(),
1771        tools = tool_count,
1772        providers = provider_count,
1773        mcp = mcp_count,
1774        openapi = openapi_count,
1775        skills = skill_count,
1776        keyring = keyring_source,
1777        "ATI proxy server starting"
1778    );
1779    for (name, transport) in &mcp_providers {
1780        tracing::info!(provider = %name, transport = %transport, "MCP provider");
1781    }
1782    for name in &openapi_providers {
1783        tracing::info!(provider = %name, "OpenAPI provider");
1784    }
1785
1786    let listener = tokio::net::TcpListener::bind(addr).await?;
1787    axum::serve(listener, app).await?;
1788
1789    Ok(())
1790}
1791
1792/// Write an audit entry from the proxy server. Failures are silently ignored.
1793fn write_proxy_audit(
1794    call_req: &CallRequest,
1795    agent_sub: &str,
1796    duration: std::time::Duration,
1797    error: Option<&str>,
1798) {
1799    let entry = crate::core::audit::AuditEntry {
1800        ts: chrono::Utc::now().to_rfc3339(),
1801        tool: call_req.tool_name.clone(),
1802        args: crate::core::audit::sanitize_args(&call_req.args),
1803        status: if error.is_some() {
1804            crate::core::audit::AuditStatus::Error
1805        } else {
1806            crate::core::audit::AuditStatus::Ok
1807        },
1808        duration_ms: duration.as_millis() as u64,
1809        agent_sub: agent_sub.to_string(),
1810        error: error.map(|s| s.to_string()),
1811        exit_code: None,
1812    };
1813    let _ = crate::core::audit::append(&entry);
1814}
1815
1816// --- Helpers ---
1817
1818const HELP_SYSTEM_PROMPT: &str = r#"You are a helpful assistant for an AI agent that uses external tools via the `ati` CLI.
1819
1820## Available Tools
1821{tools}
1822
1823{skills_section}
1824
1825Answer the agent's question naturally, like a knowledgeable colleague would. Keep it short but useful:
1826
1827- Explain which tools to use and why, with `ati run` commands showing realistic parameter values
1828- If multiple steps are needed, walk through them briefly in order
1829- Mention important gotchas or parameter choices that matter
1830- If skills are relevant, suggest `ati skill show <name>` for the full methodology
1831
1832Keep your answer concise — a few short paragraphs with embedded code blocks. Only recommend tools from the list above."#;
1833
1834async fn build_remote_skillati_section(keyring: &Keyring, query: &str, limit: usize) -> String {
1835    let client = match SkillAtiClient::from_env(keyring) {
1836        Ok(Some(client)) => client,
1837        Ok(None) => return String::new(),
1838        Err(err) => {
1839            tracing::warn!(error = %err, "failed to initialize SkillATI catalog for proxy help");
1840            return String::new();
1841        }
1842    };
1843
1844    let catalog = match client.catalog().await {
1845        Ok(catalog) => catalog,
1846        Err(err) => {
1847            tracing::warn!(error = %err, "failed to load SkillATI catalog for proxy help");
1848            return String::new();
1849        }
1850    };
1851
1852    let matched = SkillAtiClient::filter_catalog(&catalog, query, limit);
1853    if matched.is_empty() {
1854        return String::new();
1855    }
1856
1857    render_remote_skillati_section(&matched, catalog.len())
1858}
1859
1860fn render_remote_skillati_section(skills: &[RemoteSkillMeta], total_catalog: usize) -> String {
1861    let mut section = String::from("## Remote Skills Available Via SkillATI\n\n");
1862    section.push_str(
1863        "These skills are available remotely from the SkillATI registry. They are not installed locally. Activate one on demand with `ati skillati read <name>`, inspect bundled paths with `ati skillati resources <name>`, and fetch specific files with `ati skillati cat <name> <path>`.\n\n",
1864    );
1865
1866    for skill in skills {
1867        section.push_str(&format!("- **{}**: {}\n", skill.name, skill.description));
1868    }
1869
1870    if total_catalog > skills.len() {
1871        section.push_str(&format!(
1872            "\nOnly the most relevant {} remote skills are shown here.\n",
1873            skills.len()
1874        ));
1875    }
1876
1877    section
1878}
1879
1880fn merge_help_skill_sections(sections: &[String]) -> String {
1881    sections
1882        .iter()
1883        .filter_map(|section| {
1884            let trimmed = section.trim();
1885            if trimmed.is_empty() {
1886                None
1887            } else {
1888                Some(trimmed.to_string())
1889            }
1890        })
1891        .collect::<Vec<_>>()
1892        .join("\n\n")
1893}
1894
1895fn build_tool_context(
1896    tools: &[(
1897        &crate::core::manifest::Provider,
1898        &crate::core::manifest::Tool,
1899    )],
1900) -> String {
1901    let mut summaries = Vec::new();
1902    for (provider, tool) in tools {
1903        let mut summary = if let Some(cat) = &provider.category {
1904            format!(
1905                "- **{}** (provider: {}, category: {}): {}",
1906                tool.name, provider.name, cat, tool.description
1907            )
1908        } else {
1909            format!(
1910                "- **{}** (provider: {}): {}",
1911                tool.name, provider.name, tool.description
1912            )
1913        };
1914        if !tool.tags.is_empty() {
1915            summary.push_str(&format!("\n  Tags: {}", tool.tags.join(", ")));
1916        }
1917        // CLI tools: show passthrough usage
1918        if provider.is_cli() && tool.input_schema.is_none() {
1919            let cmd = provider.cli_command.as_deref().unwrap_or("?");
1920            summary.push_str(&format!(
1921                "\n  Usage: `ati run {} -- <args>`  (passthrough to `{}`)",
1922                tool.name, cmd
1923            ));
1924        } else if let Some(schema) = &tool.input_schema {
1925            if let Some(props) = schema.get("properties") {
1926                if let Some(obj) = props.as_object() {
1927                    let params: Vec<String> = obj
1928                        .iter()
1929                        .filter(|(_, v)| {
1930                            v.get("x-ati-param-location").is_none()
1931                                || v.get("description").is_some()
1932                        })
1933                        .map(|(k, v)| {
1934                            let type_str =
1935                                v.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1936                            let desc = v.get("description").and_then(|d| d.as_str()).unwrap_or("");
1937                            format!("    --{k} ({type_str}): {desc}")
1938                        })
1939                        .collect();
1940                    if !params.is_empty() {
1941                        summary.push_str("\n  Parameters:\n");
1942                        summary.push_str(&params.join("\n"));
1943                    }
1944                }
1945            }
1946        }
1947        summaries.push(summary);
1948    }
1949    summaries.join("\n\n")
1950}
1951
1952/// Build a scoped system prompt for a specific tool or provider.
1953///
1954/// Returns None if the scope_name doesn't match any tool or provider.
1955fn build_scoped_prompt(
1956    scope_name: &str,
1957    visible_tools: &[(&Provider, &Tool)],
1958    skills_section: &str,
1959) -> Option<String> {
1960    // Check if scope_name is a tool
1961    if let Some((provider, tool)) = visible_tools
1962        .iter()
1963        .find(|(_, tool)| tool.name == scope_name)
1964    {
1965        let mut details = format!(
1966            "**Name**: `{}`\n**Provider**: {} (handler: {})\n**Description**: {}\n",
1967            tool.name, provider.name, provider.handler, tool.description
1968        );
1969        if let Some(cat) = &provider.category {
1970            details.push_str(&format!("**Category**: {}\n", cat));
1971        }
1972        if provider.is_cli() {
1973            let cmd = provider.cli_command.as_deref().unwrap_or("?");
1974            details.push_str(&format!(
1975                "\n**Usage**: `ati run {} -- <args>`  (passthrough to `{}`)\n",
1976                tool.name, cmd
1977            ));
1978        } else if let Some(schema) = &tool.input_schema {
1979            if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
1980                let required: Vec<String> = schema
1981                    .get("required")
1982                    .and_then(|r| r.as_array())
1983                    .map(|arr| {
1984                        arr.iter()
1985                            .filter_map(|v| v.as_str().map(|s| s.to_string()))
1986                            .collect()
1987                    })
1988                    .unwrap_or_default();
1989                details.push_str("\n**Parameters**:\n");
1990                for (key, val) in props {
1991                    let type_str = val.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1992                    let desc = val
1993                        .get("description")
1994                        .and_then(|d| d.as_str())
1995                        .unwrap_or("");
1996                    let req = if required.contains(key) {
1997                        " **(required)**"
1998                    } else {
1999                        ""
2000                    };
2001                    details.push_str(&format!("- `--{key}` ({type_str}{req}): {desc}\n"));
2002                }
2003            }
2004        }
2005
2006        let prompt = format!(
2007            "You are an expert assistant for the `{}` tool, accessed via the `ati` CLI.\n\n\
2008            ## Tool Details\n{}\n\n{}\n\n\
2009            Answer the agent's question about this specific tool. Provide exact commands, explain flags and options, and give practical examples. Be concise and actionable.",
2010            tool.name, details, skills_section
2011        );
2012        return Some(prompt);
2013    }
2014
2015    // Check if scope_name is a provider
2016    let tools: Vec<(&Provider, &Tool)> = visible_tools
2017        .iter()
2018        .copied()
2019        .filter(|(provider, _)| provider.name == scope_name)
2020        .collect();
2021    if !tools.is_empty() {
2022        let tools_context = build_tool_context(&tools);
2023        let prompt = format!(
2024            "You are an expert assistant for the `{}` provider's tools, accessed via the `ati` CLI.\n\n\
2025            ## Tools in provider `{}`\n{}\n\n{}\n\n\
2026            Answer the agent's question about these tools. Provide exact `ati run` commands, explain parameters, and give practical examples. Be concise and actionable.",
2027            scope_name, scope_name, tools_context, skills_section
2028        );
2029        return Some(prompt);
2030    }
2031
2032    None
2033}