Skip to main content

ati/proxy/
server.rs

1/// ATI proxy server — holds API keys and executes tool calls on behalf of sandbox agents.
2///
3/// Authentication: ES256-signed JWT (or HS256 fallback). The JWT carries identity,
4/// scopes, and expiry. No more static tokens or unsigned scope lists.
5///
6/// Usage: `ati proxy --port 8080 [--ati-dir ~/.ati]`
7use axum::{
8    body::Body,
9    extract::State,
10    http::{Request as HttpRequest, StatusCode},
11    middleware::{self, Next},
12    response::{IntoResponse, Response},
13    routing::{get, post},
14    Json, Router,
15};
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::path::PathBuf;
21use std::sync::Arc;
22
23use crate::core::auth_generator::{AuthCache, GenContext};
24use crate::core::http;
25use crate::core::jwt::{self, JwtConfig, TokenClaims};
26use crate::core::keyring::Keyring;
27use crate::core::manifest::ManifestRegistry;
28use crate::core::mcp_client;
29use crate::core::response;
30use crate::core::scope::ScopeConfig;
31use crate::core::skill::{self, SkillRegistry};
32use crate::core::xai;
33
34/// Shared state for the proxy server.
35pub struct ProxyState {
36    pub registry: ManifestRegistry,
37    pub skill_registry: SkillRegistry,
38    pub keyring: Keyring,
39    /// JWT validation config (None = auth disabled / dev mode).
40    pub jwt_config: Option<JwtConfig>,
41    /// Pre-computed JWKS JSON for the /.well-known/jwks.json endpoint.
42    pub jwks_json: Option<Value>,
43    /// Shared cache for dynamically generated auth credentials.
44    pub auth_cache: AuthCache,
45}
46
47// --- Request/Response types ---
48
49#[derive(Debug, Deserialize)]
50pub struct CallRequest {
51    pub tool_name: String,
52    /// Tool arguments — accepts a JSON object (key-value pairs) for HTTP/MCP/OpenAPI tools,
53    /// or a JSON array of strings / a single string for CLI tools.
54    /// The proxy auto-detects the handler type and routes accordingly.
55    #[serde(default = "default_args")]
56    pub args: Value,
57    /// Deprecated: use `args` with an array value instead.
58    /// Kept for backward compatibility — if present, takes precedence for CLI tools.
59    #[serde(default)]
60    pub raw_args: Option<Vec<String>>,
61}
62
63fn default_args() -> Value {
64    Value::Object(serde_json::Map::new())
65}
66
67impl CallRequest {
68    /// Extract args as a HashMap for HTTP/MCP/OpenAPI tools.
69    /// If `args` is a JSON object, returns its entries.
70    /// If `args` is something else (array, string), returns an empty map.
71    fn args_as_map(&self) -> HashMap<String, Value> {
72        match &self.args {
73            Value::Object(map) => map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
74            _ => HashMap::new(),
75        }
76    }
77
78    /// Extract positional args for CLI tools.
79    /// Priority: explicit `raw_args` field > `args` array > `args` string > `args._positional` > empty.
80    fn args_as_positional(&self) -> Vec<String> {
81        // Backward compat: explicit raw_args wins
82        if let Some(ref raw) = self.raw_args {
83            return raw.clone();
84        }
85        match &self.args {
86            // ["pr", "list", "--repo", "X"]
87            Value::Array(arr) => arr
88                .iter()
89                .map(|v| match v {
90                    Value::String(s) => s.clone(),
91                    other => other.to_string(),
92                })
93                .collect(),
94            // "pr list --repo X"
95            Value::String(s) => s.split_whitespace().map(String::from).collect(),
96            // {"_positional": ["pr", "list"]} or {"--key": "value"} converted to CLI flags
97            Value::Object(map) => {
98                if let Some(Value::Array(pos)) = map.get("_positional") {
99                    return pos
100                        .iter()
101                        .map(|v| match v {
102                            Value::String(s) => s.clone(),
103                            other => other.to_string(),
104                        })
105                        .collect();
106                }
107                // Convert map entries to --key value pairs
108                let mut result = Vec::new();
109                for (k, v) in map {
110                    result.push(format!("--{k}"));
111                    match v {
112                        Value::String(s) => result.push(s.clone()),
113                        Value::Bool(true) => {} // flag, no value needed
114                        other => result.push(other.to_string()),
115                    }
116                }
117                result
118            }
119            _ => Vec::new(),
120        }
121    }
122}
123
124#[derive(Debug, Serialize)]
125pub struct CallResponse {
126    pub result: Value,
127    #[serde(skip_serializing_if = "Option::is_none")]
128    pub error: Option<String>,
129}
130
131#[derive(Debug, Deserialize)]
132pub struct HelpRequest {
133    pub query: String,
134    #[serde(default)]
135    pub tool: Option<String>,
136}
137
138#[derive(Debug, Serialize)]
139pub struct HelpResponse {
140    pub content: String,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub error: Option<String>,
143}
144
145#[derive(Debug, Serialize)]
146pub struct HealthResponse {
147    pub status: String,
148    pub version: String,
149    pub tools: usize,
150    pub providers: usize,
151    pub skills: usize,
152    pub auth: String,
153}
154
155// --- Skill endpoint types ---
156
157#[derive(Debug, Deserialize)]
158pub struct SkillsQuery {
159    #[serde(default)]
160    pub category: Option<String>,
161    #[serde(default)]
162    pub provider: Option<String>,
163    #[serde(default)]
164    pub tool: Option<String>,
165    #[serde(default)]
166    pub search: Option<String>,
167}
168
169#[derive(Debug, Deserialize)]
170pub struct SkillDetailQuery {
171    #[serde(default)]
172    pub meta: Option<bool>,
173    #[serde(default)]
174    pub refs: Option<bool>,
175}
176
177#[derive(Debug, Deserialize)]
178pub struct SkillResolveRequest {
179    pub scopes: Vec<String>,
180    /// When true, include SKILL.md content in each resolved skill.
181    #[serde(default)]
182    pub include_content: bool,
183}
184
185#[derive(Debug, Deserialize)]
186pub struct SkillBundleBatchRequest {
187    pub names: Vec<String>,
188}
189
190// --- Handlers ---
191
192async fn handle_call(
193    State(state): State<Arc<ProxyState>>,
194    req: HttpRequest<Body>,
195) -> impl IntoResponse {
196    // Extract JWT claims from request extensions (set by auth middleware)
197    let claims = req.extensions().get::<TokenClaims>().cloned();
198
199    // Parse request body
200    let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
201        Ok(b) => b,
202        Err(e) => {
203            return (
204                StatusCode::BAD_REQUEST,
205                Json(CallResponse {
206                    result: Value::Null,
207                    error: Some(format!("Failed to read request body: {e}")),
208                }),
209            );
210        }
211    };
212
213    let call_req: CallRequest = match serde_json::from_slice(&body_bytes) {
214        Ok(r) => r,
215        Err(e) => {
216            return (
217                StatusCode::UNPROCESSABLE_ENTITY,
218                Json(CallResponse {
219                    result: Value::Null,
220                    error: Some(format!("Invalid request: {e}")),
221                }),
222            );
223        }
224    };
225
226    tracing::debug!(
227        tool = %call_req.tool_name,
228        args = ?call_req.args,
229        "POST /call"
230    );
231
232    // Look up tool in registry
233    let (provider, tool) = match state.registry.get_tool(&call_req.tool_name) {
234        Some(pt) => pt,
235        None => {
236            return (
237                StatusCode::NOT_FOUND,
238                Json(CallResponse {
239                    result: Value::Null,
240                    error: Some(format!("Unknown tool: '{}'", call_req.tool_name)),
241                }),
242            );
243        }
244    };
245
246    // Scope enforcement from JWT claims
247    if let Some(tool_scope) = &tool.scope {
248        let scopes = match &claims {
249            Some(c) => ScopeConfig::from_jwt(c),
250            None if state.jwt_config.is_none() => ScopeConfig::unrestricted(), // Dev mode
251            None => {
252                return (
253                    StatusCode::FORBIDDEN,
254                    Json(CallResponse {
255                        result: Value::Null,
256                        error: Some("Authentication required — no JWT provided".into()),
257                    }),
258                );
259            }
260        };
261
262        if let Err(e) = scopes.check_access(&call_req.tool_name, tool_scope) {
263            return (
264                StatusCode::FORBIDDEN,
265                Json(CallResponse {
266                    result: Value::Null,
267                    error: Some(format!("Access denied: {e}")),
268                }),
269            );
270        }
271    }
272
273    // Rate limit check
274    {
275        let scopes = match &claims {
276            Some(c) => ScopeConfig::from_jwt(c),
277            None => ScopeConfig::unrestricted(),
278        };
279        if let Some(ref rate_config) = scopes.rate_config {
280            if let Err(e) = crate::core::rate::check_and_record(&call_req.tool_name, rate_config) {
281                return (
282                    StatusCode::TOO_MANY_REQUESTS,
283                    Json(CallResponse {
284                        result: Value::Null,
285                        error: Some(format!("{e}")),
286                    }),
287                );
288            }
289        }
290    }
291
292    // Build auth generator context from JWT claims
293    let gen_ctx = GenContext {
294        jwt_sub: claims
295            .as_ref()
296            .map(|c| c.sub.clone())
297            .unwrap_or_else(|| "dev".into()),
298        jwt_scope: claims
299            .as_ref()
300            .map(|c| c.scope.clone())
301            .unwrap_or_else(|| "*".into()),
302        tool_name: call_req.tool_name.clone(),
303        timestamp: crate::core::jwt::now_secs(),
304    };
305
306    // Execute tool call — dispatch based on handler type, with timing for audit
307    let agent_sub = claims.as_ref().map(|c| c.sub.clone()).unwrap_or_default();
308    let start = std::time::Instant::now();
309
310    let response = match provider.handler.as_str() {
311        "mcp" => {
312            let args_map = call_req.args_as_map();
313            match mcp_client::execute_with_gen(
314                provider,
315                &call_req.tool_name,
316                &args_map,
317                &state.keyring,
318                Some(&gen_ctx),
319                Some(&state.auth_cache),
320            )
321            .await
322            {
323                Ok(result) => (
324                    StatusCode::OK,
325                    Json(CallResponse {
326                        result,
327                        error: None,
328                    }),
329                ),
330                Err(e) => (
331                    StatusCode::BAD_GATEWAY,
332                    Json(CallResponse {
333                        result: Value::Null,
334                        error: Some(format!("MCP error: {e}")),
335                    }),
336                ),
337            }
338        }
339        "cli" => {
340            let positional = call_req.args_as_positional();
341            match crate::core::cli_executor::execute_with_gen(
342                provider,
343                &positional,
344                &state.keyring,
345                Some(&gen_ctx),
346                Some(&state.auth_cache),
347            )
348            .await
349            {
350                Ok(result) => (
351                    StatusCode::OK,
352                    Json(CallResponse {
353                        result,
354                        error: None,
355                    }),
356                ),
357                Err(e) => (
358                    StatusCode::BAD_GATEWAY,
359                    Json(CallResponse {
360                        result: Value::Null,
361                        error: Some(format!("CLI error: {e}")),
362                    }),
363                ),
364            }
365        }
366        _ => {
367            let args_map = call_req.args_as_map();
368            let raw_response = match match provider.handler.as_str() {
369                "xai" => xai::execute_xai_tool(provider, tool, &args_map, &state.keyring).await,
370                _ => {
371                    http::execute_tool_with_gen(
372                        provider,
373                        tool,
374                        &args_map,
375                        &state.keyring,
376                        Some(&gen_ctx),
377                        Some(&state.auth_cache),
378                    )
379                    .await
380                }
381            } {
382                Ok(resp) => resp,
383                Err(e) => {
384                    let duration = start.elapsed();
385                    write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
386                    return (
387                        StatusCode::BAD_GATEWAY,
388                        Json(CallResponse {
389                            result: Value::Null,
390                            error: Some(format!("Upstream API error: {e}")),
391                        }),
392                    );
393                }
394            };
395
396            let processed = match response::process_response(&raw_response, tool.response.as_ref())
397            {
398                Ok(p) => p,
399                Err(e) => {
400                    let duration = start.elapsed();
401                    write_proxy_audit(&call_req, &agent_sub, duration, Some(&e.to_string()));
402                    return (
403                        StatusCode::INTERNAL_SERVER_ERROR,
404                        Json(CallResponse {
405                            result: raw_response,
406                            error: Some(format!("Response processing error: {e}")),
407                        }),
408                    );
409                }
410            };
411
412            (
413                StatusCode::OK,
414                Json(CallResponse {
415                    result: processed,
416                    error: None,
417                }),
418            )
419        }
420    };
421
422    let duration = start.elapsed();
423    let error_msg = response.1.error.as_deref();
424    write_proxy_audit(&call_req, &agent_sub, duration, error_msg);
425
426    response
427}
428
429async fn handle_help(
430    State(state): State<Arc<ProxyState>>,
431    Json(req): Json<HelpRequest>,
432) -> impl IntoResponse {
433    tracing::debug!(query = %req.query, tool = ?req.tool, "POST /help");
434
435    let (llm_provider, llm_tool) = match state.registry.get_tool("_chat_completion") {
436        Some(pt) => pt,
437        None => {
438            return (
439                StatusCode::SERVICE_UNAVAILABLE,
440                Json(HelpResponse {
441                    content: String::new(),
442                    error: Some("No _llm.toml manifest found. Proxy help requires a configured LLM provider.".into()),
443                }),
444            );
445        }
446    };
447
448    let api_key = match llm_provider
449        .auth_key_name
450        .as_deref()
451        .and_then(|k| state.keyring.get(k))
452    {
453        Some(key) => key.to_string(),
454        None => {
455            return (
456                StatusCode::SERVICE_UNAVAILABLE,
457                Json(HelpResponse {
458                    content: String::new(),
459                    error: Some("LLM API key not found in keyring".into()),
460                }),
461            );
462        }
463    };
464
465    let scopes = ScopeConfig::unrestricted();
466    let resolved_skills = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
467    let skills_section = if resolved_skills.is_empty() {
468        String::new()
469    } else {
470        format!(
471            "## Available Skills (methodology guides)\n{}",
472            skill::build_skill_context(&resolved_skills)
473        )
474    };
475
476    // Build system prompt — scoped or unscoped
477    let system_prompt = if let Some(ref tool_name) = req.tool {
478        // Scoped mode: narrow tools to the specified tool or provider
479        match build_scoped_prompt(tool_name, &state.registry, &skills_section) {
480            Some(prompt) => prompt,
481            None => {
482                // Fall back to unscoped if tool/provider not found
483                tracing::debug!(scope = %tool_name, "scope not found, falling back to unscoped");
484                let all_tools = state.registry.list_public_tools();
485                let tools_context = build_tool_context(&all_tools);
486                HELP_SYSTEM_PROMPT
487                    .replace("{tools}", &tools_context)
488                    .replace("{skills_section}", &skills_section)
489            }
490        }
491    } else {
492        let all_tools = state.registry.list_public_tools();
493        let tools_context = build_tool_context(&all_tools);
494        HELP_SYSTEM_PROMPT
495            .replace("{tools}", &tools_context)
496            .replace("{skills_section}", &skills_section)
497    };
498
499    let request_body = serde_json::json!({
500        "model": "zai-glm-4.7",
501        "messages": [
502            {"role": "system", "content": system_prompt},
503            {"role": "user", "content": req.query}
504        ],
505        "max_completion_tokens": 1536,
506        "temperature": 0.3
507    });
508
509    let client = reqwest::Client::new();
510    let url = format!(
511        "{}{}",
512        llm_provider.base_url.trim_end_matches('/'),
513        llm_tool.endpoint
514    );
515
516    let response = match client
517        .post(&url)
518        .bearer_auth(&api_key)
519        .json(&request_body)
520        .send()
521        .await
522    {
523        Ok(r) => r,
524        Err(e) => {
525            return (
526                StatusCode::BAD_GATEWAY,
527                Json(HelpResponse {
528                    content: String::new(),
529                    error: Some(format!("LLM request failed: {e}")),
530                }),
531            );
532        }
533    };
534
535    if !response.status().is_success() {
536        let status = response.status();
537        let body = response.text().await.unwrap_or_default();
538        return (
539            StatusCode::BAD_GATEWAY,
540            Json(HelpResponse {
541                content: String::new(),
542                error: Some(format!("LLM API error ({status}): {body}")),
543            }),
544        );
545    }
546
547    let body: Value = match response.json().await {
548        Ok(b) => b,
549        Err(e) => {
550            return (
551                StatusCode::INTERNAL_SERVER_ERROR,
552                Json(HelpResponse {
553                    content: String::new(),
554                    error: Some(format!("Failed to parse LLM response: {e}")),
555                }),
556            );
557        }
558    };
559
560    let content = body
561        .pointer("/choices/0/message/content")
562        .and_then(|c| c.as_str())
563        .unwrap_or("No response from LLM")
564        .to_string();
565
566    (
567        StatusCode::OK,
568        Json(HelpResponse {
569            content,
570            error: None,
571        }),
572    )
573}
574
575async fn handle_health(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
576    let auth = if state.jwt_config.is_some() {
577        "jwt"
578    } else {
579        "disabled"
580    };
581
582    Json(HealthResponse {
583        status: "ok".into(),
584        version: env!("CARGO_PKG_VERSION").into(),
585        tools: state.registry.list_public_tools().len(),
586        providers: state.registry.list_providers().len(),
587        skills: state.skill_registry.skill_count(),
588        auth: auth.into(),
589    })
590}
591
592/// GET /.well-known/jwks.json — serves the public key for JWT validation.
593async fn handle_jwks(State(state): State<Arc<ProxyState>>) -> impl IntoResponse {
594    match &state.jwks_json {
595        Some(jwks) => (StatusCode::OK, Json(jwks.clone())),
596        None => (
597            StatusCode::NOT_FOUND,
598            Json(serde_json::json!({"error": "JWKS not configured"})),
599        ),
600    }
601}
602
603// ---------------------------------------------------------------------------
604// POST /mcp — MCP JSON-RPC proxy endpoint
605// ---------------------------------------------------------------------------
606
607async fn handle_mcp(
608    State(state): State<Arc<ProxyState>>,
609    Json(msg): Json<Value>,
610) -> impl IntoResponse {
611    let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
612    let id = msg.get("id").cloned();
613
614    tracing::debug!(%method, "POST /mcp");
615
616    match method {
617        "initialize" => {
618            let result = serde_json::json!({
619                "protocolVersion": "2025-03-26",
620                "capabilities": {
621                    "tools": { "listChanged": false }
622                },
623                "serverInfo": {
624                    "name": "ati-proxy",
625                    "version": env!("CARGO_PKG_VERSION")
626                }
627            });
628            jsonrpc_success(id, result)
629        }
630
631        "notifications/initialized" => (StatusCode::ACCEPTED, Json(Value::Null)),
632
633        "tools/list" => {
634            let all_tools = state.registry.list_public_tools();
635            let mcp_tools: Vec<Value> = all_tools
636                .iter()
637                .map(|(_provider, tool)| {
638                    serde_json::json!({
639                        "name": tool.name,
640                        "description": tool.description,
641                        "inputSchema": tool.input_schema.clone().unwrap_or(serde_json::json!({
642                            "type": "object",
643                            "properties": {}
644                        }))
645                    })
646                })
647                .collect();
648
649            let result = serde_json::json!({
650                "tools": mcp_tools,
651            });
652            jsonrpc_success(id, result)
653        }
654
655        "tools/call" => {
656            let params = msg.get("params").cloned().unwrap_or(Value::Null);
657            let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
658            let arguments: HashMap<String, Value> = params
659                .get("arguments")
660                .and_then(|a| serde_json::from_value(a.clone()).ok())
661                .unwrap_or_default();
662
663            if tool_name.is_empty() {
664                return jsonrpc_error(id, -32602, "Missing tool name in params.name");
665            }
666
667            let (provider, _tool) = match state.registry.get_tool(tool_name) {
668                Some(pt) => pt,
669                None => {
670                    return jsonrpc_error(id, -32602, &format!("Unknown tool: '{tool_name}'"));
671                }
672            };
673
674            tracing::debug!(%tool_name, provider = %provider.name, "MCP tools/call");
675
676            let mcp_gen_ctx = GenContext {
677                jwt_sub: "dev".into(),
678                jwt_scope: "*".into(),
679                tool_name: tool_name.to_string(),
680                timestamp: crate::core::jwt::now_secs(),
681            };
682
683            let result = if provider.is_mcp() {
684                mcp_client::execute_with_gen(
685                    provider,
686                    tool_name,
687                    &arguments,
688                    &state.keyring,
689                    Some(&mcp_gen_ctx),
690                    Some(&state.auth_cache),
691                )
692                .await
693            } else if provider.is_cli() {
694                // Convert arguments map to CLI-style args for MCP passthrough
695                let raw: Vec<String> = arguments
696                    .iter()
697                    .flat_map(|(k, v)| {
698                        let val = match v {
699                            Value::String(s) => s.clone(),
700                            other => other.to_string(),
701                        };
702                        vec![format!("--{k}"), val]
703                    })
704                    .collect();
705                crate::core::cli_executor::execute_with_gen(
706                    provider,
707                    &raw,
708                    &state.keyring,
709                    Some(&mcp_gen_ctx),
710                    Some(&state.auth_cache),
711                )
712                .await
713                .map_err(|e| mcp_client::McpError::Transport(e.to_string()))
714            } else {
715                match match provider.handler.as_str() {
716                    "xai" => {
717                        xai::execute_xai_tool(provider, _tool, &arguments, &state.keyring).await
718                    }
719                    _ => {
720                        http::execute_tool_with_gen(
721                            provider,
722                            _tool,
723                            &arguments,
724                            &state.keyring,
725                            Some(&mcp_gen_ctx),
726                            Some(&state.auth_cache),
727                        )
728                        .await
729                    }
730                } {
731                    Ok(val) => Ok(val),
732                    Err(e) => Err(mcp_client::McpError::Transport(e.to_string())),
733                }
734            };
735
736            match result {
737                Ok(value) => {
738                    let text = match &value {
739                        Value::String(s) => s.clone(),
740                        other => serde_json::to_string_pretty(other).unwrap_or_default(),
741                    };
742                    let mcp_result = serde_json::json!({
743                        "content": [{"type": "text", "text": text}],
744                        "isError": false,
745                    });
746                    jsonrpc_success(id, mcp_result)
747                }
748                Err(e) => {
749                    let mcp_result = serde_json::json!({
750                        "content": [{"type": "text", "text": format!("Error: {e}")}],
751                        "isError": true,
752                    });
753                    jsonrpc_success(id, mcp_result)
754                }
755            }
756        }
757
758        _ => jsonrpc_error(id, -32601, &format!("Method not found: '{method}'")),
759    }
760}
761
762fn jsonrpc_success(id: Option<Value>, result: Value) -> (StatusCode, Json<Value>) {
763    (
764        StatusCode::OK,
765        Json(serde_json::json!({
766            "jsonrpc": "2.0",
767            "id": id,
768            "result": result,
769        })),
770    )
771}
772
773fn jsonrpc_error(id: Option<Value>, code: i64, message: &str) -> (StatusCode, Json<Value>) {
774    (
775        StatusCode::OK,
776        Json(serde_json::json!({
777            "jsonrpc": "2.0",
778            "id": id,
779            "error": {
780                "code": code,
781                "message": message,
782            }
783        })),
784    )
785}
786
787// ---------------------------------------------------------------------------
788// Skill endpoints
789// ---------------------------------------------------------------------------
790
791async fn handle_skills_list(
792    State(state): State<Arc<ProxyState>>,
793    axum::extract::Query(query): axum::extract::Query<SkillsQuery>,
794) -> impl IntoResponse {
795    tracing::debug!(
796        category = ?query.category,
797        provider = ?query.provider,
798        tool = ?query.tool,
799        search = ?query.search,
800        "GET /skills"
801    );
802
803    let skills: Vec<&skill::SkillMeta> = if let Some(search_query) = &query.search {
804        state.skill_registry.search(search_query)
805    } else if let Some(cat) = &query.category {
806        state.skill_registry.skills_for_category(cat)
807    } else if let Some(prov) = &query.provider {
808        state.skill_registry.skills_for_provider(prov)
809    } else if let Some(t) = &query.tool {
810        state.skill_registry.skills_for_tool(t)
811    } else {
812        state.skill_registry.list_skills().iter().collect()
813    };
814
815    let json: Vec<Value> = skills
816        .iter()
817        .map(|s| {
818            serde_json::json!({
819                "name": s.name,
820                "version": s.version,
821                "description": s.description,
822                "tools": s.tools,
823                "providers": s.providers,
824                "categories": s.categories,
825                "hint": s.hint,
826            })
827        })
828        .collect();
829
830    (StatusCode::OK, Json(Value::Array(json)))
831}
832
833async fn handle_skill_detail(
834    State(state): State<Arc<ProxyState>>,
835    axum::extract::Path(name): axum::extract::Path<String>,
836    axum::extract::Query(query): axum::extract::Query<SkillDetailQuery>,
837) -> impl IntoResponse {
838    tracing::debug!(%name, meta = ?query.meta, refs = ?query.refs, "GET /skills/:name");
839
840    let skill_meta = match state.skill_registry.get_skill(&name) {
841        Some(s) => s,
842        None => {
843            return (
844                StatusCode::NOT_FOUND,
845                Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
846            );
847        }
848    };
849
850    if query.meta.unwrap_or(false) {
851        return (
852            StatusCode::OK,
853            Json(serde_json::json!({
854                "name": skill_meta.name,
855                "version": skill_meta.version,
856                "description": skill_meta.description,
857                "author": skill_meta.author,
858                "tools": skill_meta.tools,
859                "providers": skill_meta.providers,
860                "categories": skill_meta.categories,
861                "keywords": skill_meta.keywords,
862                "hint": skill_meta.hint,
863                "depends_on": skill_meta.depends_on,
864                "suggests": skill_meta.suggests,
865                "license": skill_meta.license,
866                "compatibility": skill_meta.compatibility,
867                "allowed_tools": skill_meta.allowed_tools,
868                "format": skill_meta.format,
869            })),
870        );
871    }
872
873    let content = match state.skill_registry.read_content(&name) {
874        Ok(c) => c,
875        Err(e) => {
876            return (
877                StatusCode::INTERNAL_SERVER_ERROR,
878                Json(serde_json::json!({"error": format!("Failed to read skill: {e}")})),
879            );
880        }
881    };
882
883    let mut response = serde_json::json!({
884        "name": skill_meta.name,
885        "version": skill_meta.version,
886        "description": skill_meta.description,
887        "content": content,
888    });
889
890    if query.refs.unwrap_or(false) {
891        if let Ok(refs) = state.skill_registry.list_references(&name) {
892            response["references"] = serde_json::json!(refs);
893        }
894    }
895
896    (StatusCode::OK, Json(response))
897}
898
899/// GET /skills/:name/bundle — return all files in a skill directory.
900/// Response: `{"name": "...", "files": {"SKILL.md": "...", "scripts/generate.sh": "...", ...}}`
901/// Binary files are base64-encoded; text files are returned as-is.
902async fn handle_skill_bundle(
903    State(state): State<Arc<ProxyState>>,
904    axum::extract::Path(name): axum::extract::Path<String>,
905) -> impl IntoResponse {
906    tracing::debug!(skill = %name, "GET /skills/:name/bundle");
907
908    let files = match state.skill_registry.bundle_files(&name) {
909        Ok(f) => f,
910        Err(_) => {
911            return (
912                StatusCode::NOT_FOUND,
913                Json(serde_json::json!({"error": format!("Skill '{name}' not found")})),
914            );
915        }
916    };
917
918    // Convert bytes to strings (UTF-8 text) or base64 for binary files
919    let mut file_map = serde_json::Map::new();
920    for (path, data) in &files {
921        match std::str::from_utf8(data) {
922            Ok(text) => {
923                file_map.insert(path.clone(), Value::String(text.to_string()));
924            }
925            Err(_) => {
926                // Binary file — base64 encode
927                use base64::Engine;
928                let encoded = base64::engine::general_purpose::STANDARD.encode(data);
929                file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
930            }
931        }
932    }
933
934    (
935        StatusCode::OK,
936        Json(serde_json::json!({
937            "name": name,
938            "files": file_map,
939        })),
940    )
941}
942
943/// POST /skills/bundle — return all files for multiple skills in one response.
944/// Request: `{"names": ["fal-generate", "compliance-screening"]}`
945/// Response: `{"skills": {...}, "missing": [...]}`
946async fn handle_skills_bundle_batch(
947    State(state): State<Arc<ProxyState>>,
948    Json(req): Json<SkillBundleBatchRequest>,
949) -> impl IntoResponse {
950    const MAX_BATCH: usize = 50;
951    if req.names.len() > MAX_BATCH {
952        return (
953            StatusCode::BAD_REQUEST,
954            Json(
955                serde_json::json!({"error": format!("batch size {} exceeds limit of {MAX_BATCH}", req.names.len())}),
956            ),
957        );
958    }
959
960    tracing::debug!(names = ?req.names, "POST /skills/bundle");
961
962    let mut result = serde_json::Map::new();
963    let mut missing: Vec<String> = Vec::new();
964
965    for name in &req.names {
966        let files = match state.skill_registry.bundle_files(name) {
967            Ok(f) => f,
968            Err(_) => {
969                missing.push(name.clone());
970                continue;
971            }
972        };
973
974        let mut file_map = serde_json::Map::new();
975        for (path, data) in &files {
976            match std::str::from_utf8(data) {
977                Ok(text) => {
978                    file_map.insert(path.clone(), Value::String(text.to_string()));
979                }
980                Err(_) => {
981                    use base64::Engine;
982                    let encoded = base64::engine::general_purpose::STANDARD.encode(data);
983                    file_map.insert(path.clone(), serde_json::json!({"base64": encoded}));
984                }
985            }
986        }
987
988        result.insert(name.clone(), serde_json::json!({ "files": file_map }));
989    }
990
991    (
992        StatusCode::OK,
993        Json(serde_json::json!({ "skills": result, "missing": missing })),
994    )
995}
996
997async fn handle_skills_resolve(
998    State(state): State<Arc<ProxyState>>,
999    Json(req): Json<SkillResolveRequest>,
1000) -> impl IntoResponse {
1001    tracing::debug!(scopes = ?req.scopes, include_content = req.include_content, "POST /skills/resolve");
1002
1003    let include_content = req.include_content;
1004    let scopes = ScopeConfig {
1005        scopes: req.scopes,
1006        sub: String::new(),
1007        expires_at: 0,
1008        rate_config: None,
1009    };
1010
1011    let resolved = skill::resolve_skills(&state.skill_registry, &state.registry, &scopes);
1012
1013    let json: Vec<Value> = resolved
1014        .iter()
1015        .map(|s| {
1016            let mut entry = serde_json::json!({
1017                "name": s.name,
1018                "version": s.version,
1019                "description": s.description,
1020                "tools": s.tools,
1021                "providers": s.providers,
1022                "categories": s.categories,
1023            });
1024            if include_content {
1025                if let Ok(content) = state.skill_registry.read_content(&s.name) {
1026                    entry["content"] = Value::String(content);
1027                }
1028            }
1029            entry
1030        })
1031        .collect();
1032
1033    (StatusCode::OK, Json(Value::Array(json)))
1034}
1035
1036// --- Auth middleware ---
1037
1038/// JWT authentication middleware.
1039///
1040/// - /health and /.well-known/jwks.json → skip auth
1041/// - JWT configured → validate Bearer token, attach claims to request extensions
1042/// - No JWT configured → allow all (dev mode)
1043async fn auth_middleware(
1044    State(state): State<Arc<ProxyState>>,
1045    mut req: HttpRequest<Body>,
1046    next: Next,
1047) -> Result<Response, StatusCode> {
1048    let path = req.uri().path();
1049
1050    // Skip auth for public endpoints
1051    if path == "/health" || path == "/.well-known/jwks.json" {
1052        return Ok(next.run(req).await);
1053    }
1054
1055    // If no JWT configured, allow all (dev mode)
1056    let jwt_config = match &state.jwt_config {
1057        Some(c) => c,
1058        None => return Ok(next.run(req).await),
1059    };
1060
1061    // Extract Authorization: Bearer <token>
1062    let auth_header = req
1063        .headers()
1064        .get("authorization")
1065        .and_then(|v| v.to_str().ok());
1066
1067    let token = match auth_header {
1068        Some(header) if header.starts_with("Bearer ") => &header[7..],
1069        _ => return Err(StatusCode::UNAUTHORIZED),
1070    };
1071
1072    // Validate JWT
1073    match jwt::validate(token, jwt_config) {
1074        Ok(claims) => {
1075            tracing::debug!(sub = %claims.sub, scopes = %claims.scope, "JWT validated");
1076            req.extensions_mut().insert(claims);
1077            Ok(next.run(req).await)
1078        }
1079        Err(e) => {
1080            tracing::debug!(error = %e, "JWT validation failed");
1081            Err(StatusCode::UNAUTHORIZED)
1082        }
1083    }
1084}
1085
1086// --- Router builder ---
1087
1088/// Build the axum Router from a pre-constructed ProxyState.
1089pub fn build_router(state: Arc<ProxyState>) -> Router {
1090    Router::new()
1091        .route("/call", post(handle_call))
1092        .route("/help", post(handle_help))
1093        .route("/mcp", post(handle_mcp))
1094        .route("/skills", get(handle_skills_list))
1095        .route("/skills/resolve", post(handle_skills_resolve))
1096        .route("/skills/bundle", post(handle_skills_bundle_batch))
1097        .route("/skills/{name}", get(handle_skill_detail))
1098        .route("/skills/{name}/bundle", get(handle_skill_bundle))
1099        .route("/health", get(handle_health))
1100        .route("/.well-known/jwks.json", get(handle_jwks))
1101        .layer(middleware::from_fn_with_state(
1102            state.clone(),
1103            auth_middleware,
1104        ))
1105        .with_state(state)
1106}
1107
1108// --- Server startup ---
1109
1110/// Start the proxy server.
1111pub async fn run(
1112    port: u16,
1113    bind_addr: Option<String>,
1114    ati_dir: PathBuf,
1115    _verbose: bool,
1116    env_keys: bool,
1117) -> Result<(), Box<dyn std::error::Error>> {
1118    // Load manifests
1119    let manifests_dir = ati_dir.join("manifests");
1120    let registry = ManifestRegistry::load(&manifests_dir)?;
1121
1122    let tool_count = registry.list_public_tools().len();
1123    let provider_count = registry.list_providers().len();
1124
1125    // Load keyring
1126    let keyring_source;
1127    let keyring = if env_keys {
1128        // --env-keys: scan ATI_KEY_* environment variables
1129        let kr = Keyring::from_env();
1130        let key_names = kr.key_names();
1131        tracing::info!(
1132            count = key_names.len(),
1133            "loaded API keys from ATI_KEY_* env vars"
1134        );
1135        for name in &key_names {
1136            tracing::debug!(key = %name, "env key loaded");
1137        }
1138        keyring_source = "env-vars (ATI_KEY_*)";
1139        kr
1140    } else {
1141        // Cascade: keyring.enc (sealed) → keyring.enc (persistent) → credentials → empty
1142        let keyring_path = ati_dir.join("keyring.enc");
1143        if keyring_path.exists() {
1144            if let Ok(kr) = Keyring::load(&keyring_path) {
1145                keyring_source = "keyring.enc (sealed key)";
1146                kr
1147            } else if let Ok(kr) = Keyring::load_local(&keyring_path, &ati_dir) {
1148                keyring_source = "keyring.enc (persistent key)";
1149                kr
1150            } else {
1151                tracing::warn!("keyring.enc exists but could not be decrypted");
1152                keyring_source = "empty (decryption failed)";
1153                Keyring::empty()
1154            }
1155        } else {
1156            let creds_path = ati_dir.join("credentials");
1157            if creds_path.exists() {
1158                match Keyring::load_credentials(&creds_path) {
1159                    Ok(kr) => {
1160                        keyring_source = "credentials (plaintext)";
1161                        kr
1162                    }
1163                    Err(e) => {
1164                        tracing::warn!(error = %e, "failed to load credentials");
1165                        keyring_source = "empty (credentials error)";
1166                        Keyring::empty()
1167                    }
1168                }
1169            } else {
1170                tracing::warn!("no keyring.enc or credentials found — running without API keys");
1171                tracing::warn!("tools requiring authentication will fail");
1172                keyring_source = "empty (no auth)";
1173                Keyring::empty()
1174            }
1175        }
1176    };
1177
1178    // Log MCP and OpenAPI providers
1179    let mcp_providers: Vec<(String, String)> = registry
1180        .list_mcp_providers()
1181        .iter()
1182        .map(|p| (p.name.clone(), p.mcp_transport_type().to_string()))
1183        .collect();
1184    let mcp_count = mcp_providers.len();
1185    let openapi_providers: Vec<String> = registry
1186        .list_openapi_providers()
1187        .iter()
1188        .map(|p| p.name.clone())
1189        .collect();
1190    let openapi_count = openapi_providers.len();
1191
1192    // Load skill registry (local + optional GCS)
1193    let skills_dir = ati_dir.join("skills");
1194    let mut skill_registry = SkillRegistry::load(&skills_dir).unwrap_or_else(|e| {
1195        tracing::warn!(error = %e, "failed to load skills");
1196        SkillRegistry::load(std::path::Path::new("/nonexistent-fallback")).unwrap()
1197    });
1198
1199    // Load GCS skills if ATI_SKILL_REGISTRY is set
1200    if let Ok(registry_url) = std::env::var("ATI_SKILL_REGISTRY") {
1201        if let Some(bucket) = registry_url.strip_prefix("gcs://") {
1202            let cred_key = "gcp_credentials";
1203            if let Some(cred_json) = keyring.get(cred_key) {
1204                match crate::core::gcs::GcsClient::new(bucket.to_string(), cred_json) {
1205                    Ok(client) => match crate::core::gcs::GcsSkillSource::load(&client).await {
1206                        Ok(gcs_source) => {
1207                            let gcs_count = gcs_source.skill_count();
1208                            skill_registry.merge(gcs_source);
1209                            tracing::info!(
1210                                bucket = %bucket,
1211                                skills = gcs_count,
1212                                "loaded skills from GCS registry"
1213                            );
1214                        }
1215                        Err(e) => {
1216                            tracing::warn!(error = %e, bucket = %bucket, "failed to load GCS skills");
1217                        }
1218                    },
1219                    Err(e) => {
1220                        tracing::warn!(error = %e, "failed to init GCS client");
1221                    }
1222                }
1223            } else {
1224                tracing::warn!(
1225                    key = %cred_key,
1226                    "ATI_SKILL_REGISTRY set but GCS credentials not found in keyring"
1227                );
1228            }
1229        } else {
1230            tracing::warn!(
1231                url = %registry_url,
1232                "unsupported skill registry scheme (only gcs:// is supported)"
1233            );
1234        }
1235    }
1236
1237    let skill_count = skill_registry.skill_count();
1238
1239    // Load JWT config from environment
1240    let jwt_config = match jwt::config_from_env() {
1241        Ok(config) => config,
1242        Err(e) => {
1243            tracing::warn!(error = %e, "JWT config error");
1244            None
1245        }
1246    };
1247
1248    let auth_status = if jwt_config.is_some() {
1249        "JWT enabled"
1250    } else {
1251        "DISABLED (no JWT keys configured)"
1252    };
1253
1254    // Build JWKS for the endpoint
1255    let jwks_json = jwt_config.as_ref().and_then(|config| {
1256        config
1257            .public_key_pem
1258            .as_ref()
1259            .and_then(|pem| jwt::public_key_to_jwks(pem, config.algorithm, "ati-proxy-1").ok())
1260    });
1261
1262    let state = Arc::new(ProxyState {
1263        registry,
1264        skill_registry,
1265        keyring,
1266        jwt_config,
1267        jwks_json,
1268        auth_cache: AuthCache::new(),
1269    });
1270
1271    let app = build_router(state);
1272
1273    let addr: SocketAddr = if let Some(ref bind) = bind_addr {
1274        format!("{bind}:{port}").parse()?
1275    } else {
1276        SocketAddr::from(([127, 0, 0, 1], port))
1277    };
1278
1279    tracing::info!(
1280        version = env!("CARGO_PKG_VERSION"),
1281        %addr,
1282        auth = auth_status,
1283        ati_dir = %ati_dir.display(),
1284        tools = tool_count,
1285        providers = provider_count,
1286        mcp = mcp_count,
1287        openapi = openapi_count,
1288        skills = skill_count,
1289        keyring = keyring_source,
1290        "ATI proxy server starting"
1291    );
1292    for (name, transport) in &mcp_providers {
1293        tracing::info!(provider = %name, transport = %transport, "MCP provider");
1294    }
1295    for name in &openapi_providers {
1296        tracing::info!(provider = %name, "OpenAPI provider");
1297    }
1298
1299    let listener = tokio::net::TcpListener::bind(addr).await?;
1300    axum::serve(listener, app).await?;
1301
1302    Ok(())
1303}
1304
1305/// Write an audit entry from the proxy server. Failures are silently ignored.
1306fn write_proxy_audit(
1307    call_req: &CallRequest,
1308    agent_sub: &str,
1309    duration: std::time::Duration,
1310    error: Option<&str>,
1311) {
1312    let entry = crate::core::audit::AuditEntry {
1313        ts: chrono::Utc::now().to_rfc3339(),
1314        tool: call_req.tool_name.clone(),
1315        args: crate::core::audit::sanitize_args(&call_req.args),
1316        status: if error.is_some() {
1317            crate::core::audit::AuditStatus::Error
1318        } else {
1319            crate::core::audit::AuditStatus::Ok
1320        },
1321        duration_ms: duration.as_millis() as u64,
1322        agent_sub: agent_sub.to_string(),
1323        error: error.map(|s| s.to_string()),
1324        exit_code: None,
1325    };
1326    let _ = crate::core::audit::append(&entry);
1327}
1328
1329// --- Helpers ---
1330
1331const HELP_SYSTEM_PROMPT: &str = r#"You are a helpful assistant for an AI agent that uses external tools via the `ati` CLI.
1332
1333## Available Tools
1334{tools}
1335
1336{skills_section}
1337
1338Answer the agent's question naturally, like a knowledgeable colleague would. Keep it short but useful:
1339
1340- Explain which tools to use and why, with `ati run` commands showing realistic parameter values
1341- If multiple steps are needed, walk through them briefly in order
1342- Mention important gotchas or parameter choices that matter
1343- If skills are relevant, suggest `ati skill show <name>` for the full methodology
1344
1345Keep your answer concise — a few short paragraphs with embedded code blocks. Only recommend tools from the list above."#;
1346
1347fn build_tool_context(
1348    tools: &[(
1349        &crate::core::manifest::Provider,
1350        &crate::core::manifest::Tool,
1351    )],
1352) -> String {
1353    let mut summaries = Vec::new();
1354    for (provider, tool) in tools {
1355        let mut summary = if let Some(cat) = &provider.category {
1356            format!(
1357                "- **{}** (provider: {}, category: {}): {}",
1358                tool.name, provider.name, cat, tool.description
1359            )
1360        } else {
1361            format!(
1362                "- **{}** (provider: {}): {}",
1363                tool.name, provider.name, tool.description
1364            )
1365        };
1366        if !tool.tags.is_empty() {
1367            summary.push_str(&format!("\n  Tags: {}", tool.tags.join(", ")));
1368        }
1369        // CLI tools: show passthrough usage
1370        if provider.is_cli() && tool.input_schema.is_none() {
1371            let cmd = provider.cli_command.as_deref().unwrap_or("?");
1372            summary.push_str(&format!(
1373                "\n  Usage: `ati run {} -- <args>`  (passthrough to `{}`)",
1374                tool.name, cmd
1375            ));
1376        } else if let Some(schema) = &tool.input_schema {
1377            if let Some(props) = schema.get("properties") {
1378                if let Some(obj) = props.as_object() {
1379                    let params: Vec<String> = obj
1380                        .iter()
1381                        .filter(|(_, v)| {
1382                            v.get("x-ati-param-location").is_none()
1383                                || v.get("description").is_some()
1384                        })
1385                        .map(|(k, v)| {
1386                            let type_str =
1387                                v.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1388                            let desc = v.get("description").and_then(|d| d.as_str()).unwrap_or("");
1389                            format!("    --{k} ({type_str}): {desc}")
1390                        })
1391                        .collect();
1392                    if !params.is_empty() {
1393                        summary.push_str("\n  Parameters:\n");
1394                        summary.push_str(&params.join("\n"));
1395                    }
1396                }
1397            }
1398        }
1399        summaries.push(summary);
1400    }
1401    summaries.join("\n\n")
1402}
1403
1404/// Build a scoped system prompt for a specific tool or provider.
1405///
1406/// Returns None if the scope_name doesn't match any tool or provider.
1407fn build_scoped_prompt(
1408    scope_name: &str,
1409    registry: &ManifestRegistry,
1410    skills_section: &str,
1411) -> Option<String> {
1412    // Check if scope_name is a tool
1413    if let Some((provider, tool)) = registry.get_tool(scope_name) {
1414        let mut details = format!(
1415            "**Name**: `{}`\n**Provider**: {} (handler: {})\n**Description**: {}\n",
1416            tool.name, provider.name, provider.handler, tool.description
1417        );
1418        if let Some(cat) = &provider.category {
1419            details.push_str(&format!("**Category**: {}\n", cat));
1420        }
1421        if provider.is_cli() {
1422            let cmd = provider.cli_command.as_deref().unwrap_or("?");
1423            details.push_str(&format!(
1424                "\n**Usage**: `ati run {} -- <args>`  (passthrough to `{}`)\n",
1425                tool.name, cmd
1426            ));
1427        } else if let Some(schema) = &tool.input_schema {
1428            if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) {
1429                let required: Vec<String> = schema
1430                    .get("required")
1431                    .and_then(|r| r.as_array())
1432                    .map(|arr| {
1433                        arr.iter()
1434                            .filter_map(|v| v.as_str().map(|s| s.to_string()))
1435                            .collect()
1436                    })
1437                    .unwrap_or_default();
1438                details.push_str("\n**Parameters**:\n");
1439                for (key, val) in props {
1440                    let type_str = val.get("type").and_then(|t| t.as_str()).unwrap_or("string");
1441                    let desc = val
1442                        .get("description")
1443                        .and_then(|d| d.as_str())
1444                        .unwrap_or("");
1445                    let req = if required.contains(key) {
1446                        " **(required)**"
1447                    } else {
1448                        ""
1449                    };
1450                    details.push_str(&format!("- `--{key}` ({type_str}{req}): {desc}\n"));
1451                }
1452            }
1453        }
1454
1455        let prompt = format!(
1456            "You are an expert assistant for the `{}` tool, accessed via the `ati` CLI.\n\n\
1457            ## Tool Details\n{}\n\n{}\n\n\
1458            Answer the agent's question about this specific tool. Provide exact commands, explain flags and options, and give practical examples. Be concise and actionable.",
1459            tool.name, details, skills_section
1460        );
1461        return Some(prompt);
1462    }
1463
1464    // Check if scope_name is a provider
1465    if registry.has_provider(scope_name) {
1466        let tools = registry.tools_by_provider(scope_name);
1467        if tools.is_empty() {
1468            return None;
1469        }
1470        let tools_context = build_tool_context(&tools);
1471        let prompt = format!(
1472            "You are an expert assistant for the `{}` provider's tools, accessed via the `ati` CLI.\n\n\
1473            ## Tools in provider `{}`\n{}\n\n{}\n\n\
1474            Answer the agent's question about these tools. Provide exact `ati run` commands, explain parameters, and give practical examples. Be concise and actionable.",
1475            scope_name, scope_name, tools_context, skills_section
1476        );
1477        return Some(prompt);
1478    }
1479
1480    None
1481}